diff options
Diffstat (limited to 'silx/sx/_plot.py')
-rw-r--r-- | silx/sx/_plot.py | 623 |
1 files changed, 0 insertions, 623 deletions
diff --git a/silx/sx/_plot.py b/silx/sx/_plot.py deleted file mode 100644 index 5746492..0000000 --- a/silx/sx/_plot.py +++ /dev/null @@ -1,623 +0,0 @@ -# coding: utf-8 -# /*########################################################################## -# -# Copyright (c) 2016-2020 European Synchrotron Radiation Facility -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# ###########################################################################*/ -"""This module adds convenient functions to use plot widgets from the console. -""" - -__authors__ = ["T. Vincent"] -__license__ = "MIT" -__date__ = "06/11/2018" - - -import collections -try: - from collections import abc -except ImportError: # Python2 support - import collections as abc -import logging -import weakref - -import numpy -import six - -from ..utils.weakref import WeakList -from ..gui import qt -from ..gui.plot import Plot1D, Plot2D, ScatterView -from ..gui.plot import items -from ..gui import colors -from ..gui.plot.tools import roi -from ..gui.plot.items import roi as roi_items -from ..gui.plot.tools.toolbars import InteractiveModeToolBar - -_logger = logging.getLogger(__name__) - -_plots = WeakList() -"""List of widgets created through plot and imshow""" - - -def plot(*args, **kwargs): - """ - Plot curves in a :class:`~silx.gui.plot.PlotWindow.Plot1D` widget. - - How to use: - - >>> from silx import sx - >>> import numpy - - Plot a single curve given some values: - - >>> values = numpy.random.random(100) - >>> plot_1curve = sx.plot(values, title='Random data') - - Plot a single curve given the x and y values: - - >>> angles = numpy.linspace(0, numpy.pi, 100) - >>> sin_a = numpy.sin(angles) - >>> plot_sinus = sx.plot(angles, sin_a, xlabel='angle (radian)', ylabel='sin(a)') - - Plot many curves by giving a 2D array, provided xn, yn arrays: - - >>> plot_curves = sx.plot(x0, y0, x1, y1, x2, y2, ...) - - Plot curve with style giving a style string: - - >>> plot_styled = sx.plot(x0, y0, 'ro-', x1, y1, 'b.') - - Supported symbols: - - - 'o' circle - - '.' point - - ',' pixel - - '+' cross - - 'x' x-cross - - 'd' diamond - - 's' square - - Supported types of line: - - - ' ' no line - - '-' solid line - - '--' dashed line - - '-.' dash-dot line - - ':' dotted line - - If provided, the names arguments color, linestyle, linewidth and marker - override any style provided to a curve. - - This function supports a subset of `matplotlib.pyplot.plot - <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.plot>`_ - arguments. - - :param str color: Color to use for all curves (default: None) - :param str linestyle: Type of line to use for all curves (default: None) - :param float linewidth: With of all the curves (default: 1) - :param str marker: Symbol to use for all the curves (default: None) - :param str title: The title of the Plot widget (default: None) - :param str xlabel: The label of the X axis (default: None) - :param str ylabel: The label of the Y axis (default: None) - :return: The widget plotting the curve(s) - :rtype: silx.gui.plot.Plot1D - """ - plt = Plot1D() - if 'title' in kwargs: - plt.setGraphTitle(kwargs['title']) - if 'xlabel' in kwargs: - plt.getXAxis().setLabel(kwargs['xlabel']) - if 'ylabel' in kwargs: - plt.getYAxis().setLabel(kwargs['ylabel']) - - color = kwargs.get('color') - linestyle = kwargs.get('linestyle') - linewidth = kwargs.get('linewidth') - marker = kwargs.get('marker') - - # Parse args and store curves as (x, y, style string) - args = list(args) - curves = [] - while args: - first_arg = args.pop(0) # Process an arg - - if len(args) == 0: - # Last curve defined as (y,) - curves.append((numpy.arange(len(first_arg)), first_arg, None)) - else: - second_arg = args.pop(0) - if isinstance(second_arg, six.string_types): - # curve defined as (y, style) - y = first_arg - style = second_arg - curves.append((numpy.arange(len(y)), y, style)) - else: # second_arg must be an array-like - x = first_arg - y = second_arg - if len(args) >= 1 and isinstance(args[0], six.string_types): - # Curve defined as (x, y, style) - style = args.pop(0) - curves.append((x, y, style)) - else: - # Curve defined as (x, y) - curves.append((x, y, None)) - - for index, curve in enumerate(curves): - x, y, style = curve - - # Default style - curve_symbol, curve_linestyle, curve_color = None, None, None - - # Parse style - if style: - # Handle color first - possible_colors = [c for c in colors.COLORDICT if style.startswith(c)] - if possible_colors: # Take the longest string matching a color name - curve_color = possible_colors[0] - for c in possible_colors[1:]: - if len(c) > len(curve_color): - curve_color = c - style = style[len(curve_color):] - - if style: - # Run twice to handle inversion symbol/linestyle - for _i in range(2): - # Handle linestyle - for line in (' ', '--', '-', '-.', ':'): - if style.endswith(line): - curve_linestyle = line - style = style[:-len(line)] - break - - # Handle symbol - for curve_marker in ('o', '.', ',', '+', 'x', 'd', 's'): - if style.endswith(curve_marker): - curve_symbol = style[-1] - style = style[:-1] - break - - # As in matplotlib, marker, linestyle and color override other style - plt.addCurve(x, y, - legend=('curve_%d' % index), - symbol=marker or curve_symbol, - linestyle=linestyle or curve_linestyle, - linewidth=linewidth, - color=color or curve_color) - - plt.show() - _plots.insert(0, plt) - return plt - - -def imshow(data=None, cmap=None, norm=colors.Colormap.LINEAR, - vmin=None, vmax=None, - aspect=False, - origin='upper', scale=(1., 1.), - title='', xlabel='X', ylabel='Y'): - """ - Plot an image in a :class:`~silx.gui.plot.PlotWindow.Plot2D` widget. - - How to use: - - >>> from silx import sx - >>> import numpy - - >>> data = numpy.random.random(1024 * 1024).reshape(1024, 1024) - >>> plt = sx.imshow(data, title='Random data') - - By default, the image origin is displayed in the upper left - corner of the plot. To invert the Y axis, and place the image origin - in the lower left corner of the plot, use the *origin* parameter: - - >>> plt = sx.imshow(data, origin='lower') - - This function supports a subset of `matplotlib.pyplot.imshow - <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow>`_ - arguments. - - :param data: data to plot as an image - :type data: numpy.ndarray-like with 2 dimensions - :param str cmap: The name of the colormap to use for the plot. It also - supports a numpy array containing a RGB LUT, or a `colors.Colormap` - instance. - :param str norm: The normalization of the colormap: - 'linear' (default) or 'log' - :param float vmin: The value to use for the min of the colormap - :param float vmax: The value to use for the max of the colormap - :param bool aspect: True to keep aspect ratio (Default: False) - :param origin: Either image origin as the Y axis orientation: - 'upper' (default) or 'lower' - or the coordinates (ox, oy) of the image origin in the plot. - :type origin: str or 2-tuple of floats - :param scale: (sx, sy) The scale of the image in the plot - (i.e., the size of the image's pixel in plot coordinates) - :type scale: 2-tuple of floats - :param str title: The title of the Plot widget - :param str xlabel: The label of the X axis - :param str ylabel: The label of the Y axis - :return: The widget plotting the image - :rtype: silx.gui.plot.Plot2D - """ - plt = Plot2D() - plt.setGraphTitle(title) - plt.getXAxis().setLabel(xlabel) - plt.getYAxis().setLabel(ylabel) - - # Update default colormap with input parameters - colormap = plt.getDefaultColormap() - if isinstance(cmap, colors.Colormap): - colormap = cmap - plt.setDefaultColormap(colormap) - elif isinstance(cmap, numpy.ndarray): - colormap.setColors(cmap) - elif cmap is not None: - colormap.setName(cmap) - assert norm in colors.Colormap.NORMALIZATIONS - colormap.setNormalization(norm) - colormap.setVMin(vmin) - colormap.setVMax(vmax) - - # Handle aspect - if aspect in (None, False, 'auto', 'normal'): - plt.setKeepDataAspectRatio(False) - elif aspect in (True, 'equal') or aspect == 1: - plt.setKeepDataAspectRatio(True) - else: - _logger.warning( - 'imshow: Unhandled aspect argument: %s', str(aspect)) - - # Handle matplotlib-like origin - if origin in ('upper', 'lower'): - plt.setYAxisInverted(origin == 'upper') - origin = 0., 0. # Set origin to the definition of silx - - if data is not None: - data = numpy.array(data, copy=True) - - assert data.ndim in (2, 3) # data or RGB(A) - if data.ndim == 3: - assert data.shape[-1] in (3, 4) # RGB(A) image - - plt.addImage(data, origin=origin, scale=scale) - - plt.show() - _plots.insert(0, plt) - return plt - - -def scatter(x=None, y=None, value=None, size=None, - marker=None, - cmap=None, norm=colors.Colormap.LINEAR, - vmin=None, vmax=None): - """ - Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget. - - How to use: - - >>> from silx import sx - >>> import numpy - - >>> x = numpy.random.random(100) - >>> y = numpy.random.random(100) - >>> values = numpy.random.random(100) - >>> plt = sx.scatter(x, y, values, cmap='viridis') - - Supported symbols: - - - 'o' circle - - '.' point - - ',' pixel - - '+' cross - - 'x' x-cross - - 'd' diamond - - 's' square - - This function supports a subset of `matplotlib.pyplot.scatter - <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`_ - arguments. - - :param numpy.ndarray x: 1D array-like of x coordinates - :param numpy.ndarray y: 1D array-like of y coordinates - :param numpy.ndarray value: 1D array-like of data values - :param float size: Size^2 of the markers - :param str marker: Symbol used to represent the points - :param str cmap: The name of the colormap to use for the plot - :param str norm: The normalization of the colormap: - 'linear' (default) or 'log' - :param float vmin: The value to use for the min of the colormap - :param float vmax: The value to use for the max of the colormap - :return: The widget plotting the scatter plot - :rtype: silx.gui.plot.ScatterView.ScatterView - """ - plt = ScatterView() - - # Update default colormap with input parameters - colormap = plt.getPlotWidget().getDefaultColormap() - if cmap is not None: - colormap.setName(cmap) - assert norm in colors.Colormap.NORMALIZATIONS - colormap.setNormalization(norm) - colormap.setVMin(vmin) - colormap.setVMax(vmax) - plt.getPlotWidget().setDefaultColormap(colormap) - - if x is not None and y is not None: # Add a scatter plot - x = numpy.array(x, copy=True).reshape(-1) - y = numpy.array(y, copy=True).reshape(-1) - assert len(x) == len(y) - - if value is None: - value = numpy.ones(len(x), dtype=numpy.float32) - - elif isinstance(value, abc.Iterable): - value = numpy.array(value, copy=True).reshape(-1) - assert len(x) == len(value) - - else: - value = numpy.ones(len(x), dtype=numpy.float64) * value - - plt.setData(x, y, value) - item = plt.getScatterItem() - if marker is not None: - item.setSymbol(marker) - if size is not None: - item.setSymbolSize(numpy.sqrt(size)) - - plt.resetZoom() - - plt.show() - _plots.insert(0, plt.getPlotWidget()) - return plt - - -class _GInputResult(tuple): - """Object storing :func:`ginput` result - - :param position: Selected point coordinates in the plot (x, y) - :param Item item: Plot item under the selected position - :param indices: Selected indices in the data of the item. - For a curve it is a list of indices, for an image it is (row, column) - :param data: Value of data at selected indices. - For a curve it is an array of values, for an image it is a single value - """ - - def __new__(cls, position, item, indices, data): - return super(_GInputResult, cls).__new__(cls, position) - - def __init__(self, position, item, indices, data): - self._itemRef = weakref.ref(item) if item is not None else None - self._indices = numpy.array(indices, copy=True) - if isinstance(data, abc.Iterable): - self._data = numpy.array(data, copy=True) - else: - self._data = data - - def getItem(self): - """Returns the item at the selected position if any. - - :return: plot item under the selected postion. - It is None if there was no item at that position or if - it is no more in the plot. - :rtype: silx.gui.plot.items.Item""" - return None if self._itemRef is None else self._itemRef() - - def getIndices(self): - """Returns indices in data array at the select position - - :return: 1D array of indices for curve and (row, column) for images - :rtype: numpy.ndarray - """ - return numpy.array(self._indices, copy=True) - - def getData(self): - """Returns data value at the selected position. - - For curves, an array of (x, y) values close to the point is returned. - For images, either a single value or a RGB(A) array is returned. - - :return: 2D array of (x, y) data values for curves (Nx2), - a single value for data images and RGB(A) array for images. - """ - if isinstance(self._data, numpy.ndarray): - return numpy.array(self._data, copy=True) - else: - return self._data - - -class _GInputHandler(roi.InteractiveRegionOfInterestManager): - """Implements :func:`ginput` - - :param PlotWidget plot: - :param int n: Max number of points to request - :param float timeout: Timeout in seconds - """ - - def __init__(self, plot, n, timeout): - super(_GInputHandler, self).__init__(plot) - - self._timeout = timeout - self.__selections = collections.OrderedDict() - - window = plot.window() # Retrieve window containing PlotWidget - statusBar = window.statusBar() - self.sigMessageChanged.connect(statusBar.showMessage) - self.setMaxRois(n) - self.setValidationMode(self.ValidationMode.AUTO_ENTER) - self.sigRoiAdded.connect(self.__added) - self.sigRoiAboutToBeRemoved.connect(self.__removed) - - def exec_(self): - """Request user inputs - - :return: List of selection points information - """ - plot = self.parent() - if plot is None: - return - - window = plot.window() # Retrieve window containing PlotWidget - - # Add ROI point interactive mode action - for toolbar in window.findChildren(qt.QToolBar): - if isinstance(toolbar, InteractiveModeToolBar): - break - else: # Add a toolbar - toolbar = qt.QToolBar() - window.addToolBar(toolbar) - toolbar.addAction(self.getInteractionModeAction(roi_items.PointROI)) - - super(_GInputHandler, self).exec_(roiClass=roi_items.PointROI, timeout=self._timeout) - - if isinstance(toolbar, InteractiveModeToolBar): - toolbar.removeAction(self.getInteractionModeAction(roi_items.PointROI)) - else: - toolbar.setParent(None) - - return tuple(self.__selections.values()) - - def __updateSelection(self, roi): - """Perform picking and update selection list - - :param RegionOfInterest roi: - """ - plot = self.parent() - if plot is None: - return # No plot, abort - - if not isinstance(roi, roi_items.PointROI): - # Only handle points - raise RuntimeError("Unexpected item") - - x, y = roi.getPosition() - xPixel, yPixel = plot.dataToPixel(x, y, axis='left', check=False) - - # Pick item at selected position - pickingResult = plot._pickTopMost( - xPixel, yPixel, - lambda item: isinstance(item, (items.ImageBase, items.Curve))) - - if pickingResult is None: - result = _GInputResult((x, y), - item=None, - indices=numpy.array((), dtype=int), - data=None) - else: - item = pickingResult.getItem() - indices = pickingResult.getIndices(copy=True) - - if isinstance(item, items.Curve): - xData = item.getXData(copy=False)[indices] - yData = item.getYData(copy=False)[indices] - result = _GInputResult((x, y), - item=item, - indices=indices, - data=numpy.array((xData, yData)).T) - - elif isinstance(item, items.ImageBase): - row, column = indices[0][0], indices[1][0] - data = item.getData(copy=False)[row, column] - result = _GInputResult((x, y), - item=item, - indices=(row, column), - data=data) - - self.__selections[roi] = result - - def __added(self, roi): - """Handle new ROI added - - :param RegionOfInterest roi: - """ - if isinstance(roi, roi_items.PointROI): - # Only handle points - roi.setName('%d' % len(self.__selections)) - self.__updateSelection(roi) - roi.sigRegionChanged.connect(self.__regionChanged) - - def __removed(self, roi): - """Handle ROI removed""" - if self.__selections.pop(roi, None) is not None: - roi.sigRegionChanged.disconnect(self.__regionChanged) - - def __regionChanged(self): - """Handle update of a ROI""" - roi = self.sender() - self.__updateSelection(roi) - - -def ginput(n=1, timeout=30, plot=None): - """Get input points on a plot. - - If no plot is provided, it uses a plot widget created with - either :func:`silx.sx.plot` or :func:`silx.sx.imshow`. - - How to use: - - >>> from silx import sx - - >>> sx.imshow(image) # Plot the image - >>> sx.ginput(1) # Request selection on the image plot - ((0.598, 1.234)) - - How to get more information about the selected positions: - - >>> positions = sx.ginput(1) - - >>> positions[0].getData() # Returns value(s) at selected position - - >>> positions[0].getIndices() # Returns data indices at selected position - - >>> positions[0].getItem() # Returns plot item at selected position - - :param int n: Number of points the user need to select - :param float timeout: Timeout in seconds before ginput returns - event if selection is not completed - :param silx.gui.plot.PlotWidget.PlotWidget plot: An optional PlotWidget - from which to get input - :return: List of clicked points coordinates (x, y) in plot - :raise ValueError: If provided plot is not a PlotWidget - """ - if plot is None: - # Select most recent visible plot widget - for widget in _plots: - if widget.isVisible(): - plot = widget - break - else: # If no plot widget is visible, take the most recent one - try: - plot = _plots[0] - except IndexError: - pass - else: - plot.show() - - if plot is None: - _logger.warning('No plot available to perform ginput, create one') - plot = Plot1D() - plot.show() - _plots.insert(0, plot) - - plot.raise_() # So window becomes the top level one - - _logger.info('Performing ginput with plot widget %s', str(plot)) - handler = _GInputHandler(plot, n, timeout) - points = handler.exec_() - - return points |