summaryrefslogtreecommitdiff
path: root/src/silx/sx/_plot.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/silx/sx/_plot.py')
-rw-r--r--src/silx/sx/_plot.py625
1 files changed, 625 insertions, 0 deletions
diff --git a/src/silx/sx/_plot.py b/src/silx/sx/_plot.py
new file mode 100644
index 0000000..b44c042
--- /dev/null
+++ b/src/silx/sx/_plot.py
@@ -0,0 +1,625 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module adds convenient functions to use plot widgets from the console.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/11/2018"
+
+
+import collections
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+import logging
+import weakref
+
+import numpy
+
+from ..utils.weakref import WeakList
+from ..gui import qt
+from ..gui.plot import Plot1D, Plot2D, ScatterView
+from ..gui.plot import items
+from ..gui import colors
+from ..gui.plot.tools import roi
+from ..gui.plot.items import roi as roi_items
+from ..gui.plot.tools.toolbars import InteractiveModeToolBar
+
+_logger = logging.getLogger(__name__)
+
+_plots = WeakList()
+"""List of widgets created through plot and imshow"""
+
+
+def plot(*args, **kwargs):
+ """
+ Plot curves in a :class:`~silx.gui.plot.PlotWindow.Plot1D` widget.
+
+ How to use:
+
+ >>> from silx import sx
+ >>> import numpy
+
+ Plot a single curve given some values:
+
+ >>> values = numpy.random.random(100)
+ >>> plot_1curve = sx.plot(values, title='Random data')
+
+ Plot a single curve given the x and y values:
+
+ >>> angles = numpy.linspace(0, numpy.pi, 100)
+ >>> sin_a = numpy.sin(angles)
+ >>> plot_sinus = sx.plot(angles, sin_a, xlabel='angle (radian)', ylabel='sin(a)')
+
+ Plot many curves by giving a 2D array, provided xn, yn arrays:
+
+ >>> plot_curves = sx.plot(x0, y0, x1, y1, x2, y2, ...)
+
+ Plot curve with style giving a style string:
+
+ >>> plot_styled = sx.plot(x0, y0, 'ro-', x1, y1, 'b.')
+
+ Supported symbols:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ Supported types of line:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ If provided, the names arguments color, linestyle, linewidth and marker
+ override any style provided to a curve.
+
+ This function supports a subset of `matplotlib.pyplot.plot
+ <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.plot>`_
+ arguments.
+
+ :param str color: Color to use for all curves (default: None)
+ :param str linestyle: Type of line to use for all curves (default: None)
+ :param float linewidth: With of all the curves (default: 1)
+ :param str marker: Symbol to use for all the curves (default: None)
+ :param str title: The title of the Plot widget (default: None)
+ :param str xlabel: The label of the X axis (default: None)
+ :param str ylabel: The label of the Y axis (default: None)
+ :return: The widget plotting the curve(s)
+ :rtype: silx.gui.plot.Plot1D
+ """
+ plt = Plot1D()
+ if 'title' in kwargs:
+ plt.setGraphTitle(kwargs['title'])
+ if 'xlabel' in kwargs:
+ plt.getXAxis().setLabel(kwargs['xlabel'])
+ if 'ylabel' in kwargs:
+ plt.getYAxis().setLabel(kwargs['ylabel'])
+
+ color = kwargs.get('color')
+ linestyle = kwargs.get('linestyle')
+ linewidth = kwargs.get('linewidth')
+ marker = kwargs.get('marker')
+
+ # Parse args and store curves as (x, y, style string)
+ args = list(args)
+ curves = []
+ while args:
+ first_arg = args.pop(0) # Process an arg
+
+ if len(args) == 0:
+ # Last curve defined as (y,)
+ curves.append((numpy.arange(len(first_arg)), first_arg, None))
+ else:
+ second_arg = args.pop(0)
+ if isinstance(second_arg, str):
+ # curve defined as (y, style)
+ y = first_arg
+ style = second_arg
+ curves.append((numpy.arange(len(y)), y, style))
+ else: # second_arg must be an array-like
+ x = first_arg
+ y = second_arg
+ if len(args) >= 1 and isinstance(args[0], str):
+ # Curve defined as (x, y, style)
+ style = args.pop(0)
+ curves.append((x, y, style))
+ else:
+ # Curve defined as (x, y)
+ curves.append((x, y, None))
+
+ for index, curve in enumerate(curves):
+ x, y, style = curve
+
+ # Default style
+ curve_symbol, curve_linestyle, curve_color = None, None, None
+
+ # Parse style
+ if style:
+ # Handle color first
+ possible_colors = [c for c in colors.COLORDICT if style.startswith(c)]
+ if possible_colors: # Take the longest string matching a color name
+ curve_color = possible_colors[0]
+ for c in possible_colors[1:]:
+ if len(c) > len(curve_color):
+ curve_color = c
+ style = style[len(curve_color):]
+
+ if style:
+ # Run twice to handle inversion symbol/linestyle
+ for _i in range(2):
+ # Handle linestyle
+ for line in (' ', '--', '-', '-.', ':'):
+ if style.endswith(line):
+ curve_linestyle = line
+ style = style[:-len(line)]
+ break
+
+ # Handle symbol
+ for curve_marker in ('o', '.', ',', '+', 'x', 'd', 's'):
+ if style.endswith(curve_marker):
+ curve_symbol = style[-1]
+ style = style[:-1]
+ break
+
+ # As in matplotlib, marker, linestyle and color override other style
+ plt.addCurve(x, y,
+ legend=('curve_%d' % index),
+ symbol=marker or curve_symbol,
+ linestyle=linestyle or curve_linestyle,
+ linewidth=linewidth,
+ color=color or curve_color)
+
+ plt.show()
+ _plots.insert(0, plt)
+ return plt
+
+
+def imshow(data=None, cmap=None, norm=colors.Colormap.LINEAR,
+ vmin=None, vmax=None,
+ aspect=False,
+ origin='upper', scale=(1., 1.),
+ title='', xlabel='X', ylabel='Y'):
+ """
+ Plot an image in a :class:`~silx.gui.plot.PlotWindow.Plot2D` widget.
+
+ How to use:
+
+ >>> from silx import sx
+ >>> import numpy
+
+ >>> data = numpy.random.random(1024 * 1024).reshape(1024, 1024)
+ >>> plt = sx.imshow(data, title='Random data')
+
+ By default, the image origin is displayed in the upper left
+ corner of the plot. To invert the Y axis, and place the image origin
+ in the lower left corner of the plot, use the *origin* parameter:
+
+ >>> plt = sx.imshow(data, origin='lower')
+
+ This function supports a subset of `matplotlib.pyplot.imshow
+ <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow>`_
+ arguments.
+
+ :param data: data to plot as an image
+ :type data: numpy.ndarray-like with 2 dimensions
+ :param str cmap: The name of the colormap to use for the plot. It also
+ supports a numpy array containing a RGB LUT, or a `colors.Colormap`
+ instance.
+ :param str norm: The normalization of the colormap:
+ 'linear' (default) or 'log'
+ :param float vmin: The value to use for the min of the colormap
+ :param float vmax: The value to use for the max of the colormap
+ :param bool aspect: True to keep aspect ratio (Default: False)
+ :param origin: Either image origin as the Y axis orientation:
+ 'upper' (default) or 'lower'
+ or the coordinates (ox, oy) of the image origin in the plot.
+ :type origin: str or 2-tuple of floats
+ :param scale: (sx, sy) The scale of the image in the plot
+ (i.e., the size of the image's pixel in plot coordinates)
+ :type scale: 2-tuple of floats
+ :param str title: The title of the Plot widget
+ :param str xlabel: The label of the X axis
+ :param str ylabel: The label of the Y axis
+ :return: The widget plotting the image
+ :rtype: silx.gui.plot.Plot2D
+ """
+ plt = Plot2D()
+ plt.setGraphTitle(title)
+ plt.getXAxis().setLabel(xlabel)
+ plt.getYAxis().setLabel(ylabel)
+
+ # Update default colormap with input parameters
+ colormap = plt.getDefaultColormap()
+ if isinstance(cmap, colors.Colormap):
+ colormap = cmap
+ plt.setDefaultColormap(colormap)
+ elif isinstance(cmap, numpy.ndarray):
+ colormap.setColors(cmap)
+ elif cmap is not None:
+ colormap.setName(cmap)
+ assert norm in colors.Colormap.NORMALIZATIONS
+ colormap.setNormalization(norm)
+ colormap.setVMin(vmin)
+ colormap.setVMax(vmax)
+
+ # Handle aspect
+ if aspect in (None, False, 'auto', 'normal'):
+ plt.setKeepDataAspectRatio(False)
+ elif aspect in (True, 'equal') or aspect == 1:
+ plt.setKeepDataAspectRatio(True)
+ else:
+ _logger.warning(
+ 'imshow: Unhandled aspect argument: %s', str(aspect))
+
+ # Handle matplotlib-like origin
+ if origin in ('upper', 'lower'):
+ plt.setYAxisInverted(origin == 'upper')
+ origin = 0., 0. # Set origin to the definition of silx
+
+ if data is not None:
+ data = numpy.array(data, copy=True)
+
+ assert data.ndim in (2, 3) # data or RGB(A)
+ if data.ndim == 3:
+ assert data.shape[-1] in (3, 4) # RGB(A) image
+
+ plt.addImage(data, origin=origin, scale=scale)
+
+ plt.show()
+ _plots.insert(0, plt)
+ return plt
+
+
+def scatter(x=None, y=None, value=None, size=None,
+ marker=None,
+ cmap=None, norm=colors.Colormap.LINEAR,
+ vmin=None, vmax=None):
+ """
+ Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget.
+
+ How to use:
+
+ >>> from silx import sx
+ >>> import numpy
+
+ >>> x = numpy.random.random(100)
+ >>> y = numpy.random.random(100)
+ >>> values = numpy.random.random(100)
+ >>> plt = sx.scatter(x, y, values, cmap='viridis')
+
+ Supported symbols:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ This function supports a subset of `matplotlib.pyplot.scatter
+ <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`_
+ arguments.
+
+ :param numpy.ndarray x: 1D array-like of x coordinates
+ :param numpy.ndarray y: 1D array-like of y coordinates
+ :param numpy.ndarray value: 1D array-like of data values
+ :param float size: Size^2 of the markers
+ :param str marker: Symbol used to represent the points
+ :param str cmap: The name of the colormap to use for the plot
+ :param str norm: The normalization of the colormap:
+ 'linear' (default) or 'log'
+ :param float vmin: The value to use for the min of the colormap
+ :param float vmax: The value to use for the max of the colormap
+ :return: The widget plotting the scatter plot
+ :rtype: silx.gui.plot.ScatterView.ScatterView
+ """
+ plt = ScatterView()
+
+ # Update default colormap with input parameters
+ colormap = plt.getPlotWidget().getDefaultColormap()
+ if cmap is not None:
+ colormap.setName(cmap)
+ assert norm in colors.Colormap.NORMALIZATIONS
+ colormap.setNormalization(norm)
+ colormap.setVMin(vmin)
+ colormap.setVMax(vmax)
+ plt.getPlotWidget().setDefaultColormap(colormap)
+
+ if x is not None and y is not None: # Add a scatter plot
+ x = numpy.array(x, copy=True).reshape(-1)
+ y = numpy.array(y, copy=True).reshape(-1)
+ assert len(x) == len(y)
+
+ if value is None:
+ value = numpy.ones(len(x), dtype=numpy.float32)
+
+ elif isinstance(value, abc.Iterable):
+ value = numpy.array(value, copy=True).reshape(-1)
+ assert len(x) == len(value)
+
+ else:
+ value = numpy.ones(len(x), dtype=numpy.float64) * value
+
+ plt.setData(x, y, value)
+ item = plt.getScatterItem()
+ if marker is not None:
+ item.setSymbol(marker)
+ if size is not None:
+ item.setSymbolSize(numpy.sqrt(size))
+
+ plt.resetZoom()
+
+ plt.show()
+ _plots.insert(0, plt.getPlotWidget())
+ return plt
+
+
+class _GInputResult(tuple):
+ """Object storing :func:`ginput` result
+
+ :param position: Selected point coordinates in the plot (x, y)
+ :param Item item: Plot item under the selected position
+ :param indices: Selected indices in the data of the item.
+ For a curve it is a list of indices, for an image it is (row, column)
+ :param data: Value of data at selected indices.
+ For a curve it is an array of values, for an image it is a single value
+ """
+
+ def __new__(cls, position, item, indices, data):
+ return super(_GInputResult, cls).__new__(cls, position)
+
+ def __init__(self, position, item, indices, data):
+ self._itemRef = weakref.ref(item) if item is not None else None
+ self._indices = numpy.array(indices, copy=True)
+ if isinstance(data, abc.Iterable):
+ self._data = numpy.array(data, copy=True)
+ else:
+ self._data = data
+
+ def getItem(self):
+ """Returns the item at the selected position if any.
+
+ :return: plot item under the selected postion.
+ It is None if there was no item at that position or if
+ it is no more in the plot.
+ :rtype: silx.gui.plot.items.Item"""
+ return None if self._itemRef is None else self._itemRef()
+
+ def getIndices(self):
+ """Returns indices in data array at the select position
+
+ :return: 1D array of indices for curve and (row, column) for images
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._indices, copy=True)
+
+ def getData(self):
+ """Returns data value at the selected position.
+
+ For curves, an array of (x, y) values close to the point is returned.
+ For images, either a single value or a RGB(A) array is returned.
+
+ :return: 2D array of (x, y) data values for curves (Nx2),
+ a single value for data images and RGB(A) array for images.
+ """
+ if isinstance(self._data, numpy.ndarray):
+ return numpy.array(self._data, copy=True)
+ else:
+ return self._data
+
+
+class _GInputHandler(roi.InteractiveRegionOfInterestManager):
+ """Implements :func:`ginput`
+
+ :param PlotWidget plot:
+ :param int n: Max number of points to request
+ :param float timeout: Timeout in seconds
+ """
+
+ def __init__(self, plot, n, timeout):
+ super(_GInputHandler, self).__init__(plot)
+
+ self._timeout = timeout
+ self.__selections = collections.OrderedDict()
+
+ window = plot.window() # Retrieve window containing PlotWidget
+ statusBar = window.statusBar()
+ self.sigMessageChanged.connect(statusBar.showMessage)
+ self.setMaxRois(n)
+ self.setValidationMode(self.ValidationMode.AUTO_ENTER)
+ self.sigRoiAdded.connect(self.__added)
+ self.sigRoiAboutToBeRemoved.connect(self.__removed)
+
+ def exec(self):
+ """Request user inputs
+
+ :return: List of selection points information
+ """
+ plot = self.parent()
+ if plot is None:
+ return
+
+ window = plot.window() # Retrieve window containing PlotWidget
+
+ # Add ROI point interactive mode action
+ for toolbar in window.findChildren(qt.QToolBar):
+ if isinstance(toolbar, InteractiveModeToolBar):
+ break
+ else: # Add a toolbar
+ toolbar = qt.QToolBar()
+ window.addToolBar(toolbar)
+ toolbar.addAction(self.getInteractionModeAction(roi_items.PointROI))
+
+ super(_GInputHandler, self).exec(roiClass=roi_items.PointROI, timeout=self._timeout)
+
+ if isinstance(toolbar, InteractiveModeToolBar):
+ toolbar.removeAction(self.getInteractionModeAction(roi_items.PointROI))
+ else:
+ toolbar.setParent(None)
+
+ return tuple(self.__selections.values())
+
+ def exec_(self): # Qt5-like compatibility
+ return self.exec()
+
+ def __updateSelection(self, roi):
+ """Perform picking and update selection list
+
+ :param RegionOfInterest roi:
+ """
+ plot = self.parent()
+ if plot is None:
+ return # No plot, abort
+
+ if not isinstance(roi, roi_items.PointROI):
+ # Only handle points
+ raise RuntimeError("Unexpected item")
+
+ x, y = roi.getPosition()
+ xPixel, yPixel = plot.dataToPixel(x, y, axis='left', check=False)
+
+ # Pick item at selected position
+ pickingResult = plot._pickTopMost(
+ xPixel, yPixel,
+ lambda item: isinstance(item, (items.ImageBase, items.Curve)))
+
+ if pickingResult is None:
+ result = _GInputResult((x, y),
+ item=None,
+ indices=numpy.array((), dtype=int),
+ data=None)
+ else:
+ item = pickingResult.getItem()
+ indices = pickingResult.getIndices(copy=True)
+
+ if isinstance(item, items.Curve):
+ xData = item.getXData(copy=False)[indices]
+ yData = item.getYData(copy=False)[indices]
+ result = _GInputResult((x, y),
+ item=item,
+ indices=indices,
+ data=numpy.array((xData, yData)).T)
+
+ elif isinstance(item, items.ImageBase):
+ row, column = indices[0][0], indices[1][0]
+ data = item.getData(copy=False)[row, column]
+ result = _GInputResult((x, y),
+ item=item,
+ indices=(row, column),
+ data=data)
+
+ self.__selections[roi] = result
+
+ def __added(self, roi):
+ """Handle new ROI added
+
+ :param RegionOfInterest roi:
+ """
+ if isinstance(roi, roi_items.PointROI):
+ # Only handle points
+ roi.setName('%d' % len(self.__selections))
+ self.__updateSelection(roi)
+ roi.sigRegionChanged.connect(self.__regionChanged)
+
+ def __removed(self, roi):
+ """Handle ROI removed"""
+ if self.__selections.pop(roi, None) is not None:
+ roi.sigRegionChanged.disconnect(self.__regionChanged)
+
+ def __regionChanged(self):
+ """Handle update of a ROI"""
+ roi = self.sender()
+ self.__updateSelection(roi)
+
+
+def ginput(n=1, timeout=30, plot=None):
+ """Get input points on a plot.
+
+ If no plot is provided, it uses a plot widget created with
+ either :func:`silx.sx.plot` or :func:`silx.sx.imshow`.
+
+ How to use:
+
+ >>> from silx import sx
+
+ >>> sx.imshow(image) # Plot the image
+ >>> sx.ginput(1) # Request selection on the image plot
+ ((0.598, 1.234))
+
+ How to get more information about the selected positions:
+
+ >>> positions = sx.ginput(1)
+
+ >>> positions[0].getData() # Returns value(s) at selected position
+
+ >>> positions[0].getIndices() # Returns data indices at selected position
+
+ >>> positions[0].getItem() # Returns plot item at selected position
+
+ :param int n: Number of points the user need to select
+ :param float timeout: Timeout in seconds before ginput returns
+ event if selection is not completed
+ :param silx.gui.plot.PlotWidget.PlotWidget plot: An optional PlotWidget
+ from which to get input
+ :return: List of clicked points coordinates (x, y) in plot
+ :raise ValueError: If provided plot is not a PlotWidget
+ """
+ if plot is None:
+ # Select most recent visible plot widget
+ for widget in _plots:
+ if widget.isVisible():
+ plot = widget
+ break
+ else: # If no plot widget is visible, take the most recent one
+ try:
+ plot = _plots[0]
+ except IndexError:
+ pass
+ else:
+ plot.show()
+
+ if plot is None:
+ _logger.warning('No plot available to perform ginput, create one')
+ plot = Plot1D()
+ plot.show()
+ _plots.insert(0, plot)
+
+ plot.raise_() # So window becomes the top level one
+
+ _logger.info('Performing ginput with plot widget %s', str(plot))
+ handler = _GInputHandler(plot, n, timeout)
+ points = handler.exec()
+
+ return points