summaryrefslogtreecommitdiff
path: root/silx/sx/_plot.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/sx/_plot.py')
-rw-r--r--silx/sx/_plot.py373
1 files changed, 215 insertions, 158 deletions
diff --git a/silx/sx/_plot.py b/silx/sx/_plot.py
index dfc24d9..d434fec 100644
--- a/silx/sx/_plot.py
+++ b/silx/sx/_plot.py
@@ -27,21 +27,23 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "26/02/2018"
+__date__ = "28/06/2018"
import collections
import logging
-import time
import weakref
import numpy
from ..utils.weakref import WeakList
from ..gui import qt
-from ..gui.plot import Plot1D, Plot2D, PlotWidget
-from ..gui.plot.Colors import COLORDICT
-from ..gui.plot.Colormap import Colormap
+from ..gui.plot import Plot1D, Plot2D, ScatterView
+from ..gui.colors import COLORDICT
+from ..gui.colors import Colormap
+from ..gui.plot.tools import roi
+from ..gui.plot.items import roi as roi_items
+from ..gui.plot.tools.toolbars import InteractiveModeToolBar
from silx.third_party import six
@@ -111,6 +113,8 @@ def plot(*args, **kwargs):
:param str title: The title of the Plot widget (default: None)
:param str xlabel: The label of the X axis (default: None)
:param str ylabel: The label of the Y axis (default: None)
+ :return: The widget plotting the curve(s)
+ :rtype: silx.gui.plot.Plot1D
"""
plt = Plot1D()
if 'title' in kwargs:
@@ -202,7 +206,7 @@ def plot(*args, **kwargs):
def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
vmin=None, vmax=None,
aspect=False,
- origin=(0., 0.), scale=(1., 1.),
+ origin='upper', scale=(1., 1.),
title='', xlabel='X', ylabel='Y'):
"""
Plot an image in a :class:`~silx.gui.plot.PlotWindow.Plot2D` widget.
@@ -215,6 +219,12 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
>>> data = numpy.random.random(1024 * 1024).reshape(1024, 1024)
>>> plt = sx.imshow(data, title='Random data')
+ By default, the image origin is displayed in the upper left
+ corner of the plot. To invert the Y axis, and place the image origin
+ in the lower left corner of the plot, use the *origin* parameter:
+
+ >>> plt = sx.imshow(data, origin='lower')
+
This function supports a subset of `matplotlib.pyplot.imshow
<http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow>`_
arguments.
@@ -227,14 +237,18 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
:param float vmin: The value to use for the min of the colormap
:param float vmax: The value to use for the max of the colormap
:param bool aspect: True to keep aspect ratio (Default: False)
- :param origin: (ox, oy) The coordinates of the image origin in the plot
- :type origin: 2-tuple of floats
+ :param origin: Either image origin as the Y axis orientation:
+ 'upper' (default) or 'lower'
+ or the coordinates (ox, oy) of the image origin in the plot.
+ :type origin: str or 2-tuple of floats
:param scale: (sx, sy) The scale of the image in the plot
(i.e., the size of the image's pixel in plot coordinates)
:type scale: 2-tuple of floats
:param str title: The title of the Plot widget
:param str xlabel: The label of the X axis
:param str ylabel: The label of the Y axis
+ :return: The widget plotting the image
+ :rtype: silx.gui.plot.Plot2D
"""
plt = Plot2D()
plt.setGraphTitle(title)
@@ -260,6 +274,11 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
_logger.warning(
'imshow: Unhandled aspect argument: %s', str(aspect))
+ # Handle matplotlib-like origin
+ if origin in ('upper', 'lower'):
+ plt.setYAxisInverted(origin == 'upper')
+ origin = 0., 0. # Set origin to the definition of silx
+
if data is not None:
data = numpy.array(data, copy=True)
@@ -274,6 +293,90 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
return plt
+def scatter(x=None, y=None, value=None, size=None,
+ marker='o',
+ cmap=None, norm=Colormap.LINEAR,
+ vmin=None, vmax=None):
+ """
+ Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget.
+
+ How to use:
+
+ >>> from silx import sx
+ >>> import numpy
+
+ >>> x = numpy.random.random(100)
+ >>> y = numpy.random.random(100)
+ >>> values = numpy.random.random(100)
+ >>> plt = sx.scatter(x, y, values, cmap='viridis')
+
+ Supported symbols:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ This function supports a subset of `matplotlib.pyplot.scatter
+ <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`_
+ arguments.
+
+ :param numpy.ndarray x: 1D array-like of x coordinates
+ :param numpy.ndarray y: 1D array-like of y coordinates
+ :param numpy.ndarray value: 1D array-like of data values
+ :param float size: Size^2 of the markers
+ :param str marker: Symbol used to represent the points (default: 'o')
+ :param str cmap: The name of the colormap to use for the plot.
+ :param str norm: The normalization of the colormap:
+ 'linear' (default) or 'log'
+ :param float vmin: The value to use for the min of the colormap
+ :param float vmax: The value to use for the max of the colormap
+ :return: The widget plotting the scatter plot
+ :rtype: silx.gui.plot.ScatterView.ScatterView
+ """
+ plt = ScatterView()
+
+ # Update default colormap with input parameters
+ colormap = plt.getPlotWidget().getDefaultColormap()
+ if cmap is not None:
+ colormap.setName(cmap)
+ assert norm in Colormap.NORMALIZATIONS
+ colormap.setNormalization(norm)
+ colormap.setVMin(vmin)
+ colormap.setVMax(vmax)
+ plt.getPlotWidget().setDefaultColormap(colormap)
+
+ if x is not None and y is not None: # Add a scatter plot
+ x = numpy.array(x, copy=True).reshape(-1)
+ y = numpy.array(y, copy=True).reshape(-1)
+ assert len(x) == len(y)
+
+ if value is None:
+ value = numpy.ones(len(x), dtype=numpy.float32)
+
+ elif isinstance(value, collections.Iterable):
+ value = numpy.array(value, copy=True).reshape(-1)
+ assert len(x) == len(value)
+
+ else:
+ value = numpy.ones(len(x), dtype=numpy.float) * value
+
+ plt.setData(x, y, value)
+ item = plt.getScatterItem()
+ item.setSymbol(marker)
+ if size is not None:
+ item.setSymbolSize(numpy.sqrt(size))
+
+ plt.resetZoom()
+
+ plt.show()
+ _plots.insert(0, plt.getPlotWidget())
+ return plt
+
+
class _GInputResult(tuple):
"""Object storing :func:`ginput` result
@@ -328,175 +431,128 @@ class _GInputResult(tuple):
return self._data
-class _GInputHandler(qt.QEventLoop):
+class _GInputHandler(roi.InteractiveRegionOfInterestManager):
"""Implements :func:`ginput`
:param PlotWidget plot:
- :param int n:
- :param float timeout:
+ :param int n: Max number of points to request
+ :param float timeout: Timeout in seconds
"""
def __init__(self, plot, n, timeout):
- super(_GInputHandler, self).__init__()
-
- if not isinstance(plot, PlotWidget):
- raise ValueError('plot is not a PlotWidget: %s', plot)
+ super(_GInputHandler, self).__init__(plot)
- self._plot = plot
self._timeout = timeout
- self._markersAndResult = []
- self._totalPoints = n
- self._endTime = 0.
+ self.__selections = collections.OrderedDict()
- def eventFilter(self, obj, event):
- """Event filter for plot hide event"""
- if event.type() == qt.QEvent.Hide:
- self.quit()
+ window = plot.window() # Retrieve window containing PlotWidget
+ statusBar = window.statusBar()
+ self.sigMessageChanged.connect(statusBar.showMessage)
+ self.setMaxRois(n)
+ self.setValidationMode(self.ValidationMode.AUTO_ENTER)
+ self.sigRoiAdded.connect(self.__added)
+ self.sigRoiAboutToBeRemoved.connect(self.__removed)
- elif event.type() == qt.QEvent.KeyPress:
- if event.key() in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or (
- event.key() == qt.Qt.Key_Z and event.modifiers() & qt.Qt.ControlModifier):
- if len(self._markersAndResult) > 0:
- legend, _ = self._markersAndResult.pop()
- self._plot.remove(legend=legend, kind='marker')
+ def exec_(self):
+ """Request user inputs
- self._updateStatusBar()
- return True # Stop further handling of those keys
+ :return: List of selection points information
+ """
+ plot = self.parent()
+ if plot is None:
+ return
- elif event.key() == qt.Qt.Key_Return:
- self.quit()
- return True # Stop further handling of those keys
+ window = plot.window() # Retrieve window containing PlotWidget
- return super(_GInputHandler, self).eventFilter(obj, event)
+ # Add ROI point interactive mode action
+ for toolbar in window.findChildren(qt.QToolBar):
+ if isinstance(toolbar, InteractiveModeToolBar):
+ break
+ else: # Add a toolbar
+ toolbar = qt.QToolBar()
+ window.addToolBar(toolbar)
+ toolbar.addAction(self.getInteractionModeAction(roi_items.PointROI))
- def exec_(self):
- """Run blocking ginput handler
+ super(_GInputHandler, self).exec_(roiClass=roi_items.PointROI, timeout=self._timeout)
- :returns: List of selected points
- """
- # Bootstrap
- self._plot.setInteractiveMode(mode='zoom')
- self._handleInteractiveModeChanged(None)
- self._plot.sigInteractiveModeChanged.connect(
- self._handleInteractiveModeChanged)
-
- self._plot.installEventFilter(self)
+ if isinstance(toolbar, InteractiveModeToolBar):
+ toolbar.removeAction(self.getInteractionModeAction(roi_items.PointROI))
+ else:
+ toolbar.setParent(None)
- # Run
- if self._timeout:
- timeoutTimer = qt.QTimer()
- timeoutTimer.timeout.connect(self._updateStatusBar)
- timeoutTimer.start(1000)
+ return tuple(self.__selections.values())
- self._endTime = time.time() + self._timeout
- self._updateStatusBar()
+ def __updateSelection(self, roi):
+ """Perform picking and update selection list
- returnCode = super(_GInputHandler, self).exec_()
+ :param RegionOfInterest roi:
+ """
- timeoutTimer.stop()
- else:
- returnCode = super(_GInputHandler, self).exec_()
-
- # Clean-up
- self._plot.removeEventFilter(self)
-
- self._plot.sigInteractiveModeChanged.disconnect(
- self._handleInteractiveModeChanged)
-
- currentMode = self._plot.getInteractiveMode()
- if currentMode['mode'] == 'zoom': # Stop handling mouse click
- self._plot.sigPlotSignal.disconnect(self._handleSelect)
-
- self._plot.statusBar().clearMessage()
-
- points = tuple(result for _, result in self._markersAndResult)
-
- for legend, _ in self._markersAndResult:
- self._plot.remove(legend=legend, kind='marker')
- self._markersAndResult = []
-
- return points if returnCode == 0 else ()
-
- def _updateStatusBar(self):
- """Update status bar message"""
- msg = 'ginput: %d/%d input points' % (len(self._markersAndResult),
- self._totalPoints)
- if self._timeout:
- remaining = self._endTime - time.time()
- if remaining < 0:
- self.quit()
- return
- msg += ', %d seconds remaining' % max(1, int(remaining))
-
- currentMode = self._plot.getInteractiveMode()
- if currentMode['mode'] != 'zoom':
- msg += ' (Use zoom mode to add/remove points)'
-
- self._plot.statusBar().showMessage(msg)
-
- def _handleSelect(self, event):
- """Handle mouse events"""
- if event['event'] == 'mouseClicked' and event['button'] == 'left':
- x, y = event['x'], event['y']
- xPixel, yPixel = event['xpixel'], event['ypixel']
-
- # Add marker
- legend = "sx.ginput %d" % len(self._markersAndResult)
- self._plot.addMarker(
- x, y,
- legend=legend,
- text='%d' % len(self._markersAndResult),
- color='red',
- draggable=False)
-
- # Pick item at selected position
- picked = self._plot._pickImageOrCurve(xPixel, yPixel)
-
- if picked is None:
- result = _GInputResult((x, y),
- item=None,
- indices=numpy.array((), dtype=int),
- data=None)
-
- elif picked[0] == 'curve':
- curve = picked[1]
- indices = picked[2]
- xData = curve.getXData(copy=False)[indices]
- yData = curve.getYData(copy=False)[indices]
- result = _GInputResult((x, y),
- item=curve,
- indices=indices,
- data=numpy.array((xData, yData)).T)
-
- elif picked[0] == 'image':
- image = picked[1]
- # Get corresponding coordinate in image
- origin = image.getOrigin()
- scale = image.getScale()
- column = int((x - origin[0]) / float(scale[0]))
- row = int((y - origin[1]) / float(scale[1]))
- data = image.getData(copy=False)[row, column]
- result = _GInputResult((x, y),
- item=image,
- indices=(row, column),
- data=data)
-
- self._markersAndResult.append((legend, result))
- self._updateStatusBar()
- if len(self._markersAndResult) == self._totalPoints:
- self.quit()
-
- def _handleInteractiveModeChanged(self, source):
- """Handle change of interactive mode in the plot
-
- :param source: Objects that triggered the mode change
+ plot = self.parent()
+ if plot is None:
+ return # No plot, abort
+
+ if not isinstance(roi, roi_items.PointROI):
+ # Only handle points
+ raise RuntimeError("Unexpected item")
+
+ x, y = roi.getPosition()
+ xPixel, yPixel = plot.dataToPixel(x, y, axis='left', check=False)
+
+ # Pick item at selected position
+ picked = plot._pickImageOrCurve(xPixel, yPixel)
+
+ if picked is None:
+ result = _GInputResult((x, y),
+ item=None,
+ indices=numpy.array((), dtype=int),
+ data=None)
+
+ elif picked[0] == 'curve':
+ curve = picked[1]
+ indices = picked[2]
+ xData = curve.getXData(copy=False)[indices]
+ yData = curve.getYData(copy=False)[indices]
+ result = _GInputResult((x, y),
+ item=curve,
+ indices=indices,
+ data=numpy.array((xData, yData)).T)
+
+ elif picked[0] == 'image':
+ image = picked[1]
+ # Get corresponding coordinate in image
+ origin = image.getOrigin()
+ scale = image.getScale()
+ column = int((x - origin[0]) / float(scale[0]))
+ row = int((y - origin[1]) / float(scale[1]))
+ data = image.getData(copy=False)[row, column]
+ result = _GInputResult((x, y),
+ item=image,
+ indices=(row, column),
+ data=data)
+
+ self.__selections[roi] = result
+
+ def __added(self, roi):
+ """Handle new ROI added
+
+ :param RegionOfInterest roi:
"""
- mode = self._plot.getInteractiveMode()
- if mode['mode'] == 'zoom': # Handle click events
- self._plot.sigPlotSignal.connect(self._handleSelect)
- else: # Do not handle click event
- self._plot.sigPlotSignal.disconnect(self._handleSelect)
- self._updateStatusBar()
+ if isinstance(roi, roi_items.PointROI):
+ # Only handle points
+ roi.setLabel('%d' % len(self.__selections))
+ self.__updateSelection(roi)
+ roi.sigRegionChanged.connect(self.__regionChanged)
+
+ def __removed(self, roi):
+ """Handle ROI removed"""
+ if self.__selections.pop(roi, None) is not None:
+ roi.sigRegionChanged.disconnect(self.__regionChanged)
+
+ def __regionChanged(self):
+ """Handle update of a ROI"""
+ roi = self.sender()
+ self.__updateSelection(roi)
def ginput(n=1, timeout=30, plot=None):
@@ -537,7 +593,7 @@ def ginput(n=1, timeout=30, plot=None):
if widget.isVisible():
plot = widget
break
- else: # If no plot widgets are visible, take most recent one
+ else: # If no plot widget is visible, take the most recent one
try:
plot = _plots[0]
except IndexError:
@@ -549,6 +605,7 @@ def ginput(n=1, timeout=30, plot=None):
_logger.warning('No plot available to perform ginput, create one')
plot = Plot1D()
plot.show()
+ _plots.insert(0, plot)
plot.raise_() # So window becomes the top level one