diff options
Diffstat (limited to 'silx/gui/plot')
43 files changed, 2443 insertions, 985 deletions
diff --git a/silx/gui/plot/CompareImages.py b/silx/gui/plot/CompareImages.py index f7c4899..3875be4 100644 --- a/silx/gui/plot/CompareImages.py +++ b/silx/gui/plot/CompareImages.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -48,7 +48,11 @@ _logger = logging.getLogger(__name__) from silx.opencl import ocl if ocl is not None: - from silx.opencl import sift + try: + from silx.opencl import sift + except ImportError: + # sift module is not available (e.g., in official Debian packages) + sift = None else: # No OpenCL device or no pyopencl sift = None @@ -62,6 +66,7 @@ class VisualizationMode(enum.Enum): HORIZONTAL_LINE = 'hline' COMPOSITE_RED_BLUE_GRAY = "rbgchannel" COMPOSITE_RED_BLUE_GRAY_NEG = "rbgnegchannel" + COMPOSITE_A_MINUS_B = "aminusb" @enum.unique @@ -161,6 +166,16 @@ class CompareImagesToolBar(qt.QToolBar): self.__ycChannelModeAction = action self.__visualizationGroup.addAction(action) + icon = icons.getQIcon("compare-mode-a-minus-b") + action = qt.QAction(icon, "Raw A minus B compare mode", self) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + action.setShortcut(qt.QKeySequence(qt.Qt.Key_W)) + action.setProperty("mode", VisualizationMode.COMPOSITE_A_MINUS_B) + menu.addAction(action) + self.__ycChannelModeAction = action + self.__visualizationGroup.addAction(action) + menu = qt.QMenu(self) self.__alignmentAction = qt.QAction(self) self.__alignmentAction.setMenu(menu) @@ -539,6 +554,11 @@ class CompareImages(qt.QMainWindow): def __init__(self, parent=None, backend=None): qt.QMainWindow.__init__(self, parent) + self._resetZoomActive = True + self._colormap = Colormap() + """Colormap shared by all modes, except the compose images (rgb image)""" + self._colormapKeyPoints = Colormap('spring') + """Colormap used for sift keypoints""" if parent is None: self.setWindowTitle('Compare images') @@ -553,6 +573,7 @@ class CompareImages(qt.QMainWindow): self.__previousSeparatorPosition = None self.__plot = plot.PlotWidget(parent=self, backend=backend) + self.__plot.setDefaultColormap(self._colormap) self.__plot.getXAxis().setLabel('Columns') self.__plot.getYAxis().setLabel('Rows') if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward': @@ -630,6 +651,14 @@ class CompareImages(qt.QMainWindow): """ return self.__plot + def getColormap(self): + """ + + :return: colormap used for compare image + :rtype: silx.gui.colors.Colormap + """ + return self._colormap + def getRawPixelData(self, x, y): """Return the raw pixel of each image data from axes positions. @@ -835,7 +864,8 @@ class CompareImages(qt.QMainWindow): self.__raw1 = image1 self.__raw2 = image2 self.__updateData() - self.__plot.resetZoom() + if self.isAutoResetZoom(): + self.__plot.resetZoom() def setImage1(self, image1): """Set image1 to be compared. @@ -850,7 +880,8 @@ class CompareImages(qt.QMainWindow): """ self.__raw1 = image1 self.__updateData() - self.__plot.resetZoom() + if self.isAutoResetZoom(): + self.__plot.resetZoom() def setImage2(self, image2): """Set image2 to be compared. @@ -865,7 +896,8 @@ class CompareImages(qt.QMainWindow): """ self.__raw2 = image2 self.__updateData() - self.__plot.resetZoom() + if self.isAutoResetZoom(): + self.__plot.resetZoom() def __updateKeyPoints(self): """Update the displayed keypoints using cached keypoints. @@ -878,11 +910,11 @@ class CompareImages(qt.QMainWindow): y=data[1], z=1, value=data[2], - legend="keypoints", - colormap=Colormap("spring")) + colormap=self._colormapKeyPoints, + legend="keypoints") def __updateData(self): - """Compute aligned image when the alignement mode changes. + """Compute aligned image when the alignment mode changes. This function cache input images which are used when vertical/horizontal separators moves. @@ -943,6 +975,9 @@ class CompareImages(qt.QMainWindow): elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY: data1 = self.__composeImage(data1, data2, mode) data2 = numpy.empty((0, 0)) + elif mode == VisualizationMode.COMPOSITE_A_MINUS_B: + data1 = self.__composeImage(data1, data2, mode) + data2 = numpy.empty((0, 0)) elif mode == VisualizationMode.ONLY_A: data2 = numpy.empty((0, 0)) elif mode == VisualizationMode.ONLY_B: @@ -977,7 +1012,8 @@ class CompareImages(qt.QMainWindow): else: vmin = min(self.__data1.min(), self.__data2.min()) vmax = max(self.__data1.max(), self.__data2.max()) - colormap = Colormap(vmin=vmin, vmax=vmax) + colormap = self.getColormap() + colormap.setVRange(vmin=vmin, vmax=vmax) self.__image1.setColormap(colormap) self.__image2.setColormap(colormap) @@ -1025,6 +1061,13 @@ class CompareImages(qt.QMainWindow): :rtype: numpy.ndarray """ assert(data1.shape[0:2] == data2.shape[0:2]) + if mode == VisualizationMode.COMPOSITE_A_MINUS_B: + # TODO: this calculation has no interest of generating a 'composed' + # rgb image, this could be moved in an other function or doc + # should be modified + _type = data1.dtype + result = data1.astype(numpy.float64) - data2.astype(numpy.float64) + return result mode1 = self.__getImageMode(data1) if mode1 in ["rgb", "rgba"]: intensity1 = self.__luminosityImage(data1) @@ -1188,3 +1231,19 @@ class CompareImages(qt.QMainWindow): data2 = result["result"] self.__transformation = self.__toAffineTransformation(result) return data1, data2 + + def setAutoResetZoom(self, activate=True): + """ + + :param bool activate: True if we want to activate the automatic + plot reset zoom when setting images. + """ + self._resetZoomActive = activate + + def isAutoResetZoom(self): + """ + + :return: True if the automatic call to resetzoom is activated + :rtype: bool + """ + return self._resetZoomActive diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py index 2523cde..c8470ab 100644 --- a/silx/gui/plot/ComplexImageView.py +++ b/silx/gui/plot/ComplexImageView.py @@ -39,6 +39,7 @@ import logging import collections import numpy +from ...utils.deprecation import deprecated from .. import qt, icons from .PlotWindow import Plot2D from . import items @@ -170,16 +171,16 @@ class _ComplexDataToolButton(qt.QToolButton): """ _MODES = collections.OrderedDict([ - (ImageComplexData.Mode.ABSOLUTE, ('math-amplitude', 'Amplitude')), - (ImageComplexData.Mode.SQUARE_AMPLITUDE, + (ImageComplexData.ComplexMode.ABSOLUTE, ('math-amplitude', 'Amplitude')), + (ImageComplexData.ComplexMode.SQUARE_AMPLITUDE, ('math-square-amplitude', 'Square amplitude')), - (ImageComplexData.Mode.PHASE, ('math-phase', 'Phase')), - (ImageComplexData.Mode.REAL, ('math-real', 'Real part')), - (ImageComplexData.Mode.IMAGINARY, + (ImageComplexData.ComplexMode.PHASE, ('math-phase', 'Phase')), + (ImageComplexData.ComplexMode.REAL, ('math-real', 'Real part')), + (ImageComplexData.ComplexMode.IMAGINARY, ('math-imaginary', 'Imaginary part')), - (ImageComplexData.Mode.AMPLITUDE_PHASE, + (ImageComplexData.ComplexMode.AMPLITUDE_PHASE, ('math-phase-color', 'Amplitude and Phase')), - (ImageComplexData.Mode.LOG10_AMPLITUDE_PHASE, + (ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE, ('math-phase-color-log', 'Log10(Amp.) and Phase')) ]) @@ -208,7 +209,7 @@ class _ComplexDataToolButton(qt.QToolButton): self.setPopupMode(qt.QToolButton.InstantPopup) - self._modeChanged(self._plot2DComplex.getVisualizationMode()) + self._modeChanged(self._plot2DComplex.getComplexMode()) self._plot2DComplex.sigVisualizationModeChanged.connect( self._modeChanged) @@ -217,7 +218,8 @@ class _ComplexDataToolButton(qt.QToolButton): icon, text = self._MODES[mode] self.setIcon(icons.getQIcon(icon)) self.setToolTip('Display the ' + text.lower()) - self._rangeDialogAction.setEnabled(mode == ImageComplexData.Mode.LOG10_AMPLITUDE_PHASE) + self._rangeDialogAction.setEnabled( + mode == ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE) def _triggered(self, action): """Handle triggering of menu actions""" @@ -244,8 +246,8 @@ class _ComplexDataToolButton(qt.QToolButton): else: # update mode mode = action.data() - if isinstance(mode, ImageComplexData.Mode): - self._plot2DComplex.setVisualizationMode(mode) + if isinstance(mode, ImageComplexData.ComplexMode): + self._plot2DComplex.setComplexMode(mode) def _rangeChanged(self, range_): """Handle updates of range in the dialog""" @@ -258,8 +260,8 @@ class ComplexImageView(qt.QWidget): :param parent: See :class:`QMainWindow` """ - Mode = ImageComplexData.Mode - """Also expose the modes inside the class""" + ComplexMode = ImageComplexData.ComplexMode + """Complex Modes enumeration""" sigDataChanged = qt.Signal() """Signal emitted when data has changed.""" @@ -301,7 +303,7 @@ class ComplexImageView(qt.QWidget): if event is items.ItemChangedType.DATA: self.sigDataChanged.emit() elif event is items.ItemChangedType.VISUALIZATION_MODE: - mode = self.getVisualizationMode() + mode = self.getComplexMode() self.sigVisualizationModeChanged.emit(mode) def getPlot(self): @@ -344,15 +346,34 @@ class ComplexImageView(qt.QWidget): False to return internal data (do not modify!) :rtype: numpy.ndarray of float with 2 dims or RGBA image (uint8). """ - mode = self.getVisualizationMode() - if mode in (self.Mode.AMPLITUDE_PHASE, - self.Mode.LOG10_AMPLITUDE_PHASE): + mode = self.getComplexMode() + if mode in (self.ComplexMode.AMPLITUDE_PHASE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE): return self._plotImage.getRgbaImageData(copy=copy) else: return self._plotImage.getData(copy=copy) + # Backward compatibility + + Mode = ComplexMode + + @classmethod + @deprecated(replacement='supportedComplexModes', since_version='0.11.0') + def getSupportedVisualizationModes(cls): + return cls.supportedComplexModes() + + @deprecated(replacement='setComplexMode', since_version='0.11.0') + def setVisualizationMode(self, mode): + return self.setComplexMode(mode) + + @deprecated(replacement='getComplexMode', since_version='0.11.0') + def getVisualizationMode(self): + return self.getComplexMode() + + # Image item proxy + @staticmethod - def getSupportedVisualizationModes(): + def supportedComplexModes(): """Returns the supported visualization modes. Supported visualization modes are: @@ -365,31 +386,33 @@ class ComplexImageView(qt.QWidget): - log10_amplitude_phase: Color-coded phase with log10(amplitude) as alpha. - :rtype: List[Mode] + :rtype: List[ComplexMode] """ - return tuple(ImageComplexData.Mode) + return ImageComplexData.supportedComplexModes() - def setVisualizationMode(self, mode): + def setComplexMode(self, mode): """Set the mode of visualization of the complex data. - See :meth:`getSupportedVisualizationModes` for the list of + See :meth:`supportedComplexModes` for the list of supported modes. How-to change visualization mode:: widget = ComplexImageView() - widget.setVisualizationMode(ComplexImageView.Mode.PHASE) + widget.setComplexMode(ComplexImageView.ComplexMode.PHASE) + # or + widget.setComplexMode('phase') - :param Mode mode: The mode to use. + :param Unions[ComplexMode,str] mode: The mode to use. """ - self._plotImage.setVisualizationMode(mode) + self._plotImage.setComplexMode(mode) - def getVisualizationMode(self): + def getComplexMode(self): """Get the current visualization mode of the complex data. - :rtype: Mode + :rtype: ComplexMode """ - return self._plotImage.getVisualizationMode() + return self._plotImage.getComplexMode() def _setAmplitudeRangeInfo(self, max_=None, delta=2): """Set the amplitude range to display for 'log10_amplitude_phase' mode. @@ -407,8 +430,6 @@ class ComplexImageView(qt.QWidget): :rtype: 2-tuple""" return self._plotImage._getAmplitudeRangeInfo() - # Image item proxy - def setColormap(self, colormap, mode=None): """Set the colormap to use for amplitude, phase, real or imaginary. @@ -416,14 +437,14 @@ class ComplexImageView(qt.QWidget): amplitude and phase. :param ~silx.gui.colors.Colormap colormap: The colormap - :param Mode mode: If specified, set the colormap of this specific mode + :param ComplexMode mode: If specified, set the colormap of this specific mode """ self._plotImage.setColormap(colormap, mode) def getColormap(self, mode=None): """Returns the colormap used to display the data. - :param Mode mode: If specified, set the colormap of this specific mode + :param ComplexMode mode: If specified, set the colormap of this specific mode :rtype: ~silx.gui.colors.Colormap """ return self._plotImage.getColormap(mode=mode) diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py index b426a23..050b344 100644 --- a/silx/gui/plot/CurvesROIWidget.py +++ b/silx/gui/plot/CurvesROIWidget.py @@ -75,14 +75,19 @@ class CurvesROIWidget(qt.QWidget): """ sigROISignal = qt.Signal(object) - """Deprecated signal for backward compatibility with silx < 0.7. - Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal` - """ def __init__(self, parent=None, name=None, plot=None): super(CurvesROIWidget, self).__init__(parent) if name is not None: self.setWindowTitle(name) + self.__lastSigROISignal = None + """Store the last value emitted for the sigRoiSignal. In the case the + active curve change we need to add this extra step in order to make + sure we won't send twice the sigROISignal. + This come from the fact sigROISignal is connected to the + activeROIChanged signal which is emitted when raw and net counts + values are changing but are not embed in the sigROISignal. + """ assert plot is not None self._plotRef = weakref.ref(plot) self._showAllMarkers = False @@ -91,12 +96,12 @@ class CurvesROIWidget(qt.QWidget): layout = qt.QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) - ############## + self.headerLabel = qt.QLabel(self) self.headerLabel.setAlignment(qt.Qt.AlignHCenter) self.setHeader() layout.addWidget(self.headerLabel) - ############## + widgetAllCheckbox = qt.QWidget(parent=self) self._showAllCheckBox = qt.QCheckBox("show all ROI", parent=widgetAllCheckbox) @@ -106,14 +111,13 @@ class CurvesROIWidget(qt.QWidget): widgetAllCheckbox.layout().addWidget(spacer) widgetAllCheckbox.layout().addWidget(self._showAllCheckBox) layout.addWidget(widgetAllCheckbox) - ############## + self.roiTable = ROITable(self, plot=plot) rheight = self.roiTable.horizontalHeader().sizeHint().height() self.roiTable.setMinimumHeight(4 * rheight) layout.addWidget(self.roiTable) self._roiFileDir = qt.QDir.home().absolutePath() self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers) - ################# hbox = qt.QWidget(self) hboxlayout = qt.QHBoxLayout(hbox) @@ -216,7 +220,6 @@ class CurvesROIWidget(qt.QWidget): i += 1 newroi = "newroi %d" % i return newroi - roi = ROI(name=getNextRoiName()) if roi.getName() == "ICR": @@ -231,7 +234,6 @@ class CurvesROIWidget(qt.QWidget): fromdata, dummy0, todata, dummy1 = self._getAllLimits() roi.setFrom(fromdata) roi.setTo(todata) - self.roiTable.addRoi(roi) # back compatibility pymca roi signals @@ -257,7 +259,9 @@ class CurvesROIWidget(qt.QWidget): def _reset(self): """Reset button clicked handler""" self.roiTable.clear() + old = self.blockSignals(True) # avoid several sigROISignal emission self._add() + self.blockSignals(old) # back compatibility pymca roi signals ddict = {} @@ -402,7 +406,9 @@ class CurvesROIWidget(qt.QWidget): if visible: # if no ROI existing yet, add the default one if self.roiTable.rowCount() is 0: + old = self.blockSignals(True) # avoid several sigROISignal emission self._add() + self.blockSignals(old) self.calculateRois() def fillFromROIDict(self, *args, **kwargs): @@ -416,7 +422,10 @@ class CurvesROIWidget(qt.QWidget): ddict['current'] = self.roiTable.activeRoi.getName() else: ddict['current'] = None - self.sigROISignal.emit(ddict) + + if self.__lastSigROISignal != ddict: + self.__lastSigROISignal = ddict + self.sigROISignal.emit(ddict) @property def currentRoi(self): @@ -563,8 +572,11 @@ class ROITable(TableWidget): # backward compatibility since 0.10.0 if isinstance(rois, dict): for roiName, roi in rois.items(): - roi['name'] = roiName - _roi = ROI._fromDict(roi) + if isinstance(roi, ROI): + _roi = roi + else: + roi['name'] = roiName + _roi = ROI._fromDict(roi) self.addRoi(_roi) else: for roi in rois: @@ -688,12 +700,14 @@ class ROITable(TableWidget): activeItems = self.selectedItems() if len(activeItems) is 0: return + old = self.blockSignals(True) # avoid several emission of sigROISignal roiToRm = set() for item in activeItems: row = item.row() itemID = self.item(row, self.COLUMNS_INDEX['ID']) roiToRm.add(self._roiDict[int(itemID.text())]) [self.removeROI(roi) for roi in roiToRm] + self.blockSignals(old) self.setActiveRoi(None) def removeROI(self, roi): @@ -730,7 +744,10 @@ class ROITable(TableWidget): else: assert isinstance(roi, ROI) if roi and roi.getID() in self._roiToItems.keys(): + # avoid several call back to setActiveROI + old = self.blockSignals(True) self.selectRow(self._roiToItems[roi.getID()].row()) + self.blockSignals(old) self._markersHandler.setActiveRoi(roi) self.activeROIChanged.emit() @@ -931,9 +948,12 @@ class ROITable(TableWidget): if ddict['event'] == 'markerMoved': label = ddict['label'] roiID = self._markersHandler.getRoiID(markerID=label) - if roiID: + if roiID is not None: + # avoid several emission of sigROISignal + old = self.blockSignals(True) self._markersHandler.changePosition(markerID=label, x=ddict['x']) + self.blockSignals(old) self._updateRoiInfo(roiID) def showEvent(self, event): diff --git a/silx/gui/plot/PlotToolButtons.py b/silx/gui/plot/PlotToolButtons.py index bf6b8ce..cd1a43f 100644 --- a/silx/gui/plot/PlotToolButtons.py +++ b/silx/gui/plot/PlotToolButtons.py @@ -47,7 +47,7 @@ from .. import icons from .. import qt from ... import config -from .items import SymbolMixIn +from .items import SymbolMixIn, Scatter _logger = logging.getLogger(__name__) @@ -352,23 +352,22 @@ class ProfileToolButton(PlotToolButton): self._profileDimensionChanged(2) -class SymbolToolButton(PlotToolButton): - """A tool button with a drop-down menu to control symbol size and marker. + +class _SymbolToolButtonBase(PlotToolButton): + """Base class for PlotToolButton setting marker and size. :param parent: See QWidget :param plot: The `~silx.gui.plot.PlotWidget` to control """ def __init__(self, parent=None, plot=None): - super(SymbolToolButton, self).__init__(parent=parent, plot=plot) + super(_SymbolToolButtonBase, self).__init__(parent=parent, plot=plot) - self.setToolTip('Set symbol size and marker') - self.setIcon(icons.getQIcon('plot-symbols')) - - menu = qt.QMenu(self) - - # Size slider + def _addSizeSliderToMenu(self, menu): + """Add a slider to set size to the given menu + :param QMenu menu: + """ slider = qt.QSlider(qt.Qt.Horizontal) slider.setRange(1, 20) slider.setValue(config.DEFAULT_PLOT_SYMBOL_SIZE) @@ -378,10 +377,11 @@ class SymbolToolButton(PlotToolButton): widgetAction.setDefaultWidget(slider) menu.addAction(widgetAction) - menu.addSeparator() - - # Marker actions + def _addSymbolsToMenu(self, menu): + """Add symbols to the given menu + :param QMenu menu: + """ for marker, name in zip(SymbolMixIn.getSupportedSymbols(), SymbolMixIn.getSupportedSymbolNames()): action = qt.QAction(name, menu) @@ -390,9 +390,6 @@ class SymbolToolButton(PlotToolButton): functools.partial(self._markerChanged, marker)) menu.addAction(action) - self.setMenu(menu) - self.setPopupMode(qt.QToolButton.InstantPopup) - def _sizeChanged(self, value): """Manage slider value changed @@ -418,3 +415,78 @@ class SymbolToolButton(PlotToolButton): for item in plot._getItems(withhidden=True): if isinstance(item, SymbolMixIn): item.setSymbol(marker) + + +class SymbolToolButton(_SymbolToolButtonBase): + """A tool button with a drop-down menu to control symbol size and marker. + + :param parent: See QWidget + :param plot: The `~silx.gui.plot.PlotWidget` to control + """ + + def __init__(self, parent=None, plot=None): + super(SymbolToolButton, self).__init__(parent=parent, plot=plot) + + self.setToolTip('Set symbol size and marker') + self.setIcon(icons.getQIcon('plot-symbols')) + + menu = qt.QMenu(self) + self._addSizeSliderToMenu(menu) + menu.addSeparator() + self._addSymbolsToMenu(menu) + + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + +class ScatterVisualizationToolButton(_SymbolToolButtonBase): + """QToolButton to select the visualization mode of scatter plot + + :param parent: See QWidget + :param plot: The `~silx.gui.plot.PlotWidget` to control + """ + + def __init__(self, parent=None, plot=None): + super(ScatterVisualizationToolButton, self).__init__( + parent=parent, plot=plot) + + self.setToolTip( + 'Set scatter visualization mode, symbol marker and size') + self.setIcon(icons.getQIcon('eye')) + + menu = qt.QMenu(self) + + # Add visualization modes + + for mode in Scatter.supportedVisualizations(): + name = mode.value.capitalize() + action = qt.QAction(name, menu) + action.setCheckable(False) + action.triggered.connect( + functools.partial(self._visualizationChanged, mode)) + menu.addAction(action) + + menu.addSeparator() + + submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol") + self._addSymbolsToMenu(submenu) + + submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol Size") + self._addSizeSliderToMenu(submenu) + + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + def _visualizationChanged(self, mode): + """Handle change of visualization mode. + + :param ScatterVisualizationMixIn.Visualization mode: + The visualization mode to use for scatter + """ + plot = self.plot() + if plot is None: + return + + for item in plot._getItems(withhidden=True): + if isinstance(item, Scatter): + item.setVisualization(mode) diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py index cfe39fa..9b9b4d2 100644 --- a/silx/gui/plot/PlotWidget.py +++ b/silx/gui/plot/PlotWidget.py @@ -33,12 +33,20 @@ __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" __date__ = "21/12/2018" +import logging + +_logger = logging.getLogger(__name__) + from collections import OrderedDict, namedtuple +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc from contextlib import contextmanager import datetime as dt import itertools -import logging +import warnings import numpy @@ -46,8 +54,11 @@ import silx from silx.utils.weakref import WeakMethodProxy from silx.utils.property import classproperty from silx.utils.deprecation import deprecated -# Import matplotlib backend here to init matplotlib our way -from .backends.BackendMatplotlib import BackendMatplotlibQt +try: + # Import matplotlib now to init matplotlib our way + from . import matplotlib +except ImportError: + _logger.debug("matplotlib not available") from ..colors import Colormap from .. import colors @@ -64,7 +75,6 @@ from .. import qt from ._utils.panzoom import ViewConstraints from ...gui.plot._utils.dtime_ticklayout import timestamp -_logger = logging.getLogger(__name__) _COLORDICT = colors.COLORDICT @@ -287,33 +297,68 @@ class PlotWidget(qt.QMainWindow): self._foregroundColorsUpdated() self._backgroundColorsUpdated() + def __getBackendClass(self, backend): + """Returns backend class corresponding to backend. + + If multiple backends are provided, the first available one is used. + + :param Union[str,BackendBase,Iterable] backend: + The name of the backend or its class or an iterable of those. + :rtype: BackendBase + :raise ValueError: In case the backend is not supported + :raise RuntimeError: If a backend is not available + """ + if callable(backend): + return backend + + elif isinstance(backend, str): + backend = backend.lower() + if backend in ('matplotlib', 'mpl'): + try: + from .backends.BackendMatplotlib import \ + BackendMatplotlibQt as backendClass + except ImportError: + _logger.debug("Backtrace", exc_info=True) + raise ImportError("matplotlib backend is not available") + + elif backend in ('gl', 'opengl'): + try: + from .backends.BackendOpenGL import \ + BackendOpenGL as backendClass + except ImportError: + _logger.debug("Backtrace", exc_info=True) + raise ImportError("OpenGL backend is not available") + + elif backend == 'none': + from .backends.BackendBase import BackendBase as backendClass + + else: + raise ValueError("Backend not supported %s" % backend) + + return backendClass + + elif isinstance(backend, abc.Iterable): + for b in backend: + try: + return self.__getBackendClass(b) + except ImportError: + pass + else: # No backend was found + raise ValueError("No supported backend was found") + + raise ValueError("Backend not supported %s" % str(backend)) + def _setBackend(self, backend): - """Setup a new backend""" + """Setup a new backend + + :param backend: Either a str defining the backend to use + """ assert(self._backend is None) if backend is None: backend = silx.config.DEFAULT_PLOT_BACKEND - if hasattr(backend, "__call__"): - backend = backend(self, self) - - elif hasattr(backend, "lower"): - lowerCaseString = backend.lower() - if lowerCaseString in ("matplotlib", "mpl"): - backendClass = BackendMatplotlibQt - elif lowerCaseString in ('gl', 'opengl'): - from .backends.BackendOpenGL import BackendOpenGL - backendClass = BackendOpenGL - elif lowerCaseString == 'none': - from .backends.BackendBase import BackendBase as backendClass - else: - raise ValueError("Backend not supported %s" % backend) - backend = backendClass(self, self) - - else: - raise ValueError("Backend not supported %s" % str(backend)) - - self._backend = backend + self._backend = self.__getBackendClass(backend)(self, self) # TODO: Can be removed for silx 0.10 @staticmethod @@ -456,7 +501,7 @@ class PlotWidget(qt.QMainWindow): return qt.QColor.fromRgbF(*self._dataBackgroundColor) def setDataBackgroundColor(self, color): - """Set the background color of this widget. + """Set the background color of the plot area. Set to None or an invalid QColor to use the background color. @@ -499,16 +544,25 @@ class PlotWidget(qt.QMainWindow): if item.isVisible(): bounds = item.getBounds() if bounds is not None: - xMin = numpy.nanmin([xMin, bounds[0]]) - xMax = numpy.nanmax([xMax, bounds[1]]) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + xMin = numpy.nanmin([xMin, bounds[0]]) + xMax = numpy.nanmax([xMax, bounds[1]]) # Take care of right axis if (isinstance(item, items.YAxisMixIn) and item.getYAxis() == 'right'): - yMinRight = numpy.nanmin([yMinRight, bounds[2]]) - yMaxRight = numpy.nanmax([yMaxRight, bounds[3]]) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + yMinRight = numpy.nanmin([yMinRight, bounds[2]]) + yMaxRight = numpy.nanmax([yMaxRight, bounds[3]]) else: - yMinLeft = numpy.nanmin([yMinLeft, bounds[2]]) - yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]]) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + yMinLeft = numpy.nanmin([yMinLeft, bounds[2]]) + yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]]) def lGetRange(x, y): return None if numpy.isnan(x) and numpy.isnan(y) else (x, y) @@ -2665,9 +2719,11 @@ class PlotWidget(qt.QMainWindow): xmin, xmax = (1., 100.) if ranges.x is None else ranges.x ymin, ymax = (1., 100.) if ranges.y is None else ranges.y if ranges.yright is None: - ymin2, ymax2 = None, None + ymin2, ymax2 = ymin, ymax else: ymin2, ymax2 = ranges.yright + if ranges.y is None: + ymin, ymax = ranges.yright # Add margins around data inside the plot area newLimits = list(_utils.addMarginsToLimits( diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py index b44a512..a39430e 100644 --- a/silx/gui/plot/PlotWindow.py +++ b/silx/gui/plot/PlotWindow.py @@ -29,15 +29,19 @@ The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`. __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "21/12/2018" +__date__ = "12/04/2019" -import collections +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc import logging import weakref import silx from silx.utils.weakref import WeakMethodProxy from silx.utils.deprecation import deprecated +from silx.utils.proxy import docstring from . import PlotWidget from . import actions @@ -128,53 +132,53 @@ class PlotWindow(PlotWidget): self.group.setExclusive(False) self.resetZoomAction = self.group.addAction( - actions.control.ResetZoomAction(self)) + actions.control.ResetZoomAction(self, parent=self)) self.resetZoomAction.setVisible(resetzoom) self.addAction(self.resetZoomAction) - self.zoomInAction = actions.control.ZoomInAction(self) + self.zoomInAction = actions.control.ZoomInAction(self, parent=self) self.addAction(self.zoomInAction) - self.zoomOutAction = actions.control.ZoomOutAction(self) + self.zoomOutAction = actions.control.ZoomOutAction(self, parent=self) self.addAction(self.zoomOutAction) self.xAxisAutoScaleAction = self.group.addAction( - actions.control.XAxisAutoScaleAction(self)) + actions.control.XAxisAutoScaleAction(self, parent=self)) self.xAxisAutoScaleAction.setVisible(autoScale) self.addAction(self.xAxisAutoScaleAction) self.yAxisAutoScaleAction = self.group.addAction( - actions.control.YAxisAutoScaleAction(self)) + actions.control.YAxisAutoScaleAction(self, parent=self)) self.yAxisAutoScaleAction.setVisible(autoScale) self.addAction(self.yAxisAutoScaleAction) self.xAxisLogarithmicAction = self.group.addAction( - actions.control.XAxisLogarithmicAction(self)) + actions.control.XAxisLogarithmicAction(self, parent=self)) self.xAxisLogarithmicAction.setVisible(logScale) self.addAction(self.xAxisLogarithmicAction) self.yAxisLogarithmicAction = self.group.addAction( - actions.control.YAxisLogarithmicAction(self)) + actions.control.YAxisLogarithmicAction(self, parent=self)) self.yAxisLogarithmicAction.setVisible(logScale) self.addAction(self.yAxisLogarithmicAction) self.gridAction = self.group.addAction( - actions.control.GridAction(self, gridMode='both')) + actions.control.GridAction(self, gridMode='both', parent=self)) self.gridAction.setVisible(grid) self.addAction(self.gridAction) self.curveStyleAction = self.group.addAction( - actions.control.CurveStyleAction(self)) + actions.control.CurveStyleAction(self, parent=self)) self.curveStyleAction.setVisible(curveStyle) self.addAction(self.curveStyleAction) self.colormapAction = self.group.addAction( - actions.control.ColormapAction(self)) + actions.control.ColormapAction(self, parent=self)) self.colormapAction.setVisible(colormap) self.addAction(self.colormapAction) self.colorbarAction = self.group.addAction( - actions_control.ColorBarAction(self, self)) + actions_control.ColorBarAction(self, parent=self)) self.colorbarAction.setVisible(False) self.addAction(self.colorbarAction) self._colorbar.setVisible(False) @@ -194,18 +198,18 @@ class PlotWindow(PlotWidget): self.getMaskAction().setVisible(mask) self._intensityHistoAction = self.group.addAction( - actions_histogram.PixelIntensitiesHistoAction(self)) + actions_histogram.PixelIntensitiesHistoAction(self, parent=self)) self._intensityHistoAction.setVisible(False) self._medianFilter2DAction = self.group.addAction( - actions_medfilt.MedianFilter2DAction(self)) + actions_medfilt.MedianFilter2DAction(self, parent=self)) self._medianFilter2DAction.setVisible(False) self._medianFilter1DAction = self.group.addAction( - actions_medfilt.MedianFilter1DAction(self)) + actions_medfilt.MedianFilter1DAction(self, parent=self)) self._medianFilter1DAction.setVisible(False) - self.fitAction = self.group.addAction(actions_fit.FitAction(self)) + self.fitAction = self.group.addAction(actions_fit.FitAction(self, parent=self)) self.fitAction.setVisible(fit) self.addAction(self.fitAction) @@ -250,7 +254,7 @@ class PlotWindow(PlotWidget): hbox.addWidget(self.controlButton) if position: # Add PositionInfo widget to the bottom of the plot - if isinstance(position, collections.Iterable): + if isinstance(position, abc.Iterable): # Use position as a set of converters converters = position else: @@ -278,7 +282,7 @@ class PlotWindow(PlotWidget): parent=self, plot=self) self.addToolBar(self._interactiveModeToolBar) - self._toolbar = self._createToolBar(title='Plot', parent=None) + self._toolbar = self._createToolBar(title='Plot', parent=self) self.addToolBar(self._toolbar) self._outputToolBar = tools.OutputToolBar(parent=self, plot=self) @@ -292,24 +296,21 @@ class PlotWindow(PlotWidget): for action in toolbar.actions(): self.addAction(action) + @docstring(PlotWidget) def setBackgroundColor(self, color): super(PlotWindow, self).setBackgroundColor(color) self._updateColorBarBackground() - setBackgroundColor.__doc__ = PlotWidget.setBackgroundColor.__doc__ - + @docstring(PlotWidget) def setDataBackgroundColor(self, color): super(PlotWindow, self).setDataBackgroundColor(color) self._updateColorBarBackground() - setDataBackgroundColor.__doc__ = PlotWidget.setDataBackgroundColor.__doc__ - + @docstring(PlotWidget) def setForegroundColor(self, color): super(PlotWindow, self).setForegroundColor(color) self._updateColorBarBackground() - setForegroundColor.__doc__ = PlotWidget.setForegroundColor.__doc__ - def _updateColorBarBackground(self): """Update the colorbar background according to the state of the plot""" if self._isAxesDisplayed(): @@ -824,7 +825,9 @@ class Plot2D(PlotWindow): posInfo = [ ('X', lambda x, y: x), ('Y', lambda x, y: y), - ('Data', WeakMethodProxy(self._getImageValue))] + ('Data', WeakMethodProxy(self._getImageValue)), + ('Dims', WeakMethodProxy(self._getImageDims)), + ] super(Plot2D, self).__init__(parent=parent, backend=backend, resetzoom=True, autoScale=False, @@ -924,6 +927,15 @@ class Plot2D(PlotWindow): return value, "Masked" return value + def _getImageDims(self, *args): + activeImage = self.getActiveImage() + if (activeImage is not None and + activeImage.getData(copy=False) is not None): + dims = activeImage.getData(copy=False).shape[1::-1] + return 'x'.join(str(dim) for dim in dims) + else: + return '-' + def getProfileToolbar(self): """Profile tools attached to this plot diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py index 46e4523..e2aa5a7 100644 --- a/silx/gui/plot/Profile.py +++ b/silx/gui/plot/Profile.py @@ -28,7 +28,7 @@ and stacks of images""" __authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno"] __license__ = "MIT" -__date__ = "24/07/2018" +__date__ = "12/04/2019" import weakref @@ -419,39 +419,40 @@ class ProfileToolBar(qt.QToolBar): """ if self._profileWindow is None: - self._profileMainWindow = ProfileMainWindow(self) + backend = type(plot._backend) + self._profileMainWindow = ProfileMainWindow(self, backend=backend) # Actions self._browseAction = actions.mode.ZoomModeAction(self.plot, parent=self) self._browseAction.setVisible(False) - self.hLineAction = qt.QAction( - icons.getQIcon('shape-horizontal'), - 'Horizontal Profile Mode', None) + self.hLineAction = qt.QAction(icons.getQIcon('shape-horizontal'), + 'Horizontal Profile Mode', + self) self.hLineAction.setToolTip( 'Enables horizontal profile selection mode') self.hLineAction.setCheckable(True) self.hLineAction.toggled[bool].connect(self._hLineActionToggled) - self.vLineAction = qt.QAction( - icons.getQIcon('shape-vertical'), - 'Vertical Profile Mode', None) + self.vLineAction = qt.QAction(icons.getQIcon('shape-vertical'), + 'Vertical Profile Mode', + self) self.vLineAction.setToolTip( 'Enables vertical profile selection mode') self.vLineAction.setCheckable(True) self.vLineAction.toggled[bool].connect(self._vLineActionToggled) - self.lineAction = qt.QAction( - icons.getQIcon('shape-diagonal'), - 'Free Line Profile Mode', None) + self.lineAction = qt.QAction(icons.getQIcon('shape-diagonal'), + 'Free Line Profile Mode', + self) self.lineAction.setToolTip( 'Enables line profile selection mode') self.lineAction.setCheckable(True) self.lineAction.toggled[bool].connect(self._lineActionToggled) - self.clearAction = qt.QAction( - icons.getQIcon('profile-clear'), - 'Clear Profile', None) + self.clearAction = qt.QAction(icons.getQIcon('profile-clear'), + 'Clear Profile', + self) self.clearAction.setToolTip( 'Clear the profile Region of interest') self.clearAction.setCheckable(False) diff --git a/silx/gui/plot/ProfileMainWindow.py b/silx/gui/plot/ProfileMainWindow.py index caa076c..39830d8 100644 --- a/silx/gui/plot/ProfileMainWindow.py +++ b/silx/gui/plot/ProfileMainWindow.py @@ -35,8 +35,15 @@ __date__ = "21/02/2017" class ProfileMainWindow(qt.QMainWindow): """QMainWindow providing 2 plot widgets specialized in 1D and 2D plotting, with different toolbars. + Only one of the plots is visible at any given time. + + :param qt.QWidget parent: The parent of this widget or None (default). + :param Union[str,Class] backend: The backend to use, in: + 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none' + or a :class:`BackendBase.BackendBase` class """ + sigProfileDimensionsChanged = qt.Signal(int) """This signal is emitted when :meth:`setProfileDimensions` is called. It carries the number of dimensions for the profile data (1 or 2). @@ -51,13 +58,14 @@ class ProfileMainWindow(qt.QMainWindow): """Emitted when the method to compute the profile changed (for now can be sum or mean)""" - def __init__(self, parent=None): + def __init__(self, parent=None, backend=None): qt.QMainWindow.__init__(self, parent=parent) self.setWindowTitle('Profile window') # plots are created on demand, in self.setProfileDimensions() self._plot1D = None self._plot2D = None + self._backend = backend # by default, profile is assumed to be a 1D curve self._profileType = None self.setProfileType("1D") @@ -76,7 +84,7 @@ class ProfileMainWindow(qt.QMainWindow): if self._plot2D is not None: self._plot2D.setParent(None) # necessary to avoid widget destruction if self._plot1D is None: - self._plot1D = Plot1D() + self._plot1D = Plot1D(backend=self._backend) self._plot1D.setGraphYLabel('Profile') self._plot1D.setGraphXLabel('') self.setCentralWidget(self._plot1D) @@ -84,7 +92,7 @@ class ProfileMainWindow(qt.QMainWindow): if self._plot1D is not None: self._plot1D.setParent(None) # necessary to avoid widget destruction if self._plot2D is None: - self._plot2D = Plot2D() + self._plot2D = Plot2D(backend=self._backend) self.setCentralWidget(self._plot2D) else: raise ValueError("Profile type must be '1D' or '2D'") diff --git a/silx/gui/plot/ScatterView.py b/silx/gui/plot/ScatterView.py index 5fc66ef..1d015d4 100644 --- a/silx/gui/plot/ScatterView.py +++ b/silx/gui/plot/ScatterView.py @@ -47,6 +47,8 @@ from .ScatterMaskToolsWidget import ScatterMaskToolsWidget from ..widgets.BoxLayoutDockWidget import BoxLayoutDockWidget from .. import qt, icons +from ...utils.proxy import docstring +from ...utils.weakref import WeakMethodProxy _logger = logging.getLogger(__name__) @@ -92,10 +94,10 @@ class ScatterView(qt.QMainWindow): self.__pickingCache = None self._positionInfo = tools.PositionInfo( plot=plot, - converters=(('X', lambda x, y: x), - ('Y', lambda x, y: y), - ('Data', lambda x, y: self._getScatterValue(x, y)), - ('Index', lambda x, y: self._getScatterIndex(x, y)))) + converters=(('X', WeakMethodProxy(self._getPickedX)), + ('Y', WeakMethodProxy(self._getPickedY)), + ('Data', WeakMethodProxy(self._getPickedValue)), + ('Index', WeakMethodProxy(self._getPickedIndex)))) # Combine plot, position info and colorbar into central widget gridLayout = qt.QGridLayout() @@ -167,32 +169,52 @@ class ScatterView(qt.QMainWindow): dataIndex = indices[-1] self.__pickingCache = ( dataIndex, + item.getXData(copy=False)[dataIndex], + item.getYData(copy=False)[dataIndex], item.getValueData(copy=False)[dataIndex]) break return self.__pickingCache - def _getScatterValue(self, x, y): - """Get data value of top most scatter plot at position (x, y) + def _getPickedIndex(self, x, y): + """Get data index of top most scatter plot at position (x, y) :param float x: X position in plot coordinates :param float y: Y position in plot coordinates - :return: The data value at that point or '-' + :return: The data index at that point or '-' """ picking = self._pickScatterData(x, y) - return '-' if picking is None else picking[1] + return '-' if picking is None else picking[0] - def _getScatterIndex(self, x, y): - """Get data index of top most scatter plot at position (x, y) + def _getPickedX(self, x, y): + """Returns X position snapped to scatter plot when close enough + + :param float x: + :param float y: + :rtype: float + """ + picking = self._pickScatterData(x, y) + return x if picking is None else picking[1] + + def _getPickedY(self, x, y): + """Returns Y position snapped to scatter plot when close enough + + :param float x: + :param float y: + :rtype: float + """ + picking = self._pickScatterData(x, y) + return y if picking is None else picking[2] + + def _getPickedValue(self, x, y): + """Get data value of top most scatter plot at position (x, y) :param float x: X position in plot coordinates :param float y: Y position in plot coordinates - :return: The data index at that point or '-' + :return: The data value at that point or '-' """ picking = self._pickScatterData(x, y) - return '-' if picking is None else picking[0] - - _PICK_OFFSET = 3 # Offset in pixel used for picking + return '-' if picking is None else picking[3] def _mouseInPlotArea(self, x, y): """Clip mouse coordinates to plot area coordinates @@ -307,11 +329,10 @@ class ScatterView(qt.QMainWindow): self.getScatterItem().setData( x=x, y=y, value=value, xerror=xerror, yerror=yerror, alpha=alpha, copy=copy) + @docstring(items.Scatter) def getData(self, *args, **kwargs): return self.getScatterItem().getData(*args, **kwargs) - getData.__doc__ = items.Scatter.getData.__doc__ - def getScatterItem(self): """Returns the plot item displaying the scatter data. @@ -329,37 +350,30 @@ class ScatterView(qt.QMainWindow): # Convenient proxies + @docstring(PlotWidget) def getXAxis(self, *args, **kwargs): return self.getPlotWidget().getXAxis(*args, **kwargs) - getXAxis.__doc__ = PlotWidget.getXAxis.__doc__ - + @docstring(PlotWidget) def getYAxis(self, *args, **kwargs): return self.getPlotWidget().getYAxis(*args, **kwargs) - getYAxis.__doc__ = PlotWidget.getYAxis.__doc__ - + @docstring(PlotWidget) def setGraphTitle(self, *args, **kwargs): return self.getPlotWidget().setGraphTitle(*args, **kwargs) - setGraphTitle.__doc__ = PlotWidget.setGraphTitle.__doc__ - + @docstring(PlotWidget) def getGraphTitle(self, *args, **kwargs): return self.getPlotWidget().getGraphTitle(*args, **kwargs) - getGraphTitle.__doc__ = PlotWidget.getGraphTitle.__doc__ - + @docstring(PlotWidget) def resetZoom(self, *args, **kwargs): return self.getPlotWidget().resetZoom(*args, **kwargs) - resetZoom.__doc__ = PlotWidget.resetZoom.__doc__ - + @docstring(ScatterMaskToolsWidget) def getSelectionMask(self, *args, **kwargs): return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs) - getSelectionMask.__doc__ = ScatterMaskToolsWidget.getSelectionMask.__doc__ - + @docstring(ScatterMaskToolsWidget) def setSelectionMask(self, *args, **kwargs): return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs) - - setSelectionMask.__doc__ = ScatterMaskToolsWidget.setSelectionMask.__doc__ diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py index 4ba4fab..5e2dc58 100644 --- a/silx/gui/plot/StatsWidget.py +++ b/silx/gui/plot/StatsWidget.py @@ -35,9 +35,11 @@ from collections import OrderedDict from contextlib import contextmanager import logging import weakref - +import functools import numpy - +import enum +from silx.utils.proxy import docstring +from silx.utils.enum import Enum as _Enum from silx.gui import qt from silx.gui import icons from silx.gui.plot import stats as statsmdl @@ -52,8 +54,15 @@ from . import items as plotitems _logger = logging.getLogger(__name__) +@enum.unique +class UpdateMode(_Enum): + AUTO = 'auto' + MANUAL = 'manual' + + # Helper class to handle specific calls to PlotWidget and SceneWidget + class _Wrapper(qt.QObject): """Base class for connection with PlotWidget and SceneWidget. @@ -319,10 +328,12 @@ class _StatsWidgetBase(object): """ Base class for all widgets which want to display statistics """ + def __init__(self, statsOnVisibleData, displayOnlyActItem): self._displayOnlyActItem = displayOnlyActItem self._statsOnVisibleData = statsOnVisibleData self._statsHandler = None + self._updateMode = UpdateMode.AUTO self.__default_skipped_events = ( ItemChangedType.ALPHA, @@ -503,6 +514,29 @@ class _StatsWidgetBase(object): """ return event in self.__default_skipped_events + def setUpdateMode(self, mode): + """Set the way to update the displayed statistics. + + :param mode: mode requested for update + :type mode: Union[str,UpdateMode] + """ + mode = UpdateMode.from_value(mode) + if mode != self._updateMode: + self._updateMode = mode + self._updateModeHasChanged() + + def getUpdateMode(self): + """Returns update mode (See :meth:`setUpdateMode`). + + :return: update mode + :rtype: UpdateMode + """ + return self._updateMode + + def _updateModeHasChanged(self): + """callback when the update mode has changed""" + pass + class StatsTable(_StatsWidgetBase, TableWidget): """ @@ -522,6 +556,9 @@ class StatsTable(_StatsWidgetBase, TableWidget): _LEGEND_HEADER_DATA = 'legend' _KIND_HEADER_DATA = 'kind' + sigUpdateModeChanged = qt.Signal(object) + """Signal emitted when the update mode changed""" + def __init__(self, parent=None, plot=None): TableWidget.__init__(self, parent) _StatsWidgetBase.__init__(self, statsOnVisibleData=False, @@ -606,6 +643,8 @@ class StatsTable(_StatsWidgetBase, TableWidget): def _updateItemObserve(self, *args): """Reload table depending on mode""" + if self.getUpdateMode() is UpdateMode.MANUAL: + return self._removeAllItems() # Get selected or all items from the plot @@ -678,11 +717,19 @@ class StatsTable(_StatsWidgetBase, TableWidget): :param event: """ + if self.getUpdateMode() is UpdateMode.MANUAL: + return if self._skipPlotItemChangedEvent(event) is True: return else: item = self.sender() self._updateStats(item) + # deal with stat items visibility + if event is ItemChangedType.VISIBLE: + if len(self._itemToTableItems(item).items()) > 0: + item_0 = list(self._itemToTableItems(item).values())[0] + row_index = item_0.row() + self.setRowHidden(row_index, not item.isVisible()) def _addItem(self, item): """Add a plot item to the table @@ -810,8 +857,13 @@ class StatsTable(_StatsWidgetBase, TableWidget): else: tableItem.setText(str(value)) - def _updateAllStats(self): - """Update stats for all rows in the table""" + def _updateAllStats(self, is_request=False): + """Update stats for all rows in the table + + :param bool is_request: True if come from a manual request + """ + if self.getUpdateMode() is UpdateMode.MANUAL and not is_request: + return with self._disableSorting(): for row in range(self.rowCount()): tableItem = self.item(row, 0) @@ -851,10 +903,103 @@ class StatsTable(_StatsWidgetBase, TableWidget): else: self.setSelectionMode(qt.QAbstractItemView.NoSelection) + def _updateModeHasChanged(self): + self.sigUpdateModeChanged.emit(self._updateMode) -class _OptionsWidget(qt.QToolBar): + +class UpdateModeWidget(qt.QWidget): + """Widget used to select the mode of update""" + sigUpdateModeChanged = qt.Signal(object) + """signal emitted when the mode for update changed""" + sigUpdateRequested = qt.Signal() + """signal emitted when an manual request for example is activate""" def __init__(self, parent=None): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QHBoxLayout()) + self._buttonGrp = qt.QButtonGroup(parent=self) + self._buttonGrp.setExclusive(True) + + spacer = qt.QSpacerItem(20, 20, + qt.QSizePolicy.Expanding, + qt.QSizePolicy.Minimum) + self.layout().addItem(spacer) + + self._autoRB = qt.QRadioButton('auto', parent=self) + self.layout().addWidget(self._autoRB) + self._buttonGrp.addButton(self._autoRB) + + self._manualRB = qt.QRadioButton('manual', parent=self) + self.layout().addWidget(self._manualRB) + self._buttonGrp.addButton(self._manualRB) + self._manualRB.setChecked(True) + + refresh_icon = icons.getQIcon('view-refresh') + self._updatePB = qt.QPushButton(refresh_icon, '', parent=self) + self.layout().addWidget(self._updatePB) + + # connect signal / SLOT + self._updatePB.clicked.connect(self._updateRequested) + self._manualRB.toggled.connect(self._manualButtonToggled) + self._autoRB.toggled.connect(self._autoButtonToggled) + + def _manualButtonToggled(self, checked): + if checked: + self.setUpdateMode(UpdateMode.MANUAL) + self.sigUpdateModeChanged.emit(self.getUpdateMode()) + + def _autoButtonToggled(self, checked): + if checked: + self.setUpdateMode(UpdateMode.AUTO) + self.sigUpdateModeChanged.emit(self.getUpdateMode()) + + def _updateRequested(self): + if self.getUpdateMode() is UpdateMode.MANUAL: + self.sigUpdateRequested.emit() + + def setUpdateMode(self, mode): + """Set the way to update the displayed statistics. + + :param mode: mode requested for update + :type mode: Union[str,UpdateMode] + """ + mode = UpdateMode.from_value(mode) + + if mode is UpdateMode.AUTO: + if not self._autoRB.isChecked(): + self._autoRB.setChecked(True) + elif mode is UpdateMode.MANUAL: + if not self._manualRB.isChecked(): + self._manualRB.setChecked(True) + else: + raise ValueError('mode', mode, 'is not recognized') + + def getUpdateMode(self): + """Returns update mode (See :meth:`setUpdateMode`). + + :return: the active update mode + :rtype: UpdateMode + """ + if self._manualRB.isChecked(): + return UpdateMode.MANUAL + elif self._autoRB.isChecked(): + return UpdateMode.AUTO + else: + raise RuntimeError("No mode selected") + + def showRadioButtons(self, show): + """show / hide the QRadioButtons + + :param bool show: if True make RadioButton visible + """ + self._autoRB.setVisible(show) + self._manualRB.setVisible(show) + + +class _OptionsWidget(qt.QToolBar): + + def __init__(self, parent=None, updateMode=None, displayOnlyActItem=False): + assert updateMode is not None qt.QToolBar.__init__(self, parent) self.setIconSize(qt.QSize(16, 16)) @@ -863,7 +1008,7 @@ class _OptionsWidget(qt.QToolBar): action.setText("Active items only") action.setToolTip("Display stats for active items only.") action.setCheckable(True) - action.setChecked(True) + action.setChecked(displayOnlyActItem) self.__displayActiveItems = action action = qt.QAction(self) @@ -909,9 +1054,26 @@ class _OptionsWidget(qt.QToolBar): self.dataRangeSelection.addAction(self.__useWholeData) self.dataRangeSelection.addAction(self.__useVisibleData) + self.__updateStatsAction = qt.QAction(self) + self.__updateStatsAction.setIcon(icons.getQIcon("view-refresh")) + self.__updateStatsAction.setText("update statistics") + self.__updateStatsAction.setToolTip("update statistics") + self.__updateStatsAction.setCheckable(False) + self._updateStatsSep = self.addSeparator() + self.addAction(self.__updateStatsAction) + + self._setUpdateMode(mode=updateMode) + + # expose API + self.sigUpdateStats = self.__updateStatsAction.triggered + def isActiveItemMode(self): return self.itemSelection.checkedAction() is self.__displayActiveItems + def setDisplayActiveItems(self, only_active): + self.__displayActiveItems.setChecked(only_active) + self.__displayWholeItems.setChecked(not only_active) + def isVisibleDataRangeMode(self): return self.dataRangeSelection.checkedAction() is self.__useVisibleData @@ -925,6 +1087,18 @@ class _OptionsWidget(qt.QToolBar): if not enabled: self.__useWholeData.setChecked(True) + def _setUpdateMode(self, mode): + self.__updateStatsAction.setVisible(mode == UpdateMode.MANUAL) + self._updateStatsSep.setVisible(mode == UpdateMode.MANUAL) + + def getUpdateStatsAction(self): + """ + + :return: the action for the automatic mode + :rtype: QAction + """ + return self.__updateStatsAction + class StatsWidget(qt.QWidget): """ @@ -954,19 +1128,26 @@ class StatsWidget(qt.QWidget): qt.QWidget.__init__(self, parent) self.setLayout(qt.QVBoxLayout()) self.layout().setContentsMargins(0, 0, 0, 0) - self._options = _OptionsWidget(parent=self) + self._options = _OptionsWidget(parent=self, updateMode=UpdateMode.MANUAL) self.layout().addWidget(self._options) self._statsTable = StatsTable(parent=self, plot=plot) + self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode()) + self._options._setUpdateMode(mode=self._statsTable.getUpdateMode()) self.setStats(stats) self.layout().addWidget(self._statsTable) + old = self._statsTable.blockSignals(True) self._options.itemSelection.triggered.connect( self._optSelectionChanged) self._options.dataRangeSelection.triggered.connect( self._optDataRangeChanged) - self._optSelectionChanged() self._optDataRangeChanged() + self._statsTable.blockSignals(old) + + self._statsTable.sigUpdateModeChanged.connect(self._options._setUpdateMode) + callback = functools.partial(self._getStatsTable()._updateAllStats, is_request=True) + self._options.sigUpdateStats.connect(callback) def _getStatsTable(self): """Returns the :class:`StatsTable` used by this widget. @@ -993,33 +1174,40 @@ class StatsWidget(qt.QWidget): # Proxy methods + @docstring(StatsTable) def setStats(self, statsHandler): return self._getStatsTable().setStats(statsHandler=statsHandler) - setStats.__doc__ = StatsTable.setStats.__doc__ - + @docstring(StatsTable) def setPlot(self, plot): self._options.setVisibleDataRangeModeEnabled( plot is None or isinstance(plot, PlotWidget)) return self._getStatsTable().setPlot(plot=plot) - setPlot.__doc__ = StatsTable.setPlot.__doc__ - + @docstring(StatsTable) def getPlot(self): return self._getStatsTable().getPlot() - getPlot.__doc__ = StatsTable.getPlot.__doc__ - + @docstring(StatsTable) def setDisplayOnlyActiveItem(self, displayOnlyActItem): + old = self._options.blockSignals(True) + # update the options + self._options.setDisplayActiveItems(displayOnlyActItem) + self._options.blockSignals(old) return self._getStatsTable().setDisplayOnlyActiveItem( displayOnlyActItem=displayOnlyActItem) - setDisplayOnlyActiveItem.__doc__ = StatsTable.setDisplayOnlyActiveItem.__doc__ - + @docstring(StatsTable) def setStatsOnVisibleData(self, b): return self._getStatsTable().setStatsOnVisibleData(b=b) - setStatsOnVisibleData.__doc__ = StatsTable.setStatsOnVisibleData.__doc__ + @docstring(StatsTable) + def getUpdateMode(self): + return self._statsTable.getUpdateMode() + + @docstring(StatsTable) + def setUpdateMode(self, mode): + self._statsTable.setUpdateMode(mode) DEFAULT_STATS = StatsHandler(( @@ -1050,13 +1238,13 @@ class BasicStatsWidget(StatsWidget): from silx.gui.plot import Plot1D from silx.gui.plot.StatsWidget import BasicStatsWidget - + plot = Plot1D() x = range(100) y = x plot.addCurve(x, y, legend='curve_0') plot.setActiveCurve('curve_0') - + widget = BasicStatsWidget(plot=plot) widget.show() """ @@ -1067,9 +1255,9 @@ class BasicStatsWidget(StatsWidget): class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): """ - Widget made to display stats into a QLayout with for all stat a couple - (QLabel, QLineEdit) created. - The the layout can be defined prior of adding any statistic. + Widget made to display stats into a QLayout with couple (QLabel, QLineEdit) + created for each stats. + The layout can be defined prior of adding any statistic. :param QWidget parent: Qt parent :param Union[PlotWidget,SceneWidget] plot: @@ -1081,6 +1269,9 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): only visible ones. """ + sigUpdateModeChanged = qt.Signal(object) + """Signal emitted when the update mode changed""" + def __init__(self, parent=None, plot=None, kind='curve', stats=None, statsOnVisibleData=False): self._item_kind = kind @@ -1141,6 +1332,8 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): self._updateAllStats() def _activeItemChanged(self, kind, previous, current): + if self.getUpdateMode() is UpdateMode.MANUAL: + return if kind == self._item_kind: self._updateAllStats() @@ -1148,9 +1341,9 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): plot = self.getPlot() if plot is not None: _items = self._plotWrapper.getSelectedItems() + def kind_filter(_item): return self._plotWrapper.getKind(_item) == self.getKind() - items = list(filter(kind_filter, _items)) assert len(items) in (0, 1) if len(items) is 1: @@ -1186,8 +1379,11 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): self._statQlineEdit[statName].setText(statVal) def _updateItemObserve(self, *argv): + if self.getUpdateMode() is UpdateMode.MANUAL: + return assert self._displayOnlyActItem _items = self._plotWrapper.getSelectedItems() + def kind_filter(_item): return self._plotWrapper.getKind(_item) == self.getKind() items = list(filter(kind_filter, _items)) @@ -1208,22 +1404,11 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): def _plotCurrentChanged(selfself, current): raise NotImplementedError('Display only the active item') + def _updateModeHasChanged(self): + self.sigUpdateModeChanged.emit(self._updateMode) -class BasicLineStatsWidget(_BaseLineStatsWidget): - """ - Widget defining a simple set of :class:`Stat` to be displayed on a - :class:`LineStatsWidget`. - - :param QWidget parent: Qt parent - :param Union[PlotWidget,SceneWidget] plot: - The plot containing items on which we want statistics. - :param str kind: the kind of plotitems we want to display - :param StatsHandler stats: - Set the statistics to be displayed and how to format them using - :param bool statsOnVisibleData: compute statistics for the whole data or - only visible ones. - """ +class _BasicLineStatsWidget(_BaseLineStatsWidget): def __init__(self, parent=None, plot=None, kind='curve', stats=DEFAULT_STATS, statsOnVisibleData=False): _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind, @@ -1246,38 +1431,84 @@ class BasicLineStatsWidget(_BaseLineStatsWidget): self.layout().addWidget(widget) + def _addOptionsWidget(self, widget): + self.layout().addWidget(widget) + -class BasicGridStatsWidget(_BaseLineStatsWidget): +class BasicLineStatsWidget(qt.QWidget): """ - pymca design like widget - + Widget defining a simple set of :class:`Stat` to be displayed on a + :class:`LineStatsWidget`. + :param QWidget parent: Qt parent :param Union[PlotWidget,SceneWidget] plot: The plot containing items on which we want statistics. + :param str kind: the kind of plotitems we want to display :param StatsHandler stats: Set the statistics to be displayed and how to format them using - :param str kind: the kind of plotitems we want to display :param bool statsOnVisibleData: compute statistics for the whole data or only visible ones. - :param int statsPerLine: number of statistic to be displayed per line + """ + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QHBoxLayout()) + self.layout().setSpacing(0) + self.layout().setContentsMargins(0, 0, 0, 0) + self._lineStatsWidget = _BasicLineStatsWidget(parent=self, plot=plot, + kind=kind, stats=stats, + statsOnVisibleData=statsOnVisibleData) + self.layout().addWidget(self._lineStatsWidget) + + self._options = UpdateModeWidget() + self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode()) + self._options.showRadioButtons(False) + self.layout().addWidget(self._options) - .. snapshotqt:: img/BasicGridStatsWidget.png - :width: 600px - :align: center + # connect Signal ? SLOT + self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode) + self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode) + self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats) - from silx.gui.plot import Plot1D - from silx.gui.plot.StatsWidget import BasicGridStatsWidget - - plot = Plot1D() - x = range(100) - y = x - plot.addCurve(x, y, legend='curve_0') - plot.setActiveCurve('curve_0') - - widget = BasicGridStatsWidget(plot=plot, kind='curve') - widget.show() - """ + def showControl(self, visible): + self._options.setVisible(visible) + + # Proxy methods + + @docstring(_BasicLineStatsWidget) + def setUpdateMode(self, mode): + self._lineStatsWidget.setUpdateMode(mode=mode) + + @docstring(_BasicLineStatsWidget) + def getUpdateMode(self): + return self._lineStatsWidget.getUpdateMode() + + @docstring(_BasicLineStatsWidget) + def setPlot(self, plot): + self._lineStatsWidget.setPlot(plot=plot) + @docstring(_BasicLineStatsWidget) + def setStats(self, statsHandler): + self._lineStatsWidget.setStats(statsHandler=statsHandler) + + @docstring(_BasicLineStatsWidget) + def setKind(self, kind): + self._lineStatsWidget.setKind(kind=kind) + + @docstring(_BasicLineStatsWidget) + def getKind(self): + return self._lineStatsWidget.getKind() + + @docstring(_BasicLineStatsWidget) + def setStatsOnVisibleData(self, b): + self._lineStatsWidget.setStatsOnVisibleData(b) + + @docstring(UpdateModeWidget) + def showRadioButtons(self, show): + self._options.showRadioButtons(show=show) + + +class _BasicGridStatsWidget(_BaseLineStatsWidget): def __init__(self, parent=None, plot=None, kind='curve', stats=DEFAULT_STATS, statsOnVisibleData=False, statsPerLine=4): @@ -1294,3 +1525,94 @@ class BasicGridStatsWidget(_BaseLineStatsWidget): def _createLayout(self): return qt.QGridLayout() + + +class BasicGridStatsWidget(qt.QWidget): + """ + pymca design like widget + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param str kind: the kind of plotitems we want to display + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + :param int statsPerLine: number of statistic to be displayed per line + + .. snapshotqt:: img/_BasicGridStatsWidget.png + :width: 600px + :align: center + + from silx.gui.plot import Plot1D + from silx.gui.plot.StatsWidget import _BasicGridStatsWidget + + plot = Plot1D() + x = range(100) + y = x + plot.addCurve(x, y, legend='curve_0') + plot.setActiveCurve('curve_0') + + widget = _BasicGridStatsWidget(plot=plot, kind='curve') + widget.show() + """ + + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setSpacing(0) + self.layout().setContentsMargins(0, 0, 0, 0) + + self._options = UpdateModeWidget() + self._options.showRadioButtons(False) + self.layout().addWidget(self._options) + + self._lineStatsWidget = _BasicGridStatsWidget(parent=self, plot=plot, + kind=kind, stats=stats, + statsOnVisibleData=statsOnVisibleData) + self.layout().addWidget(self._lineStatsWidget) + + # tune options + self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode()) + + # connect Signal ? SLOT + self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode) + self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode) + self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats) + + def showControl(self, visible): + self._options.setVisible(visible) + + @docstring(_BasicGridStatsWidget) + def setUpdateMode(self, mode): + self._lineStatsWidget.setUpdateMode(mode=mode) + + @docstring(_BasicGridStatsWidget) + def getUpdateMode(self): + return self._lineStatsWidget.getUpdateMode() + + @docstring(_BasicGridStatsWidget) + def setPlot(self, plot): + self._lineStatsWidget.setPlot(plot=plot) + + @docstring(_BasicGridStatsWidget) + def setStats(self, statsHandler): + self._lineStatsWidget.setStats(statsHandler=statsHandler) + + @docstring(_BasicGridStatsWidget) + def setKind(self, kind): + self._lineStatsWidget.setKind(kind=kind) + + @docstring(_BasicGridStatsWidget) + def getKind(self): + return self._lineStatsWidget.getKind() + + @docstring(_BasicGridStatsWidget) + def setStatsOnVisibleData(self, b): + self._lineStatsWidget.setStatsOnVisibleData(b) + + @docstring(UpdateModeWidget) + def showRadioButtons(self, show): + self._options.showRadioButtons(show=show) diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py index 0d11f17..d8e9fb5 100644 --- a/silx/gui/plot/_BaseMaskToolsWidget.py +++ b/silx/gui/plot/_BaseMaskToolsWidget.py @@ -29,7 +29,7 @@ from __future__ import division __authors__ = ["T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "15/02/2019" +__date__ = "12/04/2019" import os import weakref @@ -519,7 +519,7 @@ class BaseMaskToolsWidget(qt.QWidget): def _initTransparencyWidget(self): """ Init the mask transparency widget """ - transparencyWidget = qt.QWidget(self) + transparencyWidget = qt.QWidget(parent=self) grid = qt.QGridLayout() grid.setContentsMargins(0, 0, 0, 0) self.transparencySlider = qt.QSlider(qt.Qt.Horizontal, parent=transparencyWidget) @@ -619,8 +619,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.addAction(self.browseAction) # Draw tools - self.rectAction = qt.QAction( - icons.getQIcon('shape-rectangle'), 'Rectangle selection', None) + self.rectAction = qt.QAction(icons.getQIcon('shape-rectangle'), + 'Rectangle selection', + self) self.rectAction.setToolTip( 'Rectangle selection tool: (Un)Mask a rectangular region <b>R</b>') self.rectAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R)) @@ -628,8 +629,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.rectAction.triggered.connect(self._activeRectMode) self.addAction(self.rectAction) - self.ellipseAction = qt.QAction( - icons.getQIcon('shape-ellipse'), 'Circle selection', None) + self.ellipseAction = qt.QAction(icons.getQIcon('shape-ellipse'), + 'Circle selection', + self) self.ellipseAction.setToolTip( 'Rectangle selection tool: (Un)Mask a circle region <b>R</b>') self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R)) @@ -637,8 +639,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.ellipseAction.triggered.connect(self._activeEllipseMode) self.addAction(self.ellipseAction) - self.polygonAction = qt.QAction( - icons.getQIcon('shape-polygon'), 'Polygon selection', None) + self.polygonAction = qt.QAction(icons.getQIcon('shape-polygon'), + 'Polygon selection', + self) self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S)) self.polygonAction.setToolTip( 'Polygon selection tool: (Un)Mask a polygonal region <b>S</b><br>' @@ -648,8 +651,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.polygonAction.triggered.connect(self._activePolygonMode) self.addAction(self.polygonAction) - self.pencilAction = qt.QAction( - icons.getQIcon('draw-pencil'), 'Pencil tool', None) + self.pencilAction = qt.QAction(icons.getQIcon('draw-pencil'), + 'Pencil tool', + self) self.pencilAction.setShortcut(qt.QKeySequence(qt.Qt.Key_P)) self.pencilAction.setToolTip( 'Pencil tool: (Un)Mask using a pencil <b>P</b>') @@ -733,21 +737,24 @@ class BaseMaskToolsWidget(qt.QWidget): def _initThresholdGroupBox(self): """Init thresholding widgets""" - self.belowThresholdAction = qt.QAction( - icons.getQIcon('plot-roi-below'), 'Mask below threshold', None) + self.belowThresholdAction = qt.QAction(icons.getQIcon('plot-roi-below'), + 'Mask below threshold', + self) self.belowThresholdAction.setToolTip( 'Mask image where values are below given threshold') self.belowThresholdAction.setCheckable(True) self.belowThresholdAction.setChecked(True) - self.betweenThresholdAction = qt.QAction( - icons.getQIcon('plot-roi-between'), 'Mask within range', None) + self.betweenThresholdAction = qt.QAction(icons.getQIcon('plot-roi-between'), + 'Mask within range', + self) self.betweenThresholdAction.setToolTip( 'Mask image where values are within given range') self.betweenThresholdAction.setCheckable(True) - self.aboveThresholdAction = qt.QAction( - icons.getQIcon('plot-roi-above'), 'Mask above threshold', None) + self.aboveThresholdAction = qt.QAction(icons.getQIcon('plot-roi-above'), + 'Mask above threshold', + self) self.aboveThresholdAction.setToolTip( 'Mask image where values are above given threshold') self.aboveThresholdAction.setCheckable(True) @@ -760,8 +767,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.thresholdActionGroup.triggered.connect( self._thresholdActionGroupTriggered) - self.loadColormapRangeAction = qt.QAction( - icons.getQIcon('view-refresh'), 'Set min-max from colormap', None) + self.loadColormapRangeAction = qt.QAction(icons.getQIcon('view-refresh'), + 'Set min-max from colormap', + self) self.loadColormapRangeAction.setToolTip( 'Set min and max values from current colormap range') self.loadColormapRangeAction.setCheckable(False) @@ -774,7 +782,7 @@ class BaseMaskToolsWidget(qt.QWidget): btn.setDefaultAction(action) widgets.append(btn) - spacer = qt.QWidget() + spacer = qt.QWidget(parent=self) spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Preferred) widgets.append(spacer) @@ -1059,7 +1067,7 @@ class BaseMaskToolsWidget(qt.QWidget): self.maxLineLabel.setVisible(False) self.minLineEdit.setVisible(True) self.maxLineEdit.setVisible(False) - self.applyMaskBtn.setText("Mask bellow") + self.applyMaskBtn.setText("Mask below") elif triggeredAction is self.betweenThresholdAction: self.minLineLabel.setVisible(True) self.maxLineLabel.setVisible(True) diff --git a/silx/gui/plot/_utils/delaunay.py b/silx/gui/plot/_utils/delaunay.py new file mode 100644 index 0000000..49ad05f --- /dev/null +++ b/silx/gui/plot/_utils/delaunay.py @@ -0,0 +1,62 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Wrapper over Delaunay implementation""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/05/2019" + + +import logging +import sys + +import numpy + + +_logger = logging.getLogger(__name__) + + +def delaunay(x, y): + """Returns Delaunay instance for x, y points + + :param numpy.ndarray x: + :param numpy.ndarray y: + :rtype: Union[None,scipy.spatial.Delaunay] + """ + # Lazy-loading of Delaunay + try: + from scipy.spatial import Delaunay as _Delaunay + except ImportError: # Fallback using local Delaunay + from silx.third_party.scipy_spatial import Delaunay as _Delaunay + + points = numpy.array((x, y)).T + try: + delaunay = _Delaunay(points) + except (RuntimeError, ValueError): + _logger.error("Delaunay tesselation failed: %s", + sys.exc_info()[1]) + delaunay = None + + return delaunay diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py index 2d01ef1..ec4a3de 100644 --- a/silx/gui/plot/actions/control.py +++ b/silx/gui/plot/actions/control.py @@ -35,6 +35,7 @@ The following QAction are available: - :class:`KeepAspectRatioAction` - :class:`PanWithArrowKeysAction` - :class:`ResetZoomAction` +- :class:`ShowAxisAction` - :class:`XAxisLogarithmicAction` - :class:`XAxisAutoScaleAction` - :class:`YAxisInvertedAction` @@ -43,7 +44,6 @@ The following QAction are available: - :class:`ZoomBackAction` - :class:`ZoomInAction` - :class:`ZoomOutAction` -- :class:'ShowAxisAction' """ from __future__ import division @@ -377,11 +377,11 @@ class ColormapAction(PlotAction): # Specific init for complex images colormap = image.getColormap() - mode = image.getVisualizationMode() - if mode in (items.ImageComplexData.Mode.AMPLITUDE_PHASE, - items.ImageComplexData.Mode.LOG10_AMPLITUDE_PHASE): + mode = image.getComplexMode() + if mode in (items.ImageComplexData.ComplexMode.AMPLITUDE_PHASE, + items.ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE): data = image.getData( - copy=False, mode=items.ImageComplexData.Mode.PHASE) + copy=False, mode=items.ImageComplexData.ComplexMode.PHASE) else: data = image.getData(copy=False) diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py index 0514c85..af37543 100644 --- a/silx/gui/plot/backends/BackendBase.py +++ b/silx/gui/plot/backends/BackendBase.py @@ -170,6 +170,23 @@ class BackendBase(object): """ return legend + def addTriangles(self, x, y, triangles, legend, + color, z, selectable, alpha): + """Add a set of triangles. + + :param numpy.ndarray x: The data corresponding to the x axis + :param numpy.ndarray y: The data corresponding to the y axis + :param numpy.ndarray triangles: The indices to make triangles + as a (Ntriangle, 3) array + :param str legend: The legend to be associated to the curve + :param numpy.ndarray color: color(s) as (npoints, 4) array + :param int z: Layer on which to draw the cuve + :param bool selectable: indicate if the curve can be selected + :param float alpha: Opacity as a float in [0., 1.] + :returns: The triangles' unique identifier used by the backend + """ + return legend + def addItem(self, x, y, legend, shape, color, fill, overlay, z, linestyle, linewidth, linebgcolor): """Add an item (i.e. a shape) to the plot. diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py index 726a839..7739329 100644 --- a/silx/gui/plot/backends/BackendMatplotlib.py +++ b/silx/gui/plot/backends/BackendMatplotlib.py @@ -54,7 +54,8 @@ from matplotlib.backend_bases import MouseEvent from matplotlib.lines import Line2D from matplotlib.collections import PathCollection, LineCollection from matplotlib.ticker import Formatter, ScalarFormatter, Locator - +from matplotlib.tri import Triangulation +from matplotlib.collections import TriMesh from . import BackendBase from .._utils import FLOAT32_MINPOS @@ -359,9 +360,12 @@ class BackendMatplotlib(BackendBase.BackendBase): else: errorbarColor = color - # On Debian 7 at least, Nx1 array yerr does not seems supported + # Nx1 error array deprecated in matplotlib >=3.1 (removed in 3.3) + if (isinstance(xerror, numpy.ndarray) and xerror.ndim == 2 and + xerror.shape[1] == 1): + xerror = numpy.ravel(xerror) if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and - yerror.shape[1] == 1 and len(x) != 1): + yerror.shape[1] == 1): yerror = numpy.ravel(yerror) errorbars = axes.errorbar(x, y, label=legend, @@ -477,6 +481,32 @@ class BackendMatplotlib(BackendBase.BackendBase): self.ax.add_artist(image) return image + def addTriangles(self, x, y, triangles, legend, + color, z, selectable, alpha): + for parameter in (x, y, triangles, legend, color, + z, selectable, alpha): + assert parameter is not None + + # 0 enables picking on filled triangle + picker = 0 if selectable else None + + color = numpy.array(color, copy=False) + assert color.ndim == 2 and len(color) == len(x) + + if color.dtype not in [numpy.float32, numpy.float]: + color = color.astype(numpy.float32) / 255. + + collection = TriMesh( + Triangulation(x, y, triangles), + label=legend, + alpha=alpha, + picker=picker, + zorder=z) + collection.set_color(color) + self.ax.add_collection(collection) + + return collection + def addItem(self, x, y, legend, shape, color, fill, overlay, z, linestyle, linewidth, linebgcolor): if (linebgcolor is not None and @@ -1100,6 +1130,22 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): elif label.startswith('__IMAGE__'): self._picked.append({'kind': 'image', 'legend': label[9:]}) + elif isinstance(event.artist, TriMesh): + # Convert selected triangle to data point indices + triangulation = event.artist._triangulation + indices = triangulation.get_masked_triangles()[event.ind[0]] + + # Sort picked triangle points by distance to mouse + # from furthest to closest to put closest point last + # This is to be somewhat consistent with last scatter point + # being the top one. + dists = ((triangulation.x[indices] - event.mouseevent.xdata) ** 2 + + (triangulation.y[indices] - event.mouseevent.ydata) ** 2) + indices = indices[numpy.flip(numpy.argsort(dists))] + + self._picked.append({'kind': 'curve', 'legend': label, + 'indices': indices}) + else: # it's a curve, item have no picker for now if not isinstance(event.artist, (PathCollection, Line2D)): _logger.info('Unsupported artist, ignored') diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py index e33d03c..0420aa9 100644 --- a/silx/gui/plot/backends/BackendOpenGL.py +++ b/silx/gui/plot/backends/BackendOpenGL.py @@ -31,8 +31,9 @@ __license__ = "MIT" __date__ = "21/12/2018" from collections import OrderedDict, namedtuple -from ctypes import c_void_p import logging +import warnings +import weakref import numpy @@ -44,7 +45,7 @@ from ... import qt from ..._glutils import gl from ... import _glutils as glu from .glutils import ( - GLLines2D, + GLLines2D, GLPlotTriangles, GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D, mat4Ortho, mat4Identity, LEFT, RIGHT, BOTTOM, TOP, @@ -106,7 +107,7 @@ class PlotDataContent(object): This class is only meant to work with _OpenGLPlotCanvas. """ - _PRIMITIVE_TYPES = 'curve', 'image' + _PRIMITIVE_TYPES = 'curve', 'image', 'triangles' def __init__(self): self._primitives = OrderedDict() # For images and curves @@ -124,6 +125,8 @@ class PlotDataContent(object): primitiveType = 'curve' elif isinstance(primitive, (GLPlotColormap, GLPlotRGBAImage)): primitiveType = 'image' + elif isinstance(primitive, GLPlotTriangles): + primitiveType = 'triangles' else: raise RuntimeError('Unsupported object type: %s', primitive) @@ -304,16 +307,8 @@ _texFragShd = """ } """ - # BackendOpenGL ############################################################### -_current_context = None - - -def _getContext(): - assert _current_context is not None - return _current_context - class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): """OpenGL-based Plot backend. @@ -348,7 +343,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): _baseVertShd, _baseFragShd, attrib0='position') self._progTex = glu.Program( _texVertShd, _texFragShd, attrib0='position') - self._plotFBOs = {} + self._plotFBOs = weakref.WeakKeyDictionary() self._keepDataAspectRatio = False @@ -386,6 +381,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): return qt.QSize(8 * 80, 6 * 80) # Mimic MatplotlibBackend def mousePressEvent(self, event): + if event.button() not in self._MOUSE_BTNS: + return super(BackendOpenGL, self).mousePressEvent(event) xPixel = event.x() * self.getDevicePixelRatio() yPixel = event.y() * self.getDevicePixelRatio() btn = self._MOUSE_BTNS[event.button()] @@ -411,6 +408,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): event.accept() def mouseReleaseEvent(self, event): + if event.button() not in self._MOUSE_BTNS: + return super(BackendOpenGL, self).mouseReleaseEvent(event) xPixel = event.x() * self.getDevicePixelRatio() yPixel = event.y() * self.getDevicePixelRatio() @@ -462,15 +461,17 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._renderOverlayGL() def _paintFBOGL(self): - context = glu.getGLContext() + context = glu.Context.getCurrent() plotFBOTex = self._plotFBOs.get(context) if (self._plot._getDirtyPlot() or self._plotFrame.isDirty or plotFBOTex is None): - self._plotVertices = numpy.array(((-1., -1., 0., 0.), - (1., -1., 1., 0.), - (-1., 1., 0., 1.), - (1., 1., 1., 1.)), - dtype=numpy.float32) + self._plotVertices = ( + # Vertex coordinates + numpy.array(((-1., -1.), (1., -1.), (-1., 1.), (1., 1.)), + dtype=numpy.float32), + # Texture coordinates + numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)), + dtype=numpy.float32)) if plotFBOTex is None or \ plotFBOTex.shape[1] != self._plotFrame.size[0] or \ plotFBOTex.shape[0] != self._plotFrame.size[1]: @@ -502,53 +503,45 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE, mat4Identity().astype(numpy.float32)) - stride = self._plotVertices.shape[-1] * self._plotVertices.itemsize gl.glEnableVertexAttribArray(self._progTex.attributes['position']) gl.glVertexAttribPointer(self._progTex.attributes['position'], 2, gl.GL_FLOAT, gl.GL_FALSE, - stride, self._plotVertices) + 0, + self._plotVertices[0]) - texCoordsPtr = c_void_p(self._plotVertices.ctypes.data + - 2 * self._plotVertices.itemsize) # Better way? gl.glEnableVertexAttribArray(self._progTex.attributes['texCoords']) gl.glVertexAttribPointer(self._progTex.attributes['texCoords'], 2, gl.GL_FLOAT, gl.GL_FALSE, - stride, texCoordsPtr) + 0, + self._plotVertices[1]) with plotFBOTex.texture: - gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices)) + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices[0])) self._renderMarkersGL() self._renderOverlayGL() def paintGL(self): - global _current_context - _current_context = self.context() - - glu.setGLContextGetter(_getContext) - - # Release OpenGL resources - for item in self._glGarbageCollector: - item.discard() - self._glGarbageCollector = [] - - gl.glClearColor(*self._backgroundColor) - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) + with glu.Context.current(self.context()): + # Release OpenGL resources + for item in self._glGarbageCollector: + item.discard() + self._glGarbageCollector = [] - # Check if window is large enough - plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] - if plotWidth <= 2 or plotHeight <= 2: - return + gl.glClearColor(*self._backgroundColor) + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) - # self._paintDirectGL() - self._paintFBOGL() + # Check if window is large enough + plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + if plotWidth <= 2 or plotHeight <= 2: + return - glu.setGLContextGetter() - _current_context = None + # self._paintDirectGL() + self._paintFBOGL() def _renderMarkersGL(self): if len(self._markers) == 0: @@ -892,7 +885,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): xErrorMinus, xErrorPlus = xerror[0], xerror[1] else: xErrorMinus, xErrorPlus = xerror, xerror - xErrorMinus = logX - numpy.log10(x - xErrorMinus) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore divide by zero, invalid value encountered in log10 + xErrorMinus = logX - numpy.log10(x - xErrorMinus) xErrorPlus = numpy.log10(x + xErrorPlus) - logX xerror = numpy.array((xErrorMinus, xErrorPlus), dtype=numpy.float32) @@ -912,7 +908,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): yErrorMinus, yErrorPlus = yerror[0], yerror[1] else: yErrorMinus, yErrorPlus = yerror, yerror - yErrorMinus = logY - numpy.log10(y - yErrorMinus) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore divide by zero, invalid value encountered in log10 + yErrorMinus = logY - numpy.log10(y - yErrorMinus) yErrorPlus = numpy.log10(y + yErrorPlus) - logY yerror = numpy.array((yErrorMinus, yErrorPlus), dtype=numpy.float32) @@ -1043,6 +1042,25 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): return legend, 'image' + def addTriangles(self, x, y, triangles, legend, + color, z, selectable, alpha): + + # Handle axes log scale: convert data + if self._plotFrame.xAxis.isLog: + x = numpy.log10(x) + if self._plotFrame.yAxis.isLog: + y = numpy.log10(y) + + triangles = GLPlotTriangles(x, y, color, triangles, alpha) + triangles.info = { + 'legend': legend, + 'zOrder': z, + 'behaviors': set(['selectable']) if selectable else set(), + } + self._plotContent.add(triangles) + + return legend, 'triangles' + def addItem(self, x, y, legend, shape, color, fill, overlay, z, linestyle, linewidth, linebgcolor): # TODO handle overlay @@ -1132,10 +1150,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._glGarbageCollector.append(curve) - elif kind == 'image': - image = self._plotContent.pop('image', legend) - if image is not None: - self._glGarbageCollector.append(image) + elif kind in ('image', 'triangles'): + item = self._plotContent.pop(kind, legend) + if item is not None: + self._glGarbageCollector.append(item) elif kind == 'marker': self._markers.pop(legend, False) @@ -1188,6 +1206,60 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._plotFrame.size[1] - self._plotFrame.margins.bottom - 1) return xPlot, yPlot + def __pickCurves(self, item, x, y): + """Perform picking on a curve item. + + :param GLPlotCurve2D item: + :param float x: X position of the mouse in widget coordinates + :param float y: Y position of the mouse in widget coordinates + :return: List of indices of picked points + :rtype: List[int] + """ + offset = self._PICK_OFFSET + if item.marker is not None: + offset = max(item.markerSize / 2., offset) + if item.lineStyle is not None: + offset = max(item.lineWidth / 2., offset) + + yAxis = item.info['yAxis'] + + inAreaPos = self._mouseInPlotArea(x - offset, y - offset) + dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], + axis=yAxis, check=True) + if dataPos is None: + return [] + xPick0, yPick0 = dataPos + + inAreaPos = self._mouseInPlotArea(x + offset, y + offset) + dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], + axis=yAxis, check=True) + if dataPos is None: + return [] + xPick1, yPick1 = dataPos + + if xPick0 < xPick1: + xPickMin, xPickMax = xPick0, xPick1 + else: + xPickMin, xPickMax = xPick1, xPick0 + + if yPick0 < yPick1: + yPickMin, yPickMax = yPick0, yPick1 + else: + yPickMin, yPickMax = yPick1, yPick0 + + # Apply log scale if axis is log + if self._plotFrame.xAxis.isLog: + xPickMin = numpy.log10(xPickMin) + xPickMax = numpy.log10(xPickMax) + + if (yAxis == 'left' and self._plotFrame.yAxis.isLog) or ( + yAxis == 'right' and self._plotFrame.y2Axis.isLog): + yPickMin = numpy.log10(yPickMin) + yPickMax = numpy.log10(yPickMax) + + return item.pick(xPickMin, yPickMin, + xPickMax, yPickMax) + def pickItems(self, x, y, kinds): picked = [] @@ -1236,56 +1308,20 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): picked.append(dict(kind='image', legend=item.info['legend'])) - elif 'curve' in kinds and isinstance(item, GLPlotCurve2D): - offset = self._PICK_OFFSET - if item.marker is not None: - offset = max(item.markerSize / 2., offset) - if item.lineStyle is not None: - offset = max(item.lineWidth / 2., offset) - - yAxis = item.info['yAxis'] - - inAreaPos = self._mouseInPlotArea(x - offset, y - offset) - dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], - axis=yAxis, check=True) - if dataPos is None: - continue - xPick0, yPick0 = dataPos - - inAreaPos = self._mouseInPlotArea(x + offset, y + offset) - dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], - axis=yAxis, check=True) - if dataPos is None: - continue - xPick1, yPick1 = dataPos - - if xPick0 < xPick1: - xPickMin, xPickMax = xPick0, xPick1 - else: - xPickMin, xPickMax = xPick1, xPick0 - - if yPick0 < yPick1: - yPickMin, yPickMax = yPick0, yPick1 - else: - yPickMin, yPickMax = yPick1, yPick0 - - # Apply log scale if axis is log - if self._plotFrame.xAxis.isLog: - xPickMin = numpy.log10(xPickMin) - xPickMax = numpy.log10(xPickMax) - - if (yAxis == 'left' and self._plotFrame.yAxis.isLog) or ( - yAxis == 'right' and self._plotFrame.y2Axis.isLog): - yPickMin = numpy.log10(yPickMin) - yPickMax = numpy.log10(yPickMax) - - pickedIndices = item.pick(xPickMin, yPickMin, - xPickMax, yPickMax) - if pickedIndices: - picked.append(dict(kind='curve', - legend=item.info['legend'], - indices=pickedIndices)) - + elif 'curve' in kinds: + if isinstance(item, GLPlotCurve2D): + pickedIndices = self.__pickCurves(item, x, y) + if pickedIndices: + picked.append(dict(kind='curve', + legend=item.info['legend'], + indices=pickedIndices)) + + elif isinstance(item, GLPlotTriangles): + pickedIndices = item.pick(*dataPos) + if pickedIndices: + picked.append(dict(kind='curve', + legend=item.info['legend'], + indices=pickedIndices)) return picked # Update curve diff --git a/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/silx/gui/plot/backends/glutils/GLPlotTriangles.py new file mode 100644 index 0000000..c756749 --- /dev/null +++ b/silx/gui/plot/backends/glutils/GLPlotTriangles.py @@ -0,0 +1,193 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides a class to render a set of 2D triangles +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import ctypes + +import numpy + +from .....math.combo import min_max +from .... import _glutils as glutils +from ...._glutils import gl + + +class GLPlotTriangles(object): + """Handle rendering of a set of colored triangles""" + + _PROGRAM = glutils.Program( + vertexShader=""" + #version 120 + + uniform mat4 matrix; + attribute float xPos; + attribute float yPos; + attribute vec4 color; + + varying vec4 vColor; + + void main(void) { + gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0); + vColor = color; + } + """, + fragmentShader=""" + #version 120 + + uniform float alpha; + varying vec4 vColor; + + void main(void) { + gl_FragColor = vColor; + gl_FragColor.a *= alpha; + } + """, + attrib0='xPos') + + def __init__(self, x, y, color, triangles, alpha=1.): + """ + + :param numpy.ndarray x: X coordinates of triangle corners + :param numpy.ndarray y: Y coordinates of triangle corners + :param numpy.ndarray color: color for each point + :param numpy.ndarray triangles: (N, 3) array of indices of triangles + :param float alpha: Opacity in [0, 1] + """ + # Check and convert input data + x = numpy.ravel(numpy.array(x, dtype=numpy.float32)) + y = numpy.ravel(numpy.array(y, dtype=numpy.float32)) + color = numpy.array(color, copy=False) + # Cast to uint32 + triangles = numpy.array(triangles, copy=False, dtype=numpy.uint32) + + assert x.size == y.size + assert x.size == len(color) + assert color.ndim == 2 and color.shape[1] in (3, 4) + if numpy.issubdtype(color.dtype, numpy.floating): + color = numpy.array(color, dtype=numpy.float32, copy=False) + elif numpy.issubdtype(color.dtype, numpy.integer): + color = numpy.array(color, dtype=numpy.uint8, copy=False) + else: + raise ValueError('Unsupported color type') + assert triangles.ndim == 2 and triangles.shape[1] == 3 + + self.__x_y_color = x, y, color + self.xMin, self.xMax = min_max(x, finite=True) + self.yMin, self.yMax = min_max(y, finite=True) + self.__triangles = triangles + self.__alpha = numpy.clip(float(alpha), 0., 1.) + self.__vbos = None + self.__indicesVbo = None + self.__picking_triangles = None + + def pick(self, x, y): + """Perform picking + + :param float x: X coordinates in plot data frame + :param float y: Y coordinates in plot data frame + :return: List of picked data point indices + :rtype: numpy.ndarray + """ + if (x < self.xMin or x > self.xMax or + y < self.yMin or y > self.yMax): + return () + xPts, yPts = self.__x_y_color[:2] + if self.__picking_triangles is None: + self.__picking_triangles = numpy.zeros( + self.__triangles.shape + (3,), dtype=numpy.float32) + self.__picking_triangles[:, :, 0] = xPts[self.__triangles] + self.__picking_triangles[:, :, 1] = yPts[self.__triangles] + + segment = numpy.array(((x, y, -1), (x, y, 1)), dtype=numpy.float32) + # Picked triangle indices + indices = glutils.segmentTrianglesIntersection( + segment, self.__picking_triangles)[0] + # Point indices + indices = numpy.unique(numpy.ravel(self.__triangles[indices])) + + # Sorted from furthest to closest point + dists = (xPts[indices] - x) ** 2 + (yPts[indices] - y) ** 2 + indices = indices[numpy.flip(numpy.argsort(dists))] + + return tuple(indices) + + def discard(self): + """Release resources on the GPU""" + if self.__vbos is not None: + self.__vbos[0].vbo.discard() + self.__vbos = None + self.__indicesVbo.discard() + self.__indicesVbo = None + + def prepare(self): + """Allocate resources on the GPU""" + if self.__vbos is None: + self.__vbos = glutils.vertexBuffer(self.__x_y_color) + # Normalization is need for color + self.__vbos[-1].normalization = True + + if self.__indicesVbo is None: + self.__indicesVbo = glutils.VertexBuffer( + numpy.ravel(self.__triangles), + usage=gl.GL_STATIC_DRAW, + target=gl.GL_ELEMENT_ARRAY_BUFFER) + + def render(self, matrix, isXLog, isYLog): + """Perform rendering + + :param numpy.ndarray matrix: 4x4 transform matrix to use + :param bool isXLog: + :param bool isYLog: + """ + self.prepare() + + if self.__vbos is None or self.__indicesVbo is None: + return # Nothing to display + + self._PROGRAM.use() + + gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], + 1, + gl.GL_TRUE, + matrix.astype(numpy.float32)) + + gl.glUniform1f(self._PROGRAM.uniforms['alpha'], self.__alpha) + + for index, name in enumerate(('xPos', 'yPos', 'color')): + attr = self._PROGRAM.attributes[name] + gl.glEnableVertexAttribArray(attr) + self.__vbos[index].setVertexAttrib(attr) + + with self.__indicesVbo: + gl.glDrawElements(gl.GL_TRIANGLES, + self.__triangles.size, + glutils.numpyToGLType(self.__triangles.dtype), + ctypes.c_void_p(0)) diff --git a/silx/gui/plot/backends/glutils/GLText.py b/silx/gui/plot/backends/glutils/GLText.py index 3d262bc..725c12c 100644 --- a/silx/gui/plot/backends/glutils/GLText.py +++ b/silx/gui/plot/backends/glutils/GLText.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,9 +33,11 @@ __date__ = "03/04/2017" from collections import OrderedDict +import weakref + import numpy -from ...._glutils import font, gl, getGLContext, Program, Texture +from ...._glutils import font, gl, Context, Program, Texture from .GLSupport import mat4Translate @@ -128,7 +130,7 @@ class Text2D(object): attrib0='position') # Discard texture objects when removed from the cache - _textures = _Cache(callback=lambda key, value: value[0].discard()) + _textures = weakref.WeakKeyDictionary() """Cache already created textures""" _sizes = _Cache() @@ -159,15 +161,20 @@ class Text2D(object): self._rotate = numpy.radians(rotate) def _getTexture(self, text): - key = getGLContext(), text - - if key not in self._textures: + # Retrieve/initialize texture cache for current context + context = Context.getCurrent() + if context not in self._textures: + self._textures[context] = _Cache( + callback=lambda key, value: value[0].discard()) + textures = self._textures[context] + + if text not in textures: image, offset = font.rasterText(text, font.getDefaultFontFamily()) if text not in self._sizes: self._sizes[text] = image.shape[1], image.shape[0] - self._textures[key] = ( + textures[text] = ( Texture(gl.GL_RED, data=image, minFilter=gl.GL_NEAREST, @@ -176,7 +183,7 @@ class Text2D(object): gl.GL_CLAMP_TO_EDGE)), offset) - return self._textures[key] + return textures[text] @property def text(self): diff --git a/silx/gui/plot/backends/glutils/GLTexture.py b/silx/gui/plot/backends/glutils/GLTexture.py index 25dd9f1..118a36f 100644 --- a/silx/gui/plot/backends/glutils/GLTexture.py +++ b/silx/gui/plot/backends/glutils/GLTexture.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -163,7 +163,6 @@ class Image(object): data[yOrig:yOrig+hData, xOrig:xOrig+wData], format_, - shape=(hData, wData), texUnit=texUnit, minFilter=self._MIN_FILTER, magFilter=self._MAG_FILTER, diff --git a/silx/gui/plot/backends/glutils/__init__.py b/silx/gui/plot/backends/glutils/__init__.py index 771de39..d58c084 100644 --- a/silx/gui/plot/backends/glutils/__init__.py +++ b/silx/gui/plot/backends/glutils/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -39,6 +39,7 @@ _logger = logging.getLogger(__name__) from .GLPlotCurve import * # noqa from .GLPlotFrame import * # noqa from .GLPlotImage import * # noqa +from .GLPlotTriangles import GLPlotTriangles # noqa from .GLSupport import * # noqa from .GLText import * # noqa from .GLTexture import * # noqa diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py index f829f78..f3a36db 100644 --- a/silx/gui/plot/items/__init__.py +++ b/silx/gui/plot/items/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,14 +34,15 @@ __date__ = "22/06/2017" from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa - AlphaMixIn, LineMixIn, ItemChangedType) # noqa + AlphaMixIn, LineMixIn, ScatterVisualizationMixIn, # noqa + ComplexMixIn, ItemChangedType, PointsBase) # noqa from .complex import ImageComplexData # noqa from .curve import Curve, CurveStyle # noqa from .histogram import Histogram # noqa from .image import ImageBase, ImageData, ImageRgba, MaskImageData # noqa from .shape import Shape # noqa from .scatter import Scatter # noqa -from .marker import Marker, XMarker, YMarker # noqa +from .marker import MarkerBase, Marker, XMarker, YMarker # noqa from .axis import Axis, XAxis, YAxis, YRightAxis DATA_ITEMS = ImageComplexData, Curve, Histogram, ImageBase, Scatter diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py index 7fffd77..3869a05 100644 --- a/silx/gui/plot/items/complex.py +++ b/silx/gui/plot/items/complex.py @@ -33,12 +33,13 @@ __date__ = "14/06/2018" import logging -import enum import numpy +from ....utils.proxy import docstring +from ....utils.deprecation import deprecated from ...colors import Colormap -from .core import ColormapMixIn, ItemChangedType +from .core import ColormapMixIn, ComplexMixIn, ItemChangedType from .image import ImageBase @@ -105,29 +106,19 @@ def _complex2rgbalin(phaseColormap, data, gamma=1.0, smax=None): return rgba -class ImageComplexData(ImageBase, ColormapMixIn): +class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn): """Specific plot item to force colormap when using complex colormap. This is returning the specific colormap when displaying colored phase + amplitude. """ - class Mode(enum.Enum): - """Identify available display mode for complex""" - ABSOLUTE = 'absolute' - PHASE = 'phase' - REAL = 'real' - IMAGINARY = 'imaginary' - AMPLITUDE_PHASE = 'amplitude_phase' - LOG10_AMPLITUDE_PHASE = 'log10_amplitude_phase' - SQUARE_AMPLITUDE = 'square_amplitude' - def __init__(self): ImageBase.__init__(self) ColormapMixIn.__init__(self) + ComplexMixIn.__init__(self) self._data = numpy.zeros((0, 0), dtype=numpy.complex64) self._dataByModesCache = {} - self._mode = self.Mode.ABSOLUTE self._amplitudeRangeInfo = None, 2 # Use default from ColormapMixIn @@ -139,13 +130,13 @@ class ImageComplexData(ImageBase, ColormapMixIn): vmax=numpy.pi) self._colormaps = { # Default colormaps for all modes - self.Mode.ABSOLUTE: colormap, - self.Mode.PHASE: phaseColormap, - self.Mode.REAL: colormap, - self.Mode.IMAGINARY: colormap, - self.Mode.AMPLITUDE_PHASE: phaseColormap, - self.Mode.LOG10_AMPLITUDE_PHASE: phaseColormap, - self.Mode.SQUARE_AMPLITUDE: colormap, + self.ComplexMode.ABSOLUTE: colormap, + self.ComplexMode.PHASE: phaseColormap, + self.ComplexMode.REAL: colormap, + self.ComplexMode.IMAGINARY: colormap, + self.ComplexMode.AMPLITUDE_PHASE: phaseColormap, + self.ComplexMode.LOG10_AMPLITUDE_PHASE: phaseColormap, + self.ComplexMode.SQUARE_AMPLITUDE: colormap, } def _addBackendRenderer(self, backend): @@ -156,9 +147,9 @@ class ImageComplexData(ImageBase, ColormapMixIn): # Do not render with non linear scales return None - mode = self.getVisualizationMode() - if mode in (self.Mode.AMPLITUDE_PHASE, - self.Mode.LOG10_AMPLITUDE_PHASE): + mode = self.getComplexMode() + if mode in (self.ComplexMode.AMPLITUDE_PHASE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE): # For those modes, compute RGBA image here colormap = None data = self.getRgbaImageData(copy=False) @@ -179,33 +170,21 @@ class ImageComplexData(ImageBase, ColormapMixIn): colormap=colormap, alpha=self.getAlpha()) - def setVisualizationMode(self, mode): - """Set the visualization mode to use. - - :param Mode mode: - """ - assert isinstance(mode, self.Mode) - assert mode in self._colormaps - - if mode != self._mode: - self._mode = mode - + @docstring(ComplexMixIn) + def setComplexMode(self, mode): + changed = super(ImageComplexData, self).setComplexMode(mode) + if changed: + # Backward compatibility self._updated(ItemChangedType.VISUALIZATION_MODE) # Send data updated as value returned by getData has changed self._updated(ItemChangedType.DATA) # Update ColormapMixIn colormap - colormap = self._colormaps[self._mode] + colormap = self._colormaps[self.getComplexMode()] if colormap is not super(ImageComplexData, self).getColormap(): super(ImageComplexData, self).setColormap(colormap) - - def getVisualizationMode(self): - """Returns the visualization mode in use. - - :rtype: Mode - """ - return self._mode + return changed def _setAmplitudeRangeInfo(self, max_=None, delta=2): """Set the amplitude range to display for 'log10_amplitude_phase' mode. @@ -228,15 +207,17 @@ class ImageComplexData(ImageBase, ColormapMixIn): """Set the colormap for this specific mode. :param ~silx.gui.colors.Colormap colormap: The colormap - :param Mode mode: + :param Union[ComplexMode,str] mode: If specified, set the colormap of this specific mode. Default: current mode. """ if mode is None: - mode = self.getVisualizationMode() + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) self._colormaps[mode] = colormap - if mode is self.getVisualizationMode(): + if mode is self.getComplexMode(): super(ImageComplexData, self).setColormap(colormap) else: self._updated(ItemChangedType.COLORMAP) @@ -244,13 +225,15 @@ class ImageComplexData(ImageBase, ColormapMixIn): def getColormap(self, mode=None): """Get the colormap for the (current) mode. - :param Mode mode: + :param Union[ComplexMode,str] mode: If specified, get the colormap of this specific mode. Default: current mode. :rtype: ~silx.gui.colors.Colormap """ if mode is None: - mode = self.getVisualizationMode() + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) return self._colormaps[mode] @@ -296,28 +279,30 @@ class ImageComplexData(ImageBase, ColormapMixIn): :param bool copy: True (Default) to get a copy, False to use internal representation (do not modify!) - :param Mode mode: + :param Union[ComplexMode,str] mode: If specified, get data corresponding to the mode. Default: Current mode. :rtype: numpy.ndarray of float """ if mode is None: - mode = self.getVisualizationMode() + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) if mode not in self._dataByModesCache: # Compute data for mode and store it in cache complexData = self.getComplexData(copy=False) - if mode is self.Mode.PHASE: + if mode is self.ComplexMode.PHASE: data = numpy.angle(complexData) - elif mode is self.Mode.REAL: + elif mode is self.ComplexMode.REAL: data = numpy.real(complexData) - elif mode is self.Mode.IMAGINARY: + elif mode is self.ComplexMode.IMAGINARY: data = numpy.imag(complexData) - elif mode in (self.Mode.ABSOLUTE, - self.Mode.LOG10_AMPLITUDE_PHASE, - self.Mode.AMPLITUDE_PHASE): + elif mode in (self.ComplexMode.ABSOLUTE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE, + self.ComplexMode.AMPLITUDE_PHASE): data = numpy.absolute(complexData) - elif mode is self.Mode.SQUARE_AMPLITUDE: + elif mode is self.ComplexMode.SQUARE_AMPLITUDE: data = numpy.absolute(complexData) ** 2 else: _logger.error( @@ -333,22 +318,36 @@ class ImageComplexData(ImageBase, ColormapMixIn): """Get the displayed RGB(A) image for (current) mode :param bool copy: Ignored for this class - :param Mode mode: + :param Union[ComplexMode,str] mode: If specified, get data corresponding to the mode. Default: Current mode. :rtype: numpy.ndarray of uint8 of shape (height, width, 4) """ if mode is None: - mode = self.getVisualizationMode() + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) colormap = self.getColormap(mode=mode) - if mode is self.Mode.AMPLITUDE_PHASE: + if mode is self.ComplexMode.AMPLITUDE_PHASE: data = self.getComplexData(copy=False) return _complex2rgbalin(colormap, data) - elif mode is self.Mode.LOG10_AMPLITUDE_PHASE: + elif mode is self.ComplexMode.LOG10_AMPLITUDE_PHASE: data = self.getComplexData(copy=False) max_, delta = self._getAmplitudeRangeInfo() return _complex2rgbalog(colormap, data, dlogs=delta, smax=max_) else: data = self.getData(copy=False, mode=mode) return colormap.applyToData(data) + + # Backward compatibility + + Mode = ComplexMixIn.ComplexMode + + @deprecated(replacement='setComplexMode', since_version='0.11.0') + def setVisualizationMode(self, mode): + return self.setComplexMode(mode) + + @deprecated(replacement='getComplexMode', since_version='0.11.0') + def getVisualizationMode(self): + return self.getComplexMode() diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py index bf3b719..e7342b0 100644 --- a/silx/gui/plot/items/core.py +++ b/silx/gui/plot/items/core.py @@ -30,6 +30,10 @@ __license__ = "MIT" __date__ = "29/01/2019" import collections +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc from copy import deepcopy import logging import enum @@ -39,6 +43,7 @@ import weakref import numpy import six +from ....utils.enum import Enum as _Enum from ... import qt from ... import colors from ...colors import Colormap @@ -128,6 +133,9 @@ class ItemChangedType(enum.Enum): VISUALIZATION_MODE = 'visualizationModeChanged' """Item's visualization mode changed flag.""" + COMPLEX_MODE = 'complexModeChanged' + """Item's complex data visualization mode changed flag.""" + class Item(qt.QObject): """Description of an item of the plot""" @@ -404,6 +412,14 @@ class DraggableMixIn(ItemMixInBase): """ self._draggable = bool(draggable) + def drag(self, from_, to): + """Perform a drag of the item. + + :param List[float] from_: (x, y) previous position in data coordinates + :param List[float] to: (x, y) current position in data coordinates + """ + raise NotImplementedError("Must be implemented in subclass") + class ColormapMixIn(ItemMixInBase): """Mix-in class for items with colormap""" @@ -757,7 +773,164 @@ class AlphaMixIn(ItemMixInBase): self._updated(ItemChangedType.ALPHA) -class Points(Item, SymbolMixIn, AlphaMixIn): +class ComplexMixIn(ItemMixInBase): + """Mix-in class for complex data mode""" + + _SUPPORTED_COMPLEX_MODES = None + """Override to only support a subset of all ComplexMode""" + + class ComplexMode(_Enum): + """Identify available display mode for complex""" + ABSOLUTE = 'amplitude' + PHASE = 'phase' + REAL = 'real' + IMAGINARY = 'imaginary' + AMPLITUDE_PHASE = 'amplitude_phase' + LOG10_AMPLITUDE_PHASE = 'log10_amplitude_phase' + SQUARE_AMPLITUDE = 'square_amplitude' + + def __init__(self): + self.__complex_mode = self.ComplexMode.ABSOLUTE + + def getComplexMode(self): + """Returns the current complex visualization mode. + + :rtype: ComplexMode + """ + return self.__complex_mode + + def setComplexMode(self, mode): + """Set the complex visualization mode. + + :param ComplexMode mode: The visualization mode in: + 'real', 'imaginary', 'phase', 'amplitude' + :return: True if value was set, False if is was already set + :rtype: bool + """ + mode = self.ComplexMode.from_value(mode) + assert mode in self.supportedComplexModes() + + if mode != self.__complex_mode: + self.__complex_mode = mode + self._updated(ItemChangedType.COMPLEX_MODE) + return True + else: + return False + + def _convertComplexData(self, data, mode=None): + """Convert complex data to the specific mode. + + :param Union[ComplexMode,None] mode: + The kind of value to compute. + If None (the default), the current complex mode is used. + :return: The converted dataset + :rtype: Union[numpy.ndarray[float],None] + """ + if data is None: + return None + + if mode is None: + mode = self.getComplexMode() + + if mode is self.ComplexMode.REAL: + return numpy.real(data) + elif mode is self.ComplexMode.IMAGINARY: + return numpy.imag(data) + elif mode is self.ComplexMode.ABSOLUTE: + return numpy.absolute(data) + elif mode is self.ComplexMode.PHASE: + return numpy.angle(data) + elif mode is self.ComplexMode.SQUARE_AMPLITUDE: + return numpy.absolute(data) ** 2 + else: + raise ValueError('Unsupported conversion mode: %s', str(mode)) + + @classmethod + def supportedComplexModes(cls): + """Returns the list of supported complex visualization modes. + + See :class:`ComplexMode` and :meth:`setComplexMode`. + + :rtype: List[ComplexMode] + """ + if cls._SUPPORTED_COMPLEX_MODES is None: + return cls.ComplexMode.members() + else: + return cls._SUPPORTED_COMPLEX_MODES + + +class ScatterVisualizationMixIn(ItemMixInBase): + """Mix-in class for scatter plot visualization modes""" + + _SUPPORTED_SCATTER_VISUALIZATION = None + """Allows to override supported Visualizations""" + + @enum.unique + class Visualization(_Enum): + """Different modes of scatter plot visualizations""" + + POINTS = 'points' + """Display scatter plot as a point cloud""" + + LINES = 'lines' + """Display scatter plot as a wireframe. + + This is based on Delaunay triangulation + """ + + SOLID = 'solid' + """Display scatter plot as a set of filled triangles. + + This is based on Delaunay triangulation + """ + + def __init__(self): + self.__visualization = self.Visualization.POINTS + + @classmethod + def supportedVisualizations(cls): + """Returns the list of supported scatter visualization modes. + + See :meth:`setVisualization` + + :rtype: List[Visualization] + """ + if cls._SUPPORTED_SCATTER_VISUALIZATION is None: + return cls.Visualization.members() + else: + return cls._SUPPORTED_SCATTER_VISUALIZATION + + def setVisualization(self, mode): + """Set the scatter plot visualization mode to use. + + See :class:`Visualization` for all possible values, + and :meth:`supportedVisualizations` for supported ones. + + :param Union[str,Visualization] mode: + The visualization mode to use. + :return: True if value was set, False if is was already set + :rtype: bool + """ + mode = self.Visualization.from_value(mode) + assert mode in self.supportedVisualizations() + + if mode != self.__visualization: + self.__visualization = mode + + self._updated(ItemChangedType.VISUALIZATION_MODE) + return True + else: + return False + + def getVisualization(self): + """Returns the scatter plot visualization mode in use. + + :rtype: Visualization + """ + return self.__visualization + + +class PointsBase(Item, SymbolMixIn, AlphaMixIn): """Base class for :class:`Curve` and :class:`Scatter`""" # note: _logFilterData must be overloaded if you overload # getData to change its signature @@ -906,8 +1079,7 @@ class Points(Item, SymbolMixIn, AlphaMixIn): if (xPositive, yPositive) not in self._boundsCache: # use the getData class method because instance method can be # overloaded to return additional arrays - data = Points.getData(self, copy=False, - displayed=True) + data = PointsBase.getData(self, copy=False, displayed=True) if len(data) == 5: # hack to avoid duplicating caching mechanism in Scatter # (happens when cached data is used, caching done using @@ -916,12 +1088,15 @@ class Points(Item, SymbolMixIn, AlphaMixIn): else: x, y, _xerror, _yerror = data - self._boundsCache[(xPositive, yPositive)] = ( - numpy.nanmin(x), - numpy.nanmax(x), - numpy.nanmin(y), - numpy.nanmax(y) - ) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + self._boundsCache[(xPositive, yPositive)] = ( + numpy.nanmin(x), + numpy.nanmax(x), + numpy.nanmin(y), + numpy.nanmax(y) + ) return self._boundsCache[(xPositive, yPositive)] def _getCachedData(self): @@ -1026,12 +1201,12 @@ class Points(Item, SymbolMixIn, AlphaMixIn): assert x.ndim == y.ndim == 1 if xerror is not None: - if isinstance(xerror, collections.Iterable): + if isinstance(xerror, abc.Iterable): xerror = numpy.array(xerror, copy=copy) else: xerror = float(xerror) if yerror is not None: - if isinstance(yerror, collections.Iterable): + if isinstance(yerror, abc.Iterable): yerror = numpy.array(yerror, copy=copy) else: yerror = float(yerror) diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py index 79def55..439af33 100644 --- a/silx/gui/plot/items/curve.py +++ b/silx/gui/plot/items/curve.py @@ -37,7 +37,7 @@ import six from ....utils.deprecation import deprecated from ... import colors -from .core import (Points, LabelsMixIn, ColorMixIn, YAxisMixIn, +from .core import (PointsBase, LabelsMixIn, ColorMixIn, YAxisMixIn, FillMixIn, LineMixIn, SymbolMixIn, ItemChangedType) @@ -151,7 +151,7 @@ class CurveStyle(object): return False -class Curve(Points, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, LineMixIn): +class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, LineMixIn): """Description of a curve""" _DEFAULT_Z_LAYER = 1 @@ -170,7 +170,7 @@ class Curve(Points, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, LineMixIn): """Default highlight style of the item""" def __init__(self): - Points.__init__(self) + PointsBase.__init__(self) ColorMixIn.__init__(self) YAxisMixIn.__init__(self) FillMixIn.__init__(self) diff --git a/silx/gui/plot/items/image.py b/silx/gui/plot/items/image.py index 99a916a..d74f4d3 100644 --- a/silx/gui/plot/items/image.py +++ b/silx/gui/plot/items/image.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,11 +31,15 @@ __license__ = "MIT" __date__ = "20/10/2017" -from collections import Sequence +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc import logging import numpy +from ....utils.proxy import docstring from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, AlphaMixIn, ItemChangedType) @@ -170,6 +174,12 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): else: return xmin, xmax, ymin, ymax + @docstring(DraggableMixIn) + def drag(self, from_, to): + origin = self.getOrigin() + self.setOrigin((origin[0] + to[0] - from_[0], + origin[1] + to[1] - from_[1])) + def getData(self, copy=True): """Returns the image data @@ -199,7 +209,7 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): :param origin: (ox, oy) Offset from origin :type origin: float or 2-tuple of float """ - if isinstance(origin, Sequence): + if isinstance(origin, abc.Sequence): origin = float(origin[0]), float(origin[1]) else: # single value origin origin = float(origin), float(origin) @@ -227,7 +237,7 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): :param scale: (sx, sy) Scale of the image :type scale: float or 2-tuple of float """ - if isinstance(scale, Sequence): + if isinstance(scale, abc.Sequence): scale = float(scale[0]), float(scale[1]) else: # single value scale scale = float(scale), float(scale) @@ -252,6 +262,7 @@ class ImageData(ImageBase, ColormapMixIn): ColormapMixIn.__init__(self) self._data = numpy.zeros((0, 0), dtype=numpy.float32) self._alternativeImage = None + self.__alpha = None def _addBackendRenderer(self, backend): """Update backend renderer""" @@ -261,8 +272,9 @@ class ImageData(ImageBase, ColormapMixIn): # Do not render with non linear scales return None - if self.getAlternativeImageData(copy=False) is not None: - dataToUse = self.getAlternativeImageData(copy=False) + if (self.getAlternativeImageData(copy=False) is not None or + self.getAlphaData(copy=False) is not None): + dataToUse = self.getRgbaImageData(copy=False) else: dataToUse = self.getData(copy=False) @@ -293,37 +305,56 @@ class ImageData(ImageBase, ColormapMixIn): def getRgbaImageData(self, copy=True): """Get the displayed RGB(A) image - :returns: numpy.ndarray of uint8 of shape (height, width, 4) + :returns: Array of uint8 of shape (height, width, 4) + :rtype: numpy.ndarray """ - if self._alternativeImage is not None: - return _convertImageToRgba32( - self.getAlternativeImageData(copy=False), copy=copy) + alternative = self.getAlternativeImageData(copy=False) + if alternative is not None: + return _convertImageToRgba32(alternative, copy=copy) else: # Apply colormap, in this case an new array is always returned colormap = self.getColormap() image = colormap.applyToData(self.getData(copy=False)) + alphaImage = self.getAlphaData(copy=False) + if alphaImage is not None: + # Apply transparency + image[:, :, 3] = image[:, :, 3] * alphaImage return image def getAlternativeImageData(self, copy=True): """Get the optional RGBA image that is displayed instead of the data - :param copy: True (Default) to get a copy, - False to use internal representation (do not modify!) - :returns: None or numpy.ndarray - :rtype: numpy.ndarray or None + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: Union[None,numpy.ndarray] """ if self._alternativeImage is None: return None else: return numpy.array(self._alternativeImage, copy=copy) - def setData(self, data, alternative=None, copy=True): + def getAlphaData(self, copy=True): + """Get the optional transparency image applied on the data + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: Union[None,numpy.ndarray] + """ + if self.__alpha is None: + return None + else: + return numpy.array(self.__alpha, copy=copy) + + def setData(self, data, alternative=None, alpha=None, copy=True): """"Set the image data and optionally an alternative RGB(A) representation :param numpy.ndarray data: Data array with 2 dimensions (h, w) :param alternative: RGB(A) image to display instead of data, shape: (h, w, 3 or 4) - :type alternative: None or numpy.ndarray + :type alternative: Union[None,numpy.ndarray] + :param alpha: An array of transparency value in [0, 1] to use for + display with shape: (h, w) + :type alpha: Union[None,numpy.ndarray] :param bool copy: True (Default) to get a copy, False to use internal representation (do not modify!) """ @@ -346,6 +377,15 @@ class ImageData(ImageBase, ColormapMixIn): assert alternative.shape[:2] == data.shape[:2] self._alternativeImage = alternative + if alpha is not None: + alpha = numpy.array(alpha, copy=copy) + assert alpha.shape == data.shape + if alpha.dtype.kind != 'f': + alpha = alpha.astype(numpy.float32) + if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)): + alpha = numpy.clip(alpha, 0., 1.) + self.__alpha = alpha + # TODO hackish data range implementation if self.isVisible(): plot = self.getPlot() diff --git a/silx/gui/plot/items/marker.py b/silx/gui/plot/items/marker.py index 09767a5..80ca0b6 100644 --- a/silx/gui/plot/items/marker.py +++ b/silx/gui/plot/items/marker.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,6 +32,7 @@ __date__ = "06/03/2017" import logging +from ....utils.proxy import docstring from .core import (Item, DraggableMixIn, ColorMixIn, LineMixIn, SymbolMixIn, ItemChangedType) @@ -39,7 +40,7 @@ from .core import (Item, DraggableMixIn, ColorMixIn, LineMixIn, SymbolMixIn, _logger = logging.getLogger(__name__) -class _BaseMarker(Item, DraggableMixIn, ColorMixIn): +class MarkerBase(Item, DraggableMixIn, ColorMixIn): """Base class for markers""" _DEFAULT_COLOR = (0., 0., 0., 1.) @@ -75,6 +76,10 @@ class _BaseMarker(Item, DraggableMixIn, ColorMixIn): """Update backend renderer""" raise NotImplementedError() + @docstring(DraggableMixIn) + def drag(self, from_, to): + self.setPosition(to[0], to[1]) + def isOverlay(self): """Return true if marker is drawn as an overlay. @@ -166,14 +171,14 @@ class _BaseMarker(Item, DraggableMixIn, ColorMixIn): return args -class Marker(_BaseMarker, SymbolMixIn): +class Marker(MarkerBase, SymbolMixIn): """Description of a marker""" _DEFAULT_SYMBOL = '+' """Default symbol of the marker""" def __init__(self): - _BaseMarker.__init__(self) + MarkerBase.__init__(self) SymbolMixIn.__init__(self) self._x = 0. @@ -204,11 +209,11 @@ class Marker(_BaseMarker, SymbolMixIn): return x, self.getYPosition() -class _LineMarker(_BaseMarker, LineMixIn): +class _LineMarker(MarkerBase, LineMixIn): """Base class for line markers""" def __init__(self): - _BaseMarker.__init__(self) + MarkerBase.__init__(self) LineMixIn.__init__(self) def _addBackendRenderer(self, backend): diff --git a/silx/gui/plot/items/roi.py b/silx/gui/plot/items/roi.py index 0169439..65831be 100644 --- a/silx/gui/plot/items/roi.py +++ b/silx/gui/plot/items/roi.py @@ -73,6 +73,7 @@ class RegionOfInterest(qt.QObject): self._label = '' self._labelItem = None self._editable = False + self._visible = True def __del__(self): # Clean-up plot items @@ -176,6 +177,34 @@ class RegionOfInterest(qt.QObject): # This can be avoided once marker.setDraggable is public self._createPlotItems() + def isVisible(self): + """Returns whether the ROI is visible in the plot. + + .. note:: + This does not take into account whether or not the plot + widget itself is visible (unlike :meth:`QWidget.isVisible` which + checks the visibility of all its parent widgets up to the window) + + :rtype: bool + """ + return self._visible + + def setVisible(self, visible): + """Set whether the plot items associated with this ROI are + visible in the plot. + + :param bool visible: True to show the ROI in the plot, False to + hide it. + """ + visible = bool(visible) + if self._visible == visible: + return + self._visible = visible + if self._labelItem is not None: + self._labelItem.setVisible(visible) + for item in self._items + self._editAnchors: + item.setVisible(visible) + def _getControlPoints(self): """Returns the current ROI control points. @@ -292,12 +321,14 @@ class RegionOfInterest(qt.QObject): if self._labelItem is not None: self._labelItem._setLegend(legendPrefix + "label") plot._add(self._labelItem) + self._labelItem.setVisible(self.isVisible()) self._items = WeakList() plotItems = self._createShapeItems(controlPoints) for item in plotItems: item._setLegend(legendPrefix + str(itemIndex)) plot._add(item) + item.setVisible(self.isVisible()) self._items.append(item) itemIndex += 1 @@ -309,6 +340,7 @@ class RegionOfInterest(qt.QObject): for index, item in enumerate(plotItems): item._setLegend(legendPrefix + str(itemIndex)) item.setColor(color) + item.setVisible(self.isVisible()) plot._add(item) item.sigItemChanged.connect(functools.partial( self._controlPointAnchorChanged, index)) @@ -512,10 +544,10 @@ class LineROI(RegionOfInterest, items.LineMixIn): return controlPoints def setEndPoints(self, startPoint, endPoint): - """Set this line location using the endding points + """Set this line location using the ending points :param numpy.ndarray startPoint: Staring bounding point of the line - :param numpy.ndarray endPoint: Endding bounding point of the line + :param numpy.ndarray endPoint: Ending bounding point of the line """ assert(startPoint.shape == (2,) and endPoint.shape == (2,)) shapePoints = numpy.array([startPoint, endPoint]) @@ -1261,13 +1293,13 @@ class ArcROI(RegionOfInterest, items.LineMixIn): def getGeometry(self): """Returns a tuple containing the geometry of this ROI - It is a symetric fonction of :meth:`setGeometry`. + It is a symmetric function of :meth:`setGeometry`. If `startAngle` is smaller than `endAngle` the rotation is clockwise, else the rotation is anticlockwise. :rtype: Tuple[numpy.ndarray,float,float,float,float] - :raise ValueError: In case the ROI can't be representaed as section of + :raise ValueError: In case the ROI can't be represented as section of a circle """ geometry = self._getInternalGeometry() diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py index 707dd3d..b2f087b 100644 --- a/silx/gui/plot/items/scatter.py +++ b/silx/gui/plot/items/scatter.py @@ -31,26 +31,79 @@ __date__ = "29/03/2017" import logging - +import threading import numpy -from .core import Points, ColormapMixIn +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, CancelledError + +from ....utils.weakref import WeakList +from .._utils.delaunay import delaunay +from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn +from .axis import Axis _logger = logging.getLogger(__name__) -class Scatter(Points, ColormapMixIn): +class _GreedyThreadPoolExecutor(ThreadPoolExecutor): + """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method. + """ + + def __init__(self, *args, **kwargs): + super(_GreedyThreadPoolExecutor, self).__init__(*args, **kwargs) + self.__futures = defaultdict(WeakList) + self.__lock = threading.RLock() + + def submit_greedy(self, queue, fn, *args, **kwargs): + """Same as :meth:`submit` but cancel previous tasks in given queue. + + This means that when a new task is submitted for a given queue, + all other pending tasks of that queue are cancelled. + + :param queue: Identifier of the queue. This must be hashable. + :param callable fn: The callable to call with provided extra arguments + :return: Future corresponding to this task + :rtype: concurrent.futures.Future + """ + with self.__lock: + # Cancel previous tasks in given queue + for future in self.__futures.pop(queue, []): + if not future.done(): + future.cancel() + + future = super(_GreedyThreadPoolExecutor, self).submit( + fn, *args, **kwargs) + self.__futures[queue].append(future) + + return future + + +class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): """Description of a scatter""" _DEFAULT_SELECTABLE = True """Default selectable state for scatter plots""" + _SUPPORTED_SCATTER_VISUALIZATION = ( + ScatterVisualizationMixIn.Visualization.POINTS, + ScatterVisualizationMixIn.Visualization.SOLID) + """Overrides supported Visualizations""" + def __init__(self): - Points.__init__(self) + PointsBase.__init__(self) ColormapMixIn.__init__(self) + ScatterVisualizationMixIn.__init__(self) self._value = () self.__alpha = None + # Cache Delaunay triangulation future object + self.__delaunayFuture = None + # Cache interpolator future object + self.__interpolatorFuture = None + self.__executor = None + + # Cache triangles: x, y, indices + self.__cacheTriangles = None, None, None def _addBackendRenderer(self, backend): """Update backend renderer""" @@ -58,28 +111,154 @@ class Scatter(Points, ColormapMixIn): xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData( copy=False, displayed=True) + # Remove not finite numbers (this includes filtered out x, y <= 0) + mask = numpy.logical_and(numpy.isfinite(xFiltered), numpy.isfinite(yFiltered)) + xFiltered = xFiltered[mask] + yFiltered = yFiltered[mask] + if len(xFiltered) == 0: return None # No data to display, do not add renderer to backend + # Compute colors cmap = self.getColormap() rgbacolors = cmap.applyToData(self._value) if self.__alpha is not None: rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8) - return backend.addCurve(xFiltered, yFiltered, self.getLegend(), - color=rgbacolors, - symbol=self.getSymbol(), - linewidth=0, - linestyle="", - yaxis='left', - xerror=xerror, - yerror=yerror, - z=self.getZValue(), - selectable=self.isSelectable(), - fill=False, - alpha=self.getAlpha(), - symbolsize=self.getSymbolSize()) + # Apply mask to colors + rgbacolors = rgbacolors[mask] + + if self.getVisualization() is self.Visualization.POINTS: + return backend.addCurve(xFiltered, yFiltered, self.getLegend(), + color=rgbacolors, + symbol=self.getSymbol(), + linewidth=0, + linestyle="", + yaxis='left', + xerror=xerror, + yerror=yerror, + z=self.getZValue(), + selectable=self.isSelectable(), + fill=False, + alpha=self.getAlpha(), + symbolsize=self.getSymbolSize()) + + else: # 'solid' + plot = self.getPlot() + if (plot is None or + plot.getXAxis().getScale() != Axis.LINEAR or + plot.getYAxis().getScale() != Axis.LINEAR): + # Solid visualization is not available with log scaled axes + return None + + triangulation = self._getDelaunay().result() + if triangulation is None: + return None + else: + triangles = triangulation.simplices.astype(numpy.int32) + return backend.addTriangles(xFiltered, + yFiltered, + triangles, + legend=self.getLegend(), + color=rgbacolors, + z=self.getZValue(), + selectable=self.isSelectable(), + alpha=self.getAlpha()) + + def __getExecutor(self): + """Returns async greedy executor + + :rtype: _GreedyThreadPoolExecutor + """ + if self.__executor is None: + self.__executor = _GreedyThreadPoolExecutor(max_workers=2) + return self.__executor + + def _getDelaunay(self): + """Returns a :class:`Future` which result is the Delaunay object. + + :rtype: concurrent.futures.Future + """ + if self.__delaunayFuture is None or self.__delaunayFuture.cancelled(): + # Need to init a new delaunay + x, y = self.getData(copy=False)[:2] + # Remove not finite points + mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y)) + + self.__delaunayFuture = self.__getExecutor().submit_greedy( + 'delaunay', delaunay, x[mask], y[mask]) + + return self.__delaunayFuture + + @staticmethod + def __initInterpolator(delaunayFuture, values): + """Returns an interpolator for the given data points + + :param concurrent.futures.Future delaunayFuture: + Future object which result is a Delaunay object + :param numpy.ndarray values: The data value of valid points. + :rtype: Union[callable,None] + """ + # Wait for Delaunay to complete + try: + triangulation = delaunayFuture.result() + except CancelledError: + triangulation = None + + if triangulation is None: + interpolator = None # Error case + else: + # Lazy-loading of interpolator + try: + from scipy.interpolate import LinearNDInterpolator + except ImportError: + LinearNDInterpolator = None + + if LinearNDInterpolator is not None: + interpolator = LinearNDInterpolator(triangulation, values) + + # First call takes a while, do it here + interpolator([(0., 0.)]) + + else: + # Fallback using matplotlib interpolator + import matplotlib.tri + + x, y = triangulation.points.T + tri = matplotlib.tri.Triangulation( + x, y, triangles=triangulation.simplices) + mplInterpolator = matplotlib.tri.LinearTriInterpolator( + tri, values) + + # Wrap interpolator to have same API as scipy's one + def interpolator(points): + return mplInterpolator(*points.T) + + return interpolator + + def _getInterpolator(self): + """Returns a :class:`Future` which result is the interpolator. + + The interpolator is a callable taking an array Nx2 of points + as a single argument. + The :class:`Future` result is None in case the interpolator cannot + be initialized. + + :rtype: concurrent.futures.Future + """ + if (self.__interpolatorFuture is None or + self.__interpolatorFuture.cancelled()): + # Need to init a new interpolator + x, y, values = self.getData(copy=False)[:3] + # Remove not finite points + mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y)) + x, y, values = x[mask], y[mask], values[mask] + + self.__interpolatorFuture = self.__getExecutor().submit_greedy( + 'interpolator', + self.__initInterpolator, self._getDelaunay(), values) + return self.__interpolatorFuture def _logFilterData(self, xPositive, yPositive): """Filter out values with x or y <= 0 on log axes @@ -89,7 +268,7 @@ class Scatter(Points, ColormapMixIn): :return: The filtered arrays or unchanged object if not filtering needed :rtype: (x, y, value, xerror, yerror) """ - # overloaded from Points to filter also value. + # overloaded from PointsBase to filter also value. value = self.getValueData(copy=False) if xPositive or yPositive: @@ -100,7 +279,7 @@ class Scatter(Points, ColormapMixIn): value = numpy.array(value, copy=True, dtype=numpy.float) value[clipped] = numpy.nan - x, y, xerror, yerror = Points._logFilterData(self, xPositive, yPositive) + x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive) return x, y, value, xerror, yerror @@ -146,7 +325,7 @@ class Scatter(Points, ColormapMixIn): self.getXErrorData(copy), self.getYErrorData(copy)) - # reimplemented from Points to handle `value` + # reimplemented from PointsBase to handle `value` def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True): """Set the data of the scatter. @@ -171,6 +350,14 @@ class Scatter(Points, ColormapMixIn): assert value.ndim == 1 assert len(x) == len(value) + # Reset triangulation and interpolator + if self.__delaunayFuture is not None: + self.__delaunayFuture.cancel() + self.__delaunayFuture = None + if self.__interpolatorFuture is not None: + self.__interpolatorFuture.cancel() + self.__interpolatorFuture = None + self._value = value if alpha is not None: @@ -183,8 +370,8 @@ class Scatter(Points, ColormapMixIn): if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)): alpha = numpy.clip(alpha, 0., 1.) self.__alpha = alpha - + # set x, y, xerror, yerror # call self._updated + plot._invalidateDataRange() - Points.setData(self, x, y, xerror, yerror, copy) + PointsBase.setData(self, x, y, xerror, yerror, copy) diff --git a/silx/gui/plot/matplotlib/__init__.py b/silx/gui/plot/matplotlib/__init__.py index a4dc235..7298866 100644 --- a/silx/gui/plot/matplotlib/__init__.py +++ b/silx/gui/plot/matplotlib/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -25,7 +25,7 @@ from __future__ import absolute_import -"""This module inits matplotlib and setups the backend to use. +"""This module initializes matplotlib and sets-up the backend to use. It MUST be imported prior to any other import of matplotlib. @@ -38,64 +38,34 @@ __license__ = "MIT" __date__ = "02/05/2018" -import sys -import logging +from pkg_resources import parse_version +import matplotlib +from ... import qt -_logger = logging.getLogger(__name__) -_matplotlib_already_loaded = 'matplotlib' in sys.modules -"""If true, matplotlib was already loaded""" +def _matplotlib_use(backend, warn, force): + """Wrapper of `matplotlib.use` to set-up backend. -import matplotlib -from ... import qt + It adds extra initialization for PySide and PySide2 with matplotlib < 2.2. + """ + # This is kept for compatibility with matplotlib < 2.2 + if parse_version(matplotlib.__version__) < parse_version('2.2'): + if qt.BINDING == 'PySide': + matplotlib.rcParams['backend.qt4'] = 'PySide' + if qt.BINDING == 'PySide2': + matplotlib.rcParams['backend.qt5'] = 'PySide2' + matplotlib.use(backend, warn=warn, force=force) -def _configure(backend, backend_qt4=None, backend_qt5=None, check=False): - """Configure matplotlib using a specific backend. - It initialize `matplotlib.rcParams` using the requested backend, or check - if it is already configured as requested. +if qt.BINDING in ('PyQt4', 'PySide'): + _matplotlib_use('Qt4Agg', warn=True, force=False) + from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa - :param bool check: If true, the function only check that matplotlib - is already initialized as request. If not a warning is emitted. - If `check` is false, matplotlib is initialized. - """ - if check: - valid = matplotlib.rcParams['backend'] == backend - if backend_qt4 is not None: - valid = valid and matplotlib.rcParams['backend.qt4'] == backend_qt4 - if backend_qt5 is not None: - valid = valid and matplotlib.rcParams['backend.qt5'] == backend_qt5 - - if not valid: - _logger.warning('matplotlib already loaded, setting its backend may not work') - else: - matplotlib.rcParams['backend'] = backend - if backend_qt4 is not None: - matplotlib.rcParams['backend.qt4'] = backend_qt4 - if backend_qt5 is not None: - matplotlib.rcParams['backend.qt5'] = backend_qt5 - - -if qt.BINDING == 'PySide': - _configure('Qt4Agg', backend_qt4='PySide', check=_matplotlib_already_loaded) - import matplotlib.backends.backend_qt4agg as backend - -elif qt.BINDING == 'PyQt4': - _configure('Qt4Agg', check=_matplotlib_already_loaded) - import matplotlib.backends.backend_qt4agg as backend - -elif qt.BINDING == 'PySide2': - _configure('Qt5Agg', backend_qt5="PySide2", check=_matplotlib_already_loaded) - import matplotlib.backends.backend_qt5agg as backend - -elif qt.BINDING == 'PyQt5': - _configure('Qt5Agg', check=_matplotlib_already_loaded) - import matplotlib.backends.backend_qt5agg as backend +elif qt.BINDING in ('PyQt5', 'PySide2'): + _matplotlib_use('Qt5Agg', warn=True, force=False) + from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa else: - backend = None - -if backend is not None: - FigureCanvasQTAgg = backend.FigureCanvasQTAgg # noqa + raise ImportError("Unsupported Qt binding: %s" % qt.BINDING) diff --git a/silx/gui/plot/test/testAlphaSlider.py b/silx/gui/plot/test/testAlphaSlider.py index 63de441..01e6969 100644 --- a/silx/gui/plot/test/testAlphaSlider.py +++ b/silx/gui/plot/test/testAlphaSlider.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -37,9 +37,6 @@ from silx.gui.utils.testutils import TestCaseQt from silx.gui.plot import PlotWidget from silx.gui.plot import AlphaSlider -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - class TestActiveImageAlphaSlider(TestCaseQt): def setUp(self): diff --git a/silx/gui/plot/test/testComplexImageView.py b/silx/gui/plot/test/testComplexImageView.py index 1933a95..051ec4d 100644 --- a/silx/gui/plot/test/testComplexImageView.py +++ b/silx/gui/plot/test/testComplexImageView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -63,10 +63,10 @@ class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase): self.qWait(100) # Test all modes - modes = self.plot.getSupportedVisualizationModes() + modes = self.plot.supportedComplexModes() for mode in modes: with self.subTest(mode=mode): - self.plot.setVisualizationMode(mode) + self.plot.setComplexMode(mode) self.qWait(100) # Test origin and scale API diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py index 5bcabd8..5886456 100644 --- a/silx/gui/plot/test/testCurvesROIWidget.py +++ b/silx/gui/plot/test/testCurvesROIWidget.py @@ -34,11 +34,15 @@ import os.path import unittest from collections import OrderedDict import numpy + from silx.gui import qt +from silx.gui.plot import Plot1D from silx.test.utils import temp_dir from silx.gui.utils.testutils import TestCaseQt, SignalListener from silx.gui.plot import PlotWindow, CurvesROIWidget - +from silx.gui.plot.CurvesROIWidget import ROITable +from silx.gui.utils.testutils import getQToolButtonFromAction +from silx.gui.plot.PlotInteraction import ItemsInteraction _logger = logging.getLogger(__name__) @@ -68,6 +72,18 @@ class TestCurvesROIWidget(TestCaseQt): super(TestCurvesROIWidget, self).tearDown() + def testDummyAPI(self): + """Simple test of the getRois and setRois API""" + roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20, + todata=-10, type_='X') + roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10, + todata=20, type_='X') + + self.widget.roiWidget.setRois((roi_pos, roi_neg)) + + rois_defs = self.widget.roiWidget.getRois() + self.widget.roiWidget.setRois(rois=rois_defs) + def testWithCurves(self): """Plot with curves: test all ROI widget buttons""" for offset in range(2): @@ -301,7 +317,7 @@ class TestCurvesROIWidget(TestCaseQt): self.widget.roiWidget.setRois((roi,)) self.widget.roiWidget.roiTable.setActiveRoi(None) - self.assertTrue(len(self.widget.roiWidget.roiTable.selectedItems()) is 0) + self.assertEqual(len(self.widget.roiWidget.roiTable.selectedItems()), 0) self.widget.roiWidget.setRois((roi,)) self.plot.setActiveCurve(legend='linearCurve') self.widget.calculateROIs() @@ -314,14 +330,128 @@ class TestCurvesROIWidget(TestCaseQt): self.widget.roiWidget.sigROISignal.connect(signalListener.partial()) self.widget.show() self.qapp.processEvents() - self.assertTrue(signalListener.callCount() is 0) + self.assertEqual(signalListener.callCount(), 0) self.assertTrue(self.widget.roiWidget.roiTable.activeRoi is roi) roi.setFrom(0.0) self.qapp.processEvents() - self.assertTrue(signalListener.callCount() is 0) + self.assertEqual(signalListener.callCount(), 0) roi.setFrom(0.3) self.qapp.processEvents() - self.assertTrue(signalListener.callCount() is 1) + self.assertEqual(signalListener.callCount(), 1) + + +class TestRoiWidgetSignals(TestCaseQt): + """Test Signals emitted by the RoiWidgetSignals""" + + def setUp(self): + self.plot = Plot1D() + x = range(20) + y = range(20) + self.plot.addCurve(x, y, legend='curve0') + self.listener = SignalListener() + self.curves_roi_widget = self.plot.getCurvesRoiWidget() + self.curves_roi_widget.sigROISignal.connect(self.listener) + assert self.curves_roi_widget.isVisible() is False + assert self.listener.callCount() == 0 + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + toolButton = getQToolButtonFromAction(self.plot.getRoiAction()) + self.mouseClick(widget=toolButton, button=qt.Qt.LeftButton) + + self.curves_roi_widget.show() + self.qWaitForWindowExposed(self.curves_roi_widget) + + def tearDown(self): + self.plot = None + + def testSigROISignalAddRmRois(self): + """Test SigROISignal when adding and removing ROIS""" + print(self.listener.callCount()) + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + + roi1 = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi1) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') + self.listener.clear() + + roi2 = CurvesROIWidget.ROI(name='linear2', fromdata=0, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi2) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear2') + self.listener.clear() + + self.curves_roi_widget.roiTable.removeROI(roi2) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') + self.listener.clear() + + self.curves_roi_widget.roiTable.deleteActiveRoi() + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.curves_roi_widget.roiTable.activeRoi is None) + self.assertTrue(self.listener.arguments()[0][0]['current'] is None) + self.listener.clear() + + self.curves_roi_widget.roiTable.addRoi(roi1) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') + self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1) + self.listener.clear() + self.qapp.processEvents() + + self.curves_roi_widget.roiTable.removeROI(roi1) + self.qapp.processEvents() + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'ICR') + self.listener.clear() + + def testSigROISignalModifyROI(self): + """Test SigROISignal when modifying it""" + self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True) + roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi1) + self.curves_roi_widget.roiTable.setActiveRoi(roi1) + + # test modify the roi2 object + self.listener.clear() + roi1.setFrom(0.56) + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + roi1.setTo(2.56) + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + roi1.setName('linear2') + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + roi1.setType('new type') + self.assertEqual(self.listener.callCount(), 1) + + # modify roi limits (from the gui) + roi_marker_handler = self.curves_roi_widget.roiTable._markersHandler.getMarkerHandler(roi1.getID()) + for marker_type in ('min', 'max', 'middle'): + with self.subTest(marker_type=marker_type): + self.listener.clear() + marker = roi_marker_handler.getMarker(marker_type) + self.qapp.processEvents() + items_interaction = ItemsInteraction(plot=self.plot) + x_pix, y_pix = self.plot.dataToPixel(marker.getXPosition(), 1) + items_interaction.beginDrag(x_pix, y_pix) + self.qapp.processEvents() + items_interaction.endDrag(x_pix+10, y_pix) + self.qapp.processEvents() + self.assertEqual(self.listener.callCount(), 1) + + def testSetActiveCurve(self): + """Test sigRoiSignal when set an active curve""" + roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi1) + self.curves_roi_widget.roiTable.setActiveRoi(roi1) + self.listener.clear() + self.plot.setActiveCurve('curve0') + self.assertEqual(self.listener.callCount(), 0) def suite(): diff --git a/silx/gui/plot/test/testItem.py b/silx/gui/plot/test/testItem.py index 993cce7..c864545 100644 --- a/silx/gui/plot/test/testItem.py +++ b/silx/gui/plot/test/testItem.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -54,7 +54,7 @@ class TestSigItemChangedSignal(PlotWidgetTestCase): curve.setVisible(True) curve.setZValue(100) - # Test for signals in Points class + # Test for signals in PointsBase class curve.setData(numpy.arange(100), numpy.arange(100)) # SymbolMixIn @@ -194,14 +194,17 @@ class TestSigItemChangedSignal(PlotWidgetTestCase): # ColormapMixIn scatter.getColormap().setName('viridis') - data2 = data + 10 # Test of signals in Scatter class - scatter.setData(data2, data2, data2) + scatter.setData((0, 1, 2), (1, 0, 2), (0, 1, 2)) + + # Visualization mode changed + scatter.setVisualization(scatter.Visualization.SOLID) self.assertEqual(listener.arguments(), [(ItemChangedType.COLORMAP,), - (ItemChangedType.DATA,)]) + (ItemChangedType.DATA,), + (ItemChangedType.VISUALIZATION_MODE,)]) def testShapeChanged(self): """Test sigItemChanged for shape""" diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py index 9d7c093..7449c12 100644 --- a/silx/gui/plot/test/testPlotWidget.py +++ b/silx/gui/plot/test/testPlotWidget.py @@ -386,6 +386,16 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): self.assertTrue(numpy.all(numpy.equal(retrievedData, data))) self.assertIs(retrievedData.dtype.type, numpy.int8) + def testPlotAlphaImage(self): + """Test with an alpha image layer""" + data = numpy.random.random((10, 10)) + alpha = numpy.linspace(0, 1, 100).reshape(10, 10) + self.plot.addImage(data, legend='image') + image = self.plot.getActiveImage() + image.setData(data, alpha=alpha) + self.qapp.processEvents() + self.assertTrue(numpy.array_equal(alpha, image.getAlphaData())) + class TestPlotCurve(PlotWidgetTestCase): """Basic tests for addCurve.""" @@ -463,7 +473,34 @@ class TestPlotCurve(PlotWidgetTestCase): self.plot.addCurve(self.xData, self.yData, legend="curve 2", replace=False, resetzoom=False, - color=color, symbol='o') + color=color, symbol='o') + + +class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase): + """Basic tests for addScatter""" + + def testScatter(self): + x = numpy.arange(100) + y = numpy.arange(100) + value = numpy.arange(100) + self.plot.addScatter(x, y, value) + self.plot.resetZoom() + + def testScatterVisualization(self): + self.plot.addScatter((0, 1, 2, 3), (2, 0, 2, 1), (0, 1, 2, 3)) + self.plot.resetZoom() + self.qapp.processEvents() + + scatter = self.plot.getItems()[0] + + for visualization in ('solid', + 'points', + scatter.Visualization.SOLID, + scatter.Visualization.POINTS): + with self.subTest(visualization=visualization): + scatter.setVisualization(visualization) + self.qapp.processEvents() + class TestPlotMarker(PlotWidgetTestCase): """Basic tests for add*Marker""" @@ -1524,11 +1561,19 @@ class TestPlotItemLog(PlotWidgetTestCase): def suite(): - testClasses = (TestPlotWidget, TestPlotImage, TestPlotCurve, - TestPlotMarker, TestPlotItem, TestPlotAxes, + testClasses = (TestPlotWidget, + TestPlotImage, + TestPlotCurve, + TestPlotScatter, + TestPlotMarker, + TestPlotItem, + TestPlotAxes, TestPlotActiveCurveImage, - TestPlotEmptyLog, TestPlotCurveLog, TestPlotImageLog, - TestPlotMarkerLog, TestPlotItemLog) + TestPlotEmptyLog, + TestPlotCurveLog, + TestPlotImageLog, + TestPlotMarkerLog, + TestPlotItemLog) test_suite = unittest.TestSuite() diff --git a/silx/gui/plot/test/testPlotWindow.py b/silx/gui/plot/test/testPlotWindow.py index 6d3eb8f..0a7d108 100644 --- a/silx/gui/plot/test/testPlotWindow.py +++ b/silx/gui/plot/test/testPlotWindow.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -38,32 +38,6 @@ from silx.gui import qt from silx.gui.plot import PlotWindow -# Test of the docstrings # - -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - - -def _tearDownQt(docTest): - """Tear down to use for test from docstring. - - Checks that plt widget is displayed - """ - _qapp.processEvents() - for obj in docTest.globs.values(): - if isinstance(obj, PlotWindow): - # Commented out as it takes too long - # qWaitForWindowExposedAndActivate(obj) - obj.setAttribute(qt.Qt.WA_DeleteOnClose) - obj.close() - del obj - - -plotWindowDocTestSuite = doctest.DocTestSuite('silx.gui.plot.PlotWindow', - tearDown=_tearDownQt) -"""Test suite of tests from the module's docstrings.""" - - class TestPlotWindow(TestCaseQt): """Base class for tests of PlotWindow.""" @@ -128,7 +102,6 @@ class TestPlotWindow(TestCaseQt): def suite(): test_suite = unittest.TestSuite() - test_suite.addTest(plotWindowDocTestSuite) test_suite.addTest( unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotWindow)) return test_suite diff --git a/silx/gui/plot/test/testProfile.py b/silx/gui/plot/test/testProfile.py index 847f404..cf40f76 100644 --- a/silx/gui/plot/test/testProfile.py +++ b/silx/gui/plot/test/testProfile.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -39,10 +39,6 @@ from silx.gui.plot import PlotWindow, Plot1D, Plot2D, Profile from silx.gui.plot.StackView import StackView -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - - class TestProfileToolBar(TestCaseQt, ParametricTestCase): """Tests for ProfileToolBar widget.""" diff --git a/silx/gui/plot/test/testStackView.py b/silx/gui/plot/test/testStackView.py index a5f649c..80c85d6 100644 --- a/silx/gui/plot/test/testStackView.py +++ b/silx/gui/plot/test/testStackView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -41,10 +41,6 @@ from silx.gui.plot.StackView import StackViewMainWindow from silx.utils.array_like import ListOfImages -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - - class TestStackView(TestCaseQt): """Base class for tests of StackView.""" diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py index 7fbc247..4bc2144 100644 --- a/silx/gui/plot/test/testStats.py +++ b/silx/gui/plot/test/testStats.py @@ -33,8 +33,9 @@ from silx.gui import qt from silx.gui.plot.stats import stats from silx.gui.plot import StatsWidget from silx.gui.plot.stats import statshandler -from silx.gui.utils.testutils import TestCaseQt +from silx.gui.utils.testutils import TestCaseQt, SignalListener from silx.gui.plot import Plot1D, Plot2D +from silx.utils.testutils import ParametricTestCase import unittest import logging import numpy @@ -350,7 +351,7 @@ class TestStatsHandler(unittest.TestCase): statshandler.StatsHandler(('name')) -class TestStatsWidgetWithCurves(TestCaseQt): +class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): """Basic test for StatsWidget with curves""" def setUp(self): TestCaseQt.setUp(self) @@ -363,7 +364,8 @@ class TestStatsWidgetWithCurves(TestCaseQt): self.plot.addCurve(x, y, legend='curve1') y = range(-2, 18) self.plot.addCurve(x, y, legend='curve2') - self.widget = StatsWidget.StatsTable(plot=self.plot) + self.widget = StatsWidget.StatsWidget(plot=self.plot) + self.statsTable = self.widget._statsTable mystats = statshandler.StatsHandler(( stats.StatMin(), @@ -376,67 +378,170 @@ class TestStatsWidgetWithCurves(TestCaseQt): stats.StatCOM() )) - self.widget.setStats(mystats) + self.statsTable.setStats(mystats) def tearDown(self): self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() + self.statsTable = None self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) self.widget.close() self.widget = None self.plot = None TestCaseQt.tearDown(self) + def testDisplayActiveItemsSyncOptions(self): + """ + Test that the several option of the sync options are well + synchronized between the different object""" + widget = StatsWidget.StatsWidget(plot=self.plot) + table = StatsWidget.StatsTable(plot=self.plot) + + def check_display_only_active_item(only_active): + # check internal value + self.assertTrue(widget._statsTable._displayOnlyActItem is only_active) + # self.assertTrue(table._displayOnlyActItem is only_active) + # check gui display + self.assertTrue(widget._options.isActiveItemMode() is only_active) + + for displayOnlyActiveItems in (True, False): + with self.subTest(displayOnlyActiveItems=displayOnlyActiveItems): + widget.setDisplayOnlyActiveItem(displayOnlyActiveItems) + # table.setDisplayOnlyActiveItem(displayOnlyActiveItems) + check_display_only_active_item(displayOnlyActiveItems) + + check_display_only_active_item(only_active=False) + widget.setAttribute(qt.Qt.WA_DeleteOnClose) + table.setAttribute(qt.Qt.WA_DeleteOnClose) + widget.close() + table.close() + def testInit(self): """Make sure all the curves are registred on initialization""" - self.assertEqual(self.widget.rowCount(), 3) + self.assertEqual(self.statsTable.rowCount(), 3) def testRemoveCurve(self): """Make sure the Curves stats take into account the curve removal from plot""" self.plot.removeCurve('curve2') - self.assertEqual(self.widget.rowCount(), 2) + self.assertEqual(self.statsTable.rowCount(), 2) for iRow in range(2): - self.assertTrue(self.widget.item(iRow, 0).text() in ('curve0', 'curve1')) + self.assertTrue(self.statsTable.item(iRow, 0).text() in ('curve0', 'curve1')) self.plot.removeCurve('curve0') - self.assertEqual(self.widget.rowCount(), 1) + self.assertEqual(self.statsTable.rowCount(), 1) self.plot.removeCurve('curve1') - self.assertEqual(self.widget.rowCount(), 0) + self.assertEqual(self.statsTable.rowCount(), 0) def testAddCurve(self): """Make sure the Curves stats take into account the add curve action""" self.plot.addCurve(legend='curve3', x=range(10), y=range(10)) - self.assertEqual(self.widget.rowCount(), 4) + self.assertEqual(self.statsTable.rowCount(), 4) def testUpdateCurveFromAddCurve(self): """Make sure the stats of the cuve will be removed after updating a curve""" self.plot.addCurve(legend='curve0', x=range(10), y=range(10)) self.qapp.processEvents() - self.assertEqual(self.widget.rowCount(), 3) + self.assertEqual(self.statsTable.rowCount(), 3) curve = self.plot._getItem(kind='curve', legend='curve0') - tableItems = self.widget._itemToTableItems(curve) + tableItems = self.statsTable._itemToTableItems(curve) self.assertEqual(tableItems['max'].text(), '9') def testUpdateCurveFromCurveObj(self): self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) self.qapp.processEvents() - self.assertEqual(self.widget.rowCount(), 3) + self.assertEqual(self.statsTable.rowCount(), 3) curve = self.plot._getItem(kind='curve', legend='curve0') - tableItems = self.widget._itemToTableItems(curve) + tableItems = self.statsTable._itemToTableItems(curve) self.assertEqual(tableItems['max'].text(), '3') def testSetAnotherPlot(self): plot2 = Plot1D() plot2.addCurve(x=range(26), y=range(26), legend='new curve') - self.widget.setPlot(plot2) - self.assertEqual(self.widget.rowCount(), 1) + self.statsTable.setPlot(plot2) + self.assertEqual(self.statsTable.rowCount(), 1) self.qapp.processEvents() plot2.setAttribute(qt.Qt.WA_DeleteOnClose) plot2.close() plot2 = None + def testUpdateMode(self): + """Make sure the update modes are well take into account""" + self.plot.setActiveCurve('curve0') + for display_only_active in (True, False): + with self.subTest(display_only_active=display_only_active): + self.widget.setDisplayOnlyActiveItem(display_only_active) + self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + update_stats_action = self.widget._options.getUpdateStatsAction() + # test from api + self.assertTrue(self.widget.getUpdateMode() is StatsWidget.UpdateMode.AUTO) + self.widget.show() + # check stats change in auto mode + self.plot.getCurve('curve0').setData(x=range(4), y=range(-1, 3)) + self.qapp.processEvents() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + print(curve0_min) + self.assertTrue(float(curve0_min) == -1.) + + self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5)) + self.qapp.processEvents() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == 1.) + + # check stats change in manual mode only if requested + self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) + self.assertTrue(self.widget.getUpdateMode() is StatsWidget.UpdateMode.MANUAL) + + self.plot.getCurve('curve0').setData(x=range(4), y=range(2, 6)) + self.qapp.processEvents() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == 1.) + + update_stats_action.trigger() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == 2.) + + def testItemHidden(self): + """Test if an item is hide, then the associated stats item is also + hide""" + curve0 = self.plot.getCurve('curve0') + curve1 = self.plot.getCurve('curve1') + curve2 = self.plot.getCurve('curve2') + + self.plot.show() + self.widget.show() + self.qWaitForWindowExposed(self.widget) + self.assertFalse(self.statsTable.isRowHidden(0)) + self.assertFalse(self.statsTable.isRowHidden(1)) + self.assertFalse(self.statsTable.isRowHidden(2)) + + curve0.setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.statsTable.isRowHidden(0)) + curve0.setVisible(True) + self.qapp.processEvents() + self.assertFalse(self.statsTable.isRowHidden(0)) + curve1.setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.statsTable.isRowHidden(1)) + tableItems = self.statsTable._itemToTableItems(curve2) + curve2_min = tableItems['min'].text() + self.assertTrue(float(curve2_min) == -2.) + + curve0.setVisible(False) + curve1.setVisible(False) + curve2.setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.statsTable.isRowHidden(0)) + self.assertTrue(self.statsTable.isRowHidden(1)) + self.assertTrue(self.statsTable.isRowHidden(2)) + class TestStatsWidgetWithImages(TestCaseQt): """Basic test for StatsWidget with images""" @@ -487,6 +592,17 @@ class TestStatsWidgetWithImages(TestCaseQt): self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0') self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0') + def testItemHidden(self): + """Test if an item is hide, then the associated stats item is also + hide""" + self.widget.show() + self.plot.show() + self.qWaitForWindowExposed(self.widget) + self.assertFalse(self.widget.isRowHidden(0)) + self.plot.getImage(self.IMAGE_LEGEND).setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.widget.isRowHidden(0)) + class TestStatsWidgetWithScatters(TestCaseQt): @@ -556,13 +672,13 @@ class TestLineWidget(TestCaseQt): self.plot = Plot1D() self.plot.show() - x = range(20) - y = range(20) - self.plot.addCurve(x, y, legend='curve0') - y = range(12, 32) - self.plot.addCurve(x, y, legend='curve1') - y = range(-2, 18) - self.plot.addCurve(x, y, legend='curve2') + self.x = range(20) + self.y0 = range(20) + self.curve0 = self.plot.addCurve(self.x, self.y0, legend='curve0') + self.y1 = range(12, 32) + self.plot.addCurve(self.x, self.y1, legend='curve1') + self.y2 = range(-2, 18) + self.plot.addCurve(self.x, self.y2, legend='curve2') self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot, kind='curve', stats=mystats) @@ -572,33 +688,112 @@ class TestLineWidget(TestCaseQt): self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() self.widget.setPlot(None) - self.widget._statQlineEdit.clear() + self.widget._lineStatsWidget._statQlineEdit.clear() self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) self.widget.close() self.widget = None self.plot = None TestCaseQt.tearDown(self) - def test(self): - self.widget.setStatsOnVisibleData(False) + def testProcessing(self): + self.widget._lineStatsWidget.setStatsOnVisibleData(False) self.qapp.processEvents() self.plot.setActiveCurve(legend='curve0') - self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.000') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.000') self.plot.setActiveCurve(legend='curve1') - self.assertTrue(self.widget._statQlineEdit['min'].text() == '12.000') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '12.000') self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5) self.widget.setStatsOnVisibleData(True) self.qapp.processEvents() - self.assertTrue(self.widget._statQlineEdit['min'].text() == '14.000') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000') self.plot.setActiveCurve(None) self.assertTrue(self.plot.getActiveCurve() is None) self.widget.setStatsOnVisibleData(False) self.qapp.processEvents() - self.assertFalse(self.widget._statQlineEdit['min'].text() == '14.000') + self.assertFalse(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000') self.widget.setKind('image') self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312) self.qapp.processEvents() - self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.312') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.312') + + def testUpdateMode(self): + """Make sure the update modes are well take into account""" + self.plot.setActiveCurve(self.curve0) + _autoRB = self.widget._options._autoRB + _manualRB = self.widget._options._manualRB + # test from api + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + self.assertTrue(_autoRB.isChecked()) + self.assertFalse(_manualRB.isChecked()) + + # check stats change in auto mode + curve0_min = self.widget._lineStatsWidget._statQlineEdit['min'].text() + new_y = numpy.array(self.y0) - 2.56 + self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0) + curve0_min2 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.assertTrue(curve0_min != curve0_min2) + + # check stats change in manual mode only if requested + self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) + self.assertFalse(_autoRB.isChecked()) + self.assertTrue(_manualRB.isChecked()) + + new_y = numpy.array(self.y0) - 1.2 + self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0) + curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.assertTrue(curve0_min3 == curve0_min2) + self.widget._options._updateRequested() + curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.assertTrue(curve0_min3 != curve0_min2) + + # test from gui + self.widget.showRadioButtons(True) + self.widget._options._autoRB.toggle() + self.assertTrue(_autoRB.isChecked()) + self.assertFalse(_manualRB.isChecked()) + + self.widget._options._manualRB.toggle() + self.assertFalse(_autoRB.isChecked()) + self.assertTrue(_manualRB.isChecked()) + + +class TestUpdateModeWidget(TestCaseQt): + """Test UpdateModeWidget""" + def setUp(self): + TestCaseQt.setUp(self) + self.widget = StatsWidget.UpdateModeWidget(parent=None) + + def tearDown(self): + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + TestCaseQt.tearDown(self) + + def testSignals(self): + """Test the signal emission of the widget""" + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + modeChangedListener = SignalListener() + manualUpdateListener = SignalListener() + self.widget.sigUpdateModeChanged.connect(modeChangedListener) + self.widget.sigUpdateRequested.connect(manualUpdateListener) + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + self.assertTrue(self.widget.getUpdateMode() is StatsWidget.UpdateMode.AUTO) + self.assertTrue(modeChangedListener.callCount() is 0) + self.qapp.processEvents() + + self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) + self.assertTrue(self.widget.getUpdateMode() is StatsWidget.UpdateMode.MANUAL) + self.qapp.processEvents() + self.assertTrue(modeChangedListener.callCount() is 1) + self.assertTrue(manualUpdateListener.callCount() is 0) + self.widget._updatePB.click() + self.widget._updatePB.click() + self.assertTrue(manualUpdateListener.callCount() is 2) + + self.widget._autoRB.setChecked(True) + self.assertTrue(modeChangedListener.callCount() is 2) + self.widget._updatePB.click() + self.assertTrue(manualUpdateListener.callCount() is 2) def suite(): @@ -606,7 +801,7 @@ def suite(): for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters, TestStatsWidgetWithImages, TestStatsWidgetWithCurves, TestStatsFormatter, TestEmptyStatsWidget, - TestLineWidget): + TestLineWidget, TestUpdateModeWidget): test_suite.addTest( unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) return test_suite diff --git a/silx/gui/plot/tools/profile/ScatterProfileToolBar.py b/silx/gui/plot/tools/profile/ScatterProfileToolBar.py index fd21515..0d30651 100644 --- a/silx/gui/plot/tools/profile/ScatterProfileToolBar.py +++ b/silx/gui/plot/tools/profile/ScatterProfileToolBar.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,196 +31,18 @@ __date__ = "28/06/2018" import logging -import threading -import time +import weakref import numpy -try: - from scipy.interpolate import LinearNDInterpolator -except ImportError: - LinearNDInterpolator = None - - # Fallback using local Delaunay and matplotlib interpolator - from silx.third_party.scipy_spatial import Delaunay - import matplotlib.tri - from ._BaseProfileToolBar import _BaseProfileToolBar -from .... import qt from ... import items +from ....utils.concurrent import submitToQtMainThread _logger = logging.getLogger(__name__) -# TODO support log scale - - -class _InterpolatorInitThread(qt.QThread): - """Thread building a scatter interpolator - - This works in greedy mode in that the signal is only emitted - when no other request is pending - """ - - sigInterpolatorReady = qt.Signal(object) - """Signal emitted whenever an interpolator is ready - - It provides a 3-tuple (points, values, interpolator) - """ - - _RUNNING_THREADS_TO_DELETE = [] - """Store reference of no more used threads but still running""" - - def __init__(self): - super(_InterpolatorInitThread, self).__init__() - self._lock = threading.RLock() - self._pendingData = None - self._firstFallbackRun = True - - def discard(self, obj=None): - """Wait for pending thread to complete and delete then - - Connect this to the destroyed signal of widget using this thread - """ - if self.isRunning(): - self.cancel() - self._RUNNING_THREADS_TO_DELETE.append(self) # Keep a reference - self.finished.connect(self.__finished) - - def __finished(self): - """Handle finished signal of threads to delete""" - try: - self._RUNNING_THREADS_TO_DELETE.remove(self) - except ValueError: - _logger.warning('Finished thread no longer in reference list') - - def request(self, points, values): - """Request new initialisation of interpolator - - :param numpy.ndarray points: Point coordinates (N, D) - :param numpy.ndarray values: Values the N points (1D array) - """ - with self._lock: - # Possibly replace already pending data - self._pendingData = points, values - - if not self.isRunning(): - self.start() - - def cancel(self): - """Cancel any running/pending requests""" - with self._lock: - self._pendingData = 'cancelled' - - def run(self): - """Run the init of the scatter interpolator""" - if LinearNDInterpolator is None: - self.run_matplotlib() - else: - self.run_scipy() - - def run_matplotlib(self): - """Run the init of the scatter interpolator""" - if self._firstFallbackRun: - self._firstFallbackRun = False - _logger.warning( - "scipy.spatial.LinearNDInterpolator not available: " - "Scatter plot interpolator initialisation can freeze the GUI.") - - while True: - with self._lock: - data = self._pendingData - self._pendingData = None - - if data in (None, 'cancelled'): - return - - points, values = data - - startTime = time.time() - try: - delaunay = Delaunay(points) - except: - _logger.warning( - "Cannot triangulate scatter data") - else: - with self._lock: - data = self._pendingData - - if data is not None: # Break point - _logger.info('Interpolator discarded after %f s', - time.time() - startTime) - else: - - x, y = points.T - triangulation = matplotlib.tri.Triangulation( - x, y, triangles=delaunay.simplices) - - interpolator = matplotlib.tri.LinearTriInterpolator( - triangulation, values) - - with self._lock: - data = self._pendingData - - if data is not None: - _logger.info('Interpolator discarded after %f s', - time.time() - startTime) - else: - # No other processing requested: emit the signal - _logger.info("Interpolator initialised in %f s", - time.time() - startTime) - - # Wrap interpolator to have same API as scipy's one - def wrapper(points): - return interpolator(*points.T) - - self.sigInterpolatorReady.emit( - (points, values, wrapper)) - - def run_scipy(self): - """Run the init of the scatter interpolator""" - while True: - with self._lock: - data = self._pendingData - self._pendingData = None - - if data in (None, 'cancelled'): - return - - points, values = data - - startTime = time.time() - try: - interpolator = LinearNDInterpolator(points, values) - except: - _logger.warning( - "Cannot initialise scatter profile interpolator") - else: - with self._lock: - data = self._pendingData - - if data is not None: # Break point - _logger.info('Interpolator discarded after %f s', - time.time() - startTime) - else: - # First call takes a while, do it here - interpolator([(0., 0.)]) - - with self._lock: - data = self._pendingData - - if data is not None: - _logger.info('Interpolator discarded after %f s', - time.time() - startTime) - else: - # No other processing requested: emit the signal - _logger.info("Interpolator initialised in %f s", - time.time() - startTime) - self.sigInterpolatorReady.emit( - (points, values, interpolator)) - - class ScatterProfileToolBar(_BaseProfileToolBar): """QToolBar providing scatter plot profiling tools @@ -233,49 +55,13 @@ class ScatterProfileToolBar(_BaseProfileToolBar): super(ScatterProfileToolBar, self).__init__(parent, plot, title) self.__nPoints = 1024 - self.__interpolator = None - self.__interpolatorCache = None # points, values, interpolator - - self.__initThread = _InterpolatorInitThread() - self.destroyed.connect(self.__initThread.discard) - self.__initThread.sigInterpolatorReady.connect( - self.__interpolatorReady) - - roiManager = self._getRoiManager() - if roiManager is None: - _logger.error( - "Error during scatter profile toolbar initialisation") - else: - roiManager.sigInteractiveModeStarted.connect( - self.__interactionStarted) - roiManager.sigInteractiveModeFinished.connect( - self.__interactionFinished) - if roiManager.isStarted(): - self.__interactionStarted(roiManager.getCurrentInteractionModeRoiClass()) - - def __interactionStarted(self, roiClass): - """Handle start of ROI interaction""" - plot = self.getPlotWidget() - if plot is None: - return - - plot.sigActiveScatterChanged.connect(self.__activeScatterChanged) - - scatter = plot._getActiveItem(kind='scatter') - legend = None if scatter is None else scatter.getLegend() - self.__activeScatterChanged(None, legend) + self.__scatterRef = None + self.__futureInterpolator = None - def __interactionFinished(self): - """Handle end of ROI interaction""" plot = self.getPlotWidget() - if plot is None: - return - - plot.sigActiveScatterChanged.disconnect(self.__activeScatterChanged) - - scatter = plot._getActiveItem(kind='scatter') - legend = None if scatter is None else scatter.getLegend() - self.__activeScatterChanged(legend, None) + if plot is not None: + self._setScatterItem(plot._getActiveItem(kind='scatter')) + plot.sigActiveScatterChanged.connect(self.__activeScatterChanged) def __activeScatterChanged(self, previous, legend): """Handle change of active scatter @@ -283,35 +69,37 @@ class ScatterProfileToolBar(_BaseProfileToolBar): :param Union[str,None] previous: :param Union[str,None] legend: """ - self.__initThread.cancel() + plot = self.getPlotWidget() + if plot is None or legend is None: + scatter = None + else: + scatter = plot.getScatter(legend) + self._setScatterItem(scatter) - # Reset interpolator - self.__interpolator = None + def _getScatterItem(self): + """Returns the scatter item currently handled by this tool. - plot = self.getPlotWidget() - if plot is None: - _logger.error("Associated PlotWidget no longer exists") + :rtype: ~silx.gui.plot.items.Scatter + """ + return None if self.__scatterRef is None else self.__scatterRef() + def _setScatterItem(self, scatter): + """Set the scatter tracked by this tool + + :param Union[None,silx.gui.plot.items.Scatter] scatter: + """ + self.__futureInterpolator = None # Reset currently expected future + + previousScatter = self._getScatterItem() + if previousScatter is not None: + previousScatter.sigItemChanged.disconnect( + self.__scatterItemChanged) + + if scatter is None: + self.__scatterRef = None else: - if previous is not None: # Disconnect signal - scatter = plot.getScatter(previous) - if scatter is not None: - scatter.sigItemChanged.disconnect( - self.__scatterItemChanged) - - if legend is not None: - scatter = plot.getScatter(legend) - if scatter is None: - _logger.error("Cannot retrieve active scatter") - - else: - scatter.sigItemChanged.connect(self.__scatterItemChanged) - points = numpy.transpose(numpy.array(( - scatter.getXData(copy=False), - scatter.getYData(copy=False)))) - values = scatter.getValueData(copy=False) - - self.__updateInterpolator(points, values) + self.__scatterRef = weakref.ref(scatter) + scatter.sigItemChanged.connect(self.__scatterItemChanged) # Refresh profile self.updateProfile() @@ -322,49 +110,15 @@ class ScatterProfileToolBar(_BaseProfileToolBar): :param ItemChangedType event: """ if event == items.ItemChangedType.DATA: - self.__interpolator = None - scatter = self.sender() - if scatter is None: - _logger.error("Cannot retrieve updated scatter item") - - else: - points = numpy.transpose(numpy.array(( - scatter.getXData(copy=False), - scatter.getYData(copy=False)))) - values = scatter.getValueData(copy=False) - - self.__updateInterpolator(points, values) - - # Handle interpolator init thread - - def __updateInterpolator(self, points, values): - """Update used interpolator with new data""" - if (self.__interpolatorCache is not None and - len(points) == len(self.__interpolatorCache[0]) and - numpy.all(numpy.equal(self.__interpolatorCache[0], points)) and - numpy.all(numpy.equal(self.__interpolatorCache[1], values))): - # Reuse previous interpolator - _logger.info( - 'Scatter changed: Reuse previous interpolator') - self.__interpolator = self.__interpolatorCache[2] - - else: - # Interpolator needs update: Start background processing - _logger.info( - 'Scatter changed: Rebuild interpolator') - self.__interpolator = None - self.__interpolatorCache = None - self.__initThread.request(points, values) - - def __interpolatorReady(self, data): - """Handle end of init interpolator thread""" - points, values, interpolator = data - self.__interpolator = interpolator - self.__interpolatorCache = None if interpolator is None else data - self.updateProfile() + self.updateProfile() # Refresh profile def hasPendingOperations(self): - return self.__initThread.isRunning() + """Returns True if waiting for an interpolator to be ready + + :rtype: bool + """ + return (self.__futureInterpolator is not None and + not self.__futureInterpolator.done()) # Number of points @@ -383,8 +137,9 @@ class ScatterProfileToolBar(_BaseProfileToolBar): npoints = int(npoints) if npoints < 1: raise ValueError("Unsupported number of points: %d" % npoints) - else: + elif npoints != self.__nPoints: self.__nPoints = npoints + self.updateProfile() # Overridden methods @@ -400,11 +155,16 @@ class ScatterProfileToolBar(_BaseProfileToolBar): """ if self.hasPendingOperations(): return 'Pre-processing data...' - else: return super(ScatterProfileToolBar, self).computeProfileTitle( x0, y0, x1, y1) + def __futureDone(self, future): + """Handle completion of the interpolator creation""" + if future is self.__futureInterpolator: + # Only handle future callbacks for the current one + submitToQtMainThread(self.updateProfile) + def computeProfile(self, x0, y0, x1, y1): """Compute corresponding profile @@ -414,16 +174,32 @@ class ScatterProfileToolBar(_BaseProfileToolBar): :param float y1: Profile end point Y coord :return: (points, values) profile data or None """ - if self.__interpolator is None: + scatter = self._getScatterItem() + if scatter is None or self.hasPendingOperations(): return None - nPoints = self.getNPoints() + # Lazy async request of the interpolator + future = scatter._getInterpolator() + if future is not self.__futureInterpolator: + # First time we request this interpolator + self.__futureInterpolator = future + if not future.done(): + future.add_done_callback(self.__futureDone) + return None + + if future.cancelled() or future.exception() is not None: + return None # Something went wrong + interpolator = future.result() + if interpolator is None: + return None # Cannot init an interpolator + + nPoints = self.getNPoints() points = numpy.transpose(( numpy.linspace(x0, x1, nPoints, endpoint=True), numpy.linspace(y0, y1, nPoints, endpoint=True))) - values = self.__interpolator(points) + values = interpolator(points) if not numpy.any(numpy.isfinite(values)): return None # Profile outside convex hull diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py index 98295ba..eb933a0 100644 --- a/silx/gui/plot/tools/roi.py +++ b/silx/gui/plot/tools/roi.py @@ -106,6 +106,9 @@ class RegionOfInterestManager(qt.QObject): self._rois = [] # List of ROIs self._drawnROI = None # New ROI being currently drawn + # Handle unique selection of interaction mode action + self._actionGroup = qt.QActionGroup(self) + self._roiClass = None self._color = rgba('red') @@ -158,6 +161,8 @@ class RegionOfInterestManager(qt.QObject): action.setChecked(self.getCurrentInteractionModeRoiClass() is roiClass) action.setToolTip(text) + self._actionGroup.addAction(action) + action.triggered[bool].connect(functools.partial( WeakMethodProxy(self._modeActionTriggered), roiClass=roiClass)) self._modeActions[roiClass] = action @@ -171,9 +176,6 @@ class RegionOfInterestManager(qt.QObject): """ if checked: self.start(roiClass) - else: # Keep action checked - action = self.sender() - action.setChecked(True) def _updateModeActions(self): """Check/Uncheck action corresponding to current mode""" @@ -781,9 +783,9 @@ class RegionOfInterestTableWidget(qt.QTableWidget): super(RegionOfInterestTableWidget, self).__init__(parent) self._roiManagerRef = None - self.setColumnCount(5) - self.setHorizontalHeaderLabels( - ['Label', 'Edit', 'Kind', 'Coordinates', '']) + headers = ['Label', 'Edit', 'Kind', 'Coordinates', ''] + self.setColumnCount(len(headers)) + self.setHorizontalHeaderLabels(headers) horizontalHeader = self.horizontalHeader() horizontalHeader.setDefaultAlignment(qt.Qt.AlignLeft) @@ -815,9 +817,10 @@ class RegionOfInterestTableWidget(qt.QTableWidget): manager = self.getRegionOfInterestManager() roi = manager.getRois()[index] else: - roi = None + return if column == 0: + roi.setVisible(item.checkState() == qt.Qt.Checked) roi.setLabel(item.text()) elif column == 1: roi.setEditable( @@ -884,11 +887,13 @@ class RegionOfInterestTableWidget(qt.QTableWidget): for index, roi in enumerate(rois): baseFlags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled - # Label + # Label and visible label = roi.getLabel() item = qt.QTableWidgetItem(label) - item.setFlags(baseFlags | qt.Qt.ItemIsEditable) + item.setFlags(baseFlags | qt.Qt.ItemIsEditable | qt.Qt.ItemIsUserCheckable) item.setData(qt.Qt.UserRole, index) + item.setCheckState( + qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked) self.setItem(index, 0, item) # Editable diff --git a/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/silx/gui/plot/tools/test/testScatterProfileToolBar.py index 0f4b668..714746a 100644 --- a/silx/gui/plot/tools/test/testScatterProfileToolBar.py +++ b/silx/gui/plot/tools/test/testScatterProfileToolBar.py @@ -101,6 +101,7 @@ class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase): self.qWait(200) if not self.profile.hasPendingOperations(): break + self.qapp.processEvents() self.assertIsNotNone(self.profile.getProfileValues()) points = self.profile.getProfilePoints() diff --git a/silx/gui/plot/tools/test/testTools.py b/silx/gui/plot/tools/test/testTools.py index f4adda0..70c8105 100644 --- a/silx/gui/plot/tools/test/testTools.py +++ b/silx/gui/plot/tools/test/testTools.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -41,34 +41,6 @@ from silx.gui.plot import tools from silx.gui.plot.test.utils import PlotWidgetTestCase -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - - -def _tearDownDocTest(docTest): - """Tear down to use for test from docstring. - - Checks that plot widget is displayed - """ - plot = docTest.globs['plot'] - qWaitForWindowExposedAndActivate(plot) - plot.setAttribute(qt.Qt.WA_DeleteOnClose) - plot.close() - del plot - -# Disable doctest because of -# "NameError: name 'numpy' is not defined" -# -# import doctest -# positionInfoTestSuite = doctest.DocTestSuite( -# PlotTools, tearDown=_tearDownDocTest, -# optionflags=doctest.ELLIPSIS) -# """Test suite of tests from PlotTools docstrings. -# -# Test PositionInfo and ProfileToolBar docstrings. -# """ - - class TestPositionInfo(PlotWidgetTestCase): """Tests for PositionInfo widget.""" diff --git a/silx/gui/plot/tools/toolbars.py b/silx/gui/plot/tools/toolbars.py index 28fb7f9..04d0cfc 100644 --- a/silx/gui/plot/tools/toolbars.py +++ b/silx/gui/plot/tools/toolbars.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,6 +34,7 @@ from ... import qt from .. import actions from ..PlotWidget import PlotWidget from .. import PlotToolButtons +from ....utils.deprecation import deprecated class InteractiveModeToolBar(qt.QToolBar): @@ -302,9 +303,9 @@ class ScatterToolBar(qt.QToolBar): parent=self, plot=plot) self.addAction(self._colormapAction) - self._symbolToolButton = PlotToolButtons.SymbolToolButton( - parent=self, plot=plot) - self.addWidget(self._symbolToolButton) + self._visualizationToolButton = \ + PlotToolButtons.ScatterVisualizationToolButton(parent=self, plot=plot) + self.addWidget(self._visualizationToolButton) def getResetZoomAction(self): """Returns the QAction to reset the zoom. @@ -341,16 +342,21 @@ class ScatterToolBar(qt.QToolBar): """ return self._colormapAction - def getSymbolToolButton(self): - """Returns the QToolButton controlling symbol size and marker. - - :rtype: SymbolToolButton - """ - return self._symbolToolButton - def getKeepDataAspectRatioButton(self): """Returns the QToolButton controlling data aspect ratio. :rtype: QToolButton """ return self._keepDataAspectRatioButton + + def getScatterVisualizationToolButton(self): + """Returns the QToolButton controlling the visualization mode. + + :rtype: ScatterVisualizationToolButton + """ + return self._visualizationToolButton + + @deprecated(replacement='getScatterVisualizationToolButton', + since_version='0.11.0') + def getSymbolToolButton(self): + return self.getScatterVisualizationToolButton() |