diff options
Diffstat (limited to 'silx/sx/_plot.py')
-rw-r--r-- | silx/sx/_plot.py | 373 |
1 files changed, 215 insertions, 158 deletions
diff --git a/silx/sx/_plot.py b/silx/sx/_plot.py index dfc24d9..d434fec 100644 --- a/silx/sx/_plot.py +++ b/silx/sx/_plot.py @@ -27,21 +27,23 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "26/02/2018" +__date__ = "28/06/2018" import collections import logging -import time import weakref import numpy from ..utils.weakref import WeakList from ..gui import qt -from ..gui.plot import Plot1D, Plot2D, PlotWidget -from ..gui.plot.Colors import COLORDICT -from ..gui.plot.Colormap import Colormap +from ..gui.plot import Plot1D, Plot2D, ScatterView +from ..gui.colors import COLORDICT +from ..gui.colors import Colormap +from ..gui.plot.tools import roi +from ..gui.plot.items import roi as roi_items +from ..gui.plot.tools.toolbars import InteractiveModeToolBar from silx.third_party import six @@ -111,6 +113,8 @@ def plot(*args, **kwargs): :param str title: The title of the Plot widget (default: None) :param str xlabel: The label of the X axis (default: None) :param str ylabel: The label of the Y axis (default: None) + :return: The widget plotting the curve(s) + :rtype: silx.gui.plot.Plot1D """ plt = Plot1D() if 'title' in kwargs: @@ -202,7 +206,7 @@ def plot(*args, **kwargs): def imshow(data=None, cmap=None, norm=Colormap.LINEAR, vmin=None, vmax=None, aspect=False, - origin=(0., 0.), scale=(1., 1.), + origin='upper', scale=(1., 1.), title='', xlabel='X', ylabel='Y'): """ Plot an image in a :class:`~silx.gui.plot.PlotWindow.Plot2D` widget. @@ -215,6 +219,12 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR, >>> data = numpy.random.random(1024 * 1024).reshape(1024, 1024) >>> plt = sx.imshow(data, title='Random data') + By default, the image origin is displayed in the upper left + corner of the plot. To invert the Y axis, and place the image origin + in the lower left corner of the plot, use the *origin* parameter: + + >>> plt = sx.imshow(data, origin='lower') + This function supports a subset of `matplotlib.pyplot.imshow <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow>`_ arguments. @@ -227,14 +237,18 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR, :param float vmin: The value to use for the min of the colormap :param float vmax: The value to use for the max of the colormap :param bool aspect: True to keep aspect ratio (Default: False) - :param origin: (ox, oy) The coordinates of the image origin in the plot - :type origin: 2-tuple of floats + :param origin: Either image origin as the Y axis orientation: + 'upper' (default) or 'lower' + or the coordinates (ox, oy) of the image origin in the plot. + :type origin: str or 2-tuple of floats :param scale: (sx, sy) The scale of the image in the plot (i.e., the size of the image's pixel in plot coordinates) :type scale: 2-tuple of floats :param str title: The title of the Plot widget :param str xlabel: The label of the X axis :param str ylabel: The label of the Y axis + :return: The widget plotting the image + :rtype: silx.gui.plot.Plot2D """ plt = Plot2D() plt.setGraphTitle(title) @@ -260,6 +274,11 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR, _logger.warning( 'imshow: Unhandled aspect argument: %s', str(aspect)) + # Handle matplotlib-like origin + if origin in ('upper', 'lower'): + plt.setYAxisInverted(origin == 'upper') + origin = 0., 0. # Set origin to the definition of silx + if data is not None: data = numpy.array(data, copy=True) @@ -274,6 +293,90 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR, return plt +def scatter(x=None, y=None, value=None, size=None, + marker='o', + cmap=None, norm=Colormap.LINEAR, + vmin=None, vmax=None): + """ + Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget. + + How to use: + + >>> from silx import sx + >>> import numpy + + >>> x = numpy.random.random(100) + >>> y = numpy.random.random(100) + >>> values = numpy.random.random(100) + >>> plt = sx.scatter(x, y, values, cmap='viridis') + + Supported symbols: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + + This function supports a subset of `matplotlib.pyplot.scatter + <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`_ + arguments. + + :param numpy.ndarray x: 1D array-like of x coordinates + :param numpy.ndarray y: 1D array-like of y coordinates + :param numpy.ndarray value: 1D array-like of data values + :param float size: Size^2 of the markers + :param str marker: Symbol used to represent the points (default: 'o') + :param str cmap: The name of the colormap to use for the plot. + :param str norm: The normalization of the colormap: + 'linear' (default) or 'log' + :param float vmin: The value to use for the min of the colormap + :param float vmax: The value to use for the max of the colormap + :return: The widget plotting the scatter plot + :rtype: silx.gui.plot.ScatterView.ScatterView + """ + plt = ScatterView() + + # Update default colormap with input parameters + colormap = plt.getPlotWidget().getDefaultColormap() + if cmap is not None: + colormap.setName(cmap) + assert norm in Colormap.NORMALIZATIONS + colormap.setNormalization(norm) + colormap.setVMin(vmin) + colormap.setVMax(vmax) + plt.getPlotWidget().setDefaultColormap(colormap) + + if x is not None and y is not None: # Add a scatter plot + x = numpy.array(x, copy=True).reshape(-1) + y = numpy.array(y, copy=True).reshape(-1) + assert len(x) == len(y) + + if value is None: + value = numpy.ones(len(x), dtype=numpy.float32) + + elif isinstance(value, collections.Iterable): + value = numpy.array(value, copy=True).reshape(-1) + assert len(x) == len(value) + + else: + value = numpy.ones(len(x), dtype=numpy.float) * value + + plt.setData(x, y, value) + item = plt.getScatterItem() + item.setSymbol(marker) + if size is not None: + item.setSymbolSize(numpy.sqrt(size)) + + plt.resetZoom() + + plt.show() + _plots.insert(0, plt.getPlotWidget()) + return plt + + class _GInputResult(tuple): """Object storing :func:`ginput` result @@ -328,175 +431,128 @@ class _GInputResult(tuple): return self._data -class _GInputHandler(qt.QEventLoop): +class _GInputHandler(roi.InteractiveRegionOfInterestManager): """Implements :func:`ginput` :param PlotWidget plot: - :param int n: - :param float timeout: + :param int n: Max number of points to request + :param float timeout: Timeout in seconds """ def __init__(self, plot, n, timeout): - super(_GInputHandler, self).__init__() - - if not isinstance(plot, PlotWidget): - raise ValueError('plot is not a PlotWidget: %s', plot) + super(_GInputHandler, self).__init__(plot) - self._plot = plot self._timeout = timeout - self._markersAndResult = [] - self._totalPoints = n - self._endTime = 0. + self.__selections = collections.OrderedDict() - def eventFilter(self, obj, event): - """Event filter for plot hide event""" - if event.type() == qt.QEvent.Hide: - self.quit() + window = plot.window() # Retrieve window containing PlotWidget + statusBar = window.statusBar() + self.sigMessageChanged.connect(statusBar.showMessage) + self.setMaxRois(n) + self.setValidationMode(self.ValidationMode.AUTO_ENTER) + self.sigRoiAdded.connect(self.__added) + self.sigRoiAboutToBeRemoved.connect(self.__removed) - elif event.type() == qt.QEvent.KeyPress: - if event.key() in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or ( - event.key() == qt.Qt.Key_Z and event.modifiers() & qt.Qt.ControlModifier): - if len(self._markersAndResult) > 0: - legend, _ = self._markersAndResult.pop() - self._plot.remove(legend=legend, kind='marker') + def exec_(self): + """Request user inputs - self._updateStatusBar() - return True # Stop further handling of those keys + :return: List of selection points information + """ + plot = self.parent() + if plot is None: + return - elif event.key() == qt.Qt.Key_Return: - self.quit() - return True # Stop further handling of those keys + window = plot.window() # Retrieve window containing PlotWidget - return super(_GInputHandler, self).eventFilter(obj, event) + # Add ROI point interactive mode action + for toolbar in window.findChildren(qt.QToolBar): + if isinstance(toolbar, InteractiveModeToolBar): + break + else: # Add a toolbar + toolbar = qt.QToolBar() + window.addToolBar(toolbar) + toolbar.addAction(self.getInteractionModeAction(roi_items.PointROI)) - def exec_(self): - """Run blocking ginput handler + super(_GInputHandler, self).exec_(roiClass=roi_items.PointROI, timeout=self._timeout) - :returns: List of selected points - """ - # Bootstrap - self._plot.setInteractiveMode(mode='zoom') - self._handleInteractiveModeChanged(None) - self._plot.sigInteractiveModeChanged.connect( - self._handleInteractiveModeChanged) - - self._plot.installEventFilter(self) + if isinstance(toolbar, InteractiveModeToolBar): + toolbar.removeAction(self.getInteractionModeAction(roi_items.PointROI)) + else: + toolbar.setParent(None) - # Run - if self._timeout: - timeoutTimer = qt.QTimer() - timeoutTimer.timeout.connect(self._updateStatusBar) - timeoutTimer.start(1000) + return tuple(self.__selections.values()) - self._endTime = time.time() + self._timeout - self._updateStatusBar() + def __updateSelection(self, roi): + """Perform picking and update selection list - returnCode = super(_GInputHandler, self).exec_() + :param RegionOfInterest roi: + """ - timeoutTimer.stop() - else: - returnCode = super(_GInputHandler, self).exec_() - - # Clean-up - self._plot.removeEventFilter(self) - - self._plot.sigInteractiveModeChanged.disconnect( - self._handleInteractiveModeChanged) - - currentMode = self._plot.getInteractiveMode() - if currentMode['mode'] == 'zoom': # Stop handling mouse click - self._plot.sigPlotSignal.disconnect(self._handleSelect) - - self._plot.statusBar().clearMessage() - - points = tuple(result for _, result in self._markersAndResult) - - for legend, _ in self._markersAndResult: - self._plot.remove(legend=legend, kind='marker') - self._markersAndResult = [] - - return points if returnCode == 0 else () - - def _updateStatusBar(self): - """Update status bar message""" - msg = 'ginput: %d/%d input points' % (len(self._markersAndResult), - self._totalPoints) - if self._timeout: - remaining = self._endTime - time.time() - if remaining < 0: - self.quit() - return - msg += ', %d seconds remaining' % max(1, int(remaining)) - - currentMode = self._plot.getInteractiveMode() - if currentMode['mode'] != 'zoom': - msg += ' (Use zoom mode to add/remove points)' - - self._plot.statusBar().showMessage(msg) - - def _handleSelect(self, event): - """Handle mouse events""" - if event['event'] == 'mouseClicked' and event['button'] == 'left': - x, y = event['x'], event['y'] - xPixel, yPixel = event['xpixel'], event['ypixel'] - - # Add marker - legend = "sx.ginput %d" % len(self._markersAndResult) - self._plot.addMarker( - x, y, - legend=legend, - text='%d' % len(self._markersAndResult), - color='red', - draggable=False) - - # Pick item at selected position - picked = self._plot._pickImageOrCurve(xPixel, yPixel) - - if picked is None: - result = _GInputResult((x, y), - item=None, - indices=numpy.array((), dtype=int), - data=None) - - elif picked[0] == 'curve': - curve = picked[1] - indices = picked[2] - xData = curve.getXData(copy=False)[indices] - yData = curve.getYData(copy=False)[indices] - result = _GInputResult((x, y), - item=curve, - indices=indices, - data=numpy.array((xData, yData)).T) - - elif picked[0] == 'image': - image = picked[1] - # Get corresponding coordinate in image - origin = image.getOrigin() - scale = image.getScale() - column = int((x - origin[0]) / float(scale[0])) - row = int((y - origin[1]) / float(scale[1])) - data = image.getData(copy=False)[row, column] - result = _GInputResult((x, y), - item=image, - indices=(row, column), - data=data) - - self._markersAndResult.append((legend, result)) - self._updateStatusBar() - if len(self._markersAndResult) == self._totalPoints: - self.quit() - - def _handleInteractiveModeChanged(self, source): - """Handle change of interactive mode in the plot - - :param source: Objects that triggered the mode change + plot = self.parent() + if plot is None: + return # No plot, abort + + if not isinstance(roi, roi_items.PointROI): + # Only handle points + raise RuntimeError("Unexpected item") + + x, y = roi.getPosition() + xPixel, yPixel = plot.dataToPixel(x, y, axis='left', check=False) + + # Pick item at selected position + picked = plot._pickImageOrCurve(xPixel, yPixel) + + if picked is None: + result = _GInputResult((x, y), + item=None, + indices=numpy.array((), dtype=int), + data=None) + + elif picked[0] == 'curve': + curve = picked[1] + indices = picked[2] + xData = curve.getXData(copy=False)[indices] + yData = curve.getYData(copy=False)[indices] + result = _GInputResult((x, y), + item=curve, + indices=indices, + data=numpy.array((xData, yData)).T) + + elif picked[0] == 'image': + image = picked[1] + # Get corresponding coordinate in image + origin = image.getOrigin() + scale = image.getScale() + column = int((x - origin[0]) / float(scale[0])) + row = int((y - origin[1]) / float(scale[1])) + data = image.getData(copy=False)[row, column] + result = _GInputResult((x, y), + item=image, + indices=(row, column), + data=data) + + self.__selections[roi] = result + + def __added(self, roi): + """Handle new ROI added + + :param RegionOfInterest roi: """ - mode = self._plot.getInteractiveMode() - if mode['mode'] == 'zoom': # Handle click events - self._plot.sigPlotSignal.connect(self._handleSelect) - else: # Do not handle click event - self._plot.sigPlotSignal.disconnect(self._handleSelect) - self._updateStatusBar() + if isinstance(roi, roi_items.PointROI): + # Only handle points + roi.setLabel('%d' % len(self.__selections)) + self.__updateSelection(roi) + roi.sigRegionChanged.connect(self.__regionChanged) + + def __removed(self, roi): + """Handle ROI removed""" + if self.__selections.pop(roi, None) is not None: + roi.sigRegionChanged.disconnect(self.__regionChanged) + + def __regionChanged(self): + """Handle update of a ROI""" + roi = self.sender() + self.__updateSelection(roi) def ginput(n=1, timeout=30, plot=None): @@ -537,7 +593,7 @@ def ginput(n=1, timeout=30, plot=None): if widget.isVisible(): plot = widget break - else: # If no plot widgets are visible, take most recent one + else: # If no plot widget is visible, take the most recent one try: plot = _plots[0] except IndexError: @@ -549,6 +605,7 @@ def ginput(n=1, timeout=30, plot=None): _logger.warning('No plot available to perform ginput, create one') plot = Plot1D() plot.show() + _plots.insert(0, plot) plot.raise_() # So window becomes the top level one |