summaryrefslogtreecommitdiff
path: root/src/silx/gui/plot
diff options
context:
space:
mode:
Diffstat (limited to 'src/silx/gui/plot')
-rw-r--r--src/silx/gui/plot/AlphaSlider.py299
-rw-r--r--src/silx/gui/plot/ColorBar.py900
-rw-r--r--src/silx/gui/plot/CompareImages.py954
-rw-r--r--src/silx/gui/plot/ComplexImageView.py508
-rw-r--r--src/silx/gui/plot/CurvesROIWidget.py1572
-rw-r--r--src/silx/gui/plot/ImageStack.py583
-rw-r--r--src/silx/gui/plot/ImageView.py1121
-rw-r--r--src/silx/gui/plot/Interaction.py352
-rw-r--r--src/silx/gui/plot/ItemsSelectionDialog.py293
-rwxr-xr-xsrc/silx/gui/plot/LegendSelector.py1042
-rw-r--r--src/silx/gui/plot/LimitsHistory.py82
-rw-r--r--src/silx/gui/plot/MaskToolsWidget.py914
-rw-r--r--src/silx/gui/plot/PlotEvents.py184
-rw-r--r--src/silx/gui/plot/PlotInteraction.py1885
-rw-r--r--src/silx/gui/plot/PlotToolButtons.py554
-rwxr-xr-xsrc/silx/gui/plot/PlotWidget.py3789
-rw-r--r--src/silx/gui/plot/PlotWindow.py1052
-rw-r--r--src/silx/gui/plot/PrintPreviewToolButton.py388
-rw-r--r--src/silx/gui/plot/Profile.py254
-rw-r--r--src/silx/gui/plot/ROIStatsWidget.py807
-rw-r--r--src/silx/gui/plot/ScatterMaskToolsWidget.py643
-rw-r--r--src/silx/gui/plot/ScatterView.py418
-rw-r--r--src/silx/gui/plot/StackView.py1248
-rw-r--r--src/silx/gui/plot/StatsWidget.py1720
-rw-r--r--src/silx/gui/plot/_BaseMaskToolsWidget.py1311
-rw-r--r--src/silx/gui/plot/__init__.py78
-rw-r--r--src/silx/gui/plot/_utils/__init__.py92
-rw-r--r--src/silx/gui/plot/_utils/dtime_ticklayout.py480
-rw-r--r--src/silx/gui/plot/_utils/panzoom.py366
-rw-r--r--src/silx/gui/plot/_utils/test/__init__.py23
-rw-r--r--src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py87
-rw-r--r--src/silx/gui/plot/_utils/test/test_ticklayout.py77
-rw-r--r--src/silx/gui/plot/_utils/ticklayout.py266
-rw-r--r--src/silx/gui/plot/actions/PlotAction.py84
-rw-r--r--src/silx/gui/plot/actions/PlotToolAction.py156
-rw-r--r--src/silx/gui/plot/actions/__init__.py41
-rwxr-xr-xsrc/silx/gui/plot/actions/control.py760
-rw-r--r--src/silx/gui/plot/actions/fit.py442
-rw-r--r--src/silx/gui/plot/actions/histogram.py548
-rw-r--r--src/silx/gui/plot/actions/io.py874
-rw-r--r--src/silx/gui/plot/actions/medfilt.py140
-rw-r--r--src/silx/gui/plot/actions/mode.py139
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendBase.py606
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendMatplotlib.py1726
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendOpenGL.py1660
-rw-r--r--src/silx/gui/plot/backends/__init__.py28
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotCurve.py1494
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotFrame.py1399
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotImage.py789
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotItem.py105
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotTriangles.py203
-rw-r--r--src/silx/gui/plot/backends/glutils/GLSupport.py174
-rw-r--r--src/silx/gui/plot/backends/glutils/GLText.py297
-rw-r--r--src/silx/gui/plot/backends/glutils/GLTexture.py269
-rw-r--r--src/silx/gui/plot/backends/glutils/PlotImageFile.py159
-rw-r--r--src/silx/gui/plot/backends/glutils/__init__.py45
-rw-r--r--src/silx/gui/plot/items/__init__.py80
-rw-r--r--src/silx/gui/plot/items/_arc_roi.py992
-rw-r--r--src/silx/gui/plot/items/_band_roi.py376
-rw-r--r--src/silx/gui/plot/items/_pick.py71
-rw-r--r--src/silx/gui/plot/items/_roi_base.py874
-rw-r--r--src/silx/gui/plot/items/axis.py581
-rw-r--r--src/silx/gui/plot/items/complex.py382
-rw-r--r--src/silx/gui/plot/items/core.py1832
-rw-r--r--src/silx/gui/plot/items/curve.py347
-rw-r--r--src/silx/gui/plot/items/histogram.py413
-rw-r--r--src/silx/gui/plot/items/image.py673
-rw-r--r--src/silx/gui/plot/items/image_aggregated.py230
-rwxr-xr-xsrc/silx/gui/plot/items/marker.py335
-rw-r--r--src/silx/gui/plot/items/roi.py1561
-rw-r--r--src/silx/gui/plot/items/scatter.py1060
-rw-r--r--src/silx/gui/plot/items/shape.py393
-rw-r--r--src/silx/gui/plot/stats/__init__.py32
-rw-r--r--src/silx/gui/plot/stats/stats.py911
-rw-r--r--src/silx/gui/plot/stats/statshandler.py222
-rw-r--r--src/silx/gui/plot/test/__init__.py23
-rw-r--r--src/silx/gui/plot/test/conftest.py43
-rw-r--r--src/silx/gui/plot/test/testAlphaSlider.py189
-rw-r--r--src/silx/gui/plot/test/testAxis.py147
-rw-r--r--src/silx/gui/plot/test/testColorBar.py317
-rw-r--r--src/silx/gui/plot/test/testCompareImages.py236
-rw-r--r--src/silx/gui/plot/test/testComplexImageView.py82
-rw-r--r--src/silx/gui/plot/test/testCurvesROIWidget.py489
-rw-r--r--src/silx/gui/plot/test/testImageStack.py234
-rw-r--r--src/silx/gui/plot/test/testImageView.py201
-rw-r--r--src/silx/gui/plot/test/testInteraction.py79
-rw-r--r--src/silx/gui/plot/test/testItem.py565
-rw-r--r--src/silx/gui/plot/test/testLegendSelector.py139
-rw-r--r--src/silx/gui/plot/test/testLimitConstraints.py113
-rw-r--r--src/silx/gui/plot/test/testMaskToolsWidget.py321
-rw-r--r--src/silx/gui/plot/test/testPixelIntensityHistoAction.py149
-rw-r--r--src/silx/gui/plot/test/testPlotActions.py106
-rw-r--r--src/silx/gui/plot/test/testPlotInteraction.py232
-rwxr-xr-xsrc/silx/gui/plot/test/testPlotWidget.py2049
-rwxr-xr-xsrc/silx/gui/plot/test/testPlotWidgetActiveItem.py416
-rw-r--r--src/silx/gui/plot/test/testPlotWidgetDataMargins.py135
-rw-r--r--src/silx/gui/plot/test/testPlotWidgetNoBackend.py685
-rw-r--r--src/silx/gui/plot/test/testPlotWindow.py175
-rw-r--r--src/silx/gui/plot/test/testRoiStatsWidget.py260
-rw-r--r--src/silx/gui/plot/test/testSaveAction.py131
-rw-r--r--src/silx/gui/plot/test/testScatterMaskToolsWidget.py307
-rw-r--r--src/silx/gui/plot/test/testScatterView.py122
-rw-r--r--src/silx/gui/plot/test/testStackView.py266
-rw-r--r--src/silx/gui/plot/test/testStats.py1143
-rw-r--r--src/silx/gui/plot/test/testUtilsAxis.py239
-rw-r--r--src/silx/gui/plot/test/utils.py92
-rw-r--r--src/silx/gui/plot/tools/CurveLegendsWidget.py245
-rw-r--r--src/silx/gui/plot/tools/LimitsToolBar.py123
-rw-r--r--src/silx/gui/plot/tools/PlotToolButton.py92
-rw-r--r--src/silx/gui/plot/tools/PositionInfo.py357
-rw-r--r--src/silx/gui/plot/tools/RadarView.py359
-rw-r--r--src/silx/gui/plot/tools/RulerToolButton.py183
-rw-r--r--src/silx/gui/plot/tools/__init__.py49
-rw-r--r--src/silx/gui/plot/tools/compare/__init__.py29
-rw-r--r--src/silx/gui/plot/tools/compare/core.py198
-rw-r--r--src/silx/gui/plot/tools/compare/profile.py173
-rw-r--r--src/silx/gui/plot/tools/compare/statusbar.py218
-rw-r--r--src/silx/gui/plot/tools/compare/toolbar.py390
-rw-r--r--src/silx/gui/plot/tools/menus.py93
-rw-r--r--src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py44
-rw-r--r--src/silx/gui/plot/tools/profile/__init__.py37
-rw-r--r--src/silx/gui/plot/tools/profile/core.py567
-rw-r--r--src/silx/gui/plot/tools/profile/editors.py309
-rw-r--r--src/silx/gui/plot/tools/profile/manager.py1109
-rw-r--r--src/silx/gui/plot/tools/profile/rois.py1176
-rw-r--r--src/silx/gui/plot/tools/profile/toolbar.py172
-rw-r--r--src/silx/gui/plot/tools/roi.py1515
-rw-r--r--src/silx/gui/plot/tools/test/__init__.py23
-rw-r--r--src/silx/gui/plot/tools/test/testCurveLegendsWidget.py117
-rw-r--r--src/silx/gui/plot/tools/test/testProfile.py604
-rw-r--r--src/silx/gui/plot/tools/test/testRoiCore.py535
-rw-r--r--src/silx/gui/plot/tools/test/testRoiItems.py313
-rw-r--r--src/silx/gui/plot/tools/test/testScatterProfileToolBar.py185
-rw-r--r--src/silx/gui/plot/tools/test/testTools.py135
-rw-r--r--src/silx/gui/plot/tools/toolbars.py357
-rw-r--r--src/silx/gui/plot/utils/__init__.py29
-rw-r--r--src/silx/gui/plot/utils/axis.py400
-rw-r--r--src/silx/gui/plot/utils/intersections.py106
138 files changed, 67772 insertions, 0 deletions
diff --git a/src/silx/gui/plot/AlphaSlider.py b/src/silx/gui/plot/AlphaSlider.py
new file mode 100644
index 0000000..8a0a711
--- /dev/null
+++ b/src/silx/gui/plot/AlphaSlider.py
@@ -0,0 +1,299 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines slider widgets interacting with the transparency
+of an image on a :class:`PlotWidget`
+
+Classes:
+--------
+
+- :class:`BaseAlphaSlider` (abstract class)
+- :class:`NamedImageAlphaSlider`
+- :class:`ActiveImageAlphaSlider`
+
+Example:
+--------
+
+This widget can, for instance, be added to a plot toolbar.
+
+.. code-block:: python
+
+ import numpy
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.ImageAlphaSlider import NamedImageAlphaSlider
+
+ app = qt.QApplication([])
+ pw = PlotWidget()
+
+ img0 = numpy.arange(200*150).reshape((200, 150))
+ pw.addImage(img0, legend="my background", z=0, origin=(50, 50))
+
+ x, y = numpy.meshgrid(numpy.linspace(-10, 10, 200),
+ numpy.linspace(-10, 5, 150),
+ indexing="ij")
+ img1 = numpy.asarray(numpy.sin(x * y) / (x * y),
+ dtype='float32')
+
+ pw.addImage(img1, legend="my data", z=1,
+ replace=False)
+
+ alpha_slider = NamedImageAlphaSlider(parent=pw,
+ plot=pw,
+ legend="my data")
+ alpha_slider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", pw)
+ toolbar.addWidget(alpha_slider)
+ pw.addToolBar(toolbar)
+
+ pw.show()
+ app.exec()
+
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/03/2017"
+
+import logging
+
+from silx.gui import qt
+
+_logger = logging.getLogger(__name__)
+
+
+class BaseAlphaSlider(qt.QSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of a plot primitive (image, scatter or curve).
+
+ Internally, the slider stores its state as an integer between
+ 0 and 255. This is the value emitted by the :attr:`valueChanged`
+ signal.
+
+ The method :meth:`getAlpha` returns the corresponding opacity/alpha
+ as a float between 0. and 1. (with a step of :math:`\frac{1}{255}`).
+
+ You must subclass this class and implement :meth:`getItem`.
+ """
+
+ sigAlphaChanged = qt.Signal(float)
+ """Emits the alpha value when the slider's value changes,
+ as a float between 0. and 1."""
+
+ def __init__(self, parent=None, plot=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Parent plot widget
+ """
+ assert plot is not None
+ super(BaseAlphaSlider, self).__init__(parent)
+
+ self.plot = plot
+
+ self.setRange(0, 255)
+
+ # if already connected to an item, use its alpha as initial value
+ if self.getItem() is None:
+ self.setValue(255)
+ self.setEnabled(False)
+ else:
+ alpha = self.getItem().getAlpha()
+ self.setValue(round(255 * alpha))
+
+ self.valueChanged.connect(self._valueChanged)
+
+ def getItem(self):
+ """You must implement this class to define which item
+ to work on. It must return an item that inherits
+ :class:`silx.gui.plot.items.core.AlphaMixIn`.
+
+ :return: Item on which to operate, or None
+ :rtype: :class:`silx.plot.items.Item`
+ """
+ raise NotImplementedError(
+ "BaseAlphaSlider must be subclassed to " + "implement getItem()"
+ )
+
+ def getAlpha(self):
+ """Get the opacity, as a float between 0. and 1.
+
+ :return: Alpha value in [0., 1.]
+ :rtype: float
+ """
+ return self.value() / 255.0
+
+ def _valueChanged(self, value):
+ self._updateItem()
+ self.sigAlphaChanged.emit(value / 255.0)
+
+ def _updateItem(self):
+ """Update the item's alpha channel."""
+ item = self.getItem()
+ if item is not None:
+ item.setAlpha(self.getAlpha())
+
+
+class ActiveImageAlphaSlider(BaseAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of the **active image**.
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+
+ See documentation of :class:`BaseAlphaSlider`
+ """
+
+ def __init__(self, parent=None, plot=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Plot widget on which to operate
+ """
+ super(ActiveImageAlphaSlider, self).__init__(parent, plot)
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+
+ def getItem(self):
+ return self.plot.getActiveImage()
+
+ def _activeImageChanged(self, previous, new):
+ """Activate or deactivate slider depending on presence of a new
+ active image.
+ Apply transparency value to new active image.
+
+ :param previous: Legend of previous active image, or None
+ :param new: Legend of new active image, or None
+ """
+ if new is not None and not self.isEnabled():
+ self.setEnabled(True)
+ elif new is None and self.isEnabled():
+ self.setEnabled(False)
+
+ self._updateItem()
+
+
+class NamedItemAlphaSlider(BaseAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of an item (defined by its kind and legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str kind: Kind of item whose transparency is to be
+ controlled: "scatter", "image" or "curve".
+ :param str legend: Legend of item whose transparency is to be
+ controlled.
+ """
+
+ def __init__(self, parent=None, plot=None, kind=None, legend=None):
+ self._item_legend = legend
+ self._item_kind = kind
+
+ super(NamedItemAlphaSlider, self).__init__(parent, plot)
+
+ self._updateState()
+ plot.sigContentChanged.connect(self._onContentChanged)
+
+ def _onContentChanged(self, action, kind, legend):
+ if legend == self._item_legend and kind == self._item_kind:
+ if action == "add":
+ self.setEnabled(True)
+ elif action == "remove":
+ self.setEnabled(False)
+
+ def _updateState(self):
+ """Enable or disable widget based on item's availability."""
+ if self.getItem() is not None:
+ self.setEnabled(True)
+ else:
+ self.setEnabled(False)
+
+ def getItem(self):
+ """Return plot item currently associated to this widget (can be
+ a curve, an image, a scatter...)
+
+ :rtype: subclass of :class:`silx.gui.plot.items.Item`"""
+ if self._item_legend is None or self._item_kind is None:
+ return None
+ return self.plot._getItem(kind=self._item_kind, legend=self._item_legend)
+
+ def setLegend(self, legend):
+ """Associate a different item (of the same kind) to the slider.
+
+ :param legend: New legend of item whose transparency is to be
+ controlled.
+ """
+ self._item_legend = legend
+ self._updateState()
+
+ def getLegend(self):
+ """Return legend of the item currently controlled by this slider.
+
+ :return: Image legend associated to the slider
+ """
+ return self._item_kind
+
+ def setItemKind(self, legend):
+ """Associate a different item (of the same kind) to the slider.
+
+ :param legend: New legend of item whose transparency is to be
+ controlled.
+ """
+ self._item_legend = legend
+ self._updateState()
+
+ def getItemKind(self):
+ """Return kind of the item currently controlled by this slider.
+
+ :return: Item kind ("image", "scatter"...)
+ :rtype: str on None
+ """
+ return self._item_kind
+
+
+class NamedImageAlphaSlider(NamedItemAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of an image (defined by its legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str legend: Legend of image whose transparency is to be
+ controlled.
+ """
+
+ def __init__(self, parent=None, plot=None, legend=None):
+ NamedItemAlphaSlider.__init__(self, parent, plot, kind="image", legend=legend)
+
+
+class NamedScatterAlphaSlider(NamedItemAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of a scatter (defined by its legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str legend: Legend of scatter whose transparency is to be
+ controlled.
+ """
+
+ def __init__(self, parent=None, plot=None, legend=None):
+ NamedItemAlphaSlider.__init__(self, parent, plot, kind="scatter", legend=legend)
diff --git a/src/silx/gui/plot/ColorBar.py b/src/silx/gui/plot/ColorBar.py
new file mode 100644
index 0000000..ee31f25
--- /dev/null
+++ b/src/silx/gui/plot/ColorBar.py
@@ -0,0 +1,900 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module containing several widgets associated to a colormap.
+"""
+
+__authors__ = ["H. Payno", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+import weakref
+import numpy
+
+from ._utils import ticklayout
+from .. import qt
+from ..qt import inspect as qt_inspect
+from silx.gui import colors
+from silx.math.colormap import LogarithmicNormalization
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ColorBarWidget(qt.QWidget):
+ """Colorbar widget displaying a colormap
+
+ It uses a description of colormap as dict compatible with :class:`Plot`.
+
+ .. image:: img/linearColorbar.png
+ :width: 80px
+ :align: center
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> from silx.gui.plot import Plot2D
+ >>> from silx.gui.plot.ColorBar import ColorBarWidget
+
+ >>> plot = Plot2D() # Create a plot widget
+ >>> plot.show()
+
+ >>> colorbar = ColorBarWidget(plot=plot, legend='Colormap') # Associate the colorbar with it
+ >>> colorbar.show()
+
+ Initializer parameters:
+
+ :param parent: See :class:`QWidget`
+ :param plot: PlotWidget the colorbar is attached to (optional)
+ :param str legend: the label to set to the colorbar
+ """
+
+ sigVisibleChanged = qt.Signal(bool)
+ """Emitted when the property `visible` have changed."""
+
+ def __init__(self, parent=None, plot=None, legend=None):
+ self._isConnected = False
+ self._plotRef = None
+ self._colormap = None
+ self._data = None
+
+ super(ColorBarWidget, self).__init__(parent)
+
+ self.__buildGUI()
+ self.setLegend(legend)
+ self.setPlot(plot)
+
+ def __buildGUI(self):
+ self.setLayout(qt.QHBoxLayout())
+
+ # create color scale widget
+ self._colorScale = ColorScaleBar(parent=self, colormap=None)
+ self.layout().addWidget(self._colorScale)
+
+ # legend (is the right group)
+ self.legend = _VerticalLegend("", self)
+ self.layout().addWidget(self.legend)
+
+ self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
+
+ def getPlot(self):
+ """Returns the :class:`Plot` associated to this widget or None"""
+ return None if self._plotRef is None else self._plotRef()
+
+ def setPlot(self, plot):
+ """Associate a plot to the ColorBar
+
+ :param plot: the plot to associate with the colorbar.
+ If None will remove any connection with a previous plot.
+ """
+ self._disconnectPlot()
+ self._plotRef = None if plot is None else weakref.ref(plot)
+ self._connectPlot()
+
+ def _disconnectPlot(self):
+ """Disconnect from Plot signals"""
+ if self._isConnected:
+ self._isConnected = False
+ plot = self.getPlot()
+ if plot is not None and qt_inspect.isValid(plot):
+ plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
+ plot.sigPlotSignal.disconnect(self._defaultColormapChanged)
+
+ def _connectPlot(self):
+ """Connect to Plot signals"""
+ plot = self.getPlot()
+ if plot is not None and not self._isConnected:
+ activeImageLegend = plot.getActiveImage(just_legend=True)
+ activeScatterLegend = plot.getActiveScatter(just_legend=True)
+ if activeImageLegend is None and activeScatterLegend is None:
+ # Show plot default colormap
+ self._syncWithDefaultColormap()
+ elif activeImageLegend is not None: # Show active image colormap
+ self._activeImageChanged(None, activeImageLegend)
+ elif activeScatterLegend is not None: # Show active scatter colormap
+ self._activeScatterChanged(None, activeScatterLegend)
+
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ plot.sigPlotSignal.connect(self._defaultColormapChanged)
+ self._isConnected = True
+
+ def setVisible(self, isVisible):
+ qt.QWidget.setVisible(self, isVisible)
+ self.sigVisibleChanged.emit(isVisible)
+
+ def showEvent(self, event):
+ self._connectPlot()
+
+ def hideEvent(self, event):
+ self._disconnectPlot()
+
+ def getColormap(self):
+ """Returns the colormap displayed in the colorbar.
+
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self.getColorScaleBar().getColormap()
+
+ def setColormap(self, colormap, data=None):
+ """Set the colormap to be displayed.
+
+ :param ~silx.gui.colors.Colormap colormap:
+ The colormap to apply on the ColorBarWidget
+ :param Union[numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
+ The data to display or item, needed if the colormap require an autoscale
+ """
+ self._data = data
+ self.getColorScaleBar().setColormap(colormap=colormap, data=data)
+ if self._colormap is not None:
+ self._colormap.sigChanged.disconnect(self._colormapHasChanged)
+ self._colormap = colormap
+ if self._colormap is not None:
+ self._colormap.sigChanged.connect(self._colormapHasChanged)
+
+ def _colormapHasChanged(self):
+ """handler of the Colormap.sigChanged signal"""
+ assert self._colormap is not None
+ self.setColormap(colormap=self._colormap, data=self._data)
+
+ def setLegend(self, legend):
+ """Set the legend displayed along the colorbar
+
+ :param str legend: The label
+ """
+ if legend is None or legend == "":
+ self.legend.hide()
+ self.legend.setText("")
+ else:
+ assert type(legend) is str
+ self.legend.show()
+ self.legend.setText(legend)
+
+ def getLegend(self):
+ """
+ Returns the legend displayed along the colorbar
+
+ :return: return the legend displayed along the colorbar
+ :rtype: str
+ """
+ return self.legend.text()
+
+ def _activeScatterChanged(self, previous, legend):
+ """Handle plot active scatter changed"""
+ plot = self.getPlot()
+
+ # Do not handle active scatter while there is an image
+ if plot.getActiveImage() is not None:
+ return
+
+ if legend is None: # No active scatter, display no colormap
+ self.setColormap(colormap=None)
+ return
+
+ # Sync with active scatter
+ scatter = plot.getActiveScatter()
+
+ self.setColormap(colormap=scatter.getColormap(), data=scatter)
+
+ def _activeImageChanged(self, previous, legend):
+ """Handle plot active image changed"""
+ plot = self.getPlot()
+
+ if legend is None: # No active image, try with active scatter
+ activeScatterLegend = plot.getActiveScatter(just_legend=True)
+ # No more active image, use active scatter if any
+ self._activeScatterChanged(None, activeScatterLegend)
+ else:
+ # Sync with active image
+ image = plot.getActiveImage()
+
+ # RGB(A) image, display default colormap
+ array = image.getData(copy=False)
+ if array.ndim != 2:
+ self.setColormap(colormap=None)
+ return
+
+ # data image, sync with image colormap
+ # do we need the copy here : used in the case we are changing
+ # vmin and vmax but should have already be done by the plot
+ self.setColormap(colormap=image.getColormap(), data=image)
+
+ def _defaultColormapChanged(self, event):
+ """Handle plot default colormap changed"""
+ if event["event"] == "defaultColormapChanged":
+ plot = self.getPlot()
+ if (
+ plot is not None
+ and plot.getActiveImage() is None
+ and plot.getActiveScatter() is None
+ ):
+ # No active item, take default colormap update into account
+ self._syncWithDefaultColormap()
+
+ def _syncWithDefaultColormap(self):
+ """Update colorbar according to plot default colormap"""
+ self.setColormap(self.getPlot().getDefaultColormap())
+
+ def getColorScaleBar(self):
+ """
+
+ :return: return the :class:`ColorScaleBar` used to display ColorScale
+ and ticks"""
+ return self._colorScale
+
+
+class _VerticalLegend(qt.QLabel):
+ """Display vertically the given text"""
+
+ def __init__(self, text, parent=None):
+ """
+
+ :param text: the legend
+ :param parent: the Qt parent if any
+ """
+ qt.QLabel.__init__(self, text, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+ painter.setFont(self.font())
+
+ painter.translate(0, self.rect().height())
+ painter.rotate(270)
+ newRect = qt.QRect(0, 0, self.rect().height(), self.rect().width())
+
+ painter.drawText(newRect, qt.Qt.AlignHCenter, self.text())
+
+ fm = qt.QFontMetrics(self.font())
+ preferedHeight = fm.width(self.text())
+ preferedWidth = fm.height()
+ self.setFixedWidth(preferedWidth)
+ self.setMinimumHeight(preferedHeight)
+
+
+class ColorScaleBar(qt.QWidget):
+ """This class is making the composition of a :class:`_ColorScale` and a
+ :class:`_TickBar`.
+
+ It is the simplest widget displaying ticks and colormap gradient.
+
+ .. image:: img/colorScaleBar.png
+ :width: 150px
+ :align: center
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> colormap = Colormap(name='gray',
+ ... norm='log',
+ ... vmin=1,
+ ... vmax=100000,
+ ... )
+ >>> colorscale = ColorScaleBar(parent=None,
+ ... colormap=colormap )
+ >>> colorscale.show()
+
+ Initializer parameters :
+
+ :param colormap: the colormap to be displayed
+ :param parent: the Qt parent if any
+ :param displayTicksValues: display the ticks value or only the '-'
+ """
+
+ _TEXT_MARGIN = 5
+ """The tick bar need a margin to display all labels at the correct place.
+ So the ColorScale should have the same margin in order for both to fit"""
+
+ def __init__(self, parent=None, colormap=None, data=None, displayTicksValues=True):
+ super(ColorScaleBar, self).__init__(parent)
+
+ self.minVal = None
+ """Value set to the _minLabel"""
+ self.maxVal = None
+ """Value set to the _maxLabel"""
+
+ self.setLayout(qt.QGridLayout())
+
+ # create the left side group (ColorScale)
+ self.colorScale = _ColorScale(
+ colormap=colormap, data=data, parent=self, margin=ColorScaleBar._TEXT_MARGIN
+ )
+ if colormap:
+ vmin, vmax = colormap.getColormapRange(data)
+ normalizer = colormap._getNormalizer()
+ else:
+ vmin, vmax = colors.DEFAULT_MIN_LIN, colors.DEFAULT_MAX_LIN
+ normalizer = None
+
+ self.tickbar = _TickBar(
+ vmin=vmin,
+ vmax=vmax,
+ normalizer=normalizer,
+ parent=self,
+ displayValues=displayTicksValues,
+ margin=ColorScaleBar._TEXT_MARGIN,
+ )
+
+ self.layout().addWidget(self.tickbar, 1, 0, 1, 1, qt.Qt.AlignRight)
+ self.layout().addWidget(self.colorScale, 1, 1, qt.Qt.AlignLeft)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.layout().setSpacing(0)
+
+ # max label
+ self._maxLabel = qt.QLabel(str(1.0), parent=self)
+ self._maxLabel.setToolTip(str(0.0))
+ self.layout().addWidget(self._maxLabel, 0, 0, 1, 2, qt.Qt.AlignRight)
+
+ # min label
+ self._minLabel = qt.QLabel(str(0.0), parent=self)
+ self._minLabel.setToolTip(str(0.0))
+ self.layout().addWidget(self._minLabel, 2, 0, 1, 2, qt.Qt.AlignRight)
+
+ self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
+ self.layout().setColumnStretch(0, 1)
+ self.layout().setRowStretch(1, 1)
+
+ def getTickBar(self):
+ """
+
+ :return: the instanciation of the :class:`_TickBar`
+ """
+ return self.tickbar
+
+ def getColorScale(self):
+ """
+
+ :return: the instanciation of the :class:`_ColorScale`
+ """
+ return self.colorScale
+
+ def getColormap(self):
+ """
+
+ :returns: the colormap.
+ :rtype: :class:`.Colormap`
+ """
+ return self.colorScale.getColormap()
+
+ def setColormap(self, colormap, data=None):
+ """Set the new colormap to be displayed
+
+ :param Colormap colormap: the colormap to set
+ :param Union[numpy.ndarray,~silx.gui.plot.items.Item] data:
+ The data or item to display, needed if the colormap requires an autoscale
+ """
+ self.colorScale.setColormap(colormap, data)
+
+ if colormap is not None:
+ vmin, vmax = colormap.getColormapRange(data)
+ normalizer = colormap._getNormalizer()
+ else:
+ vmin, vmax = None, None
+ normalizer = None
+
+ self.tickbar.update(vmin=vmin, vmax=vmax, normalizer=normalizer)
+ self._setMinMaxLabels(vmin, vmax)
+
+ def setMinMaxVisible(self, val=True):
+ """Change visibility of the min label and the max label
+
+ :param val: if True, set the labels visible, otherwise set it not visible
+ """
+ self._minLabel.setVisible(val)
+ self._maxLabel.setVisible(val)
+
+ def _updateMinMax(self):
+ """Update the min and max label if we are in the case of the
+ configuration 'minMaxValueOnly'"""
+ if self.minVal is None:
+ text, tooltip = "", ""
+ else:
+ if self.minVal == 0 or 0 <= numpy.log10(abs(self.minVal)) < 7:
+ text = "%.7g" % self.minVal
+ else:
+ text = "%.2e" % self.minVal
+ tooltip = repr(self.minVal)
+
+ self._minLabel.setText(text)
+ self._minLabel.setToolTip(tooltip)
+
+ if self.maxVal is None:
+ text, tooltip = "", ""
+ else:
+ if self.maxVal == 0 or 0 <= numpy.log10(abs(self.maxVal)) < 7:
+ text = "%.7g" % self.maxVal
+ else:
+ text = "%.2e" % self.maxVal
+ tooltip = repr(self.maxVal)
+
+ self._maxLabel.setText(text)
+ self._maxLabel.setToolTip(tooltip)
+
+ def _setMinMaxLabels(self, minVal, maxVal):
+ """Change the value of the min and max labels to be displayed.
+
+ :param minVal: the minimal value of the TickBar (not str)
+ :param maxVal: the maximal value of the TickBar (not str)
+ """
+ # bad hack to try to display has much information as possible
+ self.minVal = minVal
+ self.maxVal = maxVal
+ self._updateMinMax()
+
+ def resizeEvent(self, event):
+ qt.QWidget.resizeEvent(self, event)
+ self._updateMinMax()
+
+
+class _ColorScale(qt.QWidget):
+ """Widget displaying the colormap colorScale.
+
+ Show matching value between the gradient color (from the colormap) at mouse
+ position and value.
+
+ .. image:: img/colorScale.png
+ :width: 20px
+ :align: center
+
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> colormap = Colormap(name='viridis',
+ ... norm='log',
+ ... vmin=1,
+ ... vmax=100000,
+ ... )
+ >>> colorscale = ColorScale(parent=None,
+ ... colormap=colormap)
+ >>> colorscale.show()
+
+ Initializer parameters :
+
+ :param colormap: the colormap to be displayed
+ :param parent: the Qt parent if any
+ :param int margin: the top and left margin to apply.
+ :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
+ The data or item to use for getting the range for autoscale colormap.
+
+ .. warning:: Value drawing will be
+ done at the center of ticks. So if no margin is done your values
+ drawing might not be fully done for extrems values.
+ """
+
+ _NB_CONTROL_POINTS = 256
+
+ def __init__(self, colormap, parent=None, margin=5, data=None):
+ qt.QWidget.__init__(self, parent)
+ self._colormap = None
+ self.margin = margin
+ self.setColormap(colormap, data)
+
+ self.setLayout(qt.QVBoxLayout())
+ self.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Expanding)
+ # needed to get the mouse event without waiting for button click
+ self.setMouseTracking(True)
+ self.setMargin(margin)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self.setMinimumHeight(self._NB_CONTROL_POINTS // 2 + 2 * self.margin)
+ self.setFixedWidth(25)
+
+ def setColormap(self, colormap, data=None):
+ """Set the new colormap to be displayed
+
+ :param dict colormap: the colormap to set
+ :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
+ Optional data for which to compute colormap range.
+ """
+ self._colormap = colormap
+ self.setEnabled(colormap is not None)
+
+ if colormap is None:
+ self.vmin, self.vmax = None, None
+ else:
+ assert colormap.getNormalization() in colors.Colormap.NORMALIZATIONS
+ self.vmin, self.vmax = self._colormap.getColormapRange(data=data)
+ self._updateColorGradient()
+ self.update()
+
+ def getColormap(self):
+ """Returns the colormap
+
+ :rtype: :class:`.Colormap`
+ """
+ return None if self._colormap is None else self._colormap
+
+ def _updateColorGradient(self):
+ """Compute the color gradient"""
+ colormap = self.getColormap()
+ if colormap is None:
+ return
+
+ indices = numpy.linspace(0.0, 1.0, self._NB_CONTROL_POINTS)
+ colors = colormap.getNColors(nbColors=self._NB_CONTROL_POINTS)
+ self._gradient = qt.QLinearGradient(0, 1, 0, 0)
+ self._gradient.setCoordinateMode(qt.QGradient.StretchToDeviceMode)
+ self._gradient.setStops(
+ [(i, qt.QColor(*color)) for i, color in zip(indices, colors)]
+ )
+
+ def paintEvent(self, event):
+ """"""
+ painter = qt.QPainter(self)
+ if self.getColormap() is not None:
+ painter.setBrush(self._gradient)
+ penColor = self.palette().color(qt.QPalette.Active, qt.QPalette.WindowText)
+ else:
+ penColor = self.palette().color(
+ qt.QPalette.Disabled, qt.QPalette.WindowText
+ )
+ painter.setPen(penColor)
+
+ painter.drawRect(
+ qt.QRect(
+ 0, self.margin, self.width() - 1, self.height() - 2 * self.margin - 1
+ )
+ )
+
+ def mouseMoveEvent(self, event):
+ tooltip = str(
+ self.getValueFromRelativePosition(
+ self._getRelativePosition(qt.getMouseEventPosition(event)[1])
+ )
+ )
+ if qt.BINDING == "PyQt5":
+ position = event.globalPos()
+ else: # Qt6
+ position = event.globalPosition().toPoint()
+ qt.QToolTip.showText(position, tooltip, self)
+ super(_ColorScale, self).mouseMoveEvent(event)
+
+ def _getRelativePosition(self, yPixel):
+ """yPixel : pixel position into _ColorScale widget reference"""
+ # widgets are bottom-top referencial but we display in top-bottom referential
+ height = float(self.height() - 2 * self.margin)
+ if height == 0:
+ return 0.0
+ return 1.0 - (yPixel - self.margin) / height
+
+ def getValueFromRelativePosition(self, value):
+ """Return the value in the colorMap from a relative position in the
+ ColorScaleBar (y)
+
+ :param value: float value in [0, 1]
+ :return: the value in [colormap['vmin'], colormap['vmax']]
+ """
+ colormap = self.getColormap()
+ if colormap is None:
+ return
+
+ value = numpy.clip(value, 0.0, 1.0)
+ normalizer = colormap._getNormalizer()
+ normMin, normMax = normalizer.apply(
+ [self.vmin, self.vmax], self.vmin, self.vmax
+ )
+
+ return normalizer.revert(
+ normMin + (normMax - normMin) * value, self.vmin, self.vmax
+ )
+
+ def setMargin(self, margin):
+ """Define the margin to fit with a TickBar object.
+ This is needed since we can only paint on the viewport of the widget.
+ Didn't work with a simple setContentsMargins
+
+ :param int margin: the margin to apply on the top and bottom.
+ """
+ self.margin = int(margin)
+ self.update()
+
+
+class _TickBar(qt.QWidget):
+ """Bar grouping the ticks displayed
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> bar = _TickBar(1, 1000, norm='log', parent=None, displayValues=True)
+ >>> bar.show()
+
+ .. image:: img/tickbar.png
+ :width: 40px
+ :align: center
+
+ :param int vmin: smaller value of the range of values
+ :param int vmax: higher value of the range of values
+ :param normalizer: Normalization object.
+ :param parent: the Qt parent if any
+ :param bool displayValues: if True display the values close to the tick,
+ Otherwise only signal it by '-'
+ :param int nticks: the number of tick we want to display. Should be an
+ unsigned int ot None. If None, let the Tick bar find the optimal
+ number of ticks from the tick density.
+ :param int margin: margin to set on the top and bottom
+ """
+
+ _WIDTH_DISP_VAL = 45
+ """widget width when displayed with ticks labels"""
+ _WIDTH_NO_DISP_VAL = 10
+ """widget width when displayed without ticks labels"""
+ _FONT_SIZE = 10
+ """font size for ticks labels"""
+ _LINE_WIDTH = 10
+ """width of the line to mark a tick"""
+
+ DEFAULT_TICK_DENSITY = 0.015
+
+ def __init__(
+ self,
+ vmin,
+ vmax,
+ normalizer,
+ parent=None,
+ displayValues=True,
+ nticks=None,
+ margin=5,
+ ):
+ super(_TickBar, self).__init__(parent)
+ self.margin = margin
+ self._nticks = None
+ self.ticks = ()
+ self.subTicks = ()
+ self._forcedDisplayType = None
+ self.ticksDensity = _TickBar.DEFAULT_TICK_DENSITY
+
+ self._vmin = vmin
+ self._vmax = vmax
+ self._normalizer = normalizer
+ self.displayValues = displayValues
+ self.setTicksNumber(nticks)
+
+ self.setMargin(margin)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self._resetWidth()
+
+ def setTicksValuesVisible(self, val):
+ self.displayValues = val
+ self._resetWidth()
+
+ def _resetWidth(self):
+ width = self._WIDTH_DISP_VAL if self.displayValues else self._WIDTH_NO_DISP_VAL
+ self.setFixedWidth(width)
+
+ def update(self, vmin, vmax, normalizer):
+ self._vmin = vmin
+ self._vmax = vmax
+ self._normalizer = normalizer
+ self.computeTicks()
+ qt.QWidget.update(self)
+
+ def setMargin(self, margin):
+ """Define the margin to fit with a _ColorScale object.
+ This is needed since we can only paint on the viewport of the widget
+
+ :param int margin: the margin to apply on the top and bottom.
+ """
+ self.margin = margin
+
+ def setTicksNumber(self, nticks):
+ """Set the number of ticks to display.
+
+ :param nticks: the number of tick to be display. Should be an
+ unsigned int ot None. If None, let the :class:`_TickBar` find the
+ optimal number of ticks from the tick density.
+ """
+ self._nticks = nticks
+ self.computeTicks()
+ qt.QWidget.update(self)
+
+ def setTicksDensity(self, density):
+ """If you let :class:`_TickBar` deal with the number of ticks
+ (nticks=None) then you can specify a ticks density to be displayed.
+ """
+ if density < 0.0:
+ raise ValueError("Density should be a positive value")
+ self.ticksDensity = density
+
+ def computeTicks(self):
+ """This function compute ticks values labels. It is called at each
+ update and each resize event.
+ Deal only with linear and log scale.
+ """
+ nticks = self._nticks
+ if nticks is None:
+ nticks = self._getOptimalNbTicks()
+
+ if self._vmin == self._vmax:
+ # No range: no ticks
+ self.ticks = ()
+ self.subTicks = ()
+ elif isinstance(self._normalizer, LogarithmicNormalization):
+ self._computeTicksLog(nticks)
+ else: # Fallback: use linear
+ self._computeTicksLin(nticks)
+
+ # update the form
+ font = qt.QFont()
+ font.setPixelSize(_TickBar._FONT_SIZE)
+
+ self.form = self._getFormat(font)
+
+ def _computeTicksLog(self, nticks):
+ logMin = numpy.log10(self._vmin)
+ logMax = numpy.log10(self._vmax)
+ lowBound, highBound, spacing, self._nfrac = ticklayout.niceNumbersForLog10(
+ logMin, logMax, nticks
+ )
+ self.ticks = numpy.power(10.0, numpy.arange(lowBound, highBound, spacing))
+ if spacing == 1:
+ self.subTicks = ticklayout.computeLogSubTicks(
+ ticks=self.ticks,
+ lowBound=numpy.power(10.0, lowBound),
+ highBound=numpy.power(10.0, highBound),
+ )
+ else:
+ self.subTicks = []
+
+ def resizeEvent(self, event):
+ qt.QWidget.resizeEvent(self, event)
+ self.computeTicks()
+
+ def _computeTicksLin(self, nticks):
+ _min, _max, _spacing, self._nfrac = ticklayout.niceNumbers(
+ self._vmin, self._vmax, nticks
+ )
+
+ self.ticks = numpy.arange(_min, _max, _spacing)
+ self.subTicks = []
+
+ def _getOptimalNbTicks(self):
+ return max(2, int(round(self.ticksDensity * self.rect().height())))
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+ font = painter.font()
+ font.setPixelSize(_TickBar._FONT_SIZE)
+ painter.setFont(font)
+
+ # paint ticks
+ for val in self.ticks:
+ self._paintTick(val, painter, majorTick=True)
+
+ # paint subticks
+ for val in self.subTicks:
+ self._paintTick(val, painter, majorTick=False)
+
+ def _getRelativePosition(self, val):
+ """Return the relative position of val according to min and max value"""
+ if self._normalizer is None:
+ return 0.0
+ normMin, normMax, normVal = self._normalizer.apply(
+ [self._vmin, self._vmax, val], self._vmin, self._vmax
+ )
+
+ if normMin == normMax:
+ return 0.0
+ if not numpy.isfinite(normVal):
+ return 0.0
+ return 1.0 - (normVal - normMin) / (normMax - normMin)
+
+ def _paintTick(self, val, painter, majorTick=True):
+ """
+
+ :param bool majorTick: if False will never draw text and will set a line
+ with a smaller width
+ """
+ fm = qt.QFontMetrics(painter.font())
+ viewportHeight = self.rect().height() - self.margin * 2 - 1
+ relativePos = self._getRelativePosition(val)
+ height = int(viewportHeight * relativePos + self.margin)
+ lineWidth = _TickBar._LINE_WIDTH
+ if majorTick is False:
+ lineWidth /= 2
+
+ painter.drawLine(
+ qt.QLine(int(self.width() - lineWidth), height, self.width(), height)
+ )
+
+ if self.displayValues and majorTick is True:
+ painter.drawText(
+ qt.QPoint(0, int(height + fm.height() / 2)), self.form.format(val)
+ )
+
+ def setDisplayType(self, disType):
+ """Set the type of display we want to set for ticks labels
+
+ :param str disType: The type of display we want to set. disType values
+ can be :
+
+ - 'std' for standard, meaning only a formatting on the number of
+ digits is done
+ - 'e' for scientific display
+ - None to let the _TickBar guess the best display for this kind of data.
+ """
+ if disType not in (None, "std", "e"):
+ raise ValueError(
+ "display type not recognized, value should be in (None, 'std', 'e'"
+ )
+ self._forcedDisplayType = disType
+
+ def _getStandardFormat(self):
+ return "{0:.%sf}" % self._nfrac
+
+ def _getFormat(self, font):
+ if self._forcedDisplayType is None:
+ return self._guessType(font)
+ elif self._forcedDisplayType == "std":
+ return self._getStandardFormat()
+ elif self._forcedDisplayType == "e":
+ return self._getScientificForm()
+ else:
+ err = (
+ "Forced type for display %s is not recognized" % self._forcedDisplayType
+ )
+ raise ValueError(err)
+
+ def _getScientificForm(self):
+ return "{0:.0e}"
+
+ def _guessType(self, font):
+ """Try fo find the better format to display the tick's labels
+
+ :param QFont font: the font we want to use during the painting
+ """
+ form = self._getStandardFormat()
+
+ fm = qt.QFontMetrics(font)
+ width = 0
+ for tick in self.ticks:
+ width = max(fm.boundingRect(form.format(tick)).width(), width)
+
+ # if the length of the string are too long we are moving to scientific
+ # display
+ if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH:
+ return self._getScientificForm()
+ else:
+ return form
diff --git a/src/silx/gui/plot/CompareImages.py b/src/silx/gui/plot/CompareImages.py
new file mode 100644
index 0000000..3823ae2
--- /dev/null
+++ b/src/silx/gui/plot/CompareImages.py
@@ -0,0 +1,954 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A widget dedicated to compare 2 images.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/07/2018"
+
+
+import logging
+import numpy
+import math
+
+import silx.image.bilinear
+from silx.gui import qt
+from silx.gui import plot
+from silx.gui.colors import Colormap
+from silx.gui.plot import tools
+from silx.utils.deprecation import deprecated_warning
+from silx.utils.weakref import WeakMethodProxy
+from silx.gui.plot.items import Scatter
+from silx.math.colormap import normalize
+
+from .tools.compare.core import sift
+from .tools.compare.core import VisualizationMode
+from .tools.compare.core import AlignmentMode
+from .tools.compare.core import AffineTransformation
+from .tools.compare.toolbar import CompareImagesToolBar
+from .tools.compare.statusbar import CompareImagesStatusBar
+from .tools.compare.core import _CompareImageItem
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CompareImages(qt.QMainWindow):
+ """Widget providing tools to compare 2 images.
+
+ .. image:: img/CompareImages.png
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ :param backend: The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ VisualizationMode = VisualizationMode
+ """Available visualization modes"""
+
+ AlignmentMode = AlignmentMode
+ """Available alignment modes"""
+
+ sigConfigurationChanged = qt.Signal()
+ """Emitted when the configuration of the widget (visualization mode,
+ alignment mode...) have changed."""
+
+ def __init__(self, parent=None, backend=None):
+ qt.QMainWindow.__init__(self, parent)
+ self._resetZoomActive = True
+ self._colormap = Colormap()
+ """Colormap shared by all modes, except the compose images (rgb image)"""
+ self._colormapKeyPoints = Colormap("spring")
+ """Colormap used for sift keypoints"""
+
+ self._colormap.sigChanged.connect(self.__colormapChanged)
+
+ if parent is None:
+ self.setWindowTitle("Compare images")
+ else:
+ self.setWindowFlags(qt.Qt.Widget)
+
+ self.__transformation = None
+ self.__item = _CompareImageItem()
+ self.__item.setName("_virtual")
+ self.__item.setColormap(self._colormap)
+
+ self.__raw1 = None
+ self.__raw2 = None
+ self.__data1 = None
+ self.__data2 = None
+ self.__previousSeparatorPosition = None
+
+ self.__plot = plot.PlotWidget(parent=self, backend=backend)
+ self.__plot.setDefaultColormap(self._colormap)
+ self.__plot.getXAxis().setLabel("Columns")
+ self.__plot.getYAxis().setLabel("Rows")
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
+ self.__plot.getYAxis().setInverted(True)
+ self.__plot.addItem(self.__item)
+ self.__plot.setActiveImage(self.__item)
+
+ self.__plot.setKeepDataAspectRatio(True)
+ self.__plot.sigPlotSignal.connect(self.__plotSlot)
+ self.__plot.setAxesDisplayed(False)
+
+ self.__scatter = Scatter()
+ self.__scatter.setZValue(1)
+ self.__scatter.setColormap(self._colormapKeyPoints)
+ self.__plot.addItem(self.__scatter)
+
+ self.setCentralWidget(self.__plot)
+
+ legend = VisualizationMode.VERTICAL_LINE.name
+ self.__plot.addXMarker(
+ 0,
+ legend=legend,
+ text="",
+ draggable=True,
+ color="blue",
+ constraint=WeakMethodProxy(self.__separatorConstraint),
+ )
+ self.__vline = self.__plot._getMarker(legend)
+
+ legend = VisualizationMode.HORIZONTAL_LINE.name
+ self.__plot.addYMarker(
+ 0,
+ legend=legend,
+ text="",
+ draggable=True,
+ color="blue",
+ constraint=WeakMethodProxy(self.__separatorConstraint),
+ )
+ self.__hline = self.__plot._getMarker(legend)
+
+ # default values
+ self.__visualizationMode = ""
+ self.__alignmentMode = ""
+ self.__keypointsVisible = True
+
+ self.setAlignmentMode(AlignmentMode.ORIGIN)
+ self.setVisualizationMode(VisualizationMode.VERTICAL_LINE)
+ self.setKeypointsVisible(False)
+
+ # Toolbars
+
+ self._createToolBars(self.__plot)
+ if self._interactiveModeToolBar is not None:
+ self.addToolBar(self._interactiveModeToolBar)
+ if self._imageToolBar is not None:
+ self.addToolBar(self._imageToolBar)
+ if self._compareToolBar is not None:
+ self.addToolBar(self._compareToolBar)
+
+ # Statusbar
+
+ self._createStatusBar(self.__plot)
+ if self._statusBar is not None:
+ self.setStatusBar(self._statusBar)
+
+ def __getSealedColormap(self):
+ vrange = self._colormap.getColormapRange(
+ self.__item.getColormappedData(copy=False)
+ )
+ sealed = self._colormap.copy()
+ sealed.setVRange(*vrange)
+ return sealed
+
+ def __colormapChanged(self):
+ sealed = self.__getSealedColormap()
+ if self.__image1 is not None:
+ if self.__getImageMode(self.__image1.getData(copy=False)) == "intensity":
+ self.__image1.setColormap(sealed)
+ if self.__image2 is not None:
+ if self.__getImageMode(self.__image2.getData(copy=False)) == "intensity":
+ self.__image2.setColormap(sealed)
+
+ if "COMPOSITE" in self.__visualizationMode.name:
+ self.__updateData()
+
+ def _createStatusBar(self, plot):
+ self._statusBar = CompareImagesStatusBar(self)
+ self._statusBar.setCompareWidget(self)
+
+ def _createToolBars(self, plot):
+ """Create tool bars displayed by the widget"""
+ toolBar = tools.InteractiveModeToolBar(parent=self, plot=plot)
+ self._interactiveModeToolBar = toolBar
+ toolBar = tools.ImageToolBar(parent=self, plot=plot)
+ self._imageToolBar = toolBar
+ toolBar = CompareImagesToolBar(self)
+ toolBar.setCompareWidget(self)
+ self._compareToolBar = toolBar
+
+ def _getVirtualPlotItem(self):
+ return self.__item
+
+ def getPlot(self):
+ """Returns the plot which is used to display the images.
+
+ :rtype: silx.gui.plot.PlotWidget
+ """
+ return self.__plot
+
+ def getColormap(self):
+ """
+
+ :return: colormap used for compare image
+ :rtype: silx.gui.colors.Colormap
+ """
+ return self._colormap
+
+ def getRawPixelData(self, x, y):
+ """Return the raw pixel of each image data from axes positions.
+
+ If the coordinate is outside of the image it returns None element in
+ the tuple.
+
+ The pixel is reach from the raw data image without filter or
+ transformation. But the coordinate x and y are in the reference of the
+ current displayed mode.
+
+ :param float x: X-coordinate of the pixel in the current displayed plot
+ :param float y: Y-coordinate of the pixel in the current displayed plot
+ :return: A tuple of for each images containing pixel information. It
+ could be a scalar value or an array in case of RGB/RGBA informations.
+ It also could be a string containing information is some cases.
+ :rtype: Tuple(Union[int,float,numpy.ndarray,str],Union[int,float,numpy.ndarray,str])
+ """
+ alignmentMode = self.__alignmentMode
+ raw1, raw2 = self.__raw1, self.__raw2
+
+ if raw1 is None or raw2 is None:
+ x1 = x
+ y1 = y
+ x2 = x
+ y2 = y
+ elif alignmentMode == AlignmentMode.ORIGIN:
+ x1 = x
+ y1 = y
+ x2 = x
+ y2 = y
+ elif alignmentMode == AlignmentMode.CENTER:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ x1 = x - (xx - raw1.shape[1]) * 0.5
+ x2 = x - (xx - raw2.shape[1]) * 0.5
+ y1 = y - (yy - raw1.shape[0]) * 0.5
+ y2 = y - (yy - raw2.shape[0]) * 0.5
+ elif alignmentMode == AlignmentMode.STRETCH:
+ x1 = x
+ y1 = y
+ x2 = x * raw2.shape[1] / raw1.shape[1]
+ y2 = x * raw2.shape[1] / raw1.shape[1]
+ elif alignmentMode == AlignmentMode.AUTO:
+ x1 = x
+ y1 = y
+ # Not implemented
+ x2 = -1
+ y2 = -1
+ else:
+ assert False
+
+ x1, y1 = int(x1), int(y1)
+ x2, y2 = int(x2), int(y2)
+
+ if raw1 is None:
+ data1 = "No image A"
+ elif y1 < 0 or y1 >= raw1.shape[0] or x1 < 0 or x1 >= raw1.shape[1]:
+ data1 = ""
+ else:
+ data1 = raw1[y1, x1]
+
+ if raw2 is None:
+ data2 = "No image B"
+ elif alignmentMode == AlignmentMode.AUTO:
+ data2 = "Not implemented with sift"
+ elif y2 < 0 or y2 >= raw2.shape[0] or x2 < 0 or x2 >= raw2.shape[1]:
+ data2 = None
+ else:
+ data2 = raw2[y2, x2]
+
+ return data1, data2
+
+ def setVisualizationMode(self, mode):
+ """Set the visualization mode.
+
+ :param str mode: New visualization to display the image comparison
+ """
+ if self.__visualizationMode == mode:
+ return
+ self.__visualizationMode = mode
+ self.__item.setVizualisationMode(mode)
+ self.__vline.setVisible(mode == VisualizationMode.VERTICAL_LINE)
+ self.__hline.setVisible(mode == VisualizationMode.HORIZONTAL_LINE)
+ self.__updateData()
+ self.sigConfigurationChanged.emit()
+
+ def centerLines(self):
+ """Center the line used to compare the 2 images."""
+ if self.__image1 is None:
+ return
+ data_range = self.__plot.getDataRange()
+
+ if data_range[0] is not None:
+ cx = (data_range[0][0] + data_range[0][1]) * 0.5
+ else:
+ cx = 0
+ if data_range[1] is not None:
+ cy = (data_range[1][0] + data_range[1][1]) * 0.5
+ else:
+ cy = 0
+ self.__vline.setPosition(cx, cy)
+ self.__hline.setPosition(cx, cy)
+ self.__updateSeparators()
+
+ def getVisualizationMode(self):
+ """Returns the current interaction mode."""
+ return self.__visualizationMode
+
+ def setAlignmentMode(self, mode):
+ """Set the alignment mode.
+
+ :param str mode: New alignement to apply to images
+ """
+ if self.__alignmentMode == mode:
+ return
+ self.__alignmentMode = mode
+ self.__updateData()
+ self.sigConfigurationChanged.emit()
+
+ def getAlignmentMode(self):
+ """Returns the current selected alignemnt mode."""
+ return self.__alignmentMode
+
+ def getKeypointsVisible(self):
+ """Returns true if the keypoints are displayed"""
+ return self.__keypointsVisible
+
+ def setKeypointsVisible(self, isVisible):
+ """Set keypoints visibility.
+
+ :param bool isVisible: If True, keypoints are displayed (if some)
+ """
+ if self.__keypointsVisible == isVisible:
+ return
+ self.__keypointsVisible = isVisible
+ self.__updateKeyPoints()
+ self.sigConfigurationChanged.emit()
+
+ def __setDefaultAlignmentMode(self):
+ """Reset the alignemnt mode to the default value"""
+ self.setAlignmentMode(AlignmentMode.ORIGIN)
+
+ def __plotSlot(self, event):
+ """Handle events from the plot"""
+ if event["event"] in ("markerMoving", "markerMoved"):
+ mode = self.getVisualizationMode()
+ legend = mode.name
+ if event["label"] == legend:
+ if mode == VisualizationMode.VERTICAL_LINE:
+ value = int(float(str(event["xdata"])))
+ elif mode == VisualizationMode.HORIZONTAL_LINE:
+ value = int(float(str(event["ydata"])))
+ else:
+ assert False
+ if self.__previousSeparatorPosition != value:
+ self.__separatorMoved(value)
+ self.__previousSeparatorPosition = value
+
+ def __separatorConstraint(self, x, y):
+ """Manage contains on the separators to clamp them inside the images."""
+ if self.__data1 is None:
+ return 0, 0
+ x = int(x)
+ if x < 0:
+ x = 0
+ elif x > self.__data1.shape[1]:
+ x = self.__data1.shape[1]
+ y = int(y)
+ if y < 0:
+ y = 0
+ elif y > self.__data1.shape[0]:
+ y = self.__data1.shape[0]
+ return x, y
+
+ def __updateSeparators(self):
+ """Redraw images according to the current state of the separators."""
+ mode = self.getVisualizationMode()
+ if mode == VisualizationMode.VERTICAL_LINE:
+ pos = self.__vline.getXPosition()
+ self.__separatorMoved(pos)
+ self.__previousSeparatorPosition = pos
+ elif mode == VisualizationMode.HORIZONTAL_LINE:
+ pos = self.__hline.getYPosition()
+ self.__separatorMoved(pos)
+ self.__previousSeparatorPosition = pos
+ else:
+ self.__image1.setOrigin((0, 0))
+ if self.__image2 is not None:
+ self.__image2.setOrigin((0, 0))
+
+ def __separatorMoved(self, pos):
+ """Called when vertical or horizontal separators have moved.
+
+ Update the displayed images.
+ """
+ if self.__data1 is None:
+ return
+
+ mode = self.getVisualizationMode()
+ if mode == VisualizationMode.VERTICAL_LINE:
+ pos = int(pos)
+ if pos <= 0:
+ pos = 0
+ elif pos >= self.__data1.shape[1]:
+ pos = self.__data1.shape[1]
+ data1 = self.__data1[:, 0:pos]
+ data2 = self.__data2[:, pos:]
+ self.__image1.setData(data1, copy=False)
+ if self.__image2 is not None:
+ self.__image2.setData(data2, copy=False)
+ self.__image2.setOrigin((pos, 0))
+ elif mode == VisualizationMode.HORIZONTAL_LINE:
+ pos = int(pos)
+ if pos <= 0:
+ pos = 0
+ elif pos >= self.__data1.shape[0]:
+ pos = self.__data1.shape[0]
+ data1 = self.__data1[0:pos, :]
+ data2 = self.__data2[pos:, :]
+ self.__image1.setData(data1, copy=False)
+ if self.__image2 is not None:
+ self.__image2.setData(data2, copy=False)
+ self.__image2.setOrigin((0, pos))
+ else:
+ assert False
+
+ def clear(self):
+ self.setData(None, None)
+
+ def setData(self, image1, image2, updateColormap="deprecated"):
+ """Set images to compare.
+
+ Images can contains floating-point or integer values, or RGB and RGBA
+ values, but should have comparable intensities.
+
+ RGB and RGBA images are provided as an array as `[width,height,channels]`
+ of unsigned integer 8-bits or floating-points between 0.0 to 1.0.
+
+ :param numpy.ndarray image1: The first image
+ :param numpy.ndarray image2: The second image
+ """
+ if updateColormap != "deprecated":
+ deprecated_warning(
+ "Argument", "setData's updateColormap argument", since_version="2.0.0"
+ )
+
+ self.__raw1 = image1
+ self.__raw2 = image2
+ self.__updateData()
+ if self.isAutoResetZoom():
+ self.__plot.resetZoom()
+
+ def setImage1(self, image1, updateColormap="deprecated"):
+ """Set image1 to be compared.
+
+ Images can contains floating-point or integer values, or RGB and RGBA
+ values, but should have comparable intensities.
+
+ RGB and RGBA images are provided as an array as `[width,height,channels]`
+ of unsigned integer 8-bits or floating-points between 0.0 to 1.0.
+
+ :param numpy.ndarray image1: The first image
+ """
+ if updateColormap != "deprecated":
+ deprecated_warning(
+ "Argument", "setImage1's updateColormap argument", since_version="2.0.0"
+ )
+
+ self.__raw1 = image1
+ self.__updateData()
+ if self.isAutoResetZoom():
+ self.__plot.resetZoom()
+
+ def setImage2(self, image2, updateColormap="deprecated"):
+ """Set image2 to be compared.
+
+ Images can contains floating-point or integer values, or RGB and RGBA
+ values, but should have comparable intensities.
+
+ RGB and RGBA images are provided as an array as `[width,height,channels]`
+ of unsigned integer 8-bits or floating-points between 0.0 to 1.0.
+
+ :param numpy.ndarray image2: The second image
+ """
+ if updateColormap != "deprecated":
+ deprecated_warning(
+ "Argument", "setImage2's updateColormap argument", since_version="2.0.0"
+ )
+
+ self.__raw2 = image2
+ self.__updateData()
+ if self.isAutoResetZoom():
+ self.__plot.resetZoom()
+
+ def __updateKeyPoints(self):
+ """Update the displayed keypoints using cached keypoints."""
+ if self.__keypointsVisible and self.__matching_keypoints:
+ data = self.__matching_keypoints
+ else:
+ data = [], [], []
+ self.__scatter.setData(x=data[0], y=data[1], value=data[2])
+
+ def __updateData(self):
+ """Compute aligned image when the alignment mode changes.
+
+ This function cache input images which are used when
+ vertical/horizontal separators moves.
+ """
+ raw1, raw2 = self.__raw1, self.__raw2
+
+ alignmentMode = self.getAlignmentMode()
+ self.__transformation = None
+
+ if raw1 is None or raw2 is None:
+ # No need to realign the 2 images
+ # But create a dummy image when there is None for simplification
+ if raw1 is None:
+ data1 = numpy.empty((0, 0))
+ else:
+ data1 = raw1
+ if raw2 is None:
+ data2 = numpy.empty((0, 0))
+ else:
+ data2 = raw2
+ self.__matching_keypoints = None
+ else:
+ if alignmentMode == AlignmentMode.ORIGIN:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size, transparent=True)
+ data2 = self.__createMarginImage(raw2, size, transparent=True)
+ self.__matching_keypoints = [0.0], [0.0], [1.0]
+ elif alignmentMode == AlignmentMode.CENTER:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(
+ raw1, size, transparent=True, center=True
+ )
+ data2 = self.__createMarginImage(
+ raw2, size, transparent=True, center=True
+ )
+ self.__matching_keypoints = (
+ [data1.shape[1] // 2],
+ [data1.shape[0] // 2],
+ [1.0],
+ )
+ elif alignmentMode == AlignmentMode.STRETCH:
+ data1 = raw1
+ data2 = self.__rescaleImage(raw2, data1.shape)
+ self.__matching_keypoints = (
+ [0, data1.shape[1], data1.shape[1], 0],
+ [0, 0, data1.shape[0], data1.shape[0]],
+ [1.0, 1.0, 1.0, 1.0],
+ )
+ elif alignmentMode == AlignmentMode.AUTO:
+ # TODO: sift implementation do not support RGBA images
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size)
+ data2 = self.__createMarginImage(raw2, size)
+ self.__matching_keypoints = [0.0], [0.0], [1.0]
+ try:
+ data1, data2 = self.__createSiftData(data1, data2)
+ if data2 is None:
+ raise ValueError("Unexpected None value")
+ except Exception as e:
+ # TODO: Display it on the GUI
+ _logger.error(e)
+ self.__setDefaultAlignmentMode()
+ return
+ else:
+ assert False
+
+ self.__item.setImageData1(data1)
+ self.__item.setImageData2(data2)
+
+ mode = self.getVisualizationMode()
+ if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
+ data1 = self.__composeRgbImage(data1, data2, mode)
+ data2 = None
+ elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
+ data1 = self.__composeRgbImage(data1, data2, mode)
+ data2 = None
+ elif mode == VisualizationMode.COMPOSITE_A_MINUS_B:
+ data1 = self.__composeAMinusBImage(data1, data2)
+ data2 = None
+ elif mode == VisualizationMode.ONLY_A:
+ data2 = None
+ elif mode == VisualizationMode.ONLY_B:
+ data1 = numpy.empty((0, 0))
+
+ self.__data1, self.__data2 = data1, data2
+
+ colormap = self.__getSealedColormap()
+ mode1 = self.__getImageMode(self.__data1)
+ if mode1 == "intensity":
+ colormap1 = colormap
+ else:
+ colormap1 = None
+ self.__plot.addImage(
+ data1, z=0, legend="image1", resetzoom=False, colormap=colormap1
+ )
+ self.__image1 = self.__plot.getImage("image1")
+
+ if data2 is not None:
+ mode2 = self.__getImageMode(data2)
+ if mode2 == "intensity":
+ colormap2 = colormap
+ else:
+ colormap2 = None
+ self.__plot.addImage(
+ data2, z=0, legend="image2", resetzoom=False, colormap=colormap2
+ )
+ self.__image2 = self.__plot.getImage("image2")
+ self.__image2.setVisible(True)
+ else:
+ if self.__image2 is not None:
+ self.__image2.setVisible(False)
+ self.__image2 = None
+ self.__data2 = numpy.empty((0, 0))
+ self.__updateKeyPoints()
+
+ # Set the separator into the middle
+ if self.__previousSeparatorPosition is None:
+ value = self.__data1.shape[1] // 2
+ self.__vline.setPosition(value, 0)
+ value = self.__data1.shape[0] // 2
+ self.__hline.setPosition(0, value)
+ self.__updateSeparators()
+
+ def __getImageMode(self, image):
+ """Returns a value identifying the way the image is stored in the
+ array.
+
+ :param numpy.ndarray image: Image to check
+ :rtype: str
+ """
+ if len(image.shape) == 2:
+ return "intensity"
+ elif len(image.shape) == 3:
+ if image.shape[2] == 3:
+ return "rgb"
+ elif image.shape[2] == 4:
+ return "rgba"
+ raise TypeError("'image' argument is not an image.")
+
+ def __rescaleImage(self, image, shape):
+ """Rescale an image to the requested shape.
+
+ :rtype: numpy.ndarray
+ """
+ mode = self.__getImageMode(image)
+ if mode == "intensity":
+ data = self.__rescaleArray(image, shape)
+ elif mode == "rgb":
+ data = numpy.empty((shape[0], shape[1], 3), dtype=image.dtype)
+ for c in range(3):
+ data[:, :, c] = self.__rescaleArray(image[:, :, c], shape)
+ elif mode == "rgba":
+ data = numpy.empty((shape[0], shape[1], 4), dtype=image.dtype)
+ for c in range(4):
+ data[:, :, c] = self.__rescaleArray(image[:, :, c], shape)
+ return data
+
+ def __composeRgbImage(self, data1, data2, mode):
+ """Returns an RBG image containing composition of data1 and data2 in 2
+ different channels
+
+ A data image of a size of 0 is considered as missing. This does not
+ interrupt the processing.
+
+ :param numpy.ndarray data1: First image
+ :param numpy.ndarray data1: Second image
+ :param VisualizationMode mode: Composition mode.
+ :rtype: numpy.ndarray
+ """
+ if data1.size != 0 and data2.size != 0:
+ assert data1.shape[0:2] == data2.shape[0:2]
+
+ sealed = self.__getSealedColormap()
+ vmin, vmax = sealed.getVRange()
+
+ if data1.size == 0:
+ intensity1 = numpy.zeros(data2.shape[0:2])
+ else:
+ mode1 = self.__getImageMode(data1)
+ if mode1 in ["rgb", "rgba"]:
+ intensity1 = self.__luminosityImage(data1)
+ else:
+ intensity1 = data1
+
+ if data2.size == 0:
+ intensity2 = numpy.zeros(data1.shape[0:2])
+ else:
+ mode2 = self.__getImageMode(data2)
+ if mode2 in ["rgb", "rgba"]:
+ intensity2 = self.__luminosityImage(data2)
+ else:
+ intensity2 = data2
+
+ shape = intensity1.shape
+ result = numpy.empty((shape[0], shape[1], 3), dtype=numpy.uint8)
+ a, _, _ = normalize(
+ intensity1,
+ norm=sealed.getNormalization(),
+ autoscale=sealed.getAutoscaleMode(),
+ vmin=sealed.getVMin(),
+ vmax=sealed.getVMax(),
+ gamma=sealed.getGammaNormalizationParameter(),
+ )
+ b, _, _ = normalize(
+ intensity2,
+ norm=sealed.getNormalization(),
+ autoscale=sealed.getAutoscaleMode(),
+ vmin=sealed.getVMin(),
+ vmax=sealed.getVMax(),
+ gamma=sealed.getGammaNormalizationParameter(),
+ )
+ if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
+ result[:, :, 0] = a
+ result[:, :, 1] = a // 2 + b // 2
+ result[:, :, 2] = b
+ elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
+ result[:, :, 0] = 255 - b
+ result[:, :, 1] = 255 - (a // 2 + b // 2)
+ result[:, :, 2] = 255 - a
+ return result
+
+ def __composeAMinusBImage(self, data1, data2):
+ """Returns an intensity image containing the composition of `A-B`.
+
+ A data image of a size of 0 is considered as missing. This does not
+ interrupt the processing.
+
+ :param numpy.ndarray data1: First image
+ :param numpy.ndarray data1: Second image
+ :rtype: numpy.ndarray
+ """
+ if data1.size != 0 and data2.size != 0:
+ assert data1.shape[0:2] == data2.shape[0:2]
+
+ data1 = self.__asIntensityImage(data1)
+ data2 = self.__asIntensityImage(data2)
+ if data1.size == 0:
+ result = data2
+ elif data2.size == 0:
+ result = data1
+ else:
+ result = data1.astype(numpy.float32) - data2.astype(numpy.float32)
+ return result
+
+ def __asIntensityImage(self, image: numpy.ndarray):
+ """Returns an intensity image.
+
+ If the image use a single channel, it will be returned as it is.
+
+ If the image is an RBG(A) image, the luminosity (0..1) is extracted and
+ returned. The alpha channel is ignored.
+
+ :rtype: numpy.ndarray
+ """
+ mode = self.__getImageMode(image)
+ if mode in ["rgb", "rgba"]:
+ return self.__luminosityImage(image)
+ return image
+
+ def __luminosityImage(self, image: numpy.ndarray):
+ """Returns the luminosity channel from an RBG(A) image.
+
+ The alpha channel is ignored.
+
+ :rtype: numpy.ndarray
+ """
+ mode = self.__getImageMode(image)
+ assert mode in ["rgb", "rgba"]
+ is_uint8 = image.dtype.type == numpy.uint8
+ # luminosity
+ image = 0.21 * image[..., 0] + 0.72 * image[..., 1] + 0.07 * image[..., 2]
+ if is_uint8:
+ image = image / 255.0
+ return image
+
+ def __rescaleArray(self, image, shape):
+ """Rescale a 2D array to the requested shape.
+
+ :rtype: numpy.ndarray
+ """
+ y, x = numpy.ogrid[: shape[0], : shape[1]]
+ y, x = y * 1.0 * (image.shape[0] - 1) / (shape[0] - 1), x * 1.0 * (
+ image.shape[1] - 1
+ ) / (shape[1] - 1)
+ b = silx.image.bilinear.BilinearImage(image)
+ # TODO: could be optimized using strides
+ x2d = numpy.zeros_like(y) + x
+ y2d = numpy.zeros_like(x) + y
+ result = b.map_coordinates((y2d, x2d))
+ return result
+
+ def __createMarginImage(self, image, size, transparent=False, center=False):
+ """Returns a new image with margin to respect the requested size.
+
+ :rtype: numpy.ndarray
+ """
+ assert image.shape[0] <= size[0]
+ assert image.shape[1] <= size[1]
+ if image.shape == size:
+ return image
+ mode = self.__getImageMode(image)
+
+ if center:
+ pos0 = size[0] // 2 - image.shape[0] // 2
+ pos1 = size[1] // 2 - image.shape[1] // 2
+ else:
+ pos0, pos1 = 0, 0
+
+ if mode == "intensity":
+ data = numpy.zeros(size, dtype=image.dtype)
+ data[pos0 : pos0 + image.shape[0], pos1 : pos1 + image.shape[1]] = image
+ # TODO: It is maybe possible to put NaN on the margin
+ else:
+ if transparent:
+ data = numpy.zeros((size[0], size[1], 4), dtype=numpy.uint8)
+ else:
+ data = numpy.zeros((size[0], size[1], 3), dtype=numpy.uint8)
+ depth = min(data.shape[2], image.shape[2])
+ data[
+ pos0 : pos0 + image.shape[0], pos1 : pos1 + image.shape[1], 0:depth
+ ] = image[:, :, 0:depth]
+ if transparent and depth == 3:
+ data[
+ pos0 : pos0 + image.shape[0], pos1 : pos1 + image.shape[1], 3
+ ] = 255
+ return data
+
+ def __toAffineTransformation(self, sift_result):
+ """Returns an affine transformation from the sift result.
+
+ :param dict sift_result: Result of sift when using `all_result=True`
+ :rtype: AffineTransformation
+ """
+ offset = sift_result["offset"]
+ matrix = sift_result["matrix"]
+
+ tx = offset[0]
+ ty = offset[1]
+ a = matrix[0, 0]
+ b = matrix[0, 1]
+ c = matrix[1, 0]
+ d = matrix[1, 1]
+ rot = math.atan2(-b, a)
+ sx = (-1.0 if a < 0 else 1.0) * math.sqrt(a**2 + b**2)
+ sy = (-1.0 if d < 0 else 1.0) * math.sqrt(c**2 + d**2)
+ return AffineTransformation(tx, ty, sx, sy, rot)
+
+ def getTransformation(self):
+ """Returns the affine transformation applied to the second image to align
+ it to the first image.
+
+ This result is only valid for sift alignment.
+
+ :rtype: Union[None,AffineTransformation]
+ """
+ return self.__transformation
+
+ def __createSiftData(self, image, second_image):
+ """Generate key points and aligned images from 2 images.
+
+ If no keypoints matches, unaligned data are anyway returns.
+
+ :rtype: Tuple(numpy.ndarray,numpy.ndarray)
+ """
+ devicetype = "GPU"
+
+ # Compute base image
+ sift_ocl = sift.SiftPlan(template=image, devicetype=devicetype)
+ keypoints = sift_ocl(image)
+
+ # Check image compatibility
+ second_keypoints = sift_ocl(second_image)
+ mp = sift.MatchPlan()
+ match = mp(keypoints, second_keypoints)
+ _logger.info("Number of Keypoints within image 1: %i" % keypoints.size)
+ _logger.info(" within image 2: %i" % second_keypoints.size)
+
+ self.__matching_keypoints = (
+ match[:].x[:, 0],
+ match[:].y[:, 0],
+ match[:].scale[:, 0],
+ )
+ matching_keypoints = match.shape[0]
+ _logger.info("Matching keypoints: %i" % matching_keypoints)
+ if matching_keypoints == 0:
+ return image, second_image
+
+ # TODO: Problem here is we have to compute 2 time sift
+ # The first time to extract matching keypoints, second time
+ # to extract the aligned image.
+
+ # Normalize the second image
+ sa = sift.LinearAlign(image, devicetype=devicetype)
+ data1 = image
+ # TODO: Create a sift issue: if data1 is RGB and data2 intensity
+ # it returns None, while extracting manually keypoints (above) works
+ result = sa.align(second_image, return_all=True)
+ data2 = result["result"]
+ self.__transformation = self.__toAffineTransformation(result)
+ return data1, data2
+
+ def resetZoom(self, dataMargins=None):
+ """Reset the plot limits to the bounds of the data and redraw the plot."""
+ self.__plot.resetZoom(dataMargins)
+
+ def setAutoResetZoom(self, activate=True):
+ """
+
+ :param bool activate: True if we want to activate the automatic
+ plot reset zoom when setting images.
+ """
+ self._resetZoomActive = activate
+
+ def isAutoResetZoom(self):
+ """
+
+ :return: True if the automatic call to resetzoom is activated
+ :rtype: bool
+ """
+ return self._resetZoomActive
diff --git a/src/silx/gui/plot/ComplexImageView.py b/src/silx/gui/plot/ComplexImageView.py
new file mode 100644
index 0000000..654a1c1
--- /dev/null
+++ b/src/silx/gui/plot/ComplexImageView.py
@@ -0,0 +1,508 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget to view 2D complex data.
+
+The :class:`ComplexImageView` widget is dedicated to visualize a single 2D dataset
+of complex data.
+"""
+
+__authors__ = ["Vincent Favre-Nicolin", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+import numpy
+
+from .. import qt, icons
+from .PlotWindow import Plot2D
+from . import items
+from .items import ImageComplexData
+from silx.gui.widgets.FloatEdit import FloatEdit
+
+_logger = logging.getLogger(__name__)
+
+
+# Widgets
+
+
+class _AmplitudeRangeDialog(qt.QDialog):
+ """QDialog asking for the amplitude range to display."""
+
+ sigRangeChanged = qt.Signal(tuple)
+ """Signal emitted when the range has changed.
+
+ It provides the new range as a 2-tuple: (max, delta)
+ """
+
+ def __init__(self, parent=None, amplitudeRange=None, displayedRange=(None, 2)):
+ super(_AmplitudeRangeDialog, self).__init__(parent)
+ self.setWindowTitle("Set Displayed Amplitude Range")
+
+ if amplitudeRange is not None:
+ amplitudeRange = min(amplitudeRange), max(amplitudeRange)
+ self._amplitudeRange = amplitudeRange
+ self._defaultDisplayedRange = displayedRange
+
+ layout = qt.QFormLayout()
+ self.setLayout(layout)
+
+ if self._amplitudeRange is not None:
+ min_, max_ = self._amplitudeRange
+ layout.addRow(qt.QLabel("Data Amplitude Range: [%g, %g]" % (min_, max_)))
+
+ self._maxLineEdit = FloatEdit(parent=self)
+ self._maxLineEdit.validator().setBottom(0.0)
+ self._maxLineEdit.setAlignment(qt.Qt.AlignRight)
+
+ self._maxLineEdit.editingFinished.connect(self._rangeUpdated)
+ layout.addRow("Displayed Max.:", self._maxLineEdit)
+
+ self._autoscale = qt.QCheckBox("autoscale")
+ self._autoscale.toggled.connect(self._autoscaleCheckBoxToggled)
+ layout.addRow("", self._autoscale)
+
+ self._deltaLineEdit = FloatEdit(parent=self)
+ self._deltaLineEdit.validator().setBottom(1.0)
+ self._deltaLineEdit.setAlignment(qt.Qt.AlignRight)
+ self._deltaLineEdit.editingFinished.connect(self._rangeUpdated)
+ layout.addRow("Displayed delta (log10 unit):", self._deltaLineEdit)
+
+ buttons = qt.QDialogButtonBox(self)
+ buttons.addButton(qt.QDialogButtonBox.Ok)
+ buttons.addButton(qt.QDialogButtonBox.Cancel)
+ buttons.accepted.connect(self.accept)
+ buttons.rejected.connect(self.reject)
+ layout.addRow(buttons)
+
+ # Set dialog from default values
+ self._resetDialogToDefault()
+
+ self.rejected.connect(self._handleRejected)
+
+ def _resetDialogToDefault(self):
+ """Set Widgets of the dialog from range information"""
+ max_, delta = self._defaultDisplayedRange
+
+ if max_ is not None: # Not in autoscale
+ displayedMax = max_
+ elif self._amplitudeRange is not None: # Autoscale with data
+ displayedMax = self._amplitudeRange[1]
+ else: # Autoscale without data
+ displayedMax = ""
+ if displayedMax == "":
+ self._maxLineEdit.setText("")
+ else:
+ self._maxLineEdit.setValue(displayedMax)
+ self._maxLineEdit.setEnabled(max_ is not None)
+
+ self._deltaLineEdit.setValue(delta)
+
+ self._autoscale.setChecked(self._defaultDisplayedRange[0] is None)
+
+ def getRangeInfo(self):
+ """Returns the current range as a 2-tuple (max, delta (in log10))"""
+ if self._autoscale.isChecked():
+ max_ = None
+ else:
+ maxStr = self._maxLineEdit.text()
+ max_ = self._maxLineEdit.value() if maxStr else None
+ return max_, self._deltaLineEdit.value() if self._deltaLineEdit.text() else 2
+
+ def _handleRejected(self):
+ """Reset range info to default when rejected"""
+ self._resetDialogToDefault()
+ self._rangeUpdated()
+
+ def _rangeUpdated(self):
+ """Handle QLineEdit editing finised"""
+ self.sigRangeChanged.emit(self.getRangeInfo())
+
+ def _autoscaleCheckBoxToggled(self, checked):
+ """Handle autoscale checkbox state changes"""
+ if checked: # Use default values
+ if self._amplitudeRange is None:
+ max_ = ""
+ else:
+ max_ = self._amplitudeRange[1]
+ if max_ == "":
+ self._maxLineEdit.setText("")
+ else:
+ self._maxLineEdit.setValue(max_)
+ self._maxLineEdit.setEnabled(not checked)
+ self._rangeUpdated()
+
+
+class _ComplexDataToolButton(qt.QToolButton):
+ """QToolButton providing choices of complex data visualization modes
+
+ :param parent: See :class:`QToolButton`
+ :param plot: The :class:`ComplexImageView` to control
+ """
+
+ _MODES = dict(
+ [
+ (ImageComplexData.ComplexMode.ABSOLUTE, ("math-amplitude", "Amplitude")),
+ (
+ ImageComplexData.ComplexMode.SQUARE_AMPLITUDE,
+ ("math-square-amplitude", "Square amplitude"),
+ ),
+ (ImageComplexData.ComplexMode.PHASE, ("math-phase", "Phase")),
+ (ImageComplexData.ComplexMode.REAL, ("math-real", "Real part")),
+ (
+ ImageComplexData.ComplexMode.IMAGINARY,
+ ("math-imaginary", "Imaginary part"),
+ ),
+ (
+ ImageComplexData.ComplexMode.AMPLITUDE_PHASE,
+ ("math-phase-color", "Amplitude and Phase"),
+ ),
+ (
+ ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ("math-phase-color-log", "Log10(Amp.) and Phase"),
+ ),
+ ]
+ )
+
+ _RANGE_DIALOG_TEXT = "Set Amplitude Range..."
+
+ def __init__(self, parent=None, plot=None):
+ super(_ComplexDataToolButton, self).__init__(parent=parent)
+
+ assert plot is not None
+ self._plot2DComplex = plot
+
+ menu = qt.QMenu(self)
+ menu.triggered.connect(self._triggered)
+ self.setMenu(menu)
+
+ for mode, info in self._MODES.items():
+ icon, text = info
+ action = qt.QAction(icons.getQIcon(icon), text, self)
+ action.setData(mode)
+ action.setIconVisibleInMenu(True)
+ menu.addAction(action)
+
+ self._rangeDialogAction = qt.QAction(self)
+ self._rangeDialogAction.setText(self._RANGE_DIALOG_TEXT)
+ menu.addAction(self._rangeDialogAction)
+
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ self._modeChanged(self._plot2DComplex.getComplexMode())
+ self._plot2DComplex.sigVisualizationModeChanged.connect(self._modeChanged)
+
+ def _modeChanged(self, mode):
+ """Handle change of visualization modes"""
+ icon, text = self._MODES[mode]
+ self.setIcon(icons.getQIcon(icon))
+ self.setToolTip("Display the " + text.lower())
+ self._rangeDialogAction.setEnabled(
+ mode == ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE
+ )
+
+ def _triggered(self, action):
+ """Handle triggering of menu actions"""
+ actionText = action.text()
+
+ if actionText == self._RANGE_DIALOG_TEXT: # Show dialog
+ # Get amplitude range
+ data = self._plot2DComplex.getData(copy=False)
+
+ if data.size > 0:
+ absolute = numpy.absolute(data)
+ dataRange = (numpy.nanmin(absolute), numpy.nanmax(absolute))
+ else:
+ dataRange = None
+
+ # Show dialog
+ dialog = _AmplitudeRangeDialog(
+ parent=self,
+ amplitudeRange=dataRange,
+ displayedRange=self._plot2DComplex._getAmplitudeRangeInfo(),
+ )
+ dialog.sigRangeChanged.connect(self._rangeChanged)
+ dialog.exec()
+ dialog.sigRangeChanged.disconnect(self._rangeChanged)
+
+ else: # update mode
+ mode = action.data()
+ if isinstance(mode, ImageComplexData.ComplexMode):
+ self._plot2DComplex.setComplexMode(mode)
+
+ def _rangeChanged(self, range_):
+ """Handle updates of range in the dialog"""
+ self._plot2DComplex._setAmplitudeRangeInfo(*range_)
+
+
+class ComplexImageView(qt.QWidget):
+ """Display an image of complex data and allow to choose the visualization.
+
+ :param parent: See :class:`QMainWindow`
+ """
+
+ ComplexMode = ImageComplexData.ComplexMode
+ """Complex Modes enumeration"""
+
+ sigDataChanged = qt.Signal()
+ """Signal emitted when data has changed."""
+
+ sigVisualizationModeChanged = qt.Signal(object)
+ """Signal emitted when the visualization mode has changed.
+
+ It provides the new visualization mode.
+ """
+
+ def __init__(self, parent=None):
+ super(ComplexImageView, self).__init__(parent)
+ if parent is None:
+ self.setWindowTitle("ComplexImageView")
+
+ self._plot2D = Plot2D(self)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot2D)
+ self.setLayout(layout)
+
+ # Create and add image to the plot
+ self._plotImage = ImageComplexData()
+ self._plotImage.setName("__ComplexImageView__complex_image__")
+ self._plotImage.sigItemChanged.connect(self._itemChanged)
+ self._plot2D.addItem(self._plotImage)
+ self._plot2D.setActiveImage(self._plotImage)
+
+ toolBar = qt.QToolBar("Complex", self)
+ toolBar.addWidget(_ComplexDataToolButton(parent=self, plot=self))
+
+ self._plot2D.insertToolBar(self._plot2D.getProfileToolbar(), toolBar)
+
+ def _itemChanged(self, event):
+ """Handle item changed signal"""
+ if event is items.ItemChangedType.DATA:
+ self.sigDataChanged.emit()
+ elif event is items.ItemChangedType.VISUALIZATION_MODE:
+ mode = self.getComplexMode()
+ self.sigVisualizationModeChanged.emit(mode)
+
+ def getPlot(self):
+ """Return the PlotWidget displaying the data"""
+ return self._plot2D
+
+ def setData(self, data=None, copy=True):
+ """Set the complex data to display.
+
+ :param numpy.ndarray data: 2D complex data
+ :param bool copy: True (default) to copy the data,
+ False to use provided data (do not modify!).
+ """
+ if data is None:
+ data = numpy.zeros((0, 0), dtype=numpy.complex64)
+
+ previousData = self._plotImage.getComplexData(copy=False)
+
+ self._plotImage.setData(data, copy=copy)
+
+ if previousData.shape != data.shape:
+ self.getPlot().resetZoom()
+
+ def getData(self, copy=True):
+ """Get the currently displayed complex data.
+
+ :param bool copy: True (default) to return a copy of the data,
+ False to return internal data (do not modify!).
+ :return: The complex data array.
+ :rtype: numpy.ndarray of complex with 2 dimensions
+ """
+ return self._plotImage.getComplexData(copy=copy)
+
+ def getDisplayedData(self, copy=True):
+ """Returns the displayed data depending on the visualization mode
+
+ WARNING: The returned data can be a uint8 RGBA image
+
+ :param bool copy: True (default) to return a copy of the data,
+ False to return internal data (do not modify!)
+ :rtype: numpy.ndarray of float with 2 dims or RGBA image (uint8).
+ """
+ mode = self.getComplexMode()
+ if mode in (
+ self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ):
+ return self._plotImage.getRgbaImageData(copy=copy)
+ else:
+ return self._plotImage.getData(copy=copy)
+
+ # Backward compatibility
+
+ Mode = ComplexMode
+
+ # Image item proxy
+
+ @staticmethod
+ def supportedComplexModes():
+ """Returns the supported visualization modes.
+
+ Supported visualization modes are:
+
+ - amplitude: The absolute value provided by numpy.absolute
+ - phase: The phase (or argument) provided by numpy.angle
+ - real: Real part
+ - imaginary: Imaginary part
+ - amplitude_phase: Color-coded phase with amplitude as alpha.
+ - log10_amplitude_phase:
+ Color-coded phase with log10(amplitude) as alpha.
+
+ :rtype: List[ComplexMode]
+ """
+ return ImageComplexData.supportedComplexModes()
+
+ def setComplexMode(self, mode):
+ """Set the mode of visualization of the complex data.
+
+ See :meth:`supportedComplexModes` for the list of
+ supported modes.
+
+ How-to change visualization mode::
+
+ widget = ComplexImageView()
+ widget.setComplexMode(ComplexImageView.ComplexMode.PHASE)
+ # or
+ widget.setComplexMode('phase')
+
+ :param Unions[ComplexMode,str] mode: The mode to use.
+ """
+ self._plotImage.setComplexMode(mode)
+
+ def getComplexMode(self):
+ """Get the current visualization mode of the complex data.
+
+ :rtype: ComplexMode
+ """
+ return self._plotImage.getComplexMode()
+
+ def _setAmplitudeRangeInfo(self, max_=None, delta=2):
+ """Set the amplitude range to display for 'log10_amplitude_phase' mode.
+
+ :param max_: Max of the amplitude range.
+ If None it autoscales to data max.
+ :param float delta: Delta range in log10 to display
+ """
+ self._plotImage._setAmplitudeRangeInfo(max_, delta)
+
+ def _getAmplitudeRangeInfo(self):
+ """Returns the amplitude range to use for 'log10_amplitude_phase' mode.
+
+ :return: (max, delta), if max is None, then it autoscales to data max
+ :rtype: 2-tuple"""
+ return self._plotImage._getAmplitudeRangeInfo()
+
+ def setColormap(self, colormap, mode=None):
+ """Set the colormap to use for amplitude, phase, real or imaginary.
+
+ WARNING: This colormap is not used when displaying both
+ amplitude and phase.
+
+ :param ~silx.gui.colors.Colormap colormap: The colormap
+ :param ComplexMode mode: If specified, set the colormap of this specific mode
+ """
+ self._plotImage.setColormap(colormap, mode)
+
+ def getColormap(self, mode=None):
+ """Returns the colormap used to display the data.
+
+ :param ComplexMode mode: If specified, set the colormap of this specific mode
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self._plotImage.getColormap(mode=mode)
+
+ def getOrigin(self):
+ """Returns the offset from origin at which to display the image.
+
+ :rtype: 2-tuple of float
+ """
+ return self._plotImage.getOrigin()
+
+ def setOrigin(self, origin):
+ """Set the offset from origin at which to display the image.
+
+ :param origin: (ox, oy) Offset from origin
+ :type origin: float or 2-tuple of float
+ """
+ self._plotImage.setOrigin(origin)
+
+ def getScale(self):
+ """Returns the scale of the image in data coordinates.
+
+ :rtype: 2-tuple of float
+ """
+ return self._plotImage.getScale()
+
+ def setScale(self, scale):
+ """Set the scale of the image
+
+ :param scale: (sx, sy) Scale of the image
+ :type scale: float or 2-tuple of float
+ """
+ self._plotImage.setScale(scale)
+
+ # PlotWidget API proxy
+
+ def getXAxis(self):
+ """Returns the X axis
+
+ :rtype: :class:`.items.Axis`
+ """
+ return self.getPlot().getXAxis()
+
+ def getYAxis(self):
+ """Returns an Y axis
+
+ :rtype: :class:`.items.Axis`
+ """
+ return self.getPlot().getYAxis(axis="left")
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str."""
+ return self.getPlot().getGraphTitle()
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ self.getPlot().setGraphTitle(title)
+
+ def setKeepDataAspectRatio(self, flag):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ self.getPlot().setKeepDataAspectRatio(flag)
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self.getPlot().isKeepDataAspectRatio()
diff --git a/src/silx/gui/plot/CurvesROIWidget.py b/src/silx/gui/plot/CurvesROIWidget.py
new file mode 100644
index 0000000..bd47da0
--- /dev/null
+++ b/src/silx/gui/plot/CurvesROIWidget.py
@@ -0,0 +1,1572 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+Widget to handle regions of interest (:class:`ROI`) on curves displayed in a
+:class:`PlotWindow`.
+
+This widget is meant to work with :class:`PlotWindow`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
+__license__ = "MIT"
+__date__ = "13/03/2018"
+
+import logging
+import os
+import sys
+import functools
+import numpy
+from silx.io import dictdump
+from silx.utils.weakref import WeakMethodProxy
+from silx.utils.proxy import docstring
+from .. import icons, qt
+from silx.math.combo import min_max
+import weakref
+from silx.gui.widgets.TableWidget import TableWidget
+from . import items
+from .items.roi import _RegionOfInterestBase
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CurvesROIWidget(qt.QWidget):
+ """
+ Widget displaying a table of ROI information.
+
+ Implements also the following behavior:
+
+ * if the roiTable has no ROI when showing create the default ICR one
+
+ :param parent: See :class:`QWidget`
+ :param str name: The title of this widget
+ """
+
+ sigROIWidgetSignal = qt.Signal(object)
+ """Signal of ROIs modifications.
+
+ Modification information if given as a dict with an 'event' key
+ providing the type of events.
+
+ Type of events:
+
+ - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict'
+ - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader',
+ 'rowheader'
+ """
+
+ sigROISignal = qt.Signal(object)
+
+ def __init__(self, parent=None, name=None, plot=None):
+ super(CurvesROIWidget, self).__init__(parent)
+ if name is not None:
+ self.setWindowTitle(name)
+ self.__lastSigROISignal = None
+ """Store the last value emitted for the sigRoiSignal. In the case the
+ active curve change we need to add this extra step in order to make
+ sure we won't send twice the sigROISignal.
+ This come from the fact sigROISignal is connected to the
+ activeROIChanged signal which is emitted when raw and net counts
+ values are changing but are not embed in the sigROISignal.
+ """
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._showAllMarkers = False
+ self.currentROI = None
+
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ self.headerLabel = qt.QLabel(self)
+ self.headerLabel.setAlignment(qt.Qt.AlignHCenter)
+ self.setHeader()
+ layout.addWidget(self.headerLabel)
+
+ widgetAllCheckbox = qt.QWidget(parent=self)
+ self._showAllCheckBox = qt.QCheckBox("show all ROI", parent=widgetAllCheckbox)
+ widgetAllCheckbox.setLayout(qt.QHBoxLayout())
+ spacer = qt.QWidget(parent=widgetAllCheckbox)
+ spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ widgetAllCheckbox.layout().addWidget(spacer)
+ widgetAllCheckbox.layout().addWidget(self._showAllCheckBox)
+ layout.addWidget(widgetAllCheckbox)
+
+ self.roiTable = ROITable(self, plot=plot)
+ rheight = self.roiTable.horizontalHeader().sizeHint().height()
+ self.roiTable.setMinimumHeight(4 * rheight)
+ layout.addWidget(self.roiTable)
+ self._roiFileDir = qt.QDir.home().absolutePath()
+ self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers)
+
+ hbox = qt.QWidget(self)
+ hboxlayout = qt.QHBoxLayout(hbox)
+ hboxlayout.setContentsMargins(0, 0, 0, 0)
+ hboxlayout.setSpacing(0)
+
+ hboxlayout.addStretch(0)
+
+ self.addButton = qt.QPushButton(hbox)
+ self.addButton.setText("Add ROI")
+ self.addButton.setToolTip("Create a new ROI")
+ self.delButton = qt.QPushButton(hbox)
+ self.delButton.setText("Delete ROI")
+ self.addButton.setToolTip("Remove the selected ROI")
+ self.resetButton = qt.QPushButton(hbox)
+ self.resetButton.setText("Reset")
+ self.addButton.setToolTip(
+ "Clear all created ROIs. We only let the " "default ROI"
+ )
+
+ hboxlayout.addWidget(self.addButton)
+ hboxlayout.addWidget(self.delButton)
+ hboxlayout.addWidget(self.resetButton)
+
+ hboxlayout.addStretch(0)
+
+ self.loadButton = qt.QPushButton(hbox)
+ self.loadButton.setText("Load")
+ self.loadButton.setToolTip("Load ROIs from a .ini file")
+ self.saveButton = qt.QPushButton(hbox)
+ self.saveButton.setText("Save")
+ self.loadButton.setToolTip("Save ROIs to a .ini file")
+ hboxlayout.addWidget(self.loadButton)
+ hboxlayout.addWidget(self.saveButton)
+ layout.setStretchFactor(self.headerLabel, 0)
+ layout.setStretchFactor(self.roiTable, 1)
+ layout.setStretchFactor(hbox, 0)
+
+ layout.addWidget(hbox)
+
+ # Signal / Slot connections
+ self.addButton.clicked.connect(self._add)
+ self.delButton.clicked.connect(self._del)
+ self.resetButton.clicked.connect(self._reset)
+
+ self.loadButton.clicked.connect(self._load)
+ self.saveButton.clicked.connect(self._save)
+
+ self.roiTable.activeROIChanged.connect(self._emitCurrentROISignal)
+
+ self._isConnected = False # True if connected to plot signals
+ self._isInit = False
+
+ def getROIListAndDict(self):
+ return self.roiTable.getROIListAndDict()
+
+ def getPlotWidget(self):
+ """Returns the associated PlotWidget or None
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ @property
+ def roiFileDir(self):
+ """The directory from which to load/save ROI from/to files."""
+ if not os.path.isdir(self._roiFileDir):
+ self._roiFileDir = qt.QDir.home().absolutePath()
+ return self._roiFileDir
+
+ @roiFileDir.setter
+ def roiFileDir(self, roiFileDir):
+ self._roiFileDir = str(roiFileDir)
+
+ def setRois(self, rois, order=None):
+ return self.roiTable.setRois(rois, order)
+
+ def getRois(self, order=None):
+ return self.roiTable.getRois(order)
+
+ def setMiddleROIMarkerFlag(self, flag=True):
+ return self.roiTable.setMiddleROIMarkerFlag(flag)
+
+ def _add(self):
+ """Add button clicked handler"""
+
+ def getNextRoiName():
+ rois = self.roiTable.getRois(order=None)
+ roisNames = []
+ [roisNames.append(roiName) for roiName in rois]
+ nrois = len(rois)
+ if nrois == 0:
+ return "ICR"
+ else:
+ i = 1
+ newroi = "newroi %d" % i
+ while newroi in roisNames:
+ i += 1
+ newroi = "newroi %d" % i
+ return newroi
+
+ roi = ROI(name=getNextRoiName())
+
+ if roi.getName() == "ICR":
+ roi.setType("Default")
+ else:
+ roi.setType(self.getPlotWidget().getXAxis().getLabel())
+
+ xmin, xmax = self.getPlotWidget().getXAxis().getLimits()
+ fromdata = xmin + 0.25 * (xmax - xmin)
+ todata = xmin + 0.75 * (xmax - xmin)
+ if roi.isICR():
+ fromdata, dummy0, todata, dummy1 = self._getAllLimits()
+ roi.setFrom(fromdata)
+ roi.setTo(todata)
+ self.roiTable.addRoi(roi)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict["event"] = "AddROI"
+ ddict["roilist"] = self.roiTable.roidict.values()
+ ddict["roidict"] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def _del(self):
+ """Delete button clicked handler"""
+ self.roiTable.deleteActiveRoi()
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict["event"] = "DelROI"
+ ddict["roilist"] = self.roiTable.roidict.values()
+ ddict["roidict"] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def _reset(self):
+ """Reset button clicked handler"""
+ self.roiTable.clear()
+ old = self.blockSignals(True) # avoid several sigROISignal emission
+ self._add()
+ self.blockSignals(old)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict["event"] = "ResetROI"
+ ddict["roilist"] = self.roiTable.roidict.values()
+ ddict["roidict"] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def _load(self):
+ """Load button clicked handler"""
+ dialog = qt.QFileDialog(self)
+ dialog.setNameFilters(["INI File *.ini", "JSON File *.json", "All *.*"])
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.roiFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ # pyflakes bug http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=666494
+ outputFile = dialog.selectedFiles()[0]
+ dialog.close()
+
+ self.roiFileDir = os.path.dirname(outputFile)
+ self.roiTable.load(outputFile)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict["event"] = "LoadROI"
+ ddict["roilist"] = self.roiTable.roidict.values()
+ ddict["roidict"] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def load(self, filename):
+ """Load ROI widget information from a file storing a dict of ROI.
+
+ :param str filename: The file from which to load ROI
+ """
+ self.roiTable.load(filename)
+
+ def _save(self):
+ """Save button clicked handler"""
+ dialog = qt.QFileDialog(self)
+ dialog.setNameFilters(["INI File *.ini", "JSON File *.json"])
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.roiFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ outputFile = dialog.selectedFiles()[0]
+ extension = "." + dialog.selectedNameFilter().split(".")[-1]
+ dialog.close()
+
+ if not outputFile.endswith(extension):
+ outputFile += extension
+
+ if os.path.exists(outputFile):
+ try:
+ os.remove(outputFile)
+ except IOError:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Input Output Error: %s" % (sys.exc_info()[1]))
+ msg.exec()
+ return
+ self.roiFileDir = os.path.dirname(outputFile)
+ self.save(outputFile)
+
+ def save(self, filename):
+ """Save current ROIs of the widget as a dict of ROI to a file.
+
+ :param str filename: The file to which to save the ROIs
+ """
+ self.roiTable.save(filename)
+
+ def setHeader(self, text="ROIs"):
+ """Set the header text of this widget"""
+ self.headerLabel.setText("<b>%s<\b>" % text)
+
+ def calculateRois(self, roiList=None, roiDict=None):
+ """Compute ROI information"""
+ return self.roiTable.calculateRois()
+
+ def showAllMarkers(self, _show=True):
+ self.roiTable.showAllMarkers(_show)
+
+ def _getAllLimits(self):
+ """Retrieve the limits based on the curves."""
+ plot = self.getPlotWidget()
+ curves = () if plot is None else plot.getAllCurves()
+ if not curves:
+ return 1.0, 1.0, 100.0, 100.0
+
+ xmin, ymin = None, None
+ xmax, ymax = None, None
+
+ for curve in curves:
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+ if xmin is None:
+ xmin = x.min()
+ else:
+ xmin = min(xmin, x.min())
+ if xmax is None:
+ xmax = x.max()
+ else:
+ xmax = max(xmax, x.max())
+ if ymin is None:
+ ymin = y.min()
+ else:
+ ymin = min(ymin, y.min())
+ if ymax is None:
+ ymax = y.max()
+ else:
+ ymax = max(ymax, y.max())
+
+ return xmin, ymin, xmax, ymax
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
+
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ # if no ROI existing yet, add the default one
+ if self.roiTable.rowCount() == 0:
+ old = self.blockSignals(True) # avoid several sigROISignal emission
+ self._add()
+ self.blockSignals(old)
+ self.calculateRois()
+
+ def fillFromROIDict(self, *args, **kwargs):
+ self.roiTable.fillFromROIDict(*args, **kwargs)
+
+ def _emitCurrentROISignal(self):
+ ddict = {}
+ ddict["event"] = "currentROISignal"
+ if self.roiTable.activeRoi is not None:
+ ddict["ROI"] = self.roiTable.activeRoi.toDict()
+ ddict["current"] = self.roiTable.activeRoi.getName()
+ else:
+ ddict["current"] = None
+
+ if self.__lastSigROISignal != ddict:
+ self.__lastSigROISignal = ddict
+ self.sigROISignal.emit(ddict)
+
+ @property
+ def currentRoi(self):
+ return self.roiTable.activeRoi
+
+
+class _FloatItem(qt.QTableWidgetItem):
+ """
+ Simple QTableWidgetItem overloading the < operator to deal with ordering
+ """
+
+ def __init__(self):
+ qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type)
+
+ def __lt__(self, other):
+ if self.text() in ("", ROITable.INFO_NOT_FOUND):
+ return False
+ if other.text() in ("", ROITable.INFO_NOT_FOUND):
+ return True
+ return float(self.text()) < float(other.text())
+
+
+class ROITable(TableWidget):
+ """Table widget displaying ROI information.
+
+ See :class:`QTableWidget` for constructor arguments.
+
+ Behavior: listen at the active curve changed only when the widget is
+ visible. Otherwise won't compute the row and net counts...
+ """
+
+ activeROIChanged = qt.Signal()
+ """Signal emitted when the active roi changed or when the value of the
+ active roi are changing"""
+
+ COLUMNS_INDEX = dict(
+ [
+ ("ID", 0),
+ ("ROI", 1),
+ ("Type", 2),
+ ("From", 3),
+ ("To", 4),
+ ("Raw Counts", 5),
+ ("Net Counts", 6),
+ ("Raw Area", 7),
+ ("Net Area", 8),
+ ]
+ )
+
+ COLUMNS = list(COLUMNS_INDEX.keys())
+
+ INFO_NOT_FOUND = "????????"
+
+ def __init__(self, parent=None, plot=None, rois=None):
+ super(ROITable, self).__init__(parent)
+ self._showAllMarkers = False
+ self._userIsEditingRoi = False
+ """bool used to avoid conflict when editing the ROI object"""
+ self._isConnected = False
+ self._roiToItems = {}
+ self._roiDict = {}
+ """dict of ROI object. Key is ROi id, value is the ROI object"""
+ self._markersHandler = _RoiMarkerManager()
+
+ """
+ Associate for each marker legend used when the `_showAllMarkers` option
+ is active a roi.
+ """
+ self.setColumnCount(len(self.COLUMNS))
+ self.setPlot(plot)
+ self.__setTooltip()
+ self.setSortingEnabled(True)
+ self.itemChanged.connect(self._itemChanged)
+
+ @property
+ def roidict(self):
+ return self._getRoiDict()
+
+ @property
+ def activeRoi(self):
+ return self._markersHandler._activeRoi
+
+ def _getRoiDict(self):
+ ddict = {}
+ for id in self._roiDict:
+ ddict[self._roiDict[id].getName()] = self._roiDict[id]
+ return ddict
+
+ def clear(self):
+ """
+ .. note:: clear the interface only. keep the roidict...
+ """
+ self._markersHandler.clear()
+ self._roiToItems = {}
+ self._roiDict = {}
+
+ qt.QTableWidget.clear(self)
+ self.setRowCount(0)
+ self.setHorizontalHeaderLabels(self.COLUMNS)
+ header = self.horizontalHeader()
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+ self.sortByColumn(0, qt.Qt.AscendingOrder)
+ self.hideColumn(self.COLUMNS_INDEX["ID"])
+
+ def setPlot(self, plot):
+ self.clear()
+ self.plot = plot
+
+ def __setTooltip(self):
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["ROI"]).setToolTip(
+ "Region of interest identifier"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["Type"]).setToolTip(
+ "Type of the ROI"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["From"]).setToolTip(
+ "X-value of the min point"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["To"]).setToolTip(
+ "X-value of the max point"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["Raw Counts"]).setToolTip(
+ "Estimation of the integral between y=0 and the selected curve"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["Net Counts"]).setToolTip(
+ "Estimation of the integral between the segment [maxPt, minPt] "
+ "and the selected curve"
+ )
+
+ def setRois(self, rois, order=None):
+ """Set the ROIs by providing a dictionary of ROI information.
+
+ The dictionary keys are the ROI names.
+ Each value is a sub-dictionary of ROI info with the following fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+
+ :param roidict: Dictionary of ROIs
+ :param str order: Field used for ordering the ROIs.
+ One of "from", "to", "type".
+ None (default) for no ordering, or same order as specified
+ in parameter ``rois`` if provided as a dict.
+ """
+ assert order in [None, "from", "to", "type"]
+ self.clear()
+
+ # backward compatibility since 0.10.0
+ if isinstance(rois, dict):
+ for roiName, roi in rois.items():
+ if isinstance(roi, ROI):
+ _roi = roi
+ else:
+ roi["name"] = roiName
+ _roi = ROI._fromDict(roi)
+ self.addRoi(_roi)
+ else:
+ for roi in rois:
+ assert isinstance(roi, ROI)
+ self.addRoi(roi)
+ self._updateMarkers()
+
+ def addRoi(self, roi):
+ """
+
+ :param :class:`ROI` roi: roi to add to the table
+ """
+ assert isinstance(roi, ROI)
+ self._getItem(name="ID", row=None, roi=roi)
+ self._roiDict[roi.getID()] = roi
+ self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot))
+ self._updateRoiInfo(roi.getID())
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo), roi.getID())
+ roi.sigChanged.connect(callback)
+ # set it as the active one
+ self.setActiveRoi(roi)
+
+ def _getItem(self, name, row, roi):
+ if row:
+ item = self.item(row, self.COLUMNS_INDEX[name])
+ else:
+ item = None
+ if item:
+ return item
+ else:
+ if name == "ID":
+ assert roi
+ if roi.getID() in self._roiToItems:
+ return self._roiToItems[roi.getID()]
+ else:
+ # create a new row
+ row = self.rowCount()
+ self.setRowCount(self.rowCount() + 1)
+ item = qt.QTableWidgetItem(
+ str(roi.getID()), type=qt.QTableWidgetItem.Type
+ )
+ self._roiToItems[roi.getID()] = item
+ elif name == "ROI":
+ item = qt.QTableWidgetItem(
+ roi.getName() if roi else "", type=qt.QTableWidgetItem.Type
+ )
+ if roi.getName().upper() in ("ICR", "DEFAULT"):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(
+ qt.Qt.ItemIsSelectable
+ | qt.Qt.ItemIsEnabled
+ | qt.Qt.ItemIsEditable
+ )
+ elif name == "Type":
+ item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type)
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
+ elif name in ("To", "From"):
+ item = _FloatItem()
+ if roi.getName().upper() in ("ICR", "DEFAULT"):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(
+ qt.Qt.ItemIsSelectable
+ | qt.Qt.ItemIsEnabled
+ | qt.Qt.ItemIsEditable
+ )
+ elif name in ("Raw Counts", "Net Counts", "Raw Area", "Net Area"):
+ item = _FloatItem()
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
+ else:
+ raise ValueError("item type not recognized")
+
+ self.setItem(row, self.COLUMNS_INDEX[name], item)
+ return item
+
+ def _itemChanged(self, item):
+ def getRoi():
+ IDItem = self.item(item.row(), self.COLUMNS_INDEX["ID"])
+ assert IDItem
+ id = int(IDItem.text())
+ assert id in self._roiDict
+ roi = self._roiDict[id]
+ return roi
+
+ def signalChanged(roi):
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ self._userIsEditingRoi = True
+ if item.column() in (self.COLUMNS_INDEX["To"], self.COLUMNS_INDEX["From"]):
+ roi = getRoi()
+
+ if item.text() not in ("", self.INFO_NOT_FOUND):
+ try:
+ value = float(item.text())
+ except ValueError:
+ value = 0
+ changed = False
+ if item.column() == self.COLUMNS_INDEX["To"]:
+ if value != roi.getTo():
+ roi.setTo(value)
+ changed = True
+ else:
+ assert item.column() == self.COLUMNS_INDEX["From"]
+ if value != roi.getFrom():
+ roi.setFrom(value)
+ changed = True
+ if changed:
+ self._updateMarker(roi.getName())
+ signalChanged(roi)
+
+ if item.column() is self.COLUMNS_INDEX["ROI"]:
+ roi = getRoi()
+ if roi.getName() != item.text():
+ roi.setName(item.text())
+ self._markersHandler.getMarkerHandler(roi.getID()).updateTexts()
+ signalChanged(roi)
+
+ self._userIsEditingRoi = False
+
+ def deleteActiveRoi(self):
+ """
+ remove the current active roi
+ """
+ activeItems = self.selectedItems()
+ if len(activeItems) == 0:
+ return
+ old = self.blockSignals(True) # avoid several emission of sigROISignal
+ roiToRm = set()
+ for item in activeItems:
+ row = item.row()
+ itemID = self.item(row, self.COLUMNS_INDEX["ID"])
+ roiToRm.add(self._roiDict[int(itemID.text())])
+ [self.removeROI(roi) for roi in roiToRm]
+ self.blockSignals(old)
+ self.setActiveRoi(None)
+
+ def removeROI(self, roi):
+ """
+ remove the requested roi
+
+ :param str name: the name of the roi to remove from the table
+ """
+ if roi and roi.getID() in self._roiToItems:
+ item = self._roiToItems[roi.getID()]
+ self.removeRow(item.row())
+ del self._roiToItems[roi.getID()]
+
+ assert roi.getID() in self._roiDict
+ del self._roiDict[roi.getID()]
+ self._markersHandler.remove(roi)
+
+ callback = functools.partial(
+ WeakMethodProxy(self._updateRoiInfo), roi.getID()
+ )
+ roi.sigChanged.connect(callback)
+
+ def setActiveRoi(self, roi):
+ """
+ Define the given roi as the active one.
+
+ .. warning:: this roi should already be registred / added to the table
+
+ :param :class:`ROI` roi: the roi to defined as active
+ """
+ if roi is None:
+ self.clearSelection()
+ self._markersHandler.setActiveRoi(None)
+ self.activeROIChanged.emit()
+ else:
+ assert isinstance(roi, ROI)
+ if roi and roi.getID() in self._roiToItems.keys():
+ # avoid several call back to setActiveROI
+ old = self.blockSignals(True)
+ self.selectRow(self._roiToItems[roi.getID()].row())
+ self.blockSignals(old)
+ self._markersHandler.setActiveRoi(roi)
+ self.activeROIChanged.emit()
+
+ def _updateRoiInfo(self, roiID):
+ if self._userIsEditingRoi is True:
+ return
+ if roiID not in self._roiDict:
+ return
+ roi = self._roiDict[roiID]
+ if roi.isICR():
+ activeCurve = self.plot.getActiveCurve()
+ if activeCurve:
+ xData = activeCurve.getXData()
+ if len(xData) > 0:
+ min, max = min_max(xData)
+ roi.blockSignals(True)
+ roi.setFrom(min)
+ roi.setTo(max)
+ roi.blockSignals(False)
+
+ itemID = self._getItem(name="ID", roi=roi, row=None)
+ itemName = self._getItem(name="ROI", row=itemID.row(), roi=roi)
+ itemName.setText(roi.getName())
+
+ itemType = self._getItem(name="Type", row=itemID.row(), roi=roi)
+ itemType.setText(roi.getType() or self.INFO_NOT_FOUND)
+
+ itemFrom = self._getItem(name="From", row=itemID.row(), roi=roi)
+ fromdata = (
+ str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND
+ )
+ itemFrom.setText(fromdata)
+
+ itemTo = self._getItem(name="To", row=itemID.row(), roi=roi)
+ todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND
+ itemTo.setText(todata)
+
+ rawCounts, netCounts = roi.computeRawAndNetCounts(
+ curve=self.plot.getActiveCurve(just_legend=False)
+ )
+ itemRawCounts = self._getItem(name="Raw Counts", row=itemID.row(), roi=roi)
+ rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND
+ itemRawCounts.setText(rawCounts)
+
+ itemNetCounts = self._getItem(name="Net Counts", row=itemID.row(), roi=roi)
+ netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND
+ itemNetCounts.setText(netCounts)
+
+ rawArea, netArea = roi.computeRawAndNetArea(
+ curve=self.plot.getActiveCurve(just_legend=False)
+ )
+ itemRawArea = self._getItem(name="Raw Area", row=itemID.row(), roi=roi)
+ rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND
+ itemRawArea.setText(rawArea)
+
+ itemNetArea = self._getItem(name="Net Area", row=itemID.row(), roi=roi)
+ netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND
+ itemNetArea.setText(netArea)
+
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ def currentChanged(self, current, previous):
+ if previous and current.row() != previous.row() and current.row() >= 0:
+ roiItem = self.item(current.row(), self.COLUMNS_INDEX["ID"])
+
+ assert roiItem
+ self.setActiveRoi(self._roiDict[int(roiItem.text())])
+ self._markersHandler.updateAllMarkers()
+ qt.QTableWidget.currentChanged(self, current, previous)
+
+ def calculateRois(self):
+ """Update values of all registred rois (raw and net counts in particular)"""
+ for roiID in self._roiDict:
+ self._updateRoiInfo(roiID)
+
+ def _updateMarker(self, roiID):
+ """Make sure the marker of the given roi name is updated"""
+ if self._showAllMarkers or (
+ self.activeRoi and self.activeRoi.getName() == roiID
+ ):
+ self._updateMarkers()
+
+ def _updateMarkers(self):
+ if self._showAllMarkers is True:
+ self._markersHandler.updateMarkers()
+ else:
+ if not self.activeRoi or not self.plot:
+ return
+ assert isinstance(self.activeRoi, ROI)
+ markerHandler = self._markersHandler.getMarkerHandler(
+ self.activeRoi.getID()
+ )
+ if markerHandler is not None:
+ markerHandler.updateMarkers()
+
+ def getRois(self, order):
+ """
+ Return the currently defined ROIs, as an ordered dict.
+
+ The dictionary keys are the ROI names.
+ Each value is a :class:`ROI` object..
+
+ :param order: Field used for ordering the ROIs.
+ One of "from", "to", "type", "netcounts", "rawcounts".
+ None (default) to get the same order as displayed in the widget.
+ :return: Ordered dictionary of ROI information
+ """
+
+ if order is None or order.lower() == "none":
+ ordered_roilist = list(self._roiDict.values())
+ res = dict(
+ [(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist]
+ )
+ else:
+ assert order in ["from", "to", "type", "netcounts", "rawcounts"]
+ ordered_roilist = sorted(
+ self._roiDict.keys(),
+ key=lambda roi_id: self._roiDict[roi_id].get(order),
+ )
+ res = dict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist])
+
+ return res
+
+ def save(self, filename):
+ """
+ Save current ROIs of the widget as a dict of ROI to a file.
+
+ :param str filename: The file to which to save the ROIs
+ """
+ roilist = []
+ roidict = {}
+ for roiID, roi in self._roiDict.items():
+ roilist.append(roi.toDict())
+ roidict[roi.getName()] = roi.toDict()
+ datadict = {"ROI": {"roilist": roilist, "roidict": roidict}}
+ dictdump.dump(datadict, filename)
+
+ def load(self, filename):
+ """
+ Load ROI widget information from a file storing a dict of ROI.
+
+ :param str filename: The file from which to load ROI
+ """
+ roisDict = dictdump.load(filename)
+ rois = []
+
+ # Remove rawcounts and netcounts from ROIs
+ for roiDict in roisDict["ROI"]["roidict"].values():
+ roiDict.pop("rawcounts", None)
+ roiDict.pop("netcounts", None)
+ rois.append(ROI._fromDict(roiDict))
+
+ self.setRois(rois)
+
+ def showAllMarkers(self, _show=True):
+ """
+
+ :param bool _show: if true show all the markers of all the ROIs
+ boundaries otherwise will only show the one of
+ the active ROI.
+ """
+ self._markersHandler.setShowAllMarkers(_show)
+
+ def setMiddleROIMarkerFlag(self, flag=True):
+ """
+ Activate or deactivate middle marker.
+
+ This allows shifting both min and max limits at once, by dragging
+ a marker located in the middle.
+
+ :param bool flag: True to activate middle ROI marker
+ """
+ self._markersHandler._middleROIMarkerFlag = flag
+
+ def _handleROIMarkerEvent(self, ddict):
+ """Handle plot signals related to marker events."""
+ if ddict["event"] == "markerMoved":
+ label = ddict["label"]
+ roiID = self._markersHandler.getRoiID(markerID=label)
+ if roiID is not None:
+ # avoid several emission of sigROISignal
+ old = self.blockSignals(True)
+ self._markersHandler.changePosition(markerID=label, x=ddict["x"])
+ self.blockSignals(old)
+ self._updateRoiInfo(roiID)
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
+
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ assert self.plot
+ if self._isConnected is False:
+ self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ self._isConnected = True
+ self.calculateRois()
+ else:
+ if self._isConnected:
+ self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.disconnect(self._activeCurveChanged)
+ self._isConnected = False
+
+ def _activeCurveChanged(self, curve):
+ self.calculateRois()
+
+ def setCountsVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX["Raw Counts"])
+ self.showColumn(self.COLUMNS_INDEX["Net Counts"])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX["Raw Counts"])
+ self.hideColumn(self.COLUMNS_INDEX["Net Counts"])
+
+ def setAreaVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX["Raw Area"])
+ self.showColumn(self.COLUMNS_INDEX["Net Area"])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX["Raw Area"])
+ self.hideColumn(self.COLUMNS_INDEX["Net Area"])
+
+ def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None):
+ """
+ This function API is kept for compatibility.
+ But `setRois` should be preferred.
+
+ Set the ROIs by providing a list of ROI names and a dictionary
+ of ROI information for each ROI.
+ The ROI names must match an existing dictionary key.
+ The name list is used to provide an order for the ROIs.
+ The dictionary's values are sub-dictionaries containing 3
+ mandatory fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+ :param roilist: List of ROI names (keys of roidict)
+ :type roilist: List
+ :param dict roidict: Dict of ROI information
+ :param currentroi: Name of the selected ROI or None (no selection)
+ """
+ if roidict is not None:
+ self.setRois(roidict)
+ else:
+ self.setRois(roilist)
+ if currentroi:
+ self.setActiveRoi(currentroi)
+
+
+_indexNextROI = 0
+
+
+class ROI(_RegionOfInterestBase):
+ """The Region Of Interest is defined by:
+
+ - A name
+ - A type. The type is the label of the x axis. This can be used to apply or
+ not some ROI to a curve and do some post processing.
+ - The x coordinate of the left limit (fromdata)
+ - The x coordinate of the right limit (todata)
+
+ :param str: name of the ROI
+ :param fromdata: left limit of the roi
+ :param todata: right limit of the roi
+ :param type: type of the ROI
+ """
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the ROI is edited"""
+
+ def __init__(self, name, fromdata=None, todata=None, type_=None):
+ _RegionOfInterestBase.__init__(self)
+ self.setName(name)
+ global _indexNextROI
+ self._id = _indexNextROI
+ _indexNextROI += 1
+
+ self._fromdata = fromdata
+ self._todata = todata
+ self._type = type_ or "Default"
+
+ self.sigItemChanged.connect(self.__itemChanged)
+
+ def __itemChanged(self, event):
+ """Handle name change"""
+ if event == items.ItemChangedType.NAME:
+ self.sigChanged.emit()
+
+ def getID(self):
+ """
+
+ :return int: the unique ID of the ROI
+ """
+ return self._id
+
+ def setType(self, type_):
+ """
+
+ :param str type_:
+ """
+ if self._type != type_:
+ self._type = type_
+ self.sigChanged.emit()
+
+ def getType(self):
+ """
+
+ :return str: the type of the ROI.
+ """
+ return self._type
+
+ def setFrom(self, frm):
+ """
+
+ :param frm: set x coordinate of the left limit
+ """
+ if self._fromdata != frm:
+ self._fromdata = frm
+ self.sigChanged.emit()
+
+ def getFrom(self):
+ """
+
+ :return: x coordinate of the left limit
+ """
+ return self._fromdata
+
+ def setTo(self, to):
+ """
+
+ :param to: x coordinate of the right limit
+ """
+ if self._todata != to:
+ self._todata = to
+ self.sigChanged.emit()
+
+ def getTo(self):
+ """
+
+ :return: x coordinate of the right limit
+ """
+ return self._todata
+
+ def getMiddle(self):
+ """
+
+ :return: middle position between 'from' and 'to' values
+ """
+ return 0.5 * (self.getFrom() + self.getTo())
+
+ def toDict(self):
+ """
+
+ :return: dict containing the roi parameters
+ """
+ ddict = {
+ "type": self._type,
+ "name": self.getName(),
+ "from": self._fromdata,
+ "to": self._todata,
+ }
+ if hasattr(self, "_extraInfo"):
+ ddict.update(self._extraInfo)
+ return ddict
+
+ @staticmethod
+ def _fromDict(dic):
+ assert "name" in dic
+ roi = ROI(name=dic["name"])
+ roi._extraInfo = {}
+ for key in dic:
+ if key == "from":
+ roi.setFrom(dic["from"])
+ elif key == "to":
+ roi.setTo(dic["to"])
+ elif key == "type":
+ roi.setType(dic["type"])
+ else:
+ roi._extraInfo[key] = dic[key]
+
+ return roi
+
+ def isICR(self):
+ """
+
+ :return: True if the ROI is the `ICR`
+ """
+ return self.getName() == "ICR"
+
+ def computeRawAndNetCounts(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw count: Points values sum of the curve in the defined Region Of
+ Interest.
+
+ .. image:: img/rawCounts.png
+
+ - Net count: Raw counts minus background
+
+ .. image:: img/netCounts.png
+
+ :param CurveItem curve:
+ :return tuple: rawCount, netCount
+ """
+ assert isinstance(curve, items.Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ idx = numpy.nonzero((self._fromdata <= x) & (x <= self._todata))[0]
+ if len(idx):
+ xw = x[idx]
+ yw = y[idx]
+ rawCounts = yw.sum(dtype=numpy.float64)
+ deltaX = xw[-1] - xw[0]
+ deltaY = yw[-1] - yw[0]
+ if deltaX > 0.0:
+ slope = deltaY / deltaX
+ background = yw[0] + slope * (xw - xw[0])
+ netCounts = rawCounts - background.sum(dtype=numpy.float64)
+ else:
+ netCounts = 0.0
+ else:
+ rawCounts = 0.0
+ netCounts = 0.0
+ return rawCounts, netCounts
+
+ def computeRawAndNetArea(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw area: integral of the curve between the min ROI point and the
+ max ROI point to the y = 0 line.
+
+ .. image:: img/rawArea.png
+
+ - Net area: Raw counts minus background
+
+ .. image:: img/netArea.png
+
+ :param CurveItem curve:
+ :return tuple: rawArea, netArea
+ """
+ assert isinstance(curve, items.Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ y = y[(x >= self._fromdata) & (x <= self._todata)]
+ x = x[(x >= self._fromdata) & (x <= self._todata)]
+
+ if x.size == 0:
+ return 0.0, 0.0
+
+ rawArea = numpy.trapz(y, x=x)
+ # to speed up and avoid an intersection calculation we are taking the
+ # closest index to the ROI
+ closestXLeftIndex = (numpy.abs(x - self.getFrom())).argmin()
+ closestXRightIndex = (numpy.abs(x - self.getTo())).argmin()
+ yBackground = y[closestXLeftIndex], y[closestXRightIndex]
+ background = numpy.trapz(yBackground, x=x)
+ netArea = rawArea - background
+ return rawArea, netArea
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ return self._fromdata <= position[0] <= self._todata
+
+
+class _RoiMarkerManager(object):
+ """
+ Deal with all the ROI markers
+ """
+
+ def __init__(self):
+ self._roiMarkerHandlers = {}
+ self._middleROIMarkerFlag = False
+ self._showAllMarkers = False
+ self._activeRoi = None
+
+ def setActiveRoi(self, roi):
+ self._activeRoi = roi
+ self.updateAllMarkers()
+
+ def setShowAllMarkers(self, show):
+ if show != self._showAllMarkers:
+ self._showAllMarkers = show
+ self.updateAllMarkers()
+
+ def add(self, roi, markersHandler):
+ assert isinstance(roi, ROI)
+ assert isinstance(markersHandler, _RoiMarkerHandler)
+ if roi.getID() in self._roiMarkerHandlers:
+ raise ValueError("roi with the same ID already existing")
+ else:
+ self._roiMarkerHandlers[roi.getID()] = markersHandler
+
+ def getMarkerHandler(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ return self._roiMarkerHandlers[roiID]
+ else:
+ return None
+
+ def clear(self):
+ roisHandler = list(self._roiMarkerHandlers.values())
+ for roiHandler in roisHandler:
+ self.remove(roiHandler.roi)
+
+ def remove(self, roi):
+ if roi is None:
+ return
+ assert isinstance(roi, ROI)
+ if roi.getID() in self._roiMarkerHandlers:
+ self._roiMarkerHandlers[roi.getID()].clear()
+ del self._roiMarkerHandlers[roi.getID()]
+
+ def hasMarker(self, markerID):
+ assert type(markerID) is str
+ return self.getMarker(markerID) is not None
+
+ def changePosition(self, markerID, x):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError("Marker %s not register" % markerID)
+ markerHandler.changePosition(markerID=markerID, x=x)
+
+ def updateMarker(self, markerID):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError("Marker %s not register" % markerID)
+ roiID = self.getRoiID(markerID)
+ visible = (
+ self._activeRoi and self._activeRoi.getID() == roiID
+ ) or self._showAllMarkers is True
+ markerHandler.setVisible(visible)
+ markerHandler.updateAllMarkers()
+
+ def updateRoiMarkers(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ visible = (
+ self._activeRoi and self._activeRoi.getID() == roiID
+ ) or self._showAllMarkers is True
+ _roi = self._roiMarkerHandlers[roiID]._roi()
+ if _roi and not _roi.isICR():
+ self._roiMarkerHandlers[roiID].showMiddleMarker(
+ self._middleROIMarkerFlag
+ )
+ self._roiMarkerHandlers[roiID].setVisible(visible)
+ self._roiMarkerHandlers[roiID].updateMarkers()
+
+ def getMarker(self, markerID):
+ assert type(markerID) is str
+ for marker in list(self._roiMarkerHandlers.values()):
+ if marker.hasMarker(markerID):
+ return marker
+
+ def updateMarkers(self):
+ for markerHandler in list(self._roiMarkerHandlers.values()):
+ markerHandler.updateMarkers()
+
+ def getRoiID(self, markerID):
+ for roiID, markerHandler in self._roiMarkerHandlers.items():
+ if markerHandler.hasMarker(markerID):
+ return roiID
+ return None
+
+ def setShowMiddleMarkers(self, show):
+ self._middleROIMarkerFlag = show
+ self._roiMarkerHandlers.updateAllMarkers()
+
+ def updateAllMarkers(self):
+ for roiID in self._roiMarkerHandlers:
+ self.updateRoiMarkers(roiID)
+
+ def getVisibleRois(self):
+ res = {}
+ for roiID, roiHandler in self._roiMarkerHandlers.items():
+ markers = (
+ roiHandler.getMarker("min"),
+ roiHandler.getMarker("max"),
+ roiHandler.getMarker("middle"),
+ )
+ for marker in markers:
+ if marker.isVisible():
+ if roiID not in res:
+ res[roiID] = []
+ res[roiID].append(marker)
+ return res
+
+
+class _RoiMarkerHandler(object):
+ """Used to deal with ROI markers used in ROITable"""
+
+ def __init__(self, roi, plot):
+ assert roi and isinstance(roi, ROI)
+ assert plot
+
+ self._roi = weakref.ref(roi)
+ self._plot = weakref.ref(plot)
+ self._draggable = False if roi.isICR() else True
+ self._color = "black" if roi.isICR() else "blue"
+ self._displayMidMarker = False
+ self._visible = True
+
+ @property
+ def draggable(self):
+ return self._draggable
+
+ @property
+ def plot(self):
+ return self._plot()
+
+ def clear(self):
+ if self.plot and self.roi:
+ self.plot.removeMarker(self._markerID("min"))
+ self.plot.removeMarker(self._markerID("max"))
+ self.plot.removeMarker(self._markerID("middle"))
+
+ @property
+ def roi(self):
+ return self._roi()
+
+ def setVisible(self, visible):
+ if visible != self._visible:
+ self._visible = visible
+ self.updateMarkers()
+
+ def showMiddleMarker(self, visible):
+ if self.draggable is False and visible is True:
+ _logger.warning("ROI is not draggable. Won't display middle marker")
+ return
+ self._displayMidMarker = visible
+ self.getMarker("middle").setVisible(self._displayMidMarker)
+
+ def updateMarkers(self):
+ if self.roi is None:
+ return
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+ self._updateMiddleMarkerPos()
+
+ def _updateMinMarkerPos(self):
+ self.getMarker("min").setPosition(x=self.roi.getFrom(), y=None)
+ self.getMarker("min").setVisible(self._visible)
+
+ def _updateMaxMarkerPos(self):
+ self.getMarker("max").setPosition(x=self.roi.getTo(), y=None)
+ self.getMarker("max").setVisible(self._visible)
+
+ def _updateMiddleMarkerPos(self):
+ self.getMarker("middle").setPosition(x=self.roi.getMiddle(), y=None)
+ self.getMarker("middle").setVisible(self._displayMidMarker and self._visible)
+
+ def getMarker(self, markerType):
+ if self.plot is None:
+ return None
+ assert markerType in ("min", "max", "middle")
+ if self.plot._getMarker(self._markerID(markerType)) is None:
+ assert self.roi
+ if markerType == "min":
+ val = self.roi.getFrom()
+ elif markerType == "max":
+ val = self.roi.getTo()
+ else:
+ val = self.roi.getMiddle()
+
+ _color = self._color
+ if markerType == "middle":
+ _color = "yellow"
+ self.plot.addXMarker(
+ val,
+ legend=self._markerID(markerType),
+ text=self.getMarkerName(markerType),
+ color=_color,
+ draggable=self.draggable,
+ )
+ return self.plot._getMarker(self._markerID(markerType))
+
+ def _markerID(self, markerType):
+ assert markerType in ("min", "max", "middle")
+ assert self.roi
+ return "_".join((str(self.roi.getID()), markerType))
+
+ def getMarkerName(self, markerType):
+ assert markerType in ("min", "max", "middle")
+ assert self.roi
+ return " ".join((self.roi.getName(), markerType))
+
+ def updateTexts(self):
+ self.getMarker("min").setText(self.getMarkerName("min"))
+ self.getMarker("max").setText(self.getMarkerName("max"))
+ self.getMarker("middle").setText(self.getMarkerName("middle"))
+
+ def changePosition(self, markerID, x):
+ assert self.hasMarker(markerID)
+ markerType = self._getMarkerType(markerID)
+ assert markerType is not None
+ if self.roi is None:
+ return
+ if markerType == "min":
+ self.roi.setFrom(x)
+ self._updateMiddleMarkerPos()
+ elif markerType == "max":
+ self.roi.setTo(x)
+ self._updateMiddleMarkerPos()
+ else:
+ delta = x - 0.5 * (self.roi.getFrom() + self.roi.getTo())
+ self.roi.setFrom(self.roi.getFrom() + delta)
+ self.roi.setTo(self.roi.getTo() + delta)
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+
+ def hasMarker(self, marker):
+ return marker in (
+ self._markerID("min"),
+ self._markerID("max"),
+ self._markerID("middle"),
+ )
+
+ def _getMarkerType(self, markerID):
+ if markerID.endswith("_min"):
+ return "min"
+ elif markerID.endswith("_max"):
+ return "max"
+ elif markerID.endswith("_middle"):
+ return "middle"
+ else:
+ return None
+
+
+class CurvesROIDockWidget(qt.QDockWidget):
+ """QDockWidget with a :class:`CurvesROIWidget` connected to a PlotWindow.
+
+ It makes the link between the :class:`CurvesROIWidget` and the PlotWindow.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: :class:`.PlotWindow` instance on which to operate
+ :param name: See :class:`QDockWidget`
+ """
+
+ sigROISignal = qt.Signal(object)
+ """Deprecated signal for backward compatibility with silx < 0.7.
+ Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal`
+ """
+
+ def __init__(self, parent=None, plot=None, name=None):
+ super(CurvesROIDockWidget, self).__init__(name, parent)
+
+ assert plot is not None
+ self.plot = plot
+ self.roiWidget = CurvesROIWidget(self, name, plot=plot)
+ """Main widget of type :class:`CurvesROIWidget`"""
+
+ # convenience methods to offer a simpler API allowing to ignore
+ # the details of the underlying implementation
+ # (ALL DEPRECATED)
+ self.calculateROIs = self.calculateRois = self.roiWidget.calculateRois
+ self.setRois = self.roiWidget.setRois
+ self.getRois = self.roiWidget.getRois
+
+ self.roiWidget.sigROISignal.connect(self._forwardSigROISignal)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self.roiWidget)
+
+ self.setAreaVisible = self.roiWidget.roiTable.setAreaVisible
+ self.setCountsVisible = self.roiWidget.roiTable.setCountsVisible
+
+ def _forwardSigROISignal(self, ddict):
+ # emit deprecated signal for backward compatibility (silx < 0.7)
+ self.sigROISignal.emit(ddict)
+
+ def toggleViewAction(self):
+ """Returns a checkable action that shows or closes this widget.
+
+ See :class:`QMainWindow`.
+ """
+ action = super(CurvesROIDockWidget, self).toggleViewAction()
+ action.setIcon(icons.getQIcon("plot-roi"))
+ return action
+
+ @property
+ def currentROI(self):
+ return self.roiWidget.currentRoi
diff --git a/src/silx/gui/plot/ImageStack.py b/src/silx/gui/plot/ImageStack.py
new file mode 100644
index 0000000..175d6e4
--- /dev/null
+++ b/src/silx/gui/plot/ImageStack.py
@@ -0,0 +1,583 @@
+# /*##########################################################################
+#
+# Copyright (c) 2020-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Image stack view with data prefetch capabilty."""
+
+from __future__ import annotations
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "04/03/2019"
+
+
+from silx.gui import qt
+from silx.gui.plot import Plot2D
+from silx.io.url import DataUrl
+from silx.io.utils import get_data
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+from silx.gui.widgets.UrlList import UrlList
+from silx.gui.utils import blockSignals
+from silx.utils.deprecation import deprecated
+
+import typing
+import logging
+from silx.gui.widgets.WaitingOverlay import WaitingOverlay
+from collections.abc import Iterable
+
+_logger = logging.getLogger(__name__)
+
+
+class _HorizontalSlider(HorizontalSliderWithBrowser):
+ sigCurrentUrlIndexChanged = qt.Signal(int)
+
+ def __init__(self, parent):
+ super().__init__(parent=parent)
+ # connect signal / slot
+ self.valueChanged.connect(self._urlChanged)
+
+ def setUrlIndex(self, index):
+ self.setValue(index)
+ self.sigCurrentUrlIndexChanged.emit(index)
+
+ def _urlChanged(self, value):
+ self.sigCurrentUrlIndexChanged.emit(value)
+
+
+class _ToggleableUrlSelectionTable(qt.QWidget):
+ _BUTTON_ICON = qt.QStyle.SP_ToolBarHorizontalExtensionButton # noqa
+
+ sigCurrentUrlChanged = qt.Signal(str)
+ """Signal emitted when the active/current url change"""
+
+ sigUrlRemoved = qt.Signal(str)
+
+ def __init__(self, parent=None) -> None:
+ super().__init__(parent)
+ self.setLayout(qt.QGridLayout())
+ self._toggleButton = qt.QPushButton(parent=self)
+ self.layout().addWidget(self._toggleButton, 0, 2, 1, 1)
+ self._toggleButton.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
+
+ self._urlsTable = UrlList(parent=self)
+
+ self.layout().addWidget(self._urlsTable, 1, 1, 1, 2)
+
+ # set up
+ self._setButtonIcon(show=True)
+
+ # Signal / slot connection
+ self._toggleButton.clicked.connect(self.toggleUrlSelectionTable)
+ self._urlsTable.sigCurrentUrlChanged.connect(self.sigCurrentUrlChanged)
+ self._urlsTable.sigUrlRemoved.connect(self.sigUrlRemoved)
+
+ def toggleUrlSelectionTable(self):
+ visible = not self.urlSelectionTableIsVisible()
+ self._setButtonIcon(show=visible)
+ self._urlsTable.setVisible(visible)
+
+ def _setButtonIcon(self, show):
+ style = qt.QApplication.instance().style()
+ # return a QIcon
+ icon = style.standardIcon(self._BUTTON_ICON)
+ if show is False:
+ pixmap = icon.pixmap(32, 32).transformed(qt.QTransform().scale(-1, 1))
+ icon = qt.QIcon(pixmap)
+ self._toggleButton.setIcon(icon)
+
+ def urlSelectionTableIsVisible(self):
+ return self._urlsTable.isVisibleTo(self)
+
+ def clear(self):
+ self._urlsTable.clear()
+
+ # expose UrlList API
+ @deprecated(replacement="addUrls", since_version="2.0")
+ def setUrls(self, urls: Iterable[DataUrl]):
+ self._urlsTable.addUrls(urls=urls)
+
+ def addUrls(self, urls: Iterable[DataUrl]):
+ self._urlsTable.addUrls(urls=urls)
+
+ def setUrl(self, url: typing.Optional[DataUrl]):
+ self._urlsTable.setUrl(url=url)
+
+ def removeUrl(self, url: str):
+ self._urlsTable.removeUrl(url)
+
+ def currentItem(self):
+ return self._urlsTable.currentItem()
+
+
+class UrlLoader(qt.QThread):
+ """
+ Thread use to load DataUrl
+ """
+
+ def __init__(self, parent, url):
+ super().__init__(parent=parent)
+ assert isinstance(url, DataUrl)
+ self.url = url
+ self.data = None
+
+ def run(self):
+ try:
+ self.data = get_data(self.url)
+ except IOError:
+ self.data = None
+
+
+class ImageStack(qt.QMainWindow):
+ """Widget loading on the fly images contained the given urls.
+
+ It prefetches images close to the displayed one.
+ """
+
+ N_PRELOAD = 10
+
+ sigLoaded = qt.Signal(str)
+ """Signal emitted when new data is available"""
+
+ sigCurrentUrlChanged = qt.Signal(str)
+ """Signal emitted when the current url change"""
+
+ def __init__(self, parent=None) -> None:
+ super().__init__(parent)
+ self.__n_prefetch = ImageStack.N_PRELOAD
+ self._loadingThreads = []
+ self.setWindowFlags(qt.Qt.Widget)
+ self._current_url = None
+ self._url_loader = UrlLoader
+ "class to instantiate for loading urls"
+ self._autoResetZoom = True
+
+ # main widget
+ self._plot = Plot2D(parent=self)
+ self._plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self._waitingOverlay = WaitingOverlay(self._plot)
+ self._waitingOverlay.setIconSize(qt.QSize(30, 30))
+ self._waitingOverlay.hide()
+ self.setWindowTitle("Image stack")
+ self.setCentralWidget(self._plot)
+
+ # dock widget: url table
+ self._tableDockWidget = qt.QDockWidget(parent=self)
+ self._urlsTable = _ToggleableUrlSelectionTable(parent=self)
+ self._tableDockWidget.setWidget(self._urlsTable)
+ self._tableDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable)
+ self.addDockWidget(qt.Qt.RightDockWidgetArea, self._tableDockWidget)
+ # dock widget: qslider
+ self._sliderDockWidget = qt.QDockWidget(parent=self)
+ self._slider = _HorizontalSlider(parent=self)
+ self._sliderDockWidget.setWidget(self._slider)
+ self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._sliderDockWidget)
+ self._sliderDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable)
+
+ self.reset()
+
+ # connect signal / slot
+ self._urlsTable.sigCurrentUrlChanged.connect(self.setCurrentUrl)
+ self._urlsTable.sigUrlRemoved.connect(self.removeUrl)
+ self._slider.sigCurrentUrlIndexChanged.connect(self.setCurrentUrlIndex)
+
+ def close(self) -> bool:
+ self._freeLoadingThreads()
+ self._waitingOverlay.close()
+ self._plot.close()
+ super().close()
+
+ def setUrlLoaderClass(self, urlLoader: typing.Type[UrlLoader]) -> None:
+ """
+
+ :param urlLoader: define the class to call for loading urls.
+ warning: this should be a class object and not a
+ class instance.
+ """
+ assert isinstance(urlLoader, type(UrlLoader))
+ self._url_loader = urlLoader
+
+ def getUrlLoaderClass(self):
+ """
+
+ :return: class to instantiate for loading urls
+ :rtype: typing.Type[UrlLoader]
+ """
+ return self._url_loader
+
+ def _freeLoadingThreads(self):
+ for thread in self._loadingThreads:
+ thread.blockSignals(True)
+ thread.wait(5)
+ self._loadingThreads.clear()
+
+ def getPlotWidget(self) -> Plot2D:
+ """
+ Returns the PlotWidget contained in this window
+
+ :return: PlotWidget contained in this window
+ :rtype: Plot2D
+ """
+ return self._plot
+
+ def reset(self) -> None:
+ """Clear the plot and remove any link to url"""
+ self._freeLoadingThreads()
+ self._urls = None
+ self._urlIndexes = None
+ self._urlData = {}
+ self._current_url = None
+ self._plot.clear()
+ self._urlsTable.clear()
+ self._slider.setMaximum(-1)
+
+ def _preFetch(self, urls: list) -> None:
+ """Pre-fetch the given urls if necessary
+
+ :param urls: list of DataUrl to prefetch
+ :type: list
+ """
+ for url in urls:
+ if url.path() not in self._urlData:
+ self._load(url)
+
+ def _load(self, url):
+ """
+ Launch background load of a DataUrl
+
+ :param url:
+ :type: DataUrl
+ """
+ assert isinstance(url, DataUrl)
+ url_path = url.path()
+ assert url_path in self._urlIndexes
+ loader = self._url_loader(parent=self, url=url)
+ loader.finished.connect(self._urlLoaded, qt.Qt.QueuedConnection)
+ self._loadingThreads.append(loader)
+ loader.start()
+
+ def _urlLoaded(self) -> None:
+ """
+
+ :param url: restul of DataUrl.path() function
+ :return:
+ """
+ sender = self.sender()
+ assert isinstance(sender, UrlLoader)
+ url = sender.url.path()
+ if url in self._urlIndexes:
+ self._urlData[url] = sender.data
+ if self.getCurrentUrl().path() == url:
+ self._waitingOverlay.setVisible(False)
+ self._plot.addImage(self._urlData[url], resetzoom=self._autoResetZoom)
+ if sender in self._loadingThreads:
+ self._loadingThreads.remove(sender)
+ self.sigLoaded.emit(url)
+
+ def setNPrefetch(self, n: int) -> None:
+ """
+ Define the number of url to prefetch around
+
+ :param int n: number of url to prefetch on left and right sides.
+ In total n*2 DataUrl will be prefetch
+ """
+ self.__n_prefetch = n
+ current_url = self.getCurrentUrl()
+ if current_url is not None:
+ self.setCurrentUrl(current_url)
+
+ def getNPrefetch(self) -> int:
+ """
+
+ :return: number of url to prefetch on left and right sides. In total
+ will load 2* NPrefetch DataUrls
+ """
+ return self.__n_prefetch
+
+ def setUrlsEditable(self, editable: bool):
+ self._urlsTable._urlsTable.setEditable(editable)
+ if editable:
+ selection_mode = qt.QAbstractItemView.ExtendedSelection
+ else:
+ selection_mode = qt.QAbstractItemView.SingleSelection
+ self._urlsTable._urlsTable.setSelectionMode(selection_mode)
+
+ @staticmethod
+ def createUrlIndexes(urls: tuple):
+ indexes = {}
+ for index, url in enumerate(urls):
+ assert isinstance(
+ url, DataUrl
+ ), f"url is expected to be a DataUrl. Get {type(url)}"
+ indexes[index] = url
+ return indexes
+
+ def _resetSlider(self):
+ with blockSignals(self._slider):
+ self._slider.setMinimum(0)
+ self._slider.setMaximum(len(self._urls) - 1)
+
+ def setUrls(self, urls: list) -> None:
+ """list of urls within an index. Warning: urls should contain an image
+ compatible with the silx.gui.plot.Plot class
+
+ :param urls: urls we want to set in the stack. Key is the index
+ (position in the stack), value is the DataUrl
+ :type: list
+ """
+ urls_with_indexes = self.createUrlIndexes(urls=urls)
+ urlsToIndex = self._urlsToIndex(urls_with_indexes)
+ self.reset()
+ self._urls = urls_with_indexes
+ self._urlIndexes = urlsToIndex
+
+ with blockSignals(self._urlsTable):
+ self._urlsTable.addUrls(urls=list(self._urls.values()))
+
+ self._resetSlider()
+
+ if self.getCurrentUrl() in self._urls:
+ self.setCurrentUrl(self.getCurrentUrl())
+ else:
+ if len(self._urls.keys()) > 0:
+ first_url = self._urls[list(self._urls.keys())[0]]
+ self.setCurrentUrl(first_url)
+
+ def removeUrl(self, url: str) -> None:
+ """
+ Remove provided URL from the table
+
+ :param url: URL as str
+ """
+ # remove the given urls from self._urls and self._urlIndexes
+ if not isinstance(url, str):
+ raise TypeError("url is expected to be the str representation of the url")
+
+ # try to get reset the url displayed
+ current_url = self.getCurrentUrl()
+ with blockSignals(self._urlsTable):
+ self._urlsTable.removeUrl(url)
+ # update urls
+ urls_with_indexes = self.createUrlIndexes(
+ filter(
+ lambda a: a.path() != url,
+ self._urls.values(),
+ )
+ )
+ urlsToIndex = self._urlsToIndex(urls_with_indexes)
+ self._urls = urls_with_indexes
+ self._urlIndexes = urlsToIndex
+ self._resetSlider()
+
+ if current_url != url:
+ self.setCurrentUrl(current_url)
+
+ def getUrls(self) -> tuple:
+ """
+
+ :return: tuple of urls
+ :rtype: tuple
+ """
+ return tuple(self._urlIndexes.keys())
+
+ def _getNextUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]:
+ """
+ return the next url in the stack
+
+ :param url: url for which we want the next url
+ :type: DataUrl
+ :return: next url in the stack or None if `url` is the last one
+ :rtype: Union[None, DataUrl]
+ """
+ assert isinstance(url, DataUrl)
+ if self._urls is None:
+ return None
+ else:
+ index = self._urlIndexes[url.path()]
+ indexes = list(self._urls.keys())
+ res = list(filter(lambda x: x > index, indexes))
+ if len(res) == 0:
+ return None
+ else:
+ return self._urls[res[0]]
+
+ def _getPreviousUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]:
+ """
+ return the previous url in the stack
+
+ :param url: url for which we want the previous url
+ :type: DataUrl
+ :return: next url in the stack or None if `url` is the last one
+ :rtype: Union[None, DataUrl]
+ """
+ if self._urls is None:
+ return None
+ else:
+ index = self._urlIndexes[url.path()]
+ indexes = list(self._urls.keys())
+ res = list(filter(lambda x: x < index, indexes))
+ if len(res) == 0:
+ return None
+ else:
+ return self._urls[res[-1]]
+
+ def _getNNextUrls(self, n: int, url: DataUrl) -> list:
+ """
+ Deduce the next urls in the stack after `url`
+
+ :param n: the number of url store after `url`
+ :type: int
+ :param url: url for which we want n next url
+ :type: DataUrl
+ :return: list of next urls.
+ :rtype: list
+ """
+ res = []
+ next_free = self._getNextUrl(url=url)
+ while len(res) < n and next_free is not None:
+ assert isinstance(next_free, DataUrl)
+ res.append(next_free)
+ next_free = self._getNextUrl(res[-1])
+ return res
+
+ def _getNPreviousUrls(self, n: int, url: DataUrl):
+ """
+ Deduce the previous urls in the stack after `url`
+
+ :param n: the number of url store after `url`
+ :type: int
+ :param url: url for which we want n previous url
+ :type: DataUrl
+ :return: list of previous urls.
+ :rtype: list
+ """
+ res = []
+ next_free = self._getPreviousUrl(url=url)
+ while len(res) < n and next_free is not None:
+ res.insert(0, next_free)
+ next_free = self._getPreviousUrl(res[0])
+ return res
+
+ def setCurrentUrlIndex(self, index: int):
+ """
+ Define the url to be displayed
+
+ :param index: url to be displayed
+ :type: int
+ """
+ if index < 0:
+ return
+ if self._urls is None:
+ return
+ elif index >= len(self._urls):
+ raise ValueError("requested index out of bounds")
+ else:
+ return self.setCurrentUrl(self._urls[index])
+
+ def setCurrentUrl(self, url: typing.Optional[typing.Union[DataUrl, str]]) -> None:
+ """
+ Define the url to be displayed
+
+ :param url: url to be displayed
+ :type: DataUrl
+ :raises KeyError: raised if the url is not know
+ """
+ assert isinstance(url, (DataUrl, str, type(None)))
+ if url == "":
+ url = None
+ elif isinstance(url, str):
+ url = DataUrl(path=url)
+ if url is not None and url != self._current_url:
+ self._current_url = url
+ self.sigCurrentUrlChanged.emit(url.path())
+
+ with blockSignals(self._urlsTable):
+ with blockSignals(self._slider):
+ self._urlsTable.setUrl(url)
+ if url is not None:
+ self._slider.setUrlIndex(self._urlIndexes[url.path()])
+ if self._current_url is None:
+ self._plot.clear()
+ else:
+ if self._current_url.path() in self._urlData:
+ self._waitingOverlay.setVisible(False)
+ self._plot.addImage(
+ self._urlData[url.path()], resetzoom=self._autoResetZoom
+ )
+ else:
+ self._plot.clear()
+ self._load(url)
+ self._waitingOverlay.setVisible(True)
+ self._preFetch(self._getNNextUrls(self.__n_prefetch, url))
+ self._preFetch(self._getNPreviousUrls(self.__n_prefetch, url))
+
+ def getCurrentUrl(self) -> typing.Union[None, DataUrl]:
+ """
+
+ :return: url currently displayed
+ :rtype: Union[None, DataUrl]
+ """
+ return self._current_url
+
+ def getCurrentUrlIndex(self) -> typing.Union[None, int]:
+ """
+
+ :return: index of the url currently displayed
+ :rtype: Union[None, int]
+ """
+ if self._current_url is None:
+ return None
+ else:
+ return self._urlIndexes[self._current_url.path()]
+
+ @staticmethod
+ def _urlsToIndex(urls):
+ """util, return a dictionary with url as key and index as value"""
+ res = {}
+ for index, url in urls.items():
+ res[url.path()] = index
+ return res
+
+ def setAutoResetZoom(self, reset):
+ """
+ Should we reset the zoom when adding an image (eq. when browsing)
+
+ :param bool reset:
+ """
+ self._autoResetZoom = reset
+ if self._autoResetZoom:
+ self._plot.resetZoom()
+
+ def isAutoResetZoom(self) -> bool:
+ """
+
+ :return: True if a reset is done when the image change
+ :rtype: bool
+ """
+ return self._autoResetZoom
+
+ def getWaiterOverlay(self):
+ """
+
+ :return: Return the instance of `WaitingOverlay` used to display if processing or not
+ :rtype: WaitingOverlay
+ """
+ return self._waitingOverlay
diff --git a/src/silx/gui/plot/ImageView.py b/src/silx/gui/plot/ImageView.py
new file mode 100644
index 0000000..eaca42b
--- /dev/null
+++ b/src/silx/gui/plot/ImageView.py
@@ -0,0 +1,1121 @@
+# /*##########################################################################
+#
+# Copyright (c) 2015-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying a 2D image with histograms on its sides.
+
+The :class:`ImageView` implements this widget, and
+:class:`ImageViewMainWindow` provides a main window with additional toolbar
+and status bar.
+
+Basic usage of :class:`ImageView` is through the following methods:
+
+- :meth:`ImageView.getColormap`, :meth:`ImageView.setColormap` to update the
+ default colormap to use and update the currently displayed image.
+- :meth:`ImageView.setImage` to update the displayed image.
+
+For an example of use, see `imageview.py` in :ref:`sample-code`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/04/2018"
+
+
+import logging
+import numpy
+import collections
+from typing import Union
+
+import silx
+from .. import qt
+from .. import colors
+from .. import icons
+
+from . import items, PlotWindow, PlotWidget, actions
+from ..colors import Colormap
+from ..colors import cursorColorForColormap
+from .tools import LimitsToolBar
+from .Profile import ProfileToolBar
+from ...utils.proxy import docstring
+from ...utils.deprecation import deprecated
+from ...utils.enum import Enum
+from .tools.RadarView import RadarView
+from .utils.axis import SyncAxes
+from ..utils import blockSignals
+from . import _utils
+from .tools.profile import rois
+from .actions import PlotAction
+
+_logger = logging.getLogger(__name__)
+
+
+ProfileSumResult = collections.namedtuple(
+ "ProfileResult",
+ [
+ "dataXRange",
+ "dataYRange",
+ "histoH",
+ "histoHRange",
+ "histoV",
+ "histoVRange",
+ "xCoords",
+ "xData",
+ "yCoords",
+ "yData",
+ ],
+)
+
+
+def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None):
+ """
+ Compute a full vertical and horizontal profile on an image item using a
+ a range in the plot referential.
+
+ Optionally takes a previous computed result to be able to skip the
+ computation.
+
+ :rtype: ProfileSumResult
+ """
+ data = imageItem.getValueData(copy=False)
+ origin = imageItem.getOrigin()
+ scale = imageItem.getScale()
+ height, width = data.shape
+
+ xMin, xMax = xRange
+ yMin, yMax = yRange
+
+ # Convert plot area limits to image coordinates
+ # and work in image coordinates (i.e., in pixels)
+ xMin = int((xMin - origin[0]) / scale[0])
+ xMax = int((xMax - origin[0]) / scale[0])
+ yMin = int((yMin - origin[1]) / scale[1])
+ yMax = int((yMax - origin[1]) / scale[1])
+
+ if xMin >= width or xMax < 0 or yMin >= height or yMax < 0:
+ return None
+
+ # The image is at least partly in the plot area
+ # Get the visible bounds in image coords (i.e., in pixels)
+ subsetXMin = 0 if xMin < 0 else xMin
+ subsetXMax = (width if xMax >= width else xMax) + 1
+ subsetYMin = 0 if yMin < 0 else yMin
+ subsetYMax = (height if yMax >= height else yMax) + 1
+
+ if cache is not None:
+ if (subsetXMin, subsetXMax) == cache.dataXRange and (
+ subsetYMin,
+ subsetYMax,
+ ) == cache.dataYRange:
+ # The visible area of data is the same
+ return cache
+
+ # Rebuild histograms for visible area
+ visibleData = data[subsetYMin:subsetYMax, subsetXMin:subsetXMax]
+ histoHVisibleData = numpy.nansum(visibleData, axis=0)
+ histoVVisibleData = numpy.nansum(visibleData, axis=1)
+ histoHMin = numpy.nanmin(histoHVisibleData)
+ histoHMax = numpy.nanmax(histoHVisibleData)
+ histoVMin = numpy.nanmin(histoVVisibleData)
+ histoVMax = numpy.nanmax(histoVVisibleData)
+
+ # Convert to histogram curve and update plots
+ # Taking into account origin and scale
+ coords = numpy.arange(2 * histoHVisibleData.size)
+ xCoords = (coords + 1) // 2 + subsetXMin
+ xCoords = origin[0] + scale[0] * xCoords
+ xData = numpy.take(histoHVisibleData, coords // 2)
+ coords = numpy.arange(2 * histoVVisibleData.size)
+ yCoords = (coords + 1) // 2 + subsetYMin
+ yCoords = origin[1] + scale[1] * yCoords
+ yData = numpy.take(histoVVisibleData, coords // 2)
+
+ result = ProfileSumResult(
+ dataXRange=(subsetXMin, subsetXMax),
+ dataYRange=(subsetYMin, subsetYMax),
+ histoH=histoHVisibleData,
+ histoHRange=(histoHMin, histoHMax),
+ histoV=histoVVisibleData,
+ histoVRange=(histoVMin, histoVMax),
+ xCoords=xCoords,
+ xData=xData,
+ yCoords=yCoords,
+ yData=yData,
+ )
+
+ return result
+
+
+class _SideHistogram(PlotWidget):
+ """
+ Widget displaying one of the side profile of the ImageView.
+
+ Implement ProfileWindow
+ """
+
+ sigClose = qt.Signal()
+
+ sigMouseMoved = qt.Signal(float, float)
+
+ def __init__(self, parent=None, backend=None, direction=qt.Qt.Horizontal):
+ super(_SideHistogram, self).__init__(parent=parent, backend=backend)
+ self._direction = direction
+ self.sigPlotSignal.connect(self._plotEvents)
+ self._color = "blue"
+ self.__profile = None
+ self.__profileSum = None
+
+ def _plotEvents(self, eventDict):
+ """Callback for horizontal histogram plot events."""
+ if eventDict["event"] == "mouseMoved":
+ self.sigMouseMoved.emit(eventDict["x"], eventDict["y"])
+
+ def setProfileColor(self, color):
+ self._color = color
+
+ def setProfileSum(self, result):
+ self.__profileSum = result
+ if self.__profile is None:
+ self.__drawProfileSum()
+
+ def prepareWidget(self, roi):
+ """Implements `ProfileWindow`"""
+ pass
+
+ def setRoiProfile(self, roi):
+ """Implements `ProfileWindow`"""
+ if roi is None:
+ return
+ self._roiColor = colors.rgba(roi.getColor())
+
+ def getProfile(self):
+ """Implements `ProfileWindow`"""
+ return self.__profile
+
+ def setProfile(self, data):
+ """Implements `ProfileWindow`"""
+ self.__profile = data
+ if data is None:
+ self.__drawProfileSum()
+ else:
+ self.__drawProfile()
+
+ def __drawProfileSum(self):
+ """Only draw the profile sum on the plot.
+
+ Other elements are removed
+ """
+ profileSum = self.__profileSum
+
+ try:
+ self.removeCurve("profile")
+ except Exception:
+ pass
+
+ if profileSum is None:
+ try:
+ self.removeCurve("profilesum")
+ except Exception:
+ pass
+ return
+
+ if self._direction == qt.Qt.Horizontal:
+ xx, yy = profileSum.xCoords, profileSum.xData
+ elif self._direction == qt.Qt.Vertical:
+ xx, yy = profileSum.yData, profileSum.yCoords
+ else:
+ assert False
+
+ self.addCurve(
+ xx,
+ yy,
+ xlabel="",
+ ylabel="",
+ legend="profilesum",
+ color=self._color,
+ linestyle="-",
+ selectable=False,
+ resetzoom=False,
+ )
+
+ self.__updateLimits()
+
+ def __drawProfile(self):
+ """Only draw the profile on the plot.
+
+ Other elements are removed
+ """
+ profile = self.__profile
+
+ try:
+ self.removeCurve("profilesum")
+ except Exception:
+ pass
+
+ if profile is None:
+ try:
+ self.removeCurve("profile")
+ except Exception:
+ pass
+ self.setProfileSum(self.__profileSum)
+ return
+
+ if self._direction == qt.Qt.Horizontal:
+ xx, yy = profile.coords, profile.profile
+ elif self._direction == qt.Qt.Vertical:
+ xx, yy = profile.profile, profile.coords
+ else:
+ assert False
+
+ self.addCurve(xx, yy, legend="profile", color=self._roiColor, resetzoom=False)
+
+ self.__updateLimits()
+
+ def __updateLimits(self):
+ if self.__profile:
+ data = self.__profile.profile
+ vMin = numpy.nanmin(data)
+ vMax = numpy.nanmax(data)
+ elif self.__profileSum is not None:
+ if self._direction == qt.Qt.Horizontal:
+ vMin, vMax = self.__profileSum.histoHRange
+ elif self._direction == qt.Qt.Vertical:
+ vMin, vMax = self.__profileSum.histoVRange
+ else:
+ assert False
+ else:
+ vMin, vMax = 0, 0
+
+ # Tune the result using the data margins
+ margins = self.getDataMargins()
+ if self._direction == qt.Qt.Horizontal:
+ _, _, vMin, vMax = _utils.addMarginsToLimits(
+ margins, False, False, 0, 0, vMin, vMax
+ )
+ elif self._direction == qt.Qt.Vertical:
+ vMin, vMax, _, _ = _utils.addMarginsToLimits(
+ margins, False, False, vMin, vMax, 0, 0
+ )
+ else:
+ assert False
+
+ if self._direction == qt.Qt.Horizontal:
+ dataAxis = self.getYAxis()
+ elif self._direction == qt.Qt.Vertical:
+ dataAxis = self.getXAxis()
+ else:
+ assert False
+
+ with blockSignals(dataAxis):
+ dataAxis.setLimits(vMin, vMax)
+
+
+class ShowSideHistogramsAction(PlotAction):
+ """QAction to change visibility of side histogram of a :class:`.ImageView`.
+
+ :param plot: :class:`.ImageView` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ShowSideHistogramsAction, self).__init__(
+ plot,
+ icon="side-histograms",
+ text="Show/hide side histograms",
+ tooltip="Show/hide side histogram",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+
+ def _actionTriggered(self, checked=False):
+ if self.plot.isSideHistogramDisplayed() != checked:
+ self.plot.setSideHistogramDisplayed(checked)
+
+
+class AggregationModeAction(qt.QWidgetAction):
+ """Action providing few filters to the image"""
+
+ sigAggregationModeChanged = qt.Signal()
+
+ def __init__(self, parent):
+ qt.QWidgetAction.__init__(self, parent)
+
+ toolButton = qt.QToolButton(parent)
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("No filter")
+ filterAction.setCheckable(True)
+ filterAction.setChecked(True)
+ filterAction.setProperty(
+ "aggregation", items.ImageDataAggregated.Aggregation.NONE
+ )
+ densityNoFilterAction = filterAction
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("Max filter")
+ filterAction.setCheckable(True)
+ filterAction.setProperty(
+ "aggregation", items.ImageDataAggregated.Aggregation.MAX
+ )
+ densityMaxFilterAction = filterAction
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("Mean filter")
+ filterAction.setCheckable(True)
+ filterAction.setProperty(
+ "aggregation", items.ImageDataAggregated.Aggregation.MEAN
+ )
+ densityMeanFilterAction = filterAction
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("Min filter")
+ filterAction.setCheckable(True)
+ filterAction.setProperty(
+ "aggregation", items.ImageDataAggregated.Aggregation.MIN
+ )
+ densityMinFilterAction = filterAction
+
+ densityGroup = qt.QActionGroup(self)
+ densityGroup.setExclusive(True)
+ densityGroup.addAction(densityNoFilterAction)
+ densityGroup.addAction(densityMaxFilterAction)
+ densityGroup.addAction(densityMeanFilterAction)
+ densityGroup.addAction(densityMinFilterAction)
+ densityGroup.triggered.connect(self._aggregationModeChanged)
+ self.__densityGroup = densityGroup
+
+ filterMenu = qt.QMenu(toolButton)
+ filterMenu.addAction(densityNoFilterAction)
+ filterMenu.addAction(densityMaxFilterAction)
+ filterMenu.addAction(densityMeanFilterAction)
+ filterMenu.addAction(densityMinFilterAction)
+
+ toolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ toolButton.setMenu(filterMenu)
+ toolButton.setText("Data filters")
+ toolButton.setToolTip("Enable/disable filter on the image")
+ icon = icons.getQIcon("aggregation-mode")
+ toolButton.setIcon(icon)
+ toolButton.setText("Pixel aggregation filter")
+
+ self.setDefaultWidget(toolButton)
+
+ def _aggregationModeChanged(self):
+ self.sigAggregationModeChanged.emit()
+
+ def setAggregationMode(self, mode):
+ """Set an Aggregated enum from ImageDataAggregated"""
+ for a in self.__densityGroup.actions():
+ if a.property("aggregation") is mode:
+ a.setChecked(True)
+
+ def getAggregationMode(self):
+ """Returns an Aggregated enum from ImageDataAggregated"""
+ densityAction = self.__densityGroup.checkedAction()
+ if densityAction is None:
+ return items.ImageDataAggregated.Aggregation.NONE
+ return densityAction.property("aggregation")
+
+
+class ImageView(PlotWindow):
+ """Display a single image with horizontal and vertical histograms.
+
+ Use :meth:`setImage` to control the displayed image.
+ This class also provides the :class:`silx.gui.plot.Plot` API.
+
+ The :class:`ImageView` inherits from :class:`.PlotWindow` (which provides
+ the toolbars) and also exposes :class:`.PlotWidget` API for further
+ plot control (plot title, axes labels, aspect ratio, ...).
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ HISTOGRAMS_COLOR = "blue"
+ """Color to use for the side histograms."""
+
+ HISTOGRAMS_HEIGHT = 200
+ """Height in pixels of the side histograms."""
+
+ IMAGE_MIN_SIZE = 200
+ """Minimum size in pixels of the image area."""
+
+ # Qt signals
+ valueChanged = qt.Signal(float, float, object)
+ """Signals that the data value under the cursor has changed.
+
+ It provides: row, column, data value.
+
+ When the cursor is over an histogram, either row or column is Nan
+ and the provided data value is the histogram value
+ (i.e., the sum along the corresponding row/column).
+ Row and columns are either Nan or integer values.
+ """
+
+ class ProfileWindowBehavior(Enum):
+ """ImageView's profile window behavior options"""
+
+ POPUP = "popup"
+ """All profiles are displayed in pop-up windows"""
+
+ EMBEDDED = "embedded"
+ """Horizontal, vertical and cross profiles are displayed in
+ sides widgets, others are displayed in pop-up windows.
+ """
+
+ def __init__(self, parent=None, backend=None):
+ self._imageLegend = "__ImageView__image" + str(id(self))
+ self._cache = None # Store currently visible data information
+
+ super(ImageView, self).__init__(
+ parent=parent,
+ backend=backend,
+ resetzoom=True,
+ autoScale=False,
+ logScale=False,
+ grid=False,
+ curveStyle=False,
+ colormap=True,
+ aspectRatio=True,
+ yInverted=True,
+ copy=True,
+ save=True,
+ print_=True,
+ control=False,
+ position=False,
+ roi=False,
+ mask=True,
+ )
+
+ # Enable mask synchronisation to use it in profiles
+ maskToolsWidget = self.getMaskToolsDockWidget().widget()
+ maskToolsWidget.setItemMaskUpdated(True)
+
+ self.__showSideHistogramsAction = ShowSideHistogramsAction(self, self)
+ self.__showSideHistogramsAction.setChecked(True)
+
+ self.__aggregationModeAction = AggregationModeAction(self)
+ self.__aggregationModeAction.sigAggregationModeChanged.connect(
+ self._aggregationModeChanged
+ )
+
+ if parent is None:
+ self.setWindowTitle("ImageView")
+
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
+ self.getYAxis().setInverted(True)
+
+ self._initWidgets(backend)
+
+ toolBar = self.toolBar()
+ toolBar.addAction(self.__showSideHistogramsAction)
+ toolBar.addAction(self.__aggregationModeAction)
+
+ self.__profileWindowBehavior = self.ProfileWindowBehavior.POPUP
+ self.__profile = ProfileToolBar(plot=self)
+ self.addToolBar(self.__profile)
+
+ def _initWidgets(self, backend):
+ """Set-up layout and plots."""
+ self._histoHPlot = _SideHistogram(
+ backend=backend, parent=self, direction=qt.Qt.Horizontal
+ )
+ widgetHandle = self._histoHPlot.getWidgetHandle()
+ widgetHandle.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+ widgetHandle.setMaximumHeight(self.HISTOGRAMS_HEIGHT)
+ self._histoHPlot.setInteractiveMode("zoom")
+ self._histoHPlot.setDataMargins(0.0, 0.0, 0.1, 0.1)
+ self._histoHPlot.sigMouseMoved.connect(self._mouseMovedOnHistoH)
+ self._histoHPlot.setProfileColor(self.HISTOGRAMS_COLOR)
+
+ self._histoVPlot = _SideHistogram(
+ backend=backend, parent=self, direction=qt.Qt.Vertical
+ )
+ widgetHandle = self._histoVPlot.getWidgetHandle()
+ widgetHandle.setMinimumWidth(self.HISTOGRAMS_HEIGHT)
+ widgetHandle.setMaximumWidth(self.HISTOGRAMS_HEIGHT)
+ # Trick to align the histogram to the main plot
+ self._histoVPlot.setGraphTitle(" ")
+ self._histoVPlot.setInteractiveMode("zoom")
+ self._histoVPlot.setDataMargins(0.1, 0.1, 0.0, 0.0)
+ self._histoVPlot.sigMouseMoved.connect(self._mouseMovedOnHistoV)
+ self._histoVPlot.setProfileColor(self.HISTOGRAMS_COLOR)
+
+ self.setPanWithArrowKeys(True)
+ self.setInteractiveMode("zoom") # Color set in setColormap
+ self.sigPlotSignal.connect(self._imagePlotCB)
+ self.sigActiveImageChanged.connect(self._activeImageChangedSlot)
+
+ self._radarView = RadarView(parent=self)
+ self._radarView.setPlotWidget(self)
+
+ self.__syncXAxis = SyncAxes([self.getXAxis(), self._histoHPlot.getXAxis()])
+ self.__syncYAxis = SyncAxes([self.getYAxis(), self._histoVPlot.getYAxis()])
+
+ self.__setCentralWidget()
+
+ def __setCentralWidget(self):
+ """Set central widget with all its content"""
+ layout = qt.QGridLayout()
+ layout.addWidget(self.getWidgetHandle(), 0, 0)
+ layout.addWidget(self._histoVPlot, 0, 1)
+ layout.addWidget(self._histoHPlot, 1, 0)
+ layout.addWidget(self._radarView, 1, 1, 1, 2)
+ layout.addWidget(self.getColorBarWidget(), 0, 2)
+
+ self._radarView.setMinimumWidth(self.IMAGE_MIN_SIZE)
+ self._radarView.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+ self._histoHPlot.setMinimumWidth(self.IMAGE_MIN_SIZE)
+ self._histoVPlot.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+
+ layout.setColumnStretch(0, 1)
+ layout.setColumnStretch(1, 0)
+ layout.setRowStretch(0, 1)
+ layout.setRowStretch(1, 0)
+
+ layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ centralWidget = qt.QWidget(self)
+ centralWidget.setLayout(layout)
+ self.setCentralWidget(centralWidget)
+
+ @docstring(PlotWidget)
+ def setBackend(self, backend):
+ # Use PlotWidget here since we override PlotWindow behavior
+ PlotWidget.setBackend(self, backend)
+ self.__setCentralWidget()
+
+ def _dirtyCache(self):
+ self._cache = None
+
+ def getAggregationModeAction(self):
+ return self.__aggregationModeAction
+
+ def _aggregationModeChanged(self):
+ item = self._getItem("image", self._imageLegend)
+ if item is None:
+ return
+ aggregationMode = self.__aggregationModeAction.getAggregationMode()
+ if aggregationMode is not None and isinstance(item, items.ImageDataAggregated):
+ item.setAggregationMode(aggregationMode)
+ else:
+ # It means the item type have to be changed
+ self.removeImage(self._imageLegend)
+ image = item.getData(copy=False)
+ if image is None:
+ return
+ origin = item.getOrigin()
+ scale = item.getScale()
+ self.setImage(image, origin, scale, copy=False, resetzoom=False)
+
+ def getShowSideHistogramsAction(self):
+ return self.__showSideHistogramsAction
+
+ def setSideHistogramDisplayed(self, show):
+ """Display or not the side histograms"""
+ if self.isSideHistogramDisplayed() == show:
+ return
+ self._histoHPlot.setVisible(show)
+ self._histoVPlot.setVisible(show)
+ self._radarView.setVisible(show)
+ self.__showSideHistogramsAction.setChecked(show)
+ if show:
+ # Probably have to be computed
+ self._updateHistograms()
+
+ def isSideHistogramDisplayed(self):
+ """True if the side histograms are displayed"""
+ return self._histoHPlot.isVisibleTo(self)
+
+ def _updateHistograms(self):
+ """Update histograms content using current active image."""
+ if not self.isSideHistogramDisplayed():
+ # The histogram computation can be skipped
+ return
+
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ xRange = self.getXAxis().getLimits()
+ yRange = self.getYAxis().getLimits()
+ result = computeProfileSumOnRange(activeImage, xRange, yRange, self._cache)
+ self._cache = result
+ self._histoHPlot.setProfileSum(result)
+ self._histoVPlot.setProfileSum(result)
+
+ # Plots event listeners
+
+ def _imagePlotCB(self, eventDict):
+ """Callback for imageView plot events."""
+ if eventDict["event"] == "mouseMoved":
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ data = activeImage.getData(copy=False)
+ height, width = data.shape[0:2]
+
+ # Get corresponding coordinate in image
+ origin = activeImage.getOrigin()
+ scale = activeImage.getScale()
+ if eventDict["x"] >= origin[0] and eventDict["y"] >= origin[1]:
+ x = int((eventDict["x"] - origin[0]) / scale[0])
+ y = int((eventDict["y"] - origin[1]) / scale[1])
+
+ if x >= 0 and x < width and y >= 0 and y < height:
+ self.valueChanged.emit(float(x), float(y), data[y][x])
+
+ elif eventDict["event"] == "limitsChanged":
+ self._updateHistograms()
+
+ def _mouseMovedOnHistoH(self, x, y):
+ if self._cache is None:
+ return
+ activeImage = self.getActiveImage()
+ if activeImage is None:
+ return
+
+ xOrigin = activeImage.getOrigin()[0]
+ xScale = activeImage.getScale()[0]
+
+ minValue = xOrigin + xScale * self._cache.dataXRange[0]
+
+ if x >= minValue:
+ data = self._cache.histoH
+ column = int((x - minValue) / xScale)
+ if column >= 0 and column < data.shape[0]:
+ self.valueChanged.emit(
+ float("nan"),
+ float(column + self._cache.dataXRange[0]),
+ data[column],
+ )
+
+ def _mouseMovedOnHistoV(self, x, y):
+ if self._cache is None:
+ return
+ activeImage = self.getActiveImage()
+ if activeImage is None:
+ return
+
+ yOrigin = activeImage.getOrigin()[1]
+ yScale = activeImage.getScale()[1]
+
+ minValue = yOrigin + yScale * self._cache.dataYRange[0]
+
+ if y >= minValue:
+ data = self._cache.histoV
+ row = int((y - minValue) / yScale)
+ if row >= 0 and row < data.shape[0]:
+ self.valueChanged.emit(
+ float(row + self._cache.dataYRange[0]), float("nan"), data[row]
+ )
+
+ def _activeImageChangedSlot(self, previous, legend):
+ """Handle Plot active image change.
+
+ Resets side histograms cache
+ """
+ self._dirtyCache()
+ self._updateHistograms()
+
+ def setProfileWindowBehavior(self, behavior: Union[str, ProfileWindowBehavior]):
+ """Set where profile widgets are displayed.
+
+ :param ProfileWindowBehavior behavior:
+ - 'popup': All profiles are displayed in pop-up windows
+ - 'embedded': Horizontal, vertical and cross profiles are displayed in
+ sides widgets, others are displayed in pop-up windows.
+ """
+ behavior = self.ProfileWindowBehavior.from_value(behavior)
+ if behavior is not self.getProfileWindowBehavior():
+ manager = self.__profile.getProfileManager()
+ manager.clearProfile()
+ manager.requestUpdateAllProfile()
+
+ if behavior is self.ProfileWindowBehavior.EMBEDDED:
+ horizontalProfileWindow = self._histoHPlot
+ verticalProfileWindow = self._histoVPlot
+ else:
+ horizontalProfileWindow = None
+ verticalProfileWindow = None
+
+ manager.setSpecializedProfileWindow(
+ rois.ProfileImageHorizontalLineROI, horizontalProfileWindow
+ )
+ manager.setSpecializedProfileWindow(
+ rois.ProfileImageVerticalLineROI, verticalProfileWindow
+ )
+ self.__profileWindowBehavior = behavior
+
+ def getProfileWindowBehavior(self) -> ProfileWindowBehavior:
+ """Returns current profile display behavior.
+
+ See :meth:`setProfileWindowBehavior` and :class:`ProfileWindowBehavior`
+ """
+ return self.__profileWindowBehavior
+
+ def getProfileToolBar(self):
+ """Returns profile tools attached to this plot.
+
+ :rtype: silx.gui.plot.PlotTools.ProfileToolBar
+ """
+ return self.__profile
+
+ @property
+ @deprecated(replacement="getProfileToolBar()")
+ def profile(self):
+ return self.getProfileToolBar()
+
+ def getHistogram(self, axis):
+ """Return the histogram and corresponding row or column extent.
+
+ The returned value when an histogram is available is a dict with keys:
+
+ - 'data': numpy array of the histogram values.
+ - 'extent': (start, end) row or column index.
+ end index is not included in the histogram.
+
+ :param str axis: 'x' for horizontal, 'y' for vertical
+ :return: The histogram and its extent as a dict or None.
+ :rtype: dict
+ """
+ assert axis in ("x", "y")
+ if self._cache is None:
+ return None
+ else:
+ if axis == "x":
+ return dict(
+ data=numpy.array(self._cache.histoH, copy=True),
+ extent=self._cache.dataXRange,
+ )
+ else:
+ return dict(
+ data=numpy.array(self._cache.histoV, copy=True),
+ extent=(self._cache.dataYRange),
+ )
+
+ def radarView(self):
+ """Get the lower right radarView widget."""
+ return self._radarView
+
+ def setRadarView(self, radarView):
+ """Change the lower right radarView widget.
+
+ :param RadarView radarView: Widget subclassing RadarView to replace
+ the lower right corner widget.
+ """
+ self._radarView = radarView
+ self._radarView.setPlotWidget(self)
+ self.centralWidget().layout().addWidget(self._radarView, 1, 1)
+
+ # High-level API
+
+ def getColormap(self):
+ """Get the default colormap description.
+
+ :return: A description of the current colormap.
+ See :meth:`setColormap` for details.
+ :rtype: dict
+ """
+ return self.getDefaultColormap()
+
+ def setColormap(
+ self,
+ colormap=None,
+ normalization=None,
+ autoscale=None,
+ vmin=None,
+ vmax=None,
+ colors=None,
+ ):
+ """Set the default colormap and update active image.
+
+ Parameters that are not provided are taken from the current colormap.
+
+ The colormap parameter can also be a dict with the following keys:
+
+ - *name*: string. The colormap to use:
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ - *normalization*: string. The mapping to use for the colormap:
+ either 'linear' or 'log'.
+ - *autoscale*: bool. Whether to use autoscale (True)
+ or range provided by keys 'vmin' and 'vmax' (False).
+ - *vmin*: float. The minimum value of the range to use if 'autoscale'
+ is False.
+ - *vmax*: float. The maximum value of the range to use if 'autoscale'
+ is False.
+ - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
+ List of RGB or RGBA colors to use (only if name is None)
+
+ :param colormap: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ Or the description of the colormap as a dict.
+ :type colormap: dict or str.
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param bool autoscale: Whether to use autoscale (True)
+ or [vmin, vmax] range (False).
+ :param float vmin: The minimum value of the range to use if
+ 'autoscale' is False.
+ :param float vmax: The maximum value of the range to use if
+ 'autoscale' is False.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ """
+ cmap = self.getDefaultColormap()
+
+ if isinstance(colormap, Colormap):
+ # Replace colormap
+ cmap = colormap
+
+ self.setDefaultColormap(cmap)
+
+ # Update active image colormap
+ activeImage = self.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ activeImage.setColormap(cmap)
+
+ elif isinstance(colormap, dict):
+ # Support colormap parameter as a dict
+ assert normalization is None
+ assert autoscale is None
+ assert vmin is None
+ assert vmax is None
+ assert colors is None
+ cmap._setFromDict(colormap)
+
+ else:
+ if colormap is not None:
+ cmap.setName(colormap)
+ if normalization is not None:
+ cmap.setNormalization(normalization)
+ if autoscale:
+ cmap.setVRange(None, None)
+ else:
+ if vmin is not None:
+ cmap.setVMin(vmin)
+ if vmax is not None:
+ cmap.setVMax(vmax)
+ if colors is not None:
+ cmap.setColormapLUT(colors)
+
+ cursorColor = cursorColorForColormap(cmap.getName())
+ self.setInteractiveMode("zoom", color=cursorColor)
+
+ def setImage(
+ self,
+ image,
+ origin=(0, 0),
+ scale=(1.0, 1.0),
+ copy=True,
+ reset=None,
+ resetzoom=True,
+ ):
+ """Set the image to display.
+
+ :param image: A 2D array representing the image or None to empty plot.
+ :type image: numpy.ndarray-like with 2 dimensions or None.
+ :param origin: The (x, y) position of the origin of the image.
+ Default: (0, 0).
+ The origin is the lower left corner of the image when
+ the Y axis is not inverted.
+ :type origin: Tuple of 2 floats: (origin x, origin y).
+ :param scale: The scale factor to apply to the image on X and Y axes.
+ Default: (1, 1).
+ It is the size of a pixel in the coordinates of the axes.
+ Scales must be positive numbers.
+ :type scale: Tuple of 2 floats: (scale x, scale y).
+ :param bool copy: Whether to copy image data (default) or not.
+ :param bool reset: Deprecated. Alias for `resetzoom`.
+ :param bool resetzoom: Whether to reset zoom and ROI (default) or not.
+ """
+ self._dirtyCache()
+
+ if reset is not None:
+ resetzoom = reset
+
+ assert len(origin) == 2
+ assert len(scale) == 2
+ assert scale[0] > 0
+ assert scale[1] > 0
+
+ if image is None:
+ self.remove(self._imageLegend, kind="image")
+ return
+
+ data = numpy.array(image, order="C", copy=copy)
+ if data.size == 0:
+ self.remove(self._imageLegend, kind="image")
+ return
+
+ assert data.ndim == 2 or (data.ndim == 3 and data.shape[2] in (3, 4))
+
+ aggregation = self.getAggregationModeAction().getAggregationMode()
+ if data.ndim != 2 and aggregation is not None:
+ # RGB/A with aggregation is not supported
+ aggregation = items.ImageDataAggregated.Aggregation.NONE
+
+ if aggregation is items.ImageDataAggregated.Aggregation.NONE:
+ self.addImage(
+ data,
+ legend=self._imageLegend,
+ origin=origin,
+ scale=scale,
+ colormap=self.getColormap(),
+ resetzoom=False,
+ )
+ else:
+ item = self._getItem("image", self._imageLegend)
+ if isinstance(item, items.ImageDataAggregated):
+ item.setData(data)
+ item.setOrigin(origin)
+ item.setScale(scale)
+ else:
+ if isinstance(item, items.ImageDataAggregated):
+ imageItem = item
+ wasCreated = False
+ else:
+ if item is not None:
+ self.removeImage(self._imageLegend)
+ imageItem = items.ImageDataAggregated()
+ imageItem.setName(self._imageLegend)
+ imageItem.setColormap(self.getColormap())
+ wasCreated = True
+ imageItem.setData(data)
+ imageItem.setOrigin(origin)
+ imageItem.setScale(scale)
+ imageItem.setAggregationMode(aggregation)
+ if wasCreated:
+ self.addItem(imageItem)
+
+ self.setActiveImage(self._imageLegend)
+ self._updateHistograms()
+ if resetzoom:
+ self.resetZoom()
+
+
+# ImageViewMainWindow #########################################################
+
+
+class ImageViewMainWindow(ImageView):
+ """:class:`ImageView` with additional toolbars
+
+ Adds extra toolbar and a status bar to :class:`ImageView`.
+ """
+
+ def __init__(self, parent=None, backend=None):
+ self._dataInfo = None
+ super(ImageViewMainWindow, self).__init__(parent, backend)
+ self.setWindowFlags(qt.Qt.Window)
+
+ self.getXAxis().setLabel("X")
+ self.getYAxis().setLabel("Y")
+ self.setGraphTitle("Image")
+
+ # Add toolbars and status bar
+ self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self))
+
+ menu = self.menuBar().addMenu("File")
+ menu.addAction(self.getOutputToolBar().getSaveAction())
+ menu.addAction(self.getOutputToolBar().getPrintAction())
+ menu.addSeparator()
+ action = menu.addAction("Quit")
+ action.triggered[bool].connect(qt.QApplication.instance().quit)
+
+ menu = self.menuBar().addMenu("Edit")
+ menu.addAction(self.getOutputToolBar().getCopyAction())
+ menu.addSeparator()
+ menu.addAction(self.getResetZoomAction())
+ menu.addAction(self.getColormapAction())
+ menu.addAction(actions.control.KeepAspectRatioAction(self, self))
+ menu.addAction(actions.control.YAxisInvertedAction(self, self))
+ menu.addAction(self.getShowSideHistogramsAction())
+
+ self.__profileMenu = self.menuBar().addMenu("Profile")
+ self.__updateProfileMenu()
+
+ # Connect to ImageView's signal
+ self.valueChanged.connect(self._statusBarSlot)
+
+ def __updateProfileMenu(self):
+ """Update actions available in 'Profile' menu"""
+ profile = self.getProfileToolBar()
+ self.__profileMenu.clear()
+ self.__profileMenu.addAction(profile.hLineAction)
+ self.__profileMenu.addAction(profile.vLineAction)
+ self.__profileMenu.addAction(profile.crossAction)
+ self.__profileMenu.addAction(profile.lineAction)
+ self.__profileMenu.addAction(profile.clearAction)
+
+ def _formatValueToString(self, value):
+ try:
+ if isinstance(value, numpy.ndarray):
+ if len(value) == 4:
+ return "RGBA: %.3g, %.3g, %.3g, %.3g" % (
+ value[0],
+ value[1],
+ value[2],
+ value[3],
+ )
+ elif len(value) == 3:
+ return "RGB: %.3g, %.3g, %.3g" % (value[0], value[1], value[2])
+ else:
+ return "Value: %g" % value
+ except Exception:
+ _logger.error("Error while formatting pixel value", exc_info=True)
+ pass
+ return "Value: %s" % value
+
+ def _statusBarSlot(self, row, column, value):
+ """Update status bar with coordinates/value from plots."""
+ if numpy.isnan(row):
+ msg = "Column: %d, Sum: %g" % (int(column), value)
+ elif numpy.isnan(column):
+ msg = "Row: %d, Sum: %g" % (int(row), value)
+ else:
+ msg_value = self._formatValueToString(value)
+ msg = "Position: (%d, %d), %s" % (int(row), int(column), msg_value)
+ if self._dataInfo is not None:
+ msg = self._dataInfo + ", " + msg
+
+ self.statusBar().showMessage(msg)
+
+ @docstring(ImageView)
+ def setProfileWindowBehavior(self, behavior: str):
+ super().setProfileWindowBehavior(behavior)
+ self.__updateProfileMenu()
+
+ @docstring(ImageView)
+ def setImage(self, image, *args, **kwargs):
+ if hasattr(image, "dtype") and hasattr(image, "shape"):
+ assert image.ndim == 2 or (image.ndim == 3 and image.shape[2] in (3, 4))
+ height, width = image.shape[0:2]
+ dataInfo = "Data: %dx%d (%s)" % (width, height, str(image.dtype))
+ else:
+ dataInfo = None
+
+ if self._dataInfo != dataInfo:
+ self._dataInfo = dataInfo
+ self.statusBar().showMessage(self._dataInfo)
+
+ # Set the new image in ImageView widget
+ super(ImageViewMainWindow, self).setImage(image, *args, **kwargs)
diff --git a/src/silx/gui/plot/Interaction.py b/src/silx/gui/plot/Interaction.py
new file mode 100644
index 0000000..2d8bf63
--- /dev/null
+++ b/src/silx/gui/plot/Interaction.py
@@ -0,0 +1,352 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides an implementation of state machines for interaction.
+
+Sample code of a state machine with two states ('idle' and 'active')
+with transitions on left button press/release:
+
+.. code-block:: python
+
+ from silx.gui.plot.Interaction import *
+
+ class SampleStateMachine(StateMachine):
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('active')
+
+ class Active(State):
+ def enterState(self):
+ print('Enabled') # Handle enter active state here
+
+ def leaveState(self):
+ print('Disabled') # Handle leave active state here
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('idle')
+
+ def __init__(self):
+ # State machine has 2 states
+ states = {
+ 'idle': SampleStateMachine.Idle,
+ 'active': SampleStateMachine.Active
+ }
+ super(TwoStates, self).__init__(states, 'idle')
+ # idle is the initial state
+
+ stateMachine = SampleStateMachine()
+
+ # Triggers a transition to the Active state:
+ stateMachine.handleEvent('press', 0, 0, LEFT_BTN)
+
+ # Triggers a transition to the Idle state:
+ stateMachine.handleEvent('release', 0, 0, LEFT_BTN)
+
+See :class:`ClickOrDrag` for another example of a state machine.
+
+See `Renaud Blanch, Michel Beaudouin-Lafon.
+Programming Rich Interactions using the Hierarchical State Machine Toolkit.
+In Proceedings of AVI 2006. p 51-58.
+<http://iihm.imag.fr/en/publication/BB06a/>`_
+for a discussion of using (hierarchical) state machines for interaction.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import weakref
+
+
+# state machine ###############################################################
+
+
+class State(object):
+ """Base class for the states of a state machine.
+
+ This class is meant to be subclassed.
+ """
+
+ def __init__(self, machine):
+ """State instances should be created by the :class:`StateMachine`.
+
+ They are not intended to be used outside this context.
+
+ :param machine: The state machine instance this state belongs to.
+ :type machine: StateMachine
+ """
+ self._machineRef = weakref.ref(machine) # Prevent cyclic reference
+
+ @property
+ def machine(self):
+ """The state machine this state belongs to.
+
+ Useful to access data or methods that are shared across states.
+ """
+ machine = self._machineRef()
+ if machine is not None:
+ return machine
+ else:
+ raise RuntimeError("Associated StateMachine is not valid")
+
+ def goto(self, state, *args, **kwargs):
+ """Performs a transition to a new state.
+
+ Extra arguments are passed to the :meth:`enterState` method of the
+ new state.
+
+ :param str state: The name of the state to go to.
+ """
+ self.machine._goto(state, *args, **kwargs)
+
+ def enterState(self, *args, **kwargs):
+ """Called when the state machine enters this state.
+
+ Arguments are those provided to the :meth:`goto` method that
+ triggered the transition to this state.
+ """
+ pass
+
+ def leaveState(self):
+ """Called when the state machine leaves this state
+ (i.e., when :meth:`goto` is called).
+ """
+ pass
+
+ def validate(self):
+ """Called externally to validate the current interaction in case of a
+ creation.
+ """
+ pass
+
+
+class StateMachine(object):
+ """State machine controller.
+
+ This is the entry point of a state machine.
+ It is in charge of dispatching received event and handling the
+ current active state.
+ """
+
+ def __init__(self, states, initState, *args, **kwargs):
+ """Create a state machine controller with an initial state.
+
+ Extra arguments are passed to the :meth:`enterState` method
+ of the initState.
+
+ :param states: All states of the state machine
+ :type states: dict of: {str name: State subclass}
+ :param str initState: Key of the initial state in states
+ """
+ self.states = states
+
+ self.state = self.states[initState](self)
+ self.state.enterState(*args, **kwargs)
+
+ def _goto(self, state, *args, **kwargs):
+ self.state.leaveState()
+ self.state = self.states[state](self)
+ self.state.enterState(*args, **kwargs)
+
+ def handleEvent(self, eventName, *args, **kwargs):
+ """Process an event with the state machine.
+
+ This method looks up for an event handler in the current state
+ and then in the :class:`StateMachine` instance.
+ Handler are looked up as 'onEventName' method.
+ If a handler is found, it is called with the provided extra
+ arguments, and this method returns the return value of the
+ handler.
+ If no handler is found, this method returns None.
+
+ :param str eventName: Name of the event to handle
+ :returns: The return value of the handler or None
+ """
+ handlerName = "on" + eventName[0].upper() + eventName[1:]
+ try:
+ handler = getattr(self.state, handlerName)
+ except AttributeError:
+ try:
+ handler = getattr(self, handlerName)
+ except AttributeError:
+ handler = None
+ if handler is not None:
+ return handler(*args, **kwargs)
+
+ def validate(self):
+ """Called externally to validate the current interaction in case of a
+ creation.
+ """
+ self.state.validate()
+
+
+# clickOrDrag #################################################################
+
+LEFT_BTN = "left"
+"""Left mouse button."""
+
+RIGHT_BTN = "right"
+"""Right mouse button."""
+
+MIDDLE_BTN = "middle"
+"""Middle mouse button."""
+
+
+class ClickOrDrag(StateMachine):
+ """State machine for left and right click and left drag interaction.
+
+ It is intended to be used through subclassing by overriding
+ :meth:`click`, :meth:`beginDrag`, :meth:`drag` and :meth:`endDrag`.
+
+ :param Set[str] clickButtons: Set of buttons that provides click interaction
+ :param Set[str] dragButtons: Set of buttons that provides drag interaction
+ """
+
+ DRAG_THRESHOLD_SQUARE_DIST = 5**2
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn in self.machine.dragButtons:
+ self.goto("clickOrDrag", x, y, btn)
+ return True
+ elif btn in self.machine.clickButtons:
+ self.goto("click", x, y, btn)
+ return True
+
+ class Click(State):
+ def enterState(self, x, y, btn):
+ self.initPos = x, y
+ self.button = btn
+
+ def onMove(self, x, y):
+ dx2 = (x - self.initPos[0]) ** 2
+ dy2 = (y - self.initPos[1]) ** 2
+ if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
+ self.goto("idle")
+
+ def onRelease(self, x, y, btn):
+ if btn == self.button:
+ self.machine.click(x, y, btn)
+ self.goto("idle")
+
+ class ClickOrDrag(State):
+ def enterState(self, x, y, btn):
+ self.initPos = x, y
+ self.button = btn
+
+ def onMove(self, x, y):
+ dx2 = (x - self.initPos[0]) ** 2
+ dy2 = (y - self.initPos[1]) ** 2
+ if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
+ self.goto("drag", self.initPos, (x, y), self.button)
+
+ def onRelease(self, x, y, btn):
+ if btn == self.button:
+ if btn in self.machine.clickButtons:
+ self.machine.click(x, y, btn)
+ self.goto("idle")
+
+ class Drag(State):
+ def enterState(self, initPos, curPos, btn):
+ self.initPos = initPos
+ self.button = btn
+ self.machine.beginDrag(*initPos, btn)
+ self.machine.drag(*curPos, btn)
+
+ def onMove(self, x, y):
+ self.machine.drag(x, y, self.button)
+
+ def onRelease(self, x, y, btn):
+ if btn == self.button:
+ self.machine.endDrag(self.initPos, (x, y), btn)
+ self.goto("idle")
+
+ def __init__(self, clickButtons=(LEFT_BTN, RIGHT_BTN), dragButtons=(LEFT_BTN,)):
+ states = {
+ "idle": self.Idle,
+ "click": self.Click,
+ "clickOrDrag": self.ClickOrDrag,
+ "drag": self.Drag,
+ }
+ self.__clickButtons = set(clickButtons)
+ self.__dragButtons = set(dragButtons)
+ super(ClickOrDrag, self).__init__(states, "idle")
+
+ clickButtons = property(
+ lambda self: self.__clickButtons,
+ doc="Buttons with click interaction (Set[int])",
+ )
+
+ dragButtons = property(
+ lambda self: self.__dragButtons, doc="Buttons with drag interaction (Set[int])"
+ )
+
+ def click(self, x, y, btn):
+ """Called upon a button supporting click.
+
+ Override in subclass.
+
+ :param int x: X mouse position in pixels.
+ :param int y: Y mouse position in pixels.
+ :param str btn: The mouse button which was clicked.
+ """
+ pass
+
+ def beginDrag(self, x, y, btn):
+ """Called at the beginning of a drag gesture with mouse button pressed.
+
+ Override in subclass.
+
+ :param int x: X mouse position in pixels.
+ :param int y: Y mouse position in pixels.
+ :param str btn: The mouse button for which a drag is starting.
+ """
+ pass
+
+ def drag(self, x, y, btn):
+ """Called on mouse moved during a drag gesture.
+
+ Override in subclass.
+
+ :param int x: X mouse position in pixels.
+ :param int y: Y mouse position in pixels.
+ :param str btn: The mouse button for which a drag is in progress.
+ """
+ pass
+
+ def endDrag(self, startPoint, endPoint, btn):
+ """Called at the end of a drag gesture when the mouse button is released.
+
+ Override in subclass.
+
+ :param List[int] startPoint:
+ (x, y) mouse position in pixels at the beginning of the drag.
+ :param List[int] endPoint:
+ (x, y) mouse position in pixels at the end of the drag.
+ :param str btn: The mouse button for which a drag is done.
+ """
+ pass
diff --git a/src/silx/gui/plot/ItemsSelectionDialog.py b/src/silx/gui/plot/ItemsSelectionDialog.py
new file mode 100644
index 0000000..b4e4f9e
--- /dev/null
+++ b/src/silx/gui/plot/ItemsSelectionDialog.py
@@ -0,0 +1,293 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a dialog widget to select plot items.
+
+.. autoclass:: ItemsSelectionDialog
+
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/06/2017"
+
+import logging
+
+from silx.gui import qt
+from silx.gui.plot.PlotWidget import PlotWidget
+
+_logger = logging.getLogger(__name__)
+
+
+class KindsSelector(qt.QListWidget):
+ """List widget allowing to select plot item kinds
+ ("curve", "scatter", "image"...)
+ """
+
+ sigSelectedKindsChanged = qt.Signal(list)
+
+ def __init__(self, parent=None, kinds=None):
+ """
+
+ :param parent: Parent QWidget or None
+ :param tuple(str) kinds: Sequence of kinds. If None, the default
+ behavior is to provide a checkbox for all possible item kinds.
+ """
+ qt.QListWidget.__init__(self, parent)
+
+ self.plot_item_kinds = []
+
+ self.setAvailableKinds(kinds if kinds is not None else PlotWidget.ITEM_KINDS)
+
+ self.setSelectionMode(qt.QAbstractItemView.ExtendedSelection)
+ self.selectAll()
+
+ self.itemSelectionChanged.connect(self.emitSigKindsSelectionChanged)
+
+ def emitSigKindsSelectionChanged(self):
+ self.sigSelectedKindsChanged.emit(self.selectedKinds)
+
+ @property
+ def selectedKinds(self):
+ """Tuple of all selected kinds (as strings)."""
+ # check for updates when self.itemSelectionChanged
+ return [item.text() for item in self.selectedItems()]
+
+ def setAvailableKinds(self, kinds):
+ """Set a list of kinds to be displayed.
+
+ :param list[str] kinds: Sequence of kinds
+ """
+ self.plot_item_kinds = kinds
+
+ self.clear()
+ for kind in self.plot_item_kinds:
+ item = qt.QListWidgetItem(self)
+ item.setText(kind)
+ self.addItem(item)
+
+ def selectAll(self):
+ """Select all available kinds."""
+ if self.selectionMode() in [
+ qt.QAbstractItemView.SingleSelection,
+ qt.QAbstractItemView.NoSelection,
+ ]:
+ raise RuntimeError("selectAll requires a multiple selection mode")
+ for i in range(self.count()):
+ self.item(i).setSelected(True)
+
+
+class PlotItemsSelector(qt.QTableWidget):
+ """Table widget displaying the legend and kind of all
+ plot items corresponding to a list of specified kinds.
+
+ Selected plot items are provided as property :attr:`selectedPlotItems`.
+ You can be warned of selection changes by listening to signal
+ :attr:`itemSelectionChanged`.
+ """
+
+ def __init__(self, parent=None, plot=None):
+ if plot is None or not isinstance(plot, PlotWidget):
+ raise AttributeError("parameter plot is required")
+ self.plot = plot
+ """:class:`PlotWidget` instance"""
+
+ self.plot_item_kinds = None
+ """List of plot item kinds (strings)"""
+
+ qt.QTableWidget.__init__(self, parent)
+
+ self.setColumnCount(2)
+
+ self.setSelectionBehavior(qt.QTableWidget.SelectRows)
+
+ def _clear(self):
+ self.clear()
+ self.setHorizontalHeaderLabels(["legend", "type"])
+
+ def setAllKindsFilter(self):
+ """Display all kinds of plot items."""
+ self.setKindsFilter(PlotWidget.ITEM_KINDS)
+
+ def setKindsFilter(self, kinds):
+ """Set list of all kinds of plot items to be displayed.
+
+ :param list[str] kinds: Sequence of kinds
+ """
+ if not set(kinds) <= set(PlotWidget.ITEM_KINDS):
+ raise KeyError(
+ "Illegal plot item kinds: %s" % set(kinds) - set(PlotWidget.ITEM_KINDS)
+ )
+ self.plot_item_kinds = kinds
+
+ self.updatePlotItems()
+
+ def updatePlotItems(self):
+ self._clear()
+
+ # respect order of kinds as set in method setKindsFilter
+ itemsAndKind = []
+ for kind in self.plot_item_kinds:
+ itemClasses = self.plot._KIND_TO_CLASSES[kind]
+ for item in self.plot.getItems():
+ if isinstance(item, itemClasses) and item.isVisible():
+ itemsAndKind.append((item, kind))
+
+ self.setRowCount(len(itemsAndKind))
+
+ for index, (item, kind) in enumerate(itemsAndKind):
+ legend_twitem = qt.QTableWidgetItem(item.getName())
+ self.setItem(index, 0, legend_twitem)
+
+ kind_twitem = qt.QTableWidgetItem(kind)
+ self.setItem(index, 1, kind_twitem)
+
+ @property
+ def selectedPlotItems(self):
+ """List of all selected items"""
+ selection_model = self.selectionModel()
+ selected_rows_idx = selection_model.selectedRows()
+ selected_rows = [idx.row() for idx in selected_rows_idx]
+
+ items = []
+ for row in selected_rows:
+ legend = self.item(row, 0).text()
+ kind = self.item(row, 1).text()
+ item = self.plot._getItem(kind, legend)
+ if item is not None:
+ items.append(item)
+
+ return items
+
+
+class ItemsSelectionDialog(qt.QDialog):
+ """This widget is a modal dialog allowing to select one or more plot
+ items, in a table displaying their legend and kind.
+
+ Public methods:
+
+ - :meth:`getSelectedItems`
+ - :meth:`setAvailableKinds`
+ - :meth:`setItemsSelectionMode`
+
+ This widget inherits QDialog and therefore implements the usual
+ dialog methods, e.g. :meth:`exec`.
+
+ A trivial usage example would be::
+
+ isd = ItemsSelectionDialog(plot=my_plot_widget)
+ isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection)
+ result = isd.exec()
+ if result:
+ for item in isd.getSelectedItems():
+ print(item.getName(), type(item))
+ else:
+ print("Selection cancelled")
+ """
+
+ def __init__(self, parent=None, plot=None):
+ if plot is None or not isinstance(plot, PlotWidget):
+ raise AttributeError("parameter plot is required")
+ qt.QDialog.__init__(self, parent)
+
+ self.setWindowTitle("Plot items selector")
+
+ kind_selector_label = qt.QLabel("Filter item kinds:", self)
+ item_selector_label = qt.QLabel("Select items:", self)
+
+ self.kind_selector = KindsSelector(self)
+ self.kind_selector.setToolTip(
+ "select one or more item kinds to show them in the item list"
+ )
+
+ self.item_selector = PlotItemsSelector(self, plot)
+ self.item_selector.setToolTip("select items")
+
+ self.item_selector.setKindsFilter(self.kind_selector.selectedKinds)
+ self.kind_selector.sigSelectedKindsChanged.connect(
+ self.item_selector.setKindsFilter
+ )
+
+ okb = qt.QPushButton("OK", self)
+ okb.clicked.connect(self.accept)
+
+ cancelb = qt.QPushButton("Cancel", self)
+ cancelb.clicked.connect(self.reject)
+
+ layout = qt.QGridLayout(self)
+ layout.addWidget(kind_selector_label, 0, 0)
+ layout.addWidget(item_selector_label, 0, 1)
+ layout.addWidget(self.kind_selector, 1, 0)
+ layout.addWidget(self.item_selector, 1, 1)
+ layout.addWidget(okb, 2, 0)
+ layout.addWidget(cancelb, 2, 1)
+
+ self.setLayout(layout)
+
+ def getSelectedItems(self):
+ """Return a list of selected plot items
+
+ :return: List of selected plot items
+ :rtype: list[silx.gui.plot.items.Item]"""
+ return self.item_selector.selectedPlotItems
+
+ def setAvailableKinds(self, kinds):
+ """Set a list of kinds to be displayed.
+
+ :param list[str] kinds: Sequence of kinds
+ """
+ self.kind_selector.setAvailableKinds(kinds)
+
+ def selectAllKinds(self):
+ self.kind_selector.selectAll()
+
+ def setItemsSelectionMode(self, mode):
+ """Set selection mode for plot item (single item selection,
+ multiple...).
+
+ :param mode: One of :class:`QTableWidget` selection modes
+ """
+ if mode == self.item_selector.SingleSelection:
+ self.item_selector.setToolTip("Select one item by clicking on it.")
+ elif mode == self.item_selector.MultiSelection:
+ self.item_selector.setToolTip(
+ "Select one or more items by clicking with the left mouse"
+ " button.\nYou can unselect items by clicking them again.\n"
+ "Multiple items can be toggled by dragging the mouse over them."
+ )
+ elif mode == self.item_selector.ExtendedSelection:
+ self.item_selector.setToolTip(
+ "Select one or more items. You can select multiple items "
+ "by keeping the Ctrl key pushed when clicking.\nYou can "
+ "select a range of items by clicking on the first and "
+ "last while keeping the Shift key pushed."
+ )
+ elif mode == self.item_selector.ContiguousSelection:
+ self.item_selector.setToolTip(
+ "Select one item by clicking on it. If you press the Shift"
+ " key while clicking on a second item,\nall items between "
+ "the two will be selected."
+ )
+ elif mode == self.item_selector.NoSelection:
+ raise ValueError("The NoSelection mode is not allowed " "in this context.")
+ self.item_selector.setSelectionMode(mode)
diff --git a/src/silx/gui/plot/LegendSelector.py b/src/silx/gui/plot/LegendSelector.py
new file mode 100755
index 0000000..22348fb
--- /dev/null
+++ b/src/silx/gui/plot/LegendSelector.py
@@ -0,0 +1,1042 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget displaying curves legends and allowing to operate on curves.
+
+This widget is meant to work with :class:`PlotWindow`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Rueter", "T. Vincent"]
+__license__ = "MIT"
+__data__ = "16/10/2017"
+
+
+import logging
+import weakref
+
+import numpy
+
+from .. import qt, colors
+from ..widgets.LegendIconWidget import LegendIconWidget
+from . import items
+from ...utils.deprecation import deprecated
+
+
+_logger = logging.getLogger(__name__)
+
+
+class LegendIcon(LegendIconWidget):
+ """Object displaying a curve linestyle and symbol.
+
+ :param QWidget parent: See :class:`QWidget`
+ :param Union[~silx.gui.plot.items.Curve,None] curve:
+ Curve with which to synchronize
+ """
+
+ def __init__(self, parent=None, curve=None):
+ super(LegendIcon, self).__init__(parent)
+ self._curveRef = None
+ self.setCurve(curve)
+
+ def getCurve(self):
+ """Returns curve associated to this widget
+
+ :rtype: Union[~silx.gui.plot.items.Curve,None]
+ """
+ return None if self._curveRef is None else self._curveRef()
+
+ def setCurve(self, curve):
+ """Set the curve with which to synchronize this widget.
+
+ :param curve: Union[~silx.gui.plot.items.Curve,None]
+ """
+ assert curve is None or isinstance(curve, items.Curve)
+
+ previousCurve = self.getCurve()
+ if curve == previousCurve:
+ return
+
+ if previousCurve is not None:
+ previousCurve.sigItemChanged.disconnect(self._curveChanged)
+
+ self._curveRef = None if curve is None else weakref.ref(curve)
+
+ if curve is not None:
+ curve.sigItemChanged.connect(self._curveChanged)
+
+ self._update()
+
+ def _update(self):
+ """Update widget according to current curve state."""
+ curve = self.getCurve()
+ if curve is None:
+ _logger.error("Curve no more exists")
+ self.setEnabled(False)
+ return
+
+ style = curve.getCurrentStyle()
+
+ self.setEnabled(curve.isVisible())
+ self.setSymbol(style.getSymbol())
+ self.setLineWidth(style.getLineWidth())
+ self.setLineStyle(style.getLineStyle())
+
+ color = style.getColor()
+ if numpy.array(color, copy=False).ndim != 1:
+ # array of colors, use transparent black
+ color = 0.0, 0.0, 0.0, 0.0
+ color = colors.rgba(color) # Make sure it is float in [0, 1]
+ alpha = curve.getAlpha()
+ color = qt.QColor.fromRgbF(color[0], color[1], color[2], color[3] * alpha)
+ self.setLineColor(color)
+ self.setSymbolColor(color)
+ self.update() # TODO this should not be needed
+
+ def _curveChanged(self, event):
+ """Handle update of curve item
+
+ :param event: Kind of change
+ """
+ if event in (
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.SYMBOL,
+ items.ItemChangedType.SYMBOL_SIZE,
+ items.ItemChangedType.LINE_WIDTH,
+ items.ItemChangedType.LINE_STYLE,
+ items.ItemChangedType.COLOR,
+ items.ItemChangedType.ALPHA,
+ items.ItemChangedType.HIGHLIGHTED,
+ items.ItemChangedType.HIGHLIGHTED_STYLE,
+ ):
+ self._update()
+
+
+class LegendModel(qt.QAbstractListModel):
+ """Data model of curve legends.
+
+ It holds the information of the curve:
+
+ - color
+ - line width
+ - line style
+ - visibility of the lines
+ - symbol
+ - visibility of the symbols
+ """
+
+ iconColorRole = qt.Qt.UserRole + 0
+ iconLineWidthRole = qt.Qt.UserRole + 1
+ iconLineStyleRole = qt.Qt.UserRole + 2
+ showLineRole = qt.Qt.UserRole + 3
+ iconSymbolRole = qt.Qt.UserRole + 4
+ showSymbolRole = qt.Qt.UserRole + 5
+ itemRole = qt.Qt.UserRole + 6
+
+ def __init__(self, legendList=None, parent=None):
+ super(LegendModel, self).__init__(parent)
+ if legendList is None:
+ legendList = []
+ self.legendList = []
+ self.insertLegendList(0, legendList)
+ self._palette = qt.QPalette()
+
+ def __getitem__(self, idx):
+ if idx >= len(self.legendList):
+ raise IndexError("list index out of range")
+ return self.legendList[idx]
+
+ def rowCount(self, modelIndex=None):
+ return len(self.legendList)
+
+ def flags(self, index):
+ return qt.Qt.ItemIsEditable | qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable
+
+ def data(self, modelIndex, role):
+ if modelIndex.isValid:
+ idx = modelIndex.row()
+ else:
+ return None
+ if idx >= len(self.legendList):
+ raise IndexError("list index out of range")
+
+ item = self.legendList[idx]
+ isActive = item[1].get("active", False)
+ if role == qt.Qt.DisplayRole:
+ # Data to be rendered in the form of text
+ legend = str(item[0])
+ return legend
+ elif role == qt.Qt.SizeHintRole:
+ # size = qt.QSize(200,50)
+ _logger.warning("LegendModel -- size hint role not implemented")
+ return qt.QSize()
+ elif role == qt.Qt.TextAlignmentRole:
+ alignment = qt.Qt.AlignVCenter | qt.Qt.AlignLeft
+ return alignment
+ elif role == qt.Qt.BackgroundRole:
+ # Background color, must be QBrush
+ if isActive:
+ brush = self._palette.brush(qt.QPalette.Active, qt.QPalette.Highlight)
+ elif idx % 2:
+ brush = qt.QBrush(qt.QColor(240, 240, 240))
+ else:
+ brush = qt.QBrush(qt.Qt.white)
+ return brush
+ elif role == qt.Qt.ForegroundRole:
+ # ForegroundRole color, must be QBrush
+ if isActive:
+ brush = self._palette.brush(
+ qt.QPalette.Active, qt.QPalette.HighlightedText
+ )
+ else:
+ brush = self._palette.brush(qt.QPalette.Active, qt.QPalette.WindowText)
+ return brush
+ elif role == qt.Qt.CheckStateRole:
+ return bool(item[2]) # item[2] == True
+ elif role == qt.Qt.ToolTipRole or role == qt.Qt.StatusTipRole:
+ return ""
+ elif role == self.iconColorRole:
+ return item[1]["color"]
+ elif role == self.iconLineWidthRole:
+ return item[1]["linewidth"]
+ elif role == self.iconLineStyleRole:
+ return item[1]["linestyle"]
+ elif role == self.iconSymbolRole:
+ return item[1]["symbol"]
+ elif role == self.showLineRole:
+ return item[3]
+ elif role == self.showSymbolRole:
+ return item[4]
+ elif role == self.itemRole:
+ return item[5]
+ else:
+ _logger.info("Unkown role requested: %s", str(role))
+ return None
+
+ def setData(self, modelIndex, value, role):
+ if modelIndex.isValid:
+ idx = modelIndex.row()
+ else:
+ return None
+ if idx >= len(self.legendList):
+ # raise IndexError('list index out of range')
+ _logger.warning("setData -- List index out of range, idx: %d", idx)
+ return None
+
+ item = self.legendList[idx]
+ try:
+ if role == qt.Qt.DisplayRole:
+ # Set legend
+ item[0] = str(value)
+ elif role == self.iconColorRole:
+ item[1]["color"] = qt.QColor(value)
+ elif role == self.iconLineWidthRole:
+ item[1]["linewidth"] = int(value)
+ elif role == self.iconLineStyleRole:
+ item[1]["linestyle"] = value
+ elif role == self.iconSymbolRole:
+ item[1]["symbol"] = str(value)
+ elif role == qt.Qt.CheckStateRole:
+ item[2] = value
+ elif role == self.showLineRole:
+ item[3] = value
+ elif role == self.showSymbolRole:
+ item[4] = value
+ elif role == self.itemRole:
+ item[5] = value
+ except ValueError:
+ _logger.warning(
+ "Conversion failed:\n\tvalue: %s\n\trole: %s", str(value), str(role)
+ )
+ # Can that be right? Read docs again..
+ self.dataChanged.emit(modelIndex, modelIndex)
+ return True
+
+ def insertLegendList(self, row, llist):
+ """
+ :param int row: Determines after which row the items are inserted
+ :param llist: Carries the new legend information
+ :type llist: List
+ """
+ modelIndex = self.createIndex(row, 0)
+ count = len(llist)
+ super(LegendModel, self).beginInsertRows(modelIndex, row, row + count)
+ head = self.legendList[0:row]
+ tail = self.legendList[row:]
+ new = []
+ for legend, icon in llist:
+ linestyle = icon.get("linestyle", None)
+ if LegendIconWidget.isEmptyLineStyle(linestyle):
+ # Curve had no line, give it one and hide it
+ # So when toggle line, it will display a solid line
+ showLine = False
+ icon["linestyle"] = "-"
+ else:
+ showLine = True
+
+ symbol = icon.get("symbol", None)
+ if LegendIconWidget.isEmptySymbol(symbol):
+ # Curve had no symbol, give it one and hide it
+ # So when toggle symbol, it will display 'o'
+ showSymbol = False
+ icon["symbol"] = "o"
+ else:
+ showSymbol = True
+
+ selected = icon.get("selected", True)
+ item = [
+ legend,
+ icon,
+ selected,
+ showLine,
+ showSymbol,
+ icon.get("item", None),
+ ]
+ new.append(item)
+ self.legendList = head + new + tail
+ super(LegendModel, self).endInsertRows()
+ return True
+
+ def insertRows(self, row, count, modelIndex=qt.QModelIndex()):
+ raise NotImplementedError("Use LegendModel.insertLegendList instead")
+
+ def removeRow(self, row):
+ return self.removeRows(row, 1)
+
+ def removeRows(self, row, count, modelIndex=qt.QModelIndex()):
+ length = len(self.legendList)
+ if length == 0:
+ # Nothing to do..
+ return True
+ if row < 0 or row >= length:
+ raise IndexError(
+ "Index out of range -- " + "idx: %d, len: %d" % (row, length)
+ )
+ if count == 0:
+ return False
+ super(LegendModel, self).beginRemoveRows(modelIndex, row, row + count)
+ del self.legendList[row : row + count]
+ super(LegendModel, self).endRemoveRows()
+ return True
+
+ def setEditor(self, event, editor):
+ """
+ :param str event: String that identifies the editor
+ :param editor: Widget used to change data in the underlying model
+ :type editor: QWidget
+ """
+ if event not in self.eventList:
+ raise ValueError("setEditor -- Event must be in %s" % str(self.eventList))
+ self.editorDict[event] = editor
+
+
+class LegendListItemWidget(qt.QItemDelegate):
+ """Object displaying a single item (i.e., a row) in the list."""
+
+ # Notice: LegendListItem does NOT inherit
+ # from QObject, it cannot emit signals!
+
+ def __init__(self, parent=None, itemType=0):
+ super(LegendListItemWidget, self).__init__(parent)
+
+ # Dictionary to render checkboxes
+ self.cbDict = {}
+ self.labelDict = {}
+ self.iconDict = {}
+
+ # Keep checkbox and legend to get sizeHint
+ self.checkbox = qt.QCheckBox()
+ self.legend = qt.QLabel()
+ self.icon = LegendIcon()
+
+ # Context Menu and Editors
+ self.contextMenu = None
+
+ def paint(self, painter, option, modelIndex):
+ """
+ Here be docs..
+
+ :param QPainter painter:
+ :param QStyleOptionViewItem option:
+ :param QModelIndex modelIndex:
+ """
+ painter.save()
+ rect = option.rect
+
+ # Calculate the icon rectangle
+ iconSize = self.icon.sizeHint()
+ # Calculate icon position
+ x = rect.left() + 2
+ y = rect.top() + int(0.5 * (rect.height() - iconSize.height()))
+ iconRect = qt.QRect(qt.QPoint(x, y), iconSize)
+
+ # Calculate label rectangle
+ legendSize = qt.QSize(rect.width() - iconSize.width() - 30, rect.height())
+ # Calculate label position
+ x = rect.left() + iconRect.width()
+ y = rect.top()
+ labelRect = qt.QRect(qt.QPoint(x, y), legendSize)
+ labelRect.translate(qt.QPoint(10, 0))
+
+ # Calculate the checkbox rectangle
+ x = rect.right() - 30
+ y = rect.top()
+ chBoxRect = qt.QRect(qt.QPoint(x, y), rect.bottomRight())
+
+ # Remember the rectangles
+ idx = modelIndex.row()
+ self.cbDict[idx] = chBoxRect
+ self.iconDict[idx] = iconRect
+ self.labelDict[idx] = labelRect
+
+ # Draw background first!
+ if option.state & qt.QStyle.State_MouseOver:
+ backgroundBrush = option.palette.highlight()
+ else:
+ backgroundBrush = modelIndex.data(qt.Qt.BackgroundRole)
+ painter.fillRect(rect, backgroundBrush)
+
+ # Draw label
+ legendText = modelIndex.data(qt.Qt.DisplayRole)
+ textBrush = modelIndex.data(qt.Qt.ForegroundRole)
+ textAlign = modelIndex.data(qt.Qt.TextAlignmentRole)
+ painter.setBrush(textBrush)
+ painter.setFont(self.legend.font())
+ painter.setPen(textBrush.color())
+ painter.drawText(labelRect, textAlign, legendText)
+
+ # Draw icon
+ iconColor = modelIndex.data(LegendModel.iconColorRole)
+ iconLineWidth = modelIndex.data(LegendModel.iconLineWidthRole)
+ iconLineStyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ iconSymbol = modelIndex.data(LegendModel.iconSymbolRole)
+ icon = LegendIcon()
+ icon.resize(iconRect.size())
+ icon.move(iconRect.topRight())
+ icon.showSymbol = modelIndex.data(LegendModel.showSymbolRole)
+ icon.showLine = modelIndex.data(LegendModel.showLineRole)
+ icon.setSymbolColor(iconColor)
+ icon.setLineColor(iconColor)
+ icon.setLineWidth(iconLineWidth)
+ icon.setLineStyle(iconLineStyle)
+ icon.setSymbol(iconSymbol)
+ icon.symbolOutlineBrush = backgroundBrush
+ icon.paint(painter, iconRect, option.palette)
+
+ # Draw the checkbox
+ if modelIndex.data(qt.Qt.CheckStateRole):
+ checkState = qt.Qt.Checked
+ else:
+ checkState = qt.Qt.Unchecked
+
+ self.drawCheck(painter, qt.QStyleOptionViewItem(), chBoxRect, checkState)
+
+ painter.restore()
+
+ def editorEvent(self, event, model, option, modelIndex):
+ # From the docs:
+ # Mouse events are sent to editorEvent()
+ # even if they don't start editing of the item.
+ if event.button() == qt.Qt.RightButton and self.contextMenu:
+ if qt.BINDING == "PyQt5":
+ position = event.globalPos()
+ else: # Qt6
+ position = event.globalPosition().toPoint()
+ self.contextMenu.exec(position, modelIndex)
+ return True
+ elif event.button() == qt.Qt.LeftButton:
+ # Check if checkbox was clicked
+ idx = modelIndex.row()
+ cbRect = self.cbDict[idx]
+ if cbRect.contains(event.pos()):
+ # Toggle checkbox
+ model.setData(
+ modelIndex,
+ not modelIndex.data(qt.Qt.CheckStateRole),
+ qt.Qt.CheckStateRole,
+ )
+ event.ignore()
+ return True
+ else:
+ return super(LegendListItemWidget, self).editorEvent(
+ event, model, option, modelIndex
+ )
+
+ def createEditor(self, parent, option, idx):
+ _logger.info("### Editor request ###")
+
+ def sizeHint(self, option, idx):
+ # return qt.QSize(68,24)
+ iconSize = self.icon.sizeHint()
+ legendSize = self.legend.sizeHint()
+ checkboxSize = self.checkbox.sizeHint()
+ height = (
+ max([iconSize.height(), legendSize.height(), checkboxSize.height()]) + 4
+ )
+ width = iconSize.width() + legendSize.width() + checkboxSize.width()
+ return qt.QSize(width, height)
+
+
+class LegendListView(qt.QListView):
+ """Widget displaying a list of curve legends, line style and symbol."""
+
+ sigLegendSignal = qt.Signal(object)
+ """Signal emitting a dict when an action is triggered by the user."""
+
+ __mouseClickedEvent = "mouseClicked"
+ __checkBoxClickedEvent = "checkBoxClicked"
+ __legendClickedEvent = "legendClicked"
+
+ def __init__(self, parent=None, model=None, contextMenu=None):
+ super(LegendListView, self).__init__(parent)
+ self.__lastButton = None
+ self.__lastClickPos = None
+ self.__lastModelIdx = None
+ # Set default delegate
+ self.setItemDelegate(LegendListItemWidget())
+ # Set default editors
+ # self.setSizePolicy(qt.QSizePolicy.MinimumExpanding,
+ # qt.QSizePolicy.MinimumExpanding)
+ # Set edit triggers by hand using self.edit(QModelIndex)
+ # in mousePressEvent (better to control than signals)
+ self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+
+ # Control layout
+ # self.setBatchSize(2)
+ # self.setLayoutMode(qt.QListView.Batched)
+ # self.setFlow(qt.QListView.LeftToRight)
+
+ # Control selection
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+ if model is None:
+ model = LegendModel(parent=self)
+ self.setModel(model)
+ self.setContextMenu(contextMenu)
+
+ def setLegendList(self, legendList, row=None):
+ if row is not None:
+ model = self.model()
+ model.insertLegendList(row, legendList)
+ elif len(legendList) != self.model().rowCount():
+ self.clear()
+ model = self.model()
+ model.insertLegendList(0, legendList)
+ else:
+ model = self.model()
+ for i, (new_legend, icon) in enumerate(legendList):
+ modelIndex = model.index(i)
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ if new_legend != legend:
+ model.setData(modelIndex, new_legend, qt.Qt.DisplayRole)
+
+ color = modelIndex.data(LegendModel.iconColorRole)
+ new_color = icon.get("color", None)
+ if new_color != color:
+ model.setData(modelIndex, new_color, LegendModel.iconColorRole)
+
+ linewidth = modelIndex.data(LegendModel.iconLineWidthRole)
+ new_linewidth = icon.get("linewidth", 1.0)
+ if new_linewidth != linewidth:
+ model.setData(
+ modelIndex, new_linewidth, LegendModel.iconLineWidthRole
+ )
+
+ linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ new_linestyle = icon.get("linestyle", None)
+ visible = not LegendIconWidget.isEmptyLineStyle(new_linestyle)
+ model.setData(modelIndex, visible, LegendModel.showLineRole)
+ if new_linestyle != linestyle:
+ model.setData(
+ modelIndex, new_linestyle, LegendModel.iconLineStyleRole
+ )
+
+ symbol = modelIndex.data(LegendModel.iconSymbolRole)
+ new_symbol = icon.get("symbol", None)
+ visible = not LegendIconWidget.isEmptySymbol(new_symbol)
+ model.setData(modelIndex, visible, LegendModel.showSymbolRole)
+ if new_symbol != symbol:
+ model.setData(modelIndex, new_symbol, LegendModel.iconSymbolRole)
+
+ selected = modelIndex.data(qt.Qt.CheckStateRole)
+ new_selected = icon.get("selected", True)
+ if new_selected != selected:
+ model.setData(modelIndex, new_selected, qt.Qt.CheckStateRole)
+
+ item = modelIndex.data(LegendModel.itemRole)
+ newItem = icon.get("item", None)
+ if item is not newItem:
+ model.setData(modelIndex, newItem, LegendModel.itemRole)
+ _logger.debug("LegendListView.setLegendList(legendList) finished")
+
+ def clear(self):
+ model = self.model()
+ model.removeRows(0, model.rowCount())
+ _logger.debug("LegendListView.clear() finished")
+
+ def setContextMenu(self, contextMenu=None):
+ delegate = self.itemDelegate()
+ if isinstance(delegate, LegendListItemWidget) and self.model():
+ if contextMenu is None:
+ delegate.contextMenu = LegendListContextMenu(self.model())
+ delegate.contextMenu.sigContextMenu.connect(self._contextMenuSlot)
+ else:
+ delegate.contextMenu = contextMenu
+
+ def __getitem__(self, idx):
+ model = self.model()
+ try:
+ item = model[idx]
+ except ValueError:
+ item = None
+ return item
+
+ def _contextMenuSlot(self, ddict):
+ self.sigLegendSignal.emit(ddict)
+
+ def mousePressEvent(self, event):
+ self.__lastButton = event.button()
+ self.__lastPosition = event.pos()
+ super(LegendListView, self).mousePressEvent(event)
+ # call _handleMouseClick after editing was handled
+ # If right click (context menu) is aborted, no
+ # signal is emitted..
+ self._handleMouseClick(self.indexAt(self.__lastPosition))
+
+ def mouseDoubleClickEvent(self, event):
+ self.__lastButton = event.button()
+ self.__lastPosition = event.pos()
+ super(LegendListView, self).mouseDoubleClickEvent(event)
+ # call _handleMouseClick after editing was handled
+ # If right click (context menu) is aborted, no
+ # signal is emitted..
+ self._handleMouseClick(self.indexAt(self.__lastPosition))
+
+ def mouseMoveEvent(self, event):
+ # LegendListView.mouseMoveEvent is overwritten
+ # to suppress unwanted behavior in the delegate.
+ pass
+
+ def mouseReleaseEvent(self, event):
+ # LegendListView.mouseReleaseEvent is overwritten
+ # to subpress unwanted behavior in the delegate.
+ pass
+
+ def _handleMouseClick(self, modelIndex):
+ """
+ Distinguish between mouse click on Legend
+ and mouse click on CheckBox by setting the
+ currentCheckState attribute in LegendListItem.
+
+ Emits signal sigLegendSignal(ddict)
+
+ :param QModelIndex modelIndex: index of the clicked item
+ """
+ _logger.debug("self._handleMouseClick called")
+ if self.__lastButton not in [qt.Qt.LeftButton, qt.Qt.RightButton]:
+ return
+ if not modelIndex.isValid():
+ _logger.debug("_handleMouseClick -- Invalid QModelIndex")
+ return
+ # model = self.model()
+ idx = modelIndex.row()
+
+ delegate = self.itemDelegate()
+ cbClicked = False
+ if isinstance(delegate, LegendListItemWidget):
+ for cbRect in delegate.cbDict.values():
+ if cbRect.contains(self.__lastPosition):
+ cbClicked = True
+ break
+
+ # TODO: Check for doubleclicks on legend/icon and spawn editors
+
+ ddict = {
+ "legend": str(modelIndex.data(qt.Qt.DisplayRole)),
+ "icon": {
+ "linewidth": str(modelIndex.data(LegendModel.iconLineWidthRole)),
+ "linestyle": modelIndex.data(LegendModel.iconLineStyleRole),
+ "symbol": str(modelIndex.data(LegendModel.iconSymbolRole)),
+ },
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ }
+ if self.__lastButton == qt.Qt.RightButton:
+ _logger.debug("Right clicked")
+ ddict["button"] = "right"
+ ddict["event"] = self.__mouseClickedEvent
+ elif cbClicked:
+ _logger.debug("CheckBox clicked")
+ ddict["button"] = "left"
+ ddict["event"] = self.__checkBoxClickedEvent
+ else:
+ _logger.debug("Legend clicked")
+ ddict["button"] = "left"
+ ddict["event"] = self.__legendClickedEvent
+ _logger.debug(" idx: %d\n ddict: %s", idx, str(ddict))
+ self.sigLegendSignal.emit(ddict)
+
+
+class LegendListContextMenu(qt.QMenu):
+ """Contextual menu associated to items in a :class:`LegendListView`."""
+
+ sigContextMenu = qt.Signal(object)
+ """Signal emitting a dict upon contextual menu actions."""
+
+ def __init__(self, model):
+ super(LegendListContextMenu, self).__init__(parent=None)
+ self.model = model
+
+ self.addAction("Set Active", self.setActiveAction)
+ self.addAction("Map to left", self.mapToLeftAction)
+ self.addAction("Map to right", self.mapToRightAction)
+
+ self._pointsAction = self.addAction("Points", self.togglePointsAction)
+ self._pointsAction.setCheckable(True)
+
+ self._linesAction = self.addAction("Lines", self.toggleLinesAction)
+ self._linesAction.setCheckable(True)
+
+ self.addAction("Remove curve", self.removeItemAction)
+ self.addAction("Rename curve", self.renameItemAction)
+
+ def exec(self, pos, idx):
+ self.__currentIdx = idx
+
+ # Set checkable action state
+ modelIndex = self.currentIdx()
+ self._pointsAction.setChecked(modelIndex.data(LegendModel.showSymbolRole))
+ self._linesAction.setChecked(modelIndex.data(LegendModel.showLineRole))
+
+ super(LegendListContextMenu, self).popup(pos)
+
+ def exec_(self, pos, idx): # Qt5-like compatibility
+ return self.exec(pos, idx)
+
+ def currentIdx(self):
+ return self.__currentIdx
+
+ def mapToLeftAction(self):
+ _logger.debug("LegendListContextMenu.mapToLeftAction called")
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "mapToLeft",
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def mapToRightAction(self):
+ _logger.debug("LegendListContextMenu.mapToRightAction called")
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "mapToRight",
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def removeItemAction(self):
+ _logger.debug("LegendListContextMenu.removeCurveAction called")
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "removeCurve",
+ }
+ self.model.removeRow(modelIndex.row())
+ self.sigContextMenu.emit(ddict)
+
+ def renameItemAction(self):
+ _logger.debug("LegendListContextMenu.renameCurveAction called")
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "renameCurve",
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def toggleLinesAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "type": str(modelIndex.data()),
+ }
+ linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ visible = not modelIndex.data(LegendModel.showLineRole)
+ _logger.debug("toggleLinesAction -- lines visible: %s", str(visible))
+ ddict["event"] = "toggleLine"
+ ddict["line"] = visible
+ ddict["linestyle"] = linestyle if visible else ""
+ self.model.setData(modelIndex, visible, LegendModel.showLineRole)
+ self.sigContextMenu.emit(ddict)
+
+ def togglePointsAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ }
+ flag = modelIndex.data(LegendModel.showSymbolRole)
+ symbol = modelIndex.data(LegendModel.iconSymbolRole)
+ visible = not flag or LegendIconWidget.isEmptySymbol(symbol)
+ _logger.debug("togglePointsAction -- Symbols visible: %s", str(visible))
+
+ ddict["event"] = "togglePoints"
+ ddict["points"] = visible
+ ddict["symbol"] = symbol if visible else ""
+ self.model.setData(modelIndex, visible, LegendModel.showSymbolRole)
+ self.sigContextMenu.emit(ddict)
+
+ def setActiveAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ _logger.debug("setActiveAction -- active curve: %s", legend)
+ ddict = {
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "setActiveCurve",
+ }
+ self.sigContextMenu.emit(ddict)
+
+
+class RenameCurveDialog(qt.QDialog):
+ """Dialog box to input the name of a curve."""
+
+ def __init__(self, parent=None, current="", curves=()):
+ super(RenameCurveDialog, self).__init__(parent)
+ self.setWindowTitle("Rename Curve %s" % current)
+ self.curves = curves
+ layout = qt.QVBoxLayout(self)
+ self.lineEdit = qt.QLineEdit(self)
+ self.lineEdit.setText(current)
+ self.hbox = qt.QWidget(self)
+ self.hboxLayout = qt.QHBoxLayout(self.hbox)
+ self.hboxLayout.addStretch(1)
+ self.okButton = qt.QPushButton(self.hbox)
+ self.okButton.setText("OK")
+ self.hboxLayout.addWidget(self.okButton)
+ self.cancelButton = qt.QPushButton(self.hbox)
+ self.cancelButton.setText("Cancel")
+ self.hboxLayout.addWidget(self.cancelButton)
+ self.hboxLayout.addStretch(1)
+ layout.addWidget(self.lineEdit)
+ layout.addWidget(self.hbox)
+ self.okButton.clicked.connect(self.preAccept)
+ self.cancelButton.clicked.connect(self.reject)
+
+ def preAccept(self):
+ text = str(self.lineEdit.text())
+ addedText = ""
+ if len(text):
+ if text not in self.curves:
+ self.accept()
+ return
+ else:
+ addedText = "Curve already exists."
+ text = "Invalid Curve Name"
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setWindowTitle(text)
+ text += "\n%s" % addedText
+ msg.setText(text)
+ msg.exec()
+
+ def getText(self):
+ return str(self.lineEdit.text())
+
+
+class LegendsDockWidget(qt.QDockWidget):
+ """QDockWidget with a :class:`LegendSelector` connected to a PlotWindow.
+
+ It makes the link between the LegendListView widget and the PlotWindow.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: :class:`.PlotWindow` instance on which to operate
+ """
+
+ def __init__(self, parent=None, plot=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._isConnected = False # True if widget connected to plot signals
+
+ super(LegendsDockWidget, self).__init__("Legends", parent)
+
+ self._legendWidget = LegendListView()
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self._legendWidget)
+
+ self.visibilityChanged.connect(self._visibilityChangedHandler)
+
+ self._legendWidget.sigLegendSignal.connect(self._legendSignalHandler)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWindow` this widget is attached to."""
+ return self._plotRef()
+
+ @deprecated(reason="No longer needed", since_version="2.0.0")
+ def renameCurve(self, oldLegend, newLegend):
+ """Change the name of a curve using remove and addCurve
+
+ :param str oldLegend: The legend of the curve to be changed
+ :param str newLegend: The new legend of the curve
+ """
+ is_active = self.plot.getActiveCurve(just_legend=True) == oldLegend
+ curve = self.plot.getCurve(oldLegend)
+ self.plot.remove(oldLegend, kind="curve")
+ self.plot.addCurve(
+ curve.getXData(copy=False),
+ curve.getYData(copy=False),
+ legend=newLegend,
+ info=curve.getInfo(),
+ color=curve.getColor(),
+ symbol=curve.getSymbol(),
+ linewidth=curve.getLineWidth(),
+ linestyle=curve.getLineStyle(),
+ xlabel=curve.getXLabel(),
+ ylabel=curve.getYLabel(),
+ xerror=curve.getXErrorData(copy=False),
+ yerror=curve.getYErrorData(copy=False),
+ z=curve.getZValue(),
+ selectable=curve.isSelectable(),
+ fill=curve.isFill(),
+ resetzoom=False,
+ )
+ if is_active:
+ self.plot.setActiveCurve(newLegend)
+
+ def _legendSignalHandler(self, ddict):
+ """Handles events from the LegendListView signal"""
+ _logger.debug("Legend signal ddict = %s", str(ddict))
+ # If item is not provided, retrieve it from its legend
+ curve = ddict.get("item", None)
+ if curve is None:
+ curve = self.plot.getCurve(ddict["legend"])
+
+ if ddict["event"] == "legendClicked":
+ if ddict["button"] == "left":
+ self.plot.setActiveCurve(curve)
+
+ elif ddict["event"] == "removeCurve":
+ self.plot.removeItem(curve)
+
+ elif ddict["event"] == "renameCurve":
+ curveList = self.plot.getAllCurves(just_legend=True)
+ oldLegend = ddict["legend"]
+ dialog = RenameCurveDialog(self.plot, oldLegend, curveList)
+ ret = dialog.exec()
+ if ret:
+ newLegend = dialog.getText()
+ wasActive = self.plot.getActiveCurve() is curve
+ self.plot.removeItem(curve)
+ curve.setName(newLegend)
+ self.plot.addItem(curve)
+ if wasActive:
+ self.plot.setActiveCurve(curve)
+
+ elif ddict["event"] == "setActiveCurve":
+ self.plot.setActiveCurve(curve)
+
+ elif ddict["event"] == "checkBoxClicked":
+ curve.setVisible(ddict["selected"])
+
+ elif ddict["event"] in ["mapToRight", "mapToLeft"]:
+ curve.setYAxis("right" if ddict["event"] == "mapToRight" else "left")
+
+ elif ddict["event"] == "togglePoints":
+ curve.setSymbol(ddict["symbol"] if ddict["points"] else "")
+
+ elif ddict["event"] == "toggleLine":
+ curve.setLineStyle(ddict["linestyle"] if ddict["line"] else "")
+
+ else:
+ _logger.debug("unhandled event %s", str(ddict["event"]))
+
+ def updateLegends(self, *args):
+ """Sync the LegendSelector widget displayed info with the plot."""
+ legendList = []
+ for curve in self.plot.getAllCurves(withhidden=True):
+ legend = curve.getName()
+ # Use active color if curve is active
+ isActive = legend == self.plot.getActiveCurve(just_legend=True)
+ style = curve.getCurrentStyle()
+ color = style.getColor()
+ if numpy.array(color, copy=False).ndim != 1:
+ # array of colors, use transparent black
+ color = 0.0, 0.0, 0.0, 0.0
+
+ curveInfo = {
+ "color": qt.QColor.fromRgbF(*color),
+ "linewidth": style.getLineWidth(),
+ "linestyle": style.getLineStyle(),
+ "symbol": style.getSymbol(),
+ "selected": not self.plot.isCurveHidden(legend),
+ "active": isActive,
+ "item": curve,
+ }
+ legendList.append((legend, curveInfo))
+
+ self._legendWidget.setLegendList(legendList)
+
+ def _visibilityChangedHandler(self, visible):
+ if visible:
+ self.updateLegends()
+ if not self._isConnected:
+ self.plot.sigContentChanged.connect(self.updateLegends)
+ self.plot.sigActiveCurveChanged.connect(self.updateLegends)
+ self._isConnected = True
+ else:
+ if self._isConnected:
+ self.plot.sigContentChanged.disconnect(self.updateLegends)
+ self.plot.sigActiveCurveChanged.disconnect(self.updateLegends)
+ self._isConnected = False
diff --git a/src/silx/gui/plot/LimitsHistory.py b/src/silx/gui/plot/LimitsHistory.py
new file mode 100644
index 0000000..f4e0afc
--- /dev/null
+++ b/src/silx/gui/plot/LimitsHistory.py
@@ -0,0 +1,82 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides handling of :class:`PlotWidget` limits history.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "19/07/2017"
+
+
+from .. import qt
+
+
+class LimitsHistory(qt.QObject):
+ """Class handling history of limits of a :class:`PlotWidget`.
+
+ :param PlotWidget parent: The plot widget this object is bound to.
+ """
+
+ def __init__(self, parent):
+ self._history = []
+ super(LimitsHistory, self).__init__(parent)
+ self.setParent(parent)
+
+ def setParent(self, parent):
+ """See :meth:`QObject.setParent`.
+
+ :param PlotWidget parent: The PlotWidget this object is bound to.
+ """
+ self.clear() # Clear history when changing parent
+ super(LimitsHistory, self).setParent(parent)
+
+ def push(self):
+ """Append current limits to the history."""
+ plot = self.parent()
+ xmin, xmax = plot.getXAxis().getLimits()
+ ymin, ymax = plot.getYAxis(axis="left").getLimits()
+ y2min, y2max = plot.getYAxis(axis="right").getLimits()
+ self._history.append((xmin, xmax, ymin, ymax, y2min, y2max))
+
+ def pop(self):
+ """Restore previously limits stored in the history.
+
+ :return: True if limits were restored, False if history was empty.
+ :rtype: bool
+ """
+ plot = self.parent()
+ if self._history:
+ limits = self._history.pop(-1)
+ plot.setLimits(*limits)
+ return True
+ else:
+ plot.resetZoom()
+ return False
+
+ def clear(self):
+ """Clear stored limits states."""
+ self._history = []
+
+ def __len__(self):
+ return len(self._history)
diff --git a/src/silx/gui/plot/MaskToolsWidget.py b/src/silx/gui/plot/MaskToolsWidget.py
new file mode 100644
index 0000000..40b2717
--- /dev/null
+++ b/src/silx/gui/plot/MaskToolsWidget.py
@@ -0,0 +1,914 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget providing a set of tools to draw masks on a PlotWidget.
+
+This widget is meant to work with :class:`silx.gui.plot.PlotWidget`.
+
+- :class:`ImageMask`: Handle mask bitmap update and history
+- :class:`MaskToolsWidget`: GUI for :class:`Mask`
+- :class:`MaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
+"""
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+import os
+import sys
+import numpy
+import logging
+import h5py
+
+import fabio
+from fabio.edfimage import EdfImage
+from fabio.TiffIO import TiffIO
+
+from silx.image import shapes
+from silx.io.utils import NEXUS_HDF5_EXT, is_dataset
+from silx.gui.dialog.DatasetDialog import DatasetDialog
+
+from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
+from . import items
+from ..colors import cursorColorForColormap, rgba
+from .. import qt
+from ..utils import LockReentrant
+
+
+_logger = logging.getLogger(__name__)
+
+_HDF5_EXT_STR = " ".join(["*" + ext for ext in NEXUS_HDF5_EXT])
+
+
+def _selectDataset(filename, mode=DatasetDialog.SaveMode):
+ """Open a dialog to prompt the user to select a dataset in
+ a hdf5 file.
+
+ :param str filename: name of an existing HDF5 file
+ :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
+ :rtype: str
+ :return: Name of selected dataset
+ """
+ dialog = DatasetDialog()
+ dialog.addFile(filename)
+ dialog.setWindowTitle("Select a 2D dataset")
+ dialog.setMode(mode)
+ if not dialog.exec():
+ return None
+ return dialog.getSelectedDataUrl().data_path()
+
+
+class ImageMask(BaseMask):
+ """A 2D mask field with update operations.
+
+ Coords follows (row, column) convention and are in mask array coords.
+
+ This is meant for internal use by :class:`MaskToolsWidget`.
+ """
+
+ def __init__(self, image=None):
+ """
+
+ :param image: :class:`silx.gui.plot.items.ImageBase` instance
+ """
+ BaseMask.__init__(self, image)
+ self.reset(shape=(0, 0)) # Init the mask with a 2D shape
+
+ def getDataValues(self):
+ """Return image data as a 2D or 3D array (if it is a RGBA image).
+
+ :rtype: 2D or 3D numpy.ndarray
+ """
+ return self._dataItem.getData(copy=False)
+
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save in 'edf', 'tif', 'npy', 'h5'
+ or 'msk' (if FabIO is installed)
+ :raise Exception: Raised if the file writing fail
+ """
+ if kind == "edf":
+ EdfImage(
+ data=self.getMask(),
+ header={"program_name": "silx-mask", "masked_value": "nonzero"},
+ ).write(filename)
+
+ elif kind == "tif":
+ tiffFile = TiffIO(filename, mode="w")
+ tiffFile.writeImage(self.getMask(copy=False), software="silx")
+
+ elif kind == "npy":
+ try:
+ numpy.save(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+
+ elif ("." + kind) in NEXUS_HDF5_EXT:
+ self._saveToHdf5(filename, self.getMask(copy=False))
+
+ elif kind == "msk":
+ try:
+ data = self.getMask(copy=False)
+ image = fabio.fabioimage.FabioImage(data=data)
+ image = image.convert(fabio.fit2dmaskimage.Fit2dMaskImage)
+ image.save(filename)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("Mask file can't be written")
+ else:
+ raise ValueError("Format '%s' is not supported" % kind)
+
+ @staticmethod
+ def _saveToHdf5(filename, mask):
+ """Save a mask array to a HDF5 file.
+
+ :param str filename: name of an existing HDF5 file
+ :param numpy.ndarray mask: Mask array.
+ :returns: True if operation succeeded, False otherwise.
+ """
+ if not os.path.exists(filename):
+ # create new file
+ with h5py.File(filename, "w") as _h5f:
+ pass
+ dataPath = _selectDataset(filename)
+ if dataPath is None:
+ return False
+ with h5py.File(filename, "a") as h5f:
+ existing_ds = h5f.get(dataPath)
+ if existing_ds is not None:
+ reply = qt.QMessageBox.question(
+ None,
+ "Confirm overwrite",
+ "Do you want to overwrite an existing dataset?",
+ qt.QMessageBox.Yes | qt.QMessageBox.No,
+ )
+ if reply != qt.QMessageBox.Yes:
+ return False
+ del h5f[dataPath]
+ try:
+ h5f.create_dataset(dataPath, data=mask)
+ except Exception:
+ return False
+ return True
+
+ # Drawing operations
+ def updateRectangle(self, level, row, col, height, width, mask=True):
+ """Mask/Unmask a rectangle of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int row: Starting row of the rectangle
+ :param int col: Starting column of the rectangle
+ :param int height:
+ :param int width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ assert 0 < level < 256
+ if row + height <= 0 or col + width <= 0:
+ return # Rectangle outside image, avoid negative indices
+ selection = self._mask[
+ max(0, row) : row + height + 1, max(0, col) : col + width + 1
+ ]
+ if mask:
+ selection[:, :] = level
+ else:
+ selection[selection == level] = 0
+ self._notify()
+
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask a polygon of the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (row, col)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ fill = shapes.polygon_fill_mask(vertices, self._mask.shape)
+ if mask:
+ self._mask[fill != 0] = level
+ else:
+ self._mask[numpy.logical_and(fill != 0, self._mask == level)] = 0
+ self._notify()
+
+ def updatePoints(self, level, rows, cols, mask=True):
+ """Mask/Unmask points with given coordinates.
+
+ :param int level: Mask level to update.
+ :param rows: Rows of selected points
+ :type rows: 1D numpy.ndarray
+ :param cols: Columns of selected points
+ :type cols: 1D numpy.ndarray
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ valid = numpy.logical_and(
+ numpy.logical_and(rows >= 0, cols >= 0),
+ numpy.logical_and(rows < self._mask.shape[0], cols < self._mask.shape[1]),
+ )
+ rows, cols = rows[valid], cols[valid]
+
+ if mask:
+ self._mask[rows, cols] = level
+ else:
+ inMask = self._mask[rows, cols] == level
+ self._mask[rows[inMask], cols[inMask]] = 0
+ self._notify()
+
+ def updateDisk(self, level, crow, ccol, radius, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Disk center row.
+ :param int ccol: Disk center column.
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.circle_fill(crow, ccol, radius)
+ self.updatePoints(level, rows, cols, mask)
+
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.ellipse_fill(crow, ccol, radius_r, radius_c)
+ self.updatePoints(level, rows, cols, mask)
+
+ def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
+ """Mask/Unmask a line of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int row0: Row of the starting point.
+ :param int col0: Column of the starting point.
+ :param int row1: Row of the end point.
+ :param int col1: Column of the end point.
+ :param int width: Width of the line in mask array unit.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.draw_line(row0, col0, row1, col1, width)
+ self.updatePoints(level, rows, cols, mask)
+
+
+class MaskToolsWidget(BaseMaskToolsWidget):
+ """Widget with tools for drawing mask on an image in a PlotWidget."""
+
+ _maxLevelNumber = 255
+
+ def __init__(self, parent=None, plot=None):
+ super(MaskToolsWidget, self).__init__(parent, plot, mask=ImageMask())
+ self._origin = (0.0, 0.0) # Mask origin in plot
+ self._scale = (1.0, 1.0) # Mask scale in plot
+ self._z = 1 # Mask layer in plot
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8) # Store image
+
+ self.__itemMaskUpdatedLock = LockReentrant()
+ self.__itemMaskUpdated = False
+
+ def __maskStateChanged(self) -> None:
+ """Handle mask commit to update item mask"""
+ item = self._mask.getDataItem()
+ if item is not None:
+ with self.__itemMaskUpdatedLock:
+ item.setMaskData(self._mask.getMask(copy=True), copy=False)
+
+ def setItemMaskUpdated(self, enabled: bool) -> None:
+ """Toggle item mask and mask tool synchronisation.
+
+ :param bool enabled: True to synchronise. Default: False
+ """
+ enabled = bool(enabled)
+ if enabled != self.__itemMaskUpdated:
+ if self.__itemMaskUpdated:
+ self._mask.sigStateChanged.disconnect(self.__maskStateChanged)
+ self.__itemMaskUpdated = enabled
+ if self.__itemMaskUpdated:
+ # Synchronize item and tool mask
+ self._setMaskedImage(self._mask.getDataItem())
+ self._mask.sigStateChanged.connect(self.__maskStateChanged)
+
+ def isItemMaskUpdated(self) -> bool:
+ """Returns whether or not item and mask tool masks are synchronised.
+
+ :rtype: bool
+ """
+ return self.__itemMaskUpdated
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask:
+ The array to use for the mask or None to reset the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 2-tuple if successful.
+ The mask can be cropped or padded to fit active image,
+ the returned shape is that of the active image.
+ """
+ if mask is None:
+ self.resetSelectionMask()
+ return self._data.shape[:2]
+
+ mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
+ if len(mask.shape) != 2:
+ _logger.error("Not an image, shape: %d", len(mask.shape))
+ return None
+
+ # Handle mask with single level
+ if self.multipleMasks() == "single":
+ mask = numpy.array(mask != 0, dtype=numpy.uint8)
+
+ # if mask has not changed, do nothing
+ if numpy.array_equal(mask, self.getSelectionMask()):
+ return mask.shape
+
+ if self._data.shape[0:2] == (0, 0) or mask.shape == self._data.shape[0:2]:
+ self._mask.setMask(mask, copy=copy)
+ self._mask.commit()
+ return mask.shape
+ else:
+ _logger.warning(
+ "Mask has not the same size as current image."
+ " Mask will be cropped or padded to fit image"
+ " dimensions. %s != %s",
+ str(mask.shape),
+ str(self._data.shape),
+ )
+ resizedMask = numpy.zeros(self._data.shape[0:2], dtype=numpy.uint8)
+ height = min(self._data.shape[0], mask.shape[0])
+ width = min(self._data.shape[1], mask.shape[1])
+ resizedMask[:height, :width] = mask[:height, :width]
+ self._mask.setMask(resizedMask, copy=False)
+ self._mask.commit()
+ return resizedMask.shape
+
+ # Handle mask refresh on the plot
+ def _updatePlotMask(self):
+ """Update mask image in plot"""
+ mask = self.getSelectionMask(copy=False)
+ if mask is not None:
+ # get the mask from the plot
+ maskItem = self.plot.getImage(self._maskName)
+ mustBeAdded = maskItem is None
+ if mustBeAdded:
+ maskItem = items.MaskImageData()
+ maskItem.setName(self._maskName)
+ # update the items
+ maskItem.setData(mask, copy=False)
+ maskItem.setColormap(self._colormap)
+ maskItem.setOrigin(self._origin)
+ maskItem.setScale(self._scale)
+ maskItem.setZValue(self._z)
+
+ if mustBeAdded:
+ self.plot.addItem(maskItem)
+
+ elif self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind="image")
+
+ def showEvent(self, event):
+ try:
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare
+ )
+ except (RuntimeError, TypeError):
+ pass
+
+ # Sync with current active image
+ self._setMaskedImage(self.plot.getActiveImage())
+ self.plot.sigActiveImageChanged.connect(self._activeImageChanged)
+
+ def hideEvent(self, event):
+ try:
+ self.plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
+ except (RuntimeError, TypeError):
+ pass
+
+ image = self.getMaskedItem()
+ if image is not None:
+ try:
+ image.sigItemChanged.disconnect(self.__imageChanged)
+ except (RuntimeError, TypeError):
+ pass # TODO should not happen
+
+ if self.isMaskInteractionActivated():
+ # Disable drawing tool
+ self.plot.resetInteractiveMode()
+
+ if self.isItemMaskUpdated(): # No "after-care"
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.setDataItem(None)
+ self._mask.reset()
+
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind="image")
+
+ elif self.getSelectionMask(copy=False) is not None:
+ self.plot.sigActiveImageChanged.connect(self._activeImageChangedAfterCare)
+
+ def _activeImageChanged(self, previous, current):
+ """Reacts upon active image change.
+
+ Only handle change of active image items here.
+ """
+ if previous != current:
+ image = self.plot.getActiveImage()
+ if image is not None and image.getName() == self._maskName:
+ image = None # Active image is the mask
+ self._setMaskedImage(image)
+
+ def _setOverlayColorForImage(self, image):
+ """Set the color of overlay adapted to image
+
+ :param image: :class:`.items.ImageBase` object to set color for.
+ """
+ if isinstance(image, items.ColormapMixIn):
+ colormap = image.getColormap()
+ self._defaultOverlayColor = rgba(cursorColorForColormap(colormap["name"]))
+ else:
+ self._defaultOverlayColor = rgba("black")
+
+ def _activeImageChangedAfterCare(self, *args):
+ """Check synchro of active image and mask when mask widget is hidden.
+
+ If active image has no more the same size as the mask, the mask is
+ removed, otherwise it is adjusted to origin, scale and z.
+ """
+ activeImage = self.plot.getActiveImage()
+ if activeImage is None or activeImage.getName() == self._maskName:
+ # No active image or active image is the mask...
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.setDataItem(None)
+ self._mask.reset()
+
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind="image")
+
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare
+ )
+ else:
+ self._setOverlayColorForImage(activeImage)
+ self._setMaskColors(
+ self.levelSpinBox.value(),
+ self.transparencySlider.value() / self.transparencySlider.maximum(),
+ )
+
+ self._origin = activeImage.getOrigin()
+ self._scale = activeImage.getScale()
+ self._z = activeImage.getZValue() + 1
+ self._data = activeImage.getData(copy=False)
+ if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
+ # Image has not the same size, remove mask and stop listening
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind="image")
+
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare
+ )
+ else:
+ # Refresh in case origin, scale, z changed
+ self._mask.setDataItem(activeImage)
+ self._updatePlotMask()
+
+ def _setMaskedImage(self, image):
+ """Change the image that is used a reference to author the mask"""
+ previous = self.getMaskedItem()
+ if previous is not None and self.isVisible():
+ # Disconnect from previous image
+ try:
+ previous.sigItemChanged.disconnect(self.__imageChanged)
+ except (RuntimeError, TypeError):
+ pass # TODO fixme should not happen
+
+ # Set the image
+ self._mask.setDataItem(image)
+
+ if image is None: # No image, disable mask
+ self.setEnabled(False)
+
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.reset()
+ self._mask.commit()
+
+ self._updateInteractiveMode()
+
+ else: # Update and connect to image's sigItemChanged
+ if self.isItemMaskUpdated():
+ if image.getMaskData(copy=False) is None:
+ # Image item has no mask: use current mask from the tool
+ image.setMaskData(self.getSelectionMask(copy=False), copy=True)
+ else: # Image item has a mask: set it in tool
+ self.setSelectionMask(image.getMaskData(copy=False), copy=True)
+ self._mask.resetHistory()
+ self.__imageUpdated()
+ if self.isVisible():
+ image.sigItemChanged.connect(self.__imageChanged)
+
+ def __imageChanged(self, event):
+ """Reacts upon image item changes"""
+ image = self._mask.getDataItem()
+ if image is None:
+ _logger.error("Mask is not attached to an image")
+ return
+
+ if event in (
+ items.ItemChangedType.COLORMAP,
+ items.ItemChangedType.DATA,
+ items.ItemChangedType.POSITION,
+ items.ItemChangedType.SCALE,
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.ZVALUE,
+ ):
+ self.__imageUpdated()
+
+ elif (
+ event == items.ItemChangedType.MASK
+ and self.isItemMaskUpdated()
+ and not self.__itemMaskUpdatedLock.locked()
+ ):
+ # Update mask from the image item unless mask tool is updating it
+ self.setSelectionMask(image.getMaskData(copy=False), copy=True)
+
+ def __imageUpdated(self):
+ """Synchronize mask with current state of the image"""
+ image = self._mask.getDataItem()
+ if image is None:
+ _logger.error("No active image while expecting one")
+ return
+
+ self._setOverlayColorForImage(image)
+
+ self._setMaskColors(
+ self.levelSpinBox.value(),
+ self.transparencySlider.value() / self.transparencySlider.maximum(),
+ )
+
+ self._origin = image.getOrigin()
+ self._scale = image.getScale()
+ self._z = image.getZValue() + 1
+ self._data = image.getData(copy=False)
+ self._mask.setDataItem(image)
+ if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
+ self._mask.reset(self._data.shape[:2])
+ self._mask.commit()
+ else:
+ # Refresh in case origin, scale, z changed
+ self._updatePlotMask()
+
+ # Visible and with data
+ self.setEnabled(image.isVisible() and self._data.size != 0)
+
+ # Threshold tools only available for data with colormap
+ self.thresholdGroup.setEnabled(self._data.ndim == 2)
+
+ self._updateInteractiveMode()
+
+ # Handle whole mask operations
+ def load(self, filename):
+ """Load a mask from an image file.
+
+ :param str filename: File name from which to load the mask
+ :raise Exception: An exception in case of failure
+ :raise RuntimeWarning: In case the mask was applied but with some
+ import changes to notice
+ """
+ _, extension = os.path.splitext(filename)
+ extension = extension.lower()[1:]
+
+ if extension == "npy":
+ try:
+ mask = numpy.load(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy file.', filename)
+ elif extension in ("edf", "msk", "tif", "tiff"):
+ try:
+ mask = fabio.open(filename).data
+ except Exception as e:
+ _logger.error(f"Can't load filename {filename}")
+ _logger.debug("Backtrace", exc_info=True)
+ raise e
+ elif ("." + extension) in NEXUS_HDF5_EXT:
+ mask = self._loadFromHdf5(filename)
+ if mask is None:
+ raise IOError("Could not load mask from HDF5 dataset")
+ else:
+ msg = "Extension '%s' is not supported."
+ raise RuntimeError(msg % extension)
+
+ effectiveMaskShape = self.setSelectionMask(mask, copy=False)
+ if effectiveMaskShape is None:
+ return
+ if mask.shape != effectiveMaskShape:
+ msg = "Mask was resized from %s to %s"
+ msg = msg % (str(mask.shape), str(effectiveMaskShape))
+ raise RuntimeWarning(msg)
+
+ def _loadMask(self):
+ """Open load mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Load Mask")
+ dialog.setModal(1)
+
+ extensions = {}
+ extensions["EDF files"] = "*.edf"
+ extensions["TIFF files"] = "*.tif *.tiff"
+ extensions["NumPy binary files"] = "*.npy"
+ extensions["HDF5 files"] = _HDF5_EXT_STR
+ # Fit2D mask is displayed anyway fabio is here or not
+ # to show to the user that the option exists
+ extensions["Fit2D mask files"] = "*.msk"
+
+ filters = []
+ filters.append("All supported files (%s)" % " ".join(extensions.values()))
+ for name, extension in extensions.items():
+ filters.append("%s (%s)" % (name, extension))
+ filters.append("All files (*)")
+
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.load(filename)
+ except RuntimeWarning as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Warning)
+ msg.setText("Mask loaded but an operation was applied.\n" + message)
+ msg.exec()
+ except Exception as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot load mask from file. " + message)
+ msg.exec()
+
+ @staticmethod
+ def _loadFromHdf5(filename):
+ """Load a mask array from a HDF5 file.
+
+ :param str filename: name of an existing HDF5 file
+ :returns: A mask as a numpy array, or None if the interactive dialog
+ was cancelled
+ """
+ dataPath = _selectDataset(filename, mode=DatasetDialog.LoadMode)
+ if dataPath is None:
+ return None
+
+ with h5py.File(filename, "r") as h5f:
+ dataset = h5f.get(dataPath)
+ if not is_dataset(dataset):
+ raise IOError("%s is not a dataset" % dataPath)
+ mask = dataset[()]
+ return mask
+
+ def _saveMask(self):
+ """Open Save mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Save Mask")
+ dialog.setOption(qt.QFileDialog.DontUseNativeDialog)
+ dialog.setModal(1)
+ hdf5Filter = "HDF5 (%s)" % _HDF5_EXT_STR
+ filters = [
+ "EDF (*.edf)",
+ "TIFF (*.tif)",
+ "NumPy binary file (*.npy)",
+ hdf5Filter,
+ # Fit2D mask is displayed anyway fabio is here or not
+ # to show to the user that the option exists
+ "Fit2D mask (*.msk)",
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.maskFileDir)
+
+ def onFilterSelection(filt_):
+ # disable overwrite confirmation for HDF5,
+ # because we append the data to existing files
+ if filt_ == hdf5Filter:
+ dialog.setOption(qt.QFileDialog.DontConfirmOverwrite)
+ else:
+ dialog.setOption(qt.QFileDialog.DontConfirmOverwrite, False)
+
+ dialog.filterSelected.connect(onFilterSelection)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if "HDF5" in nameFilter:
+ has_allowed_ext = False
+ for ext in NEXUS_HDF5_EXT:
+ if (
+ len(filename) > len(ext)
+ and filename[-len(ext) :].lower() == ext.lower()
+ ):
+ has_allowed_ext = True
+ extension = ext
+ if not has_allowed_ext:
+ extension = ".h5"
+ filename += ".h5"
+ else:
+ # convert filter name to extension name with the .
+ extension = nameFilter.split()[-1][2:-1]
+ if not filename.lower().endswith(extension):
+ filename += extension
+
+ if os.path.exists(filename) and "HDF5" not in nameFilter:
+ try:
+ os.remove(filename)
+ except IOError as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Removing existing file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save.\n" "Input Output Error: %s" % strerror)
+ msg.exec()
+ return
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.save(filename, extension[1:])
+ except Exception as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Saving mask file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save file %s\n%s" % (filename, strerror))
+ msg.exec()
+
+ def resetSelectionMask(self):
+ """Reset the mask"""
+ self._mask.reset(shape=self._data.shape[:2])
+ self._mask.commit()
+
+ def _plotDrawEvent(self, event):
+ """Handle draw events from the plot"""
+ if self._drawingMode is None or event["event"] not in (
+ "drawingProgress",
+ "drawingFinished",
+ ):
+ return
+
+ if not len(self._data):
+ return
+
+ level = self.levelSpinBox.value()
+
+ if self._drawingMode == "rectangle":
+ if event["event"] == "drawingFinished":
+ # Convert from plot to array coords
+ doMask = self._isMasking()
+ ox, oy = self._origin
+ sx, sy = self._scale
+
+ height = int(abs(event["height"] / sy))
+ width = int(abs(event["width"] / sx))
+
+ row = int((event["y"] - oy) / sy)
+ if sy < 0:
+ row -= height
+
+ col = int((event["x"] - ox) / sx)
+ if sx < 0:
+ col -= width
+
+ self._mask.updateRectangle(
+ level, row=row, col=col, height=height, width=width, mask=doMask
+ )
+ self._mask.commit()
+
+ elif self._drawingMode == "ellipse":
+ if event["event"] == "drawingFinished":
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ center = (event["points"][0] - self._origin) / self._scale
+ size = event["points"][1] / self._scale
+ center = center.astype(numpy.int64) # (row, col)
+ self._mask.updateEllipse(
+ level, center[1], center[0], size[1], size[0], doMask
+ )
+ self._mask.commit()
+
+ elif self._drawingMode == "polygon":
+ if event["event"] == "drawingFinished":
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ vertices = (event["points"] - self._origin) / self._scale
+ vertices = vertices.astype(numpy.int64)[:, (1, 0)] # (row, col)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == "pencil":
+ doMask = self._isMasking()
+ # convert from plot to array coords
+ col, row = (event["points"][-1] - self._origin) / self._scale
+ col, row = int(col), int(row)
+ brushSize = self._getPencilWidth()
+
+ if self._lastPencilPos != (row, col):
+ if self._lastPencilPos is not None:
+ # Draw the line
+ self._mask.updateLine(
+ level,
+ self._lastPencilPos[0],
+ self._lastPencilPos[1],
+ row,
+ col,
+ brushSize,
+ doMask,
+ )
+
+ # Draw the very first, or last point
+ self._mask.updateDisk(level, row, col, brushSize / 2.0, doMask)
+
+ if event["event"] == "drawingFinished":
+ self._mask.commit()
+ self._lastPencilPos = None
+ else:
+ self._lastPencilPos = row, col
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
+
+ def _loadRangeFromColormapTriggered(self):
+ """Set range from active image colormap range"""
+ activeImage = self.plot.getActiveImage()
+ if (
+ isinstance(activeImage, items.ColormapMixIn)
+ and activeImage.getName() != self._maskName
+ ):
+ # Update thresholds according to colormap
+ colormap = activeImage.getColormap()
+ if colormap["autoscale"]:
+ min_ = numpy.nanmin(activeImage.getData(copy=False))
+ max_ = numpy.nanmax(activeImage.getData(copy=False))
+ else:
+ min_, max_ = colormap["vmin"], colormap["vmax"]
+ self.minLineEdit.setText(str(min_))
+ self.maxLineEdit.setText(str(max_))
+
+
+class MaskToolsDockWidget(BaseMaskToolsDockWidget):
+ """:class:`MaskToolsWidget` embedded in a QDockWidget.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: The PlotWidget this widget is operating on
+ :paran str name: The title of this widget
+ """
+
+ def __init__(self, parent=None, plot=None, name="Mask"):
+ widget = MaskToolsWidget(plot=plot)
+ super(MaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/src/silx/gui/plot/PlotEvents.py b/src/silx/gui/plot/PlotEvents.py
new file mode 100644
index 0000000..b4cbe30
--- /dev/null
+++ b/src/silx/gui/plot/PlotEvents.py
@@ -0,0 +1,184 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Functions to prepare events to be sent to Plot callback."""
+
+__author__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import numpy as np
+
+
+def prepareDrawingSignal(event, type_, points, parameters=None):
+ """See Plot documentation for content of events"""
+ assert event in ("drawingProgress", "drawingFinished")
+
+ if parameters is None:
+ parameters = {}
+
+ eventDict = {}
+ eventDict["event"] = event
+ eventDict["type"] = type_
+ points = np.array(points, dtype=np.float32)
+ points.shape = -1, 2
+ eventDict["points"] = points
+ eventDict["xdata"] = points[:, 0]
+ eventDict["ydata"] = points[:, 1]
+ if type_ in ("rectangle",):
+ eventDict["x"] = eventDict["xdata"].min()
+ eventDict["y"] = eventDict["ydata"].min()
+ eventDict["width"] = eventDict["xdata"].max() - eventDict["x"]
+ eventDict["height"] = eventDict["ydata"].max() - eventDict["y"]
+ eventDict["parameters"] = parameters.copy()
+ return eventDict
+
+
+def prepareMouseSignal(eventType, button, xData, yData, xPixel, yPixel):
+ """See Plot documentation for content of events"""
+ assert eventType in ("mouseMoved", "mouseClicked", "mouseDoubleClicked")
+ assert button in (None, "left", "middle", "right")
+
+ return {
+ "event": eventType,
+ "x": xData,
+ "y": yData,
+ "xpixel": xPixel,
+ "ypixel": yPixel,
+ "button": button,
+ }
+
+
+def prepareHoverSignal(label, type_, posData, posPixel, draggable, selectable):
+ """See Plot documentation for content of events"""
+ return {
+ "event": "hover",
+ "label": label,
+ "type": type_,
+ "x": posData[0],
+ "y": posData[1],
+ "xpixel": posPixel[0],
+ "ypixel": posPixel[1],
+ "draggable": draggable,
+ "selectable": selectable,
+ }
+
+
+def prepareMarkerSignal(
+ eventType,
+ button,
+ label,
+ type_,
+ draggable,
+ selectable,
+ posDataMarker,
+ posPixelCursor=None,
+ posDataCursor=None,
+):
+ """See Plot documentation for content of events"""
+ if eventType == "markerClicked":
+ assert posPixelCursor is not None
+ assert posDataCursor is None
+
+ posDataCursor = list(posDataMarker)
+ if hasattr(posDataCursor[0], "__len__"):
+ posDataCursor[0] = posDataCursor[0][-1]
+ if hasattr(posDataCursor[1], "__len__"):
+ posDataCursor[1] = posDataCursor[1][-1]
+
+ elif eventType == "markerMoving":
+ assert posPixelCursor is not None
+ assert posDataCursor is not None
+
+ elif eventType == "markerMoved":
+ assert posPixelCursor is None
+ assert posDataCursor is None
+
+ posDataCursor = posDataMarker
+ else:
+ raise NotImplementedError("Unknown event type {0}".format(eventType))
+
+ eventDict = {
+ "event": eventType,
+ "button": button,
+ "label": label,
+ "type": type_,
+ "x": posDataCursor[0],
+ "y": posDataCursor[1],
+ "xdata": posDataMarker[0],
+ "ydata": posDataMarker[1],
+ "draggable": draggable,
+ "selectable": selectable,
+ }
+
+ if eventType in ("markerMoving", "markerClicked"):
+ eventDict["xpixel"] = posPixelCursor[0]
+ eventDict["ypixel"] = posPixelCursor[1]
+
+ return eventDict
+
+
+def prepareImageSignal(button, item, col, row, x, y, xPixel, yPixel):
+ """See Plot documentation for content of events"""
+ return {
+ "event": "imageClicked",
+ "button": button,
+ "item": item,
+ "label": item.getName(),
+ "type": "image",
+ "col": col,
+ "row": row,
+ "x": x,
+ "y": y,
+ "xpixel": xPixel,
+ "ypixel": yPixel,
+ }
+
+
+def prepareCurveSignal(button, item, xData, yData, x, y, xPixel, yPixel):
+ """See Plot documentation for content of events"""
+ return {
+ "event": "curveClicked",
+ "button": button,
+ "item": item,
+ "label": item.getName(),
+ "type": "curve",
+ "xdata": xData,
+ "ydata": yData,
+ "x": x,
+ "y": y,
+ "xpixel": xPixel,
+ "ypixel": yPixel,
+ }
+
+
+def prepareLimitsChangedSignal(sourceObj, xRange, yRange, y2Range):
+ """See Plot documentation for content of events"""
+ return {
+ "event": "limitsChanged",
+ "source": id(sourceObj),
+ "xdata": xRange,
+ "ydata": yRange,
+ "y2data": y2Range,
+ }
diff --git a/src/silx/gui/plot/PlotInteraction.py b/src/silx/gui/plot/PlotInteraction.py
new file mode 100644
index 0000000..d19bb6d
--- /dev/null
+++ b/src/silx/gui/plot/PlotInteraction.py
@@ -0,0 +1,1885 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Implementation of the interaction for the :class:`Plot`."""
+
+from __future__ import annotations
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/02/2019"
+
+
+import math
+import numpy
+import time
+import weakref
+from typing import NamedTuple, Optional
+
+from silx.gui import qt
+from .. import colors
+from . import items
+from .Interaction import (
+ ClickOrDrag,
+ LEFT_BTN,
+ RIGHT_BTN,
+ MIDDLE_BTN,
+ State,
+ StateMachine,
+)
+from .PlotEvents import (
+ prepareCurveSignal,
+ prepareDrawingSignal,
+ prepareHoverSignal,
+ prepareImageSignal,
+ prepareMarkerSignal,
+ prepareMouseSignal,
+)
+
+from .backends.BackendBase import (
+ CURSOR_POINTING,
+ CURSOR_SIZE_HOR,
+ CURSOR_SIZE_VER,
+ CURSOR_SIZE_ALL,
+)
+
+from ._utils import (
+ FLOAT32_SAFE_MIN,
+ FLOAT32_MINPOS,
+ FLOAT32_SAFE_MAX,
+ applyZoomToPlot,
+ EnabledAxes,
+)
+
+
+# Base class ##################################################################
+
+
+class _PlotInteraction(object):
+ """Base class for interaction handler.
+
+ It provides a weakref to the plot and methods to set/reset overlay.
+ """
+
+ def __init__(self, plot):
+ """Init.
+
+ :param plot: The plot to apply modifications to.
+ """
+ self._needReplot = False
+ self._selectionAreas = set()
+ self._plot = weakref.ref(plot) # Avoid cyclic-ref
+
+ @property
+ def plot(self):
+ plot = self._plot()
+ assert plot is not None
+ return plot
+
+ def setSelectionArea(self, points, fill, color, name="", shape="polygon"):
+ """Set a polygon selection area overlaid on the plot.
+ Multiple simultaneous areas are supported through the name parameter.
+
+ :param points: The 2D coordinates of the points of the polygon
+ :type points: An iterable of (x, y) coordinates
+ :param str fill: The fill mode: 'hatch', 'solid' or 'none'
+ :param color: RGBA color to use or None to disable display
+ :type color: list or tuple of 4 float in the range [0, 1]
+ :param name: The key associated with this selection area
+ :param str shape: Shape of the area in 'polygon', 'polylines'
+ """
+ assert shape in ("polygon", "polylines")
+
+ if color is None:
+ return
+
+ points = numpy.asarray(points)
+
+ # TODO Not very nice, but as is for now
+ legend = "__SELECTION_AREA__" + name
+
+ fill = fill != "none" # TODO not very nice either
+
+ greyed = colors.greyed(color)[0]
+ if greyed < 0.5:
+ color2 = "white"
+ else:
+ color2 = "black"
+
+ self.plot.addShape(
+ points[:, 0],
+ points[:, 1],
+ legend=legend,
+ replace=False,
+ shape=shape,
+ fill=fill,
+ color=color,
+ gapcolor=color2,
+ linestyle="--",
+ overlay=True,
+ )
+
+ self._selectionAreas.add(legend)
+
+ def resetSelectionArea(self):
+ """Remove all selection areas set by setSelectionArea."""
+ for legend in self._selectionAreas:
+ self.plot.remove(legend, kind="item")
+ self._selectionAreas = set()
+
+
+# Zoom/Pan ####################################################################
+
+
+class _PlotInteractionWithClickEvents(ClickOrDrag, _PlotInteraction):
+ """:class:`ClickOrDrag` state machine emitting click and double click events.
+
+ Base class for :class:`Pan` and :class:`Zoom`
+ """
+
+ _DOUBLE_CLICK_TIMEOUT = 0.4
+
+ def click(self, x, y, btn):
+ """Handle clicks by sending events
+
+ :param int x: Mouse X position in pixels
+ :param int y: Mouse Y position in pixels
+ :param btn: Clicked mouse button
+ """
+ if btn == LEFT_BTN:
+ lastClickTime, lastClickPos = self._lastClick
+
+ # Signal mouse double clicked event first
+ if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT:
+ # Use position of first click
+ eventDict = prepareMouseSignal(
+ "mouseDoubleClicked", "left", *lastClickPos
+ )
+ self.plot.notify(**eventDict)
+
+ self._lastClick = 0.0, None
+ else:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal(
+ "mouseClicked", "left", dataPos[0], dataPos[1], x, y
+ )
+ self.plot.notify(**eventDict)
+
+ self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y)
+
+ elif btn == RIGHT_BTN:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal(
+ "mouseClicked", "right", dataPos[0], dataPos[1], x, y
+ )
+ self.plot.notify(**eventDict)
+
+ def __init__(self, plot, **kwargs):
+ """Init.
+
+ :param plot: The plot to apply modifications to.
+ """
+ self._lastClick = 0.0, None
+
+ _PlotInteraction.__init__(self, plot)
+ ClickOrDrag.__init__(self, **kwargs)
+
+
+# Pan #########################################################################
+
+
+class Pan(_PlotInteractionWithClickEvents):
+ """Pan plot content and zoom on wheel state machine."""
+
+ def _pixelToData(self, x, y):
+ xData, yData = self.plot.pixelToData(x, y)
+ _, y2Data = self.plot.pixelToData(x, y, axis="right")
+ return xData, yData, y2Data
+
+ def beginDrag(self, x, y, btn):
+ self._previousDataPos = self._pixelToData(x, y)
+
+ def drag(self, x, y, btn):
+ xData, yData, y2Data = self._pixelToData(x, y)
+ lastX, lastY, lastY2 = self._previousDataPos
+
+ xMin, xMax = self.plot.getXAxis().getLimits()
+ yMin, yMax = self.plot.getYAxis().getLimits()
+ y2Min, y2Max = self.plot.getYAxis(axis="right").getLimits()
+
+ if self.plot.getXAxis()._isLogarithmic():
+ try:
+ dx = math.log10(xData) - math.log10(lastX)
+ newXMin = pow(10.0, (math.log10(xMin) - dx))
+ newXMax = pow(10.0, (math.log10(xMax) - dx))
+ except (ValueError, OverflowError):
+ newXMin, newXMax = xMin, xMax
+
+ # Makes sure both values stays in positive float32 range
+ if newXMin < FLOAT32_MINPOS or newXMax > FLOAT32_SAFE_MAX:
+ newXMin, newXMax = xMin, xMax
+ else:
+ dx = xData - lastX
+ newXMin, newXMax = xMin - dx, xMax - dx
+
+ # Makes sure both values stays in float32 range
+ if newXMin < FLOAT32_SAFE_MIN or newXMax > FLOAT32_SAFE_MAX:
+ newXMin, newXMax = xMin, xMax
+
+ if self.plot.getYAxis()._isLogarithmic():
+ try:
+ dy = math.log10(yData) - math.log10(lastY)
+ newYMin = pow(10.0, math.log10(yMin) - dy)
+ newYMax = pow(10.0, math.log10(yMax) - dy)
+
+ dy2 = math.log10(y2Data) - math.log10(lastY2)
+ newY2Min = pow(10.0, math.log10(y2Min) - dy2)
+ newY2Max = pow(10.0, math.log10(y2Max) - dy2)
+ except (ValueError, OverflowError):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+
+ # Makes sure y and y2 stays in positive float32 range
+ if (
+ newYMin < FLOAT32_MINPOS
+ or newYMax > FLOAT32_SAFE_MAX
+ or newY2Min < FLOAT32_MINPOS
+ or newY2Max > FLOAT32_SAFE_MAX
+ ):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+ else:
+ dy = yData - lastY
+ dy2 = y2Data - lastY2
+ newYMin, newYMax = yMin - dy, yMax - dy
+ newY2Min, newY2Max = y2Min - dy2, y2Max - dy2
+
+ # Makes sure y and y2 stays in float32 range
+ if (
+ newYMin < FLOAT32_SAFE_MIN
+ or newYMax > FLOAT32_SAFE_MAX
+ or newY2Min < FLOAT32_SAFE_MIN
+ or newY2Max > FLOAT32_SAFE_MAX
+ ):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+
+ self.plot.setLimits(newXMin, newXMax, newYMin, newYMax, newY2Min, newY2Max)
+
+ self._previousDataPos = self._pixelToData(x, y)
+
+ def endDrag(self, startPos, endPos, btn):
+ del self._previousDataPos
+
+ def cancel(self):
+ pass
+
+
+# Zoom ########################################################################
+
+
+class AxesExtent(NamedTuple):
+ xmin: float
+ xmax: float
+ ymin: float
+ ymax: float
+ y2min: float
+ y2max: float
+
+
+class Zoom(_PlotInteractionWithClickEvents):
+ """Zoom-in/out state machine.
+
+ Zoom-in on selected area, zoom-out on right click,
+ and zoom on mouse wheel.
+ """
+
+ SURFACE_THRESHOLD = 5
+
+ def __init__(self, plot, color):
+ self.color = color
+ self.enabledAxes = EnabledAxes()
+
+ super(Zoom, self).__init__(plot)
+ self.plot.getLimitsHistory().clear()
+
+ def _getAxesExtent(
+ self,
+ x0: float,
+ y0: float,
+ x1: float,
+ y1: float,
+ enabledAxes: Optional[EnabledAxes] = None,
+ ) -> AxesExtent:
+ """Convert selection coordinates (pixels) to axes coordinates (data)
+
+ This takes into account axes selected for zoom and aspect ratio.
+ """
+ if enabledAxes is None:
+ enabledAxes = self.enabledAxes
+
+ y2_0, y2_1 = y0, y1
+ left, top, width, height = self.plot.getPlotBoundsInPixels()
+
+ if not all(enabledAxes) and not self.plot.isKeepDataAspectRatio():
+ # Handle axes disabled for zoom if plot is not keeping aspec ratio
+ if not enabledAxes.xaxis:
+ x0, x1 = left, left + width
+ if not enabledAxes.yaxis:
+ y0, y1 = top, top + height
+ if not enabledAxes.y2axis:
+ y2_0, y2_1 = top, top + height
+
+ if self.plot.isKeepDataAspectRatio() and height != 0 and width != 0:
+ ratio = width / height
+ xextent, yextent = math.fabs(x1 - x0), math.fabs(y1 - y0)
+ if xextent != 0 and yextent != 0:
+ if xextent / yextent > ratio:
+ areaHeight = xextent / ratio
+ center = 0.5 * (y0 + y1)
+ y0 = center - numpy.sign(y1 - y0) * 0.5 * areaHeight
+ y1 = center + numpy.sign(y1 - y0) * 0.5 * areaHeight
+ else:
+ areaWidth = yextent * ratio
+ center = 0.5 * (x0 + x1)
+ x0 = center - numpy.sign(x1 - x0) * 0.5 * areaWidth
+ x1 = center + numpy.sign(x1 - x0) * 0.5 * areaWidth
+
+ # Convert to data space
+ x0, y0 = self.plot.pixelToData(x0, y0, check=False)
+ x1, y1 = self.plot.pixelToData(x1, y1, check=False)
+ y2_0 = self.plot.pixelToData(None, y2_0, axis="right", check=False)[1]
+ y2_1 = self.plot.pixelToData(None, y2_1, axis="right", check=False)[1]
+
+ return AxesExtent(
+ min(x0, x1),
+ max(x0, x1),
+ min(y0, y1),
+ max(y0, y1),
+ min(y2_0, y2_1),
+ max(y2_0, y2_1),
+ )
+
+ def beginDrag(self, x, y, btn):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.x0, self.y0 = x, y
+
+ def drag(self, x1, y1, btn):
+ if self.color is None:
+ return # Do not draw zoom area
+
+ dataPos = self.plot.pixelToData(x1, y1)
+ assert dataPos is not None
+
+ if self.plot.isKeepDataAspectRatio() or not all(self.enabledAxes):
+ # Patch enabledAxes to display the right Y axis area on the left Y axis
+ # since the selection area is always displayed on the left Y axis
+ isY2Visible = self.plot.getYAxis("right").isVisible()
+ areaZoomEnabledAxes = EnabledAxes(
+ self.enabledAxes.xaxis,
+ self.enabledAxes.yaxis and (not isY2Visible or self.enabledAxes.y2axis),
+ self.enabledAxes.y2axis,
+ )
+ extents = self._getAxesExtent(self.x0, self.y0, x1, y1, areaZoomEnabledAxes)
+ areaCorners = (
+ (extents.xmin, extents.ymin),
+ (extents.xmax, extents.ymin),
+ (extents.xmax, extents.ymax),
+ (extents.xmin, extents.ymax),
+ )
+
+ if self.color != "video inverted":
+ areaColor = list(self.color)
+ areaColor[3] *= 0.25
+ else:
+ areaColor = [1.0, 1.0, 1.0, 1.0]
+
+ self.setSelectionArea(
+ areaCorners, fill="none", color=areaColor, name="zoomedArea"
+ )
+
+ corners = ((self.x0, self.y0), (self.x0, y1), (x1, y1), (x1, self.y0))
+ corners = numpy.array(
+ [self.plot.pixelToData(x, y, check=False) for (x, y) in corners]
+ )
+
+ self.setSelectionArea(corners, fill="none", color=self.color)
+
+ def _zoom(self, x0, y0, x1, y1):
+ """Zoom to the rectangle view x0,y0 x1,y1."""
+ # Store current zoom state in stack
+ self.plot.getLimitsHistory().push()
+
+ extents = self._getAxesExtent(x0, y0, x1, y1)
+ self.plot.setLimits(
+ extents.xmin,
+ extents.xmax,
+ extents.ymin,
+ extents.ymax,
+ extents.y2min,
+ extents.y2max,
+ )
+
+ def endDrag(self, startPos, endPos, btn):
+ x0, y0 = startPos
+ x1, y1 = endPos
+
+ if abs(x0 - x1) * abs(y0 - y1) >= self.SURFACE_THRESHOLD:
+ # Avoid empty zoom area
+ self._zoom(x0, y0, x1, y1)
+
+ self.resetSelectionArea()
+
+ def cancel(self):
+ if isinstance(self.state, self.states["drag"]):
+ self.resetSelectionArea()
+
+
+# Select ######################################################################
+
+
+class Select(StateMachine, _PlotInteraction):
+ """Base class for drawing selection areas."""
+
+ def __init__(self, plot, parameters, states, state):
+ """Init a state machine.
+
+ :param plot: The plot to apply changes to.
+ :param dict parameters: A dict of parameters such as color.
+ :param dict states: The states of the state machine.
+ :param str state: The name of the initial state.
+ """
+ _PlotInteraction.__init__(self, plot)
+ self.parameters = parameters
+ StateMachine.__init__(self, states, state)
+
+ @property
+ def color(self):
+ return self.parameters.get("color", None)
+
+
+class SelectPolygon(Select):
+ """Drawing selection polygon area state machine."""
+
+ DRAG_THRESHOLD_DIST = 4
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto("select", x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self._firstPos = dataPos
+ self.points = [dataPos, dataPos]
+
+ self.updateFirstPoint()
+
+ def updateFirstPoint(self):
+ """Update drawing first point, using self._firstPos"""
+ x, y = self.machine.plot.dataToPixel(*self._firstPos, check=False)
+
+ offset = self.machine.getDragThreshold()
+ points = [
+ (x - offset, y - offset),
+ (x - offset, y + offset),
+ (x + offset, y + offset),
+ (x + offset, y - offset),
+ ]
+ points = [
+ self.machine.plot.pixelToData(xpix, ypix, check=False)
+ for xpix, ypix in points
+ ]
+ self.machine.setSelectionArea(
+ points, fill=None, color=self.machine.color, name="first_point"
+ )
+
+ def updateSelectionArea(self):
+ """Update drawing selection area using self.points"""
+ self.machine.setSelectionArea(
+ self.points, fill="hatch", color=self.machine.color
+ )
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "polygon", self.points, self.machine.parameters
+ )
+ self.machine.plot.notify(**eventDict)
+
+ def validate(self):
+ if len(self.points) > 2:
+ self.closePolygon()
+ else:
+ # It would be nice to have a cancel event.
+ # The plot is not aware that the interaction was cancelled
+ self.machine.cancel()
+
+ def closePolygon(self):
+ self.machine.resetSelectionArea()
+ self.points[-1] = self.points[0]
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "polygon", self.points, self.machine.parameters
+ )
+ self.machine.plot.notify(**eventDict)
+ self.goto("idle")
+
+ def onWheel(self, x, y, angle):
+ self.machine.onWheel(x, y, angle)
+ self.updateFirstPoint()
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ # checking if the position is close to the first point
+ # if yes : closing the "loop"
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos, check=False)
+ dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
+
+ threshold = self.machine.getDragThreshold()
+
+ # Only allow to close polygon after first point
+ if len(self.points) > 2 and dx <= threshold and dy <= threshold:
+ self.closePolygon()
+ return False
+
+ # Update polygon last point not too close to previous one
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.updateSelectionArea()
+
+ # checking that the new points isnt the same (within range)
+ # of the previous one
+ # This has to be done because sometimes the mouse release event
+ # is caught right after entering the Select state (i.e : press
+ # in Idle state, but with a slightly different position that
+ # the mouse press. So we had the two first vertices that were
+ # almost identical.
+ previousPos = self.machine.plot.dataToPixel(
+ *self.points[-2], check=False
+ )
+ dx, dy = abs(previousPos[0] - x), abs(previousPos[1] - y)
+ if dx >= threshold or dy >= threshold:
+ self.points.append(dataPos)
+ else:
+ self.points[-1] = dataPos
+
+ return True
+ return False
+
+ def onMove(self, x, y):
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos, check=False)
+ dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
+ threshold = self.machine.getDragThreshold()
+
+ if dx <= threshold and dy <= threshold:
+ x, y = firstPos # Snap to first point
+
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.points[-1] = dataPos
+ self.updateSelectionArea()
+
+ def __init__(self, plot, parameters):
+ states = {"idle": SelectPolygon.Idle, "select": SelectPolygon.Select}
+ super(SelectPolygon, self).__init__(plot, parameters, states, "idle")
+
+ def cancel(self):
+ if isinstance(self.state, self.states["select"]):
+ self.resetSelectionArea()
+
+ def getDragThreshold(self):
+ """Return dragging ratio with device to pixel ratio applied.
+
+ :rtype: float
+ """
+ ratio = self.plot.window().windowHandle().devicePixelRatio()
+ return self.DRAG_THRESHOLD_DIST * ratio
+
+
+class Select2Points(Select):
+ """Base class for drawing selection based on 2 input points."""
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto("start", x, y)
+ return True
+
+ class Start(State):
+ def enterState(self, x, y):
+ self.machine.beginSelect(x, y)
+
+ def onMove(self, x, y):
+ self.goto("select", x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto("select", x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.onMove(x, y)
+
+ def onMove(self, x, y):
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.endSelect(x, y)
+ self.goto("idle")
+
+ def __init__(self, plot, parameters):
+ states = {
+ "idle": Select2Points.Idle,
+ "start": Select2Points.Start,
+ "select": Select2Points.Select,
+ }
+ super(Select2Points, self).__init__(plot, parameters, states, "idle")
+
+ def beginSelect(self, x, y):
+ pass
+
+ def select(self, x, y):
+ pass
+
+ def endSelect(self, x, y):
+ pass
+
+ def cancelSelect(self):
+ pass
+
+ def cancel(self):
+ if isinstance(self.state, self.states["select"]):
+ self.cancelSelect()
+
+
+class SelectEllipse(Select2Points):
+ """Drawing ellipse selection area state machine."""
+
+ def beginSelect(self, x, y):
+ self.center = self.plot.pixelToData(x, y)
+ assert self.center is not None
+
+ def _getEllipseSize(self, pointInEllipse):
+ """
+ Returns the size from the center to the bounding box of the ellipse.
+
+ :param Tuple[float,float] pointInEllipse: A point of the ellipse
+ :rtype: Tuple[float,float]
+ """
+ x = abs(self.center[0] - pointInEllipse[0])
+ y = abs(self.center[1] - pointInEllipse[1])
+ if x == 0 or y == 0:
+ return x, y
+ # Ellipse definitions
+ # e: eccentricity
+ # a: length fron center to bounding box width
+ # b: length fron center to bounding box height
+ # Equations
+ # (1) b < a
+ # (2) For x,y a point in the ellipse: x^2/a^2 + y^2/b^2 = 1
+ # (3) b = a * sqrt(1-e^2)
+ # (4) e = sqrt(a^2 - b^2) / a
+
+ # The eccentricity of the ellipse defined by a,b=x,y is the same
+ # as the one we are searching for.
+ swap = x < y
+ if swap:
+ x, y = y, x
+ e = math.sqrt(x**2 - y**2) / x
+ # From (2) using (3) to replace b
+ # a^2 = x^2 + y^2 / (1-e^2)
+ a = math.sqrt(x**2 + y**2 / (1.0 - e**2))
+ b = a * math.sqrt(1 - e**2)
+ if swap:
+ a, b = b, a
+ return a, b
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ # Circle used for circle preview
+ nbpoints = 27.0
+ angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints
+ circleShape = numpy.array(
+ (numpy.cos(angles) * width, numpy.sin(angles) * height)
+ ).T
+ circleShape += numpy.array(self.center)
+
+ self.setSelectionArea(
+ circleShape, shape="polygon", fill="hatch", color=self.color
+ )
+
+ eventDict = prepareDrawingSignal(
+ "drawingProgress",
+ "ellipse",
+ (self.center, (width, height)),
+ self.parameters,
+ )
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ eventDict = prepareDrawingSignal(
+ "drawingFinished",
+ "ellipse",
+ (self.center, (width, height)),
+ self.parameters,
+ )
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectRectangle(Select2Points):
+ """Drawing rectangle selection area state machine."""
+
+ def beginSelect(self, x, y):
+ self.startPt = self.plot.pixelToData(x, y)
+ assert self.startPt is not None
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ self.setSelectionArea(
+ (
+ self.startPt,
+ (self.startPt[0], dataPos[1]),
+ dataPos,
+ (dataPos[0], self.startPt[1]),
+ ),
+ fill="hatch",
+ color=self.color,
+ )
+
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "rectangle", (self.startPt, dataPos), self.parameters
+ )
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "rectangle", (self.startPt, dataPos), self.parameters
+ )
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectLine(Select2Points):
+ """Drawing line selection area state machine."""
+
+ def beginSelect(self, x, y):
+ self.startPt = self.plot.pixelToData(x, y)
+ assert self.startPt is not None
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ self.setSelectionArea((self.startPt, dataPos), fill="hatch", color=self.color)
+
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "line", (self.startPt, dataPos), self.parameters
+ )
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "line", (self.startPt, dataPos), self.parameters
+ )
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class Select1Point(Select):
+ """Base class for drawing selection area based on one input point."""
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto("select", x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.onMove(x, y)
+
+ def onMove(self, x, y):
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.endSelect(x, y)
+ self.goto("idle")
+
+ def onWheel(self, x, y, angle):
+ self.machine.onWheel(x, y, angle) # Call select default wheel
+ self.machine.select(x, y)
+
+ def __init__(self, plot, parameters):
+ states = {"idle": Select1Point.Idle, "select": Select1Point.Select}
+ super(Select1Point, self).__init__(plot, parameters, states, "idle")
+
+ def select(self, x, y):
+ pass
+
+ def endSelect(self, x, y):
+ pass
+
+ def cancelSelect(self):
+ pass
+
+ def cancel(self):
+ if isinstance(self.state, self.states["select"]):
+ self.cancelSelect()
+
+
+class SelectHLine(Select1Point):
+ """Drawing a horizontal line selection area state machine."""
+
+ def _hLine(self, y):
+ """Return points in data coords of the segment visible in the plot.
+
+ Supports non-orthogonal axes.
+ """
+ left, _top, width, _height = self.plot.getPlotBoundsInPixels()
+
+ dataPos1 = self.plot.pixelToData(left, y, check=False)
+ dataPos2 = self.plot.pixelToData(left + width, y, check=False)
+ return dataPos1, dataPos2
+
+ def select(self, x, y):
+ points = self._hLine(y)
+ self.setSelectionArea(points, fill="hatch", color=self.color)
+
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "hline", points, self.parameters
+ )
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "hline", self._hLine(y), self.parameters
+ )
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectVLine(Select1Point):
+ """Drawing a vertical line selection area state machine."""
+
+ def _vLine(self, x):
+ """Return points in data coords of the segment visible in the plot.
+
+ Supports non-orthogonal axes.
+ """
+ _left, top, _width, height = self.plot.getPlotBoundsInPixels()
+
+ dataPos1 = self.plot.pixelToData(x, top, check=False)
+ dataPos2 = self.plot.pixelToData(x, top + height, check=False)
+ return dataPos1, dataPos2
+
+ def select(self, x, y):
+ points = self._vLine(x)
+ self.setSelectionArea(points, fill="hatch", color=self.color)
+
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "vline", points, self.parameters
+ )
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "vline", self._vLine(x), self.parameters
+ )
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class DrawFreeHand(Select):
+ """Interaction for drawing pencil. It display the preview of the pencil
+ before pressing the mouse.
+ """
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto("select", x, y)
+ return True
+
+ def onMove(self, x, y):
+ self.machine.updatePencilShape(x, y)
+
+ def onLeave(self):
+ self.machine.cancel()
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.__isOut = False
+ self.machine.setFirstPoint(x, y)
+
+ def onMove(self, x, y):
+ self.machine.updatePencilShape(x, y)
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ if self.__isOut:
+ self.machine.resetSelectionArea()
+ self.machine.endSelect(x, y)
+ self.goto("idle")
+
+ def onEnter(self):
+ self.__isOut = False
+
+ def onLeave(self):
+ self.__isOut = True
+
+ def __init__(self, plot, parameters):
+ # Circle used for pencil preview
+ angle = numpy.arange(13.0) * numpy.pi * 2.0 / 13.0
+ size = parameters.get("width", 1.0) * 0.5
+ self._circle = size * numpy.array((numpy.cos(angle), numpy.sin(angle))).T
+
+ states = {"idle": DrawFreeHand.Idle, "select": DrawFreeHand.Select}
+ super(DrawFreeHand, self).__init__(plot, parameters, states, "idle")
+
+ @property
+ def width(self):
+ return self.parameters.get("width", None)
+
+ def setFirstPoint(self, x, y):
+ self._points = []
+ self.select(x, y)
+
+ def updatePencilShape(self, x, y):
+ center = self.plot.pixelToData(x, y, check=False)
+ assert center is not None
+
+ polygon = center + self._circle
+
+ self.setSelectionArea(polygon, fill="none", color=self.color)
+
+ def select(self, x, y):
+ pos = self.plot.pixelToData(x, y, check=False)
+ if len(self._points) > 0:
+ if self._points[-1] == pos:
+ # Skip same points
+ return
+ self._points.append(pos)
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "polylines", self._points, self.parameters
+ )
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ pos = self.plot.pixelToData(x, y, check=False)
+ if len(self._points) > 0:
+ if self._points[-1] != pos:
+ # Append if different
+ self._points.append(pos)
+
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "polylines", self._points, self.parameters
+ )
+ self.plot.notify(**eventDict)
+ self._points = None
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+ def cancel(self):
+ self.resetSelectionArea()
+
+
+class SelectFreeLine(ClickOrDrag, _PlotInteraction):
+ """Base class for drawing free lines with tools such as pencil."""
+
+ def __init__(self, plot, parameters):
+ """Init a state machine.
+
+ :param plot: The plot to apply changes to.
+ :param dict parameters: A dict of parameters such as color.
+ """
+ # self.DRAG_THRESHOLD_SQUARE_DIST = 1 # Disable first move threshold
+ self._points = []
+ ClickOrDrag.__init__(self)
+ _PlotInteraction.__init__(self, plot)
+ self.parameters = parameters
+
+ @property
+ def color(self):
+ return self.parameters.get("color", None)
+
+ def click(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self._processEvent(x, y, isLast=True)
+
+ def beginDrag(self, x, y, btn):
+ self._processEvent(x, y, isLast=False)
+
+ def drag(self, x, y, btn):
+ self._processEvent(x, y, isLast=False)
+
+ def endDrag(self, startPos, endPos, btn):
+ x, y = endPos
+ self._processEvent(x, y, isLast=True)
+
+ def cancel(self):
+ self.resetSelectionArea()
+ self._points = []
+
+ def _processEvent(self, x, y, isLast):
+ dataPos = self.plot.pixelToData(x, y, check=False)
+ isNewPoint = not self._points or dataPos != self._points[-1]
+
+ if isNewPoint:
+ self._points.append(dataPos)
+
+ if isNewPoint or isLast:
+ eventDict = prepareDrawingSignal(
+ "drawingFinished" if isLast else "drawingProgress",
+ "polylines",
+ self._points,
+ self.parameters,
+ )
+ self.plot.notify(**eventDict)
+
+ if not isLast:
+ self.setSelectionArea(
+ self._points, fill="none", color=self.color, shape="polylines"
+ )
+ else:
+ self.cancel()
+
+
+# ItemInteraction #############################################################
+
+
+class ItemsInteraction(ClickOrDrag, _PlotInteraction):
+ """Interaction with items (markers, curves and images).
+
+ This class provides selection and dragging of plot primitives
+ that support those interaction.
+ It is also meant to be combined with the zoom interaction.
+ """
+
+ class Idle(ClickOrDrag.Idle):
+ def __init__(self, *args, **kw):
+ super(ItemsInteraction.Idle, self).__init__(*args, **kw)
+ self._hoverMarker = None
+
+ def enterState(self):
+ widget = self.machine.plot.getWidgetHandle()
+ if widget is None or not widget.isVisible():
+ return
+ position = widget.mapFromGlobal(qt.QCursor.pos())
+ self.onMove(position.x(), position.y())
+
+ def onMove(self, x, y):
+ marker = self.machine.plot._getMarkerAt(x, y)
+
+ if marker is not None:
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareHoverSignal(
+ marker.getName(),
+ "marker",
+ dataPos,
+ (x, y),
+ marker.isDraggable(),
+ marker.isSelectable(),
+ )
+ self.machine.plot.notify(**eventDict)
+
+ if marker != self._hoverMarker:
+ self._hoverMarker = marker
+ self.machine._setCursorForMarker(marker)
+
+ return True
+
+ def __init__(self, plot):
+ self._pan = Pan(plot)
+
+ _PlotInteraction.__init__(self, plot)
+ ClickOrDrag.__init__(
+ self, clickButtons=(LEFT_BTN, RIGHT_BTN), dragButtons=(LEFT_BTN, MIDDLE_BTN)
+ )
+
+ def _setCursorForMarker(self, marker: Optional[items.MarkerBase] = None):
+ """Set mouse cursor for given marker"""
+ if marker is None:
+ cursor = None
+
+ elif marker.isDraggable():
+ if isinstance(marker, items.YMarker):
+ cursor = CURSOR_SIZE_VER
+ elif isinstance(marker, items.XMarker):
+ cursor = CURSOR_SIZE_HOR
+ else:
+ cursor = CURSOR_SIZE_ALL
+
+ elif marker.isSelectable():
+ cursor = CURSOR_POINTING
+
+ else:
+ cursor = None
+
+ self.plot.setGraphCursorShape(cursor)
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal(
+ "mouseClicked", btn, dataPos[0], dataPos[1], x, y
+ )
+ self.plot.notify(**eventDict)
+
+ eventDict = self._handleClick(x, y, btn)
+ if eventDict is not None:
+ self.plot.notify(**eventDict)
+
+ def _handleClick(self, x, y, btn):
+ """Perform picking and prepare event if click is handled here
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: event description to send of None if not handling event.
+ :rtype: dict or None
+ """
+
+ if btn == LEFT_BTN:
+ result = self.plot._pickTopMost(x, y, lambda i: i.isSelectable())
+ if result is None:
+ return None
+
+ item = result.getItem()
+
+ if isinstance(item, items.MarkerBase):
+ xData, yData = item.getPosition()
+ if xData is None:
+ xData = [0, 1]
+ if yData is None:
+ yData = [0, 1]
+
+ eventDict = prepareMarkerSignal(
+ "markerClicked",
+ "left",
+ item.getName(),
+ "marker",
+ item.isDraggable(),
+ item.isSelectable(),
+ (xData, yData),
+ (x, y),
+ None,
+ )
+ return eventDict
+
+ elif isinstance(item, items.Curve):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ xData = item.getXData(copy=False)
+ yData = item.getYData(copy=False)
+
+ indices = result.getIndices(copy=False)
+ eventDict = prepareCurveSignal(
+ "left",
+ item,
+ xData[indices],
+ yData[indices],
+ dataPos[0],
+ dataPos[1],
+ x,
+ y,
+ )
+ return eventDict
+
+ elif isinstance(item, items.ImageBase):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ indices = result.getIndices(copy=False)
+ row, column = indices[0][0], indices[1][0]
+ eventDict = prepareImageSignal(
+ "left", item, column, row, dataPos[0], dataPos[1], x, y
+ )
+ return eventDict
+
+ return None
+
+ def _signalMarkerMovingEvent(self, eventType, marker, x, y):
+ assert marker is not None
+
+ xData, yData = marker.getPosition()
+ if xData is None:
+ xData = [0, 1]
+ if yData is None:
+ yData = [0, 1]
+
+ posDataCursor = self.plot.pixelToData(x, y)
+ assert posDataCursor is not None
+
+ eventDict = prepareMarkerSignal(
+ eventType,
+ "left",
+ marker.getName(),
+ "marker",
+ marker.isDraggable(),
+ marker.isSelectable(),
+ (xData, yData),
+ (x, y),
+ posDataCursor,
+ )
+ self.plot.notify(**eventDict)
+
+ @staticmethod
+ def __isDraggableItem(item):
+ return isinstance(item, items.DraggableMixIn) and item.isDraggable()
+
+ def __terminateDrag(self, x, y):
+ """Finalize a drag operation by reseting to initial state"""
+ self._setCursorForMarker(self.plot._getMarkerAt(x, y))
+ self.draggedItemRef = None
+
+ def beginDrag(self, x, y, btn):
+ """Handle begining of drag interaction
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param str btn: The mouse button for which a drag is starting.
+ :return: True if drag is catched by an item, False otherwise
+ """
+ if btn == LEFT_BTN:
+ self._lastPos = self.plot.pixelToData(x, y)
+ assert self._lastPos is not None
+
+ result = self.plot._pickTopMost(x, y, self.__isDraggableItem)
+ item = result.getItem() if result is not None else None
+
+ self.draggedItemRef = None if item is None else weakref.ref(item)
+
+ if item is None:
+ self.__terminateDrag(x, y)
+ return False
+
+ if isinstance(item, items.MarkerBase):
+ self._signalMarkerMovingEvent("markerMoving", item, x, y)
+ item._startDrag()
+
+ return True
+ elif btn == MIDDLE_BTN:
+ self._pan.beginDrag(x, y, btn)
+ return True
+
+ def drag(self, x, y, btn):
+ if btn == LEFT_BTN:
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ item = None if self.draggedItemRef is None else self.draggedItemRef()
+ if item is not None:
+ item.drag(self._lastPos, dataPos)
+
+ if isinstance(item, items.MarkerBase):
+ self._signalMarkerMovingEvent("markerMoving", item, x, y)
+
+ self._lastPos = dataPos
+ elif btn == MIDDLE_BTN:
+ self._pan.drag(x, y, btn)
+
+ def endDrag(self, startPos, endPos, btn):
+ if btn == LEFT_BTN:
+ item = None if self.draggedItemRef is None else self.draggedItemRef()
+ if isinstance(item, items.MarkerBase):
+ posData = list(item.getPosition())
+ if posData[0] is None:
+ posData[0] = 1.0
+ if posData[1] is None:
+ posData[1] = 1.0
+
+ eventDict = prepareMarkerSignal(
+ "markerMoved",
+ "left",
+ item.getLegend(),
+ "marker",
+ item.isDraggable(),
+ item.isSelectable(),
+ posData,
+ )
+ self.plot.notify(**eventDict)
+ item._endDrag()
+
+ self.__terminateDrag(*endPos)
+ elif btn == MIDDLE_BTN:
+ self._pan.endDrag(startPos, endPos, btn)
+
+ def cancel(self):
+ self._pan.cancel()
+ widget = self.plot.getWidgetHandle()
+ if widget is None or not widget.isVisible():
+ return
+ position = widget.mapFromGlobal(qt.QCursor.pos())
+ self.__terminateDrag(position.x(), position.y())
+
+
+class ItemsInteractionForCombo(ItemsInteraction):
+ """Interaction with items to combine through :class:`FocusManager`."""
+
+ class Idle(ItemsInteraction.Idle):
+ @staticmethod
+ def __isItemSelectableOrDraggable(item):
+ return item.isSelectable() or (
+ isinstance(item, items.DraggableMixIn) and item.isDraggable()
+ )
+
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ result = self.machine.plot._pickTopMost(
+ x, y, self.__isItemSelectableOrDraggable
+ )
+ if result is not None: # Request focus and handle interaction
+ self.goto("clickOrDrag", x, y, btn)
+ return True
+ else: # Do not request focus
+ return False
+ else:
+ return super().onPress(x, y, btn)
+
+
+# FocusManager ################################################################
+
+
+class FocusManager(StateMachine):
+ """Manages focus across multiple event handlers
+
+ On press an event handler can acquire focus.
+ By default it looses focus when all buttons are released.
+ """
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ for eventHandler in self.machine.eventHandlers:
+ requestFocus = eventHandler.handleEvent("press", x, y, btn)
+ if requestFocus:
+ self.goto("focus", eventHandler, btn)
+ break
+
+ def _processEvent(self, *args):
+ for eventHandler in self.machine.eventHandlers:
+ consumeEvent = eventHandler.handleEvent(*args)
+ if consumeEvent:
+ break
+
+ def onMove(self, x, y):
+ self._processEvent("move", x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self._processEvent("release", x, y, btn)
+
+ def onWheel(self, x, y, angle):
+ self._processEvent("wheel", x, y, angle)
+
+ class Focus(State):
+ def enterState(self, eventHandler, btn):
+ self.eventHandler = eventHandler
+ self.focusBtns = {btn}
+
+ def validate(self):
+ self.eventHandler.validate()
+ self.goto("idle")
+
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.focusBtns.add(btn)
+ self.eventHandler.handleEvent("press", x, y, btn)
+
+ def onMove(self, x, y):
+ self.eventHandler.handleEvent("move", x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.focusBtns.discard(btn)
+ requestFocus = self.eventHandler.handleEvent("release", x, y, btn)
+ if len(self.focusBtns) == 0 and not requestFocus:
+ self.goto("idle")
+
+ def onWheel(self, x, y, angleInDegrees):
+ self.eventHandler.handleEvent("wheel", x, y, angleInDegrees)
+
+ def __init__(self, eventHandlers=()):
+ self.eventHandlers = list(eventHandlers)
+
+ states = {"idle": FocusManager.Idle, "focus": FocusManager.Focus}
+ super(FocusManager, self).__init__(states, "idle")
+
+ def cancel(self):
+ for handler in self.eventHandlers:
+ handler.cancel()
+
+
+class ZoomAndSelect(ItemsInteraction):
+ """Combine Zoom and ItemInteraction state machine.
+
+ :param plot: The Plot to which this interaction is attached
+ :param color: The color to use for the zoom area bounding box
+ """
+
+ def __init__(self, plot, color):
+ super(ZoomAndSelect, self).__init__(plot)
+ self._zoom = Zoom(plot, color)
+ self._doZoom = False
+
+ @property
+ def color(self):
+ """Color of the zoom area"""
+ return self._zoom.color
+
+ @property
+ def zoomEnabledAxes(self) -> EnabledAxes:
+ """Whether or not to apply zoom for each axis"""
+ return self._zoom.enabledAxes
+
+ @zoomEnabledAxes.setter
+ def zoomEnabledAxes(self, enabledAxes: EnabledAxes):
+ self._zoom.enabledAxes = enabledAxes
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ eventDict = self._handleClick(x, y, btn)
+
+ if eventDict is not None:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ clickedEventDict = prepareMouseSignal(
+ "mouseClicked", btn, dataPos[0], dataPos[1], x, y
+ )
+ self.plot.notify(**clickedEventDict)
+
+ self.plot.notify(**eventDict)
+
+ else:
+ self._zoom.click(x, y, btn)
+
+ def beginDrag(self, x, y, btn):
+ """Handle start drag and switching between zoom and item drag.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is starting.
+ """
+ self._doZoom = not super(ZoomAndSelect, self).beginDrag(x, y, btn)
+ if self._doZoom:
+ self._zoom.beginDrag(x, y, btn)
+
+ def drag(self, x, y, btn):
+ """Handle drag, eventually forwarding to zoom.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is in progress.
+ """
+ if self._doZoom:
+ return self._zoom.drag(x, y, btn)
+ else:
+ return super(ZoomAndSelect, self).drag(x, y, btn)
+
+ def endDrag(self, startPos, endPos, btn):
+ """Handle end of drag, eventually forwarding to zoom.
+
+ :param startPos: (x, y) position at the beginning of the drag
+ :param endPos: (x, y) position at the end of the drag
+ :param str btn: The mouse button for which a drag is done.
+ """
+ if self._doZoom:
+ return self._zoom.endDrag(startPos, endPos, btn)
+ else:
+ return super(ZoomAndSelect, self).endDrag(startPos, endPos, btn)
+
+
+class PanAndSelect(ItemsInteraction):
+ """Combine Pan and ItemInteraction state machine.
+
+ :param plot: The Plot to which this interaction is attached
+ """
+
+ def __init__(self, plot):
+ super(PanAndSelect, self).__init__(plot)
+ self._pan = Pan(plot)
+ self._doPan = False
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ eventDict = self._handleClick(x, y, btn)
+
+ if eventDict is not None:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ clickedEventDict = prepareMouseSignal(
+ "mouseClicked", btn, dataPos[0], dataPos[1], x, y
+ )
+ self.plot.notify(**clickedEventDict)
+
+ self.plot.notify(**eventDict)
+
+ else:
+ self._pan.click(x, y, btn)
+
+ def beginDrag(self, x, y, btn):
+ """Handle start drag and switching between zoom and item drag.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is starting.
+ """
+ self._doPan = not super(PanAndSelect, self).beginDrag(x, y, btn)
+ if self._doPan:
+ self._pan.beginDrag(x, y, btn)
+
+ def drag(self, x, y, btn):
+ """Handle drag, eventually forwarding to zoom.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is in progress.
+ """
+ if self._doPan:
+ return self._pan.drag(x, y, btn)
+ else:
+ return super(PanAndSelect, self).drag(x, y, btn)
+
+ def endDrag(self, startPos, endPos, btn):
+ """Handle end of drag, eventually forwarding to zoom.
+
+ :param startPos: (x, y) position at the beginning of the drag
+ :param endPos: (x, y) position at the end of the drag
+ :param str btn: The mouse button for which a drag is done.
+ """
+ if self._doPan:
+ return self._pan.endDrag(startPos, endPos, btn)
+ else:
+ return super(PanAndSelect, self).endDrag(startPos, endPos, btn)
+
+
+# Interaction mode control ####################################################
+
+# Mapping of draw modes: event handler
+_DRAW_MODES = {
+ "polygon": SelectPolygon,
+ "rectangle": SelectRectangle,
+ "ellipse": SelectEllipse,
+ "line": SelectLine,
+ "vline": SelectVLine,
+ "hline": SelectHLine,
+ "polylines": SelectFreeLine,
+ "pencil": DrawFreeHand,
+}
+
+
+class DrawMode(FocusManager):
+ """Interactive mode for draw and select"""
+
+ def __init__(self, plot, shape, label, color, width):
+ eventHandlerClass = _DRAW_MODES[shape]
+ parameters = {
+ "shape": shape,
+ "label": label,
+ "color": color,
+ "width": width,
+ }
+ super().__init__(
+ (
+ Pan(plot, clickButtons=(), dragButtons=(MIDDLE_BTN,)),
+ eventHandlerClass(plot, parameters),
+ )
+ )
+
+ def getDescription(self):
+ """Returns the dict describing this interactive mode"""
+ params = self.eventHandlers[1].parameters.copy()
+ params["mode"] = "draw"
+ return params
+
+
+class DrawSelectMode(FocusManager):
+ """Interactive mode for draw and select"""
+
+ def __init__(self, plot, shape, label, color, width):
+ eventHandlerClass = _DRAW_MODES[shape]
+ self._pan = Pan(plot)
+ self._panStart = None
+ parameters = {
+ "shape": shape,
+ "label": label,
+ "color": color,
+ "width": width,
+ }
+ super().__init__(
+ (ItemsInteractionForCombo(plot), eventHandlerClass(plot, parameters))
+ )
+
+ def handleEvent(self, eventName, *args, **kwargs):
+ # Hack to add pan interaction to select-draw
+ # See issue Refactor PlotWidget interaction #3292
+ if eventName == "press" and args[2] == MIDDLE_BTN:
+ self._panStart = args[:2]
+ self._pan.beginDrag(*args)
+ return # Consume middle click events
+ elif eventName == "release" and args[2] == MIDDLE_BTN:
+ self._panStart = None
+ self._pan.endDrag(self._panStart, args[:2], MIDDLE_BTN)
+ return # Consume middle click events
+ elif self._panStart is not None and eventName == "move":
+ x, y = args[:2]
+ self._pan.drag(x, y, MIDDLE_BTN)
+
+ super().handleEvent(eventName, *args, **kwargs)
+
+ def getDescription(self):
+ """Returns the dict describing this interactive mode"""
+ params = self.eventHandlers[1].parameters.copy()
+ params["mode"] = "select-draw"
+ return params
+
+
+class PlotInteraction(qt.QObject):
+ """PlotWidget user interaction handler.
+
+ :param plot: The :class:`PlotWidget` to apply interaction to
+ """
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the interaction configuration has changed"""
+
+ _DRAW_MODES = {
+ "polygon": SelectPolygon,
+ "rectangle": SelectRectangle,
+ "ellipse": SelectEllipse,
+ "line": SelectLine,
+ "vline": SelectVLine,
+ "hline": SelectHLine,
+ "polylines": SelectFreeLine,
+ "pencil": DrawFreeHand,
+ }
+
+ def __init__(self, parent):
+ super().__init__(parent)
+ self.__zoomOnWheel = True
+ self.__zoomEnabledAxes = EnabledAxes()
+
+ # Default event handler
+ self._eventHandler = ItemsInteraction(parent)
+
+ def isZoomOnWheelEnabled(self) -> bool:
+ """Returns whether or not wheel interaction triggers zoom"""
+ return self.__zoomOnWheel
+
+ def setZoomOnWheelEnabled(self, enabled: bool):
+ """Toggle zoom on wheel interaction"""
+ if enabled != self.__zoomOnWheel:
+ self.__zoomOnWheel = enabled
+ self.sigChanged.emit()
+
+ def setZoomEnabledAxes(self, xaxis: bool, yaxis: bool, y2axis: bool):
+ """Toggle zoom interaction for each axis
+
+ This is taken into account only if the plot does not keep aspect ratio.
+ """
+ zoomEnabledAxes = EnabledAxes(xaxis, yaxis, y2axis)
+ if zoomEnabledAxes != self.__zoomEnabledAxes:
+ self.__zoomEnabledAxes = zoomEnabledAxes
+ if isinstance(self._eventHandler, ZoomAndSelect):
+ self._eventHandler.zoomEnabledAxes = zoomEnabledAxes
+ self.sigChanged.emit()
+
+ def getZoomEnabledAxes(self) -> EnabledAxes:
+ """Returns axes for which zoom is enabled"""
+ return self.__zoomEnabledAxes
+
+ def _getInteractiveMode(self):
+ """Returns the current interactive mode as a dict.
+
+ The returned dict contains at least the key 'mode'.
+ Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ It can also contains extra keys (e.g., 'color') specific to a mode
+ as provided to :meth:`_setInteractiveMode`.
+ """
+ if isinstance(self._eventHandler, ZoomAndSelect):
+ return {"mode": "zoom", "color": self._eventHandler.color}
+
+ elif isinstance(self._eventHandler, (DrawMode, DrawSelectMode)):
+ return self._eventHandler.getDescription()
+
+ elif isinstance(self._eventHandler, PanAndSelect):
+ return {"mode": "pan"}
+
+ else:
+ return {"mode": "select"}
+
+ def _validate(self):
+ """Validate the current interaction if possible
+
+ If was designed to close the polygon interaction.
+ """
+ self._eventHandler.validate()
+
+ def _setInteractiveMode(
+ self, mode, color="black", shape="polygon", label=None, width=None
+ ):
+ """Switch the interactive mode.
+
+ :param str mode: The name of the interactive mode.
+ In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ :param color: Only for 'draw' and 'zoom' modes.
+ Color to use for drawing selection area. Default black.
+ If None, selection area is not drawn.
+ :type color: Color description: The name as a str or
+ a tuple of 4 floats or None.
+ :param str shape: Only for 'draw' mode. The kind of shape to draw.
+ In 'polygon', 'rectangle', 'line', 'vline', 'hline',
+ 'polylines'.
+ Default is 'polygon'.
+ :param str label: Only for 'draw' mode.
+ :param float width: Width of the pencil. Only for draw pencil mode.
+ """
+ assert mode in ("draw", "pan", "select", "select-draw", "zoom")
+
+ plotWidget = self.parent()
+ assert plotWidget is not None
+
+ if isinstance(color, numpy.ndarray) or color not in (None, "video inverted"):
+ color = colors.rgba(color)
+
+ if mode in ("draw", "select-draw"):
+ self._eventHandler.cancel()
+ handlerClass = DrawMode if mode == "draw" else DrawSelectMode
+ self._eventHandler = handlerClass(plotWidget, shape, label, color, width)
+
+ elif mode == "pan":
+ # Ignores color, shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = PanAndSelect(plotWidget)
+
+ elif mode == "zoom":
+ # Ignores shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = ZoomAndSelect(plotWidget, color)
+ self._eventHandler.zoomEnabledAxes = self.getZoomEnabledAxes()
+
+ else: # Default mode: interaction with plot objects
+ # Ignores color, shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = ItemsInteraction(plotWidget)
+
+ self.sigChanged.emit()
+
+ def handleEvent(self, event, *args, **kwargs):
+ """Forward event to current interactive mode state machine."""
+ if event == "wheel": # Handle wheel events directly
+ self._onWheel(*args, **kwargs)
+ return
+
+ self._eventHandler.handleEvent(event, *args, **kwargs)
+
+ def _onWheel(self, x: float, y: float, angle: float):
+ """Handle wheel events"""
+ if not self.isZoomOnWheelEnabled():
+ return
+
+ plotWidget = self.parent()
+ if plotWidget is None:
+ return
+
+ # All axes are enabled if keep aspect ratio is on
+ enabledAxes = (
+ EnabledAxes()
+ if plotWidget.isKeepDataAspectRatio()
+ else self.getZoomEnabledAxes()
+ )
+ if enabledAxes.isDisabled():
+ return
+
+ scale = 1.1 if angle > 0 else 1.0 / 1.1
+ applyZoomToPlot(plotWidget, scale, (x, y), enabledAxes)
diff --git a/src/silx/gui/plot/PlotToolButtons.py b/src/silx/gui/plot/PlotToolButtons.py
new file mode 100644
index 0000000..e132877
--- /dev/null
+++ b/src/silx/gui/plot/PlotToolButtons.py
@@ -0,0 +1,554 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a set of QToolButton to use with
+:class:`~silx.gui.plot.PlotWidget`.
+
+The following QToolButton are available:
+
+- :class:`.AspectToolButton`
+- :class:`.YAxisOriginToolButton`
+- :class:`.ProfileToolButton`
+- :class:`.RulerToolButton`
+- :class:`.SymbolToolButton`
+
+"""
+
+__authors__ = ["V. Valls", "H. Payno"]
+__license__ = "MIT"
+__date__ = "27/06/2017"
+
+
+import functools
+import logging
+
+from .. import icons
+from .. import qt
+from ... import config
+from .tools.PlotToolButton import PlotToolButton
+
+from .items import SymbolMixIn, Scatter
+
+
+_logger = logging.getLogger(__name__)
+
+
+class AspectToolButton(PlotToolButton):
+ """Tool button to switch keep aspect ratio of a plot"""
+
+ STATE = None
+ """Lazy loaded states used to feed AspectToolButton"""
+
+ def __init__(self, parent=None, plot=None):
+ if self.STATE is None:
+ self.STATE = {}
+ # dont keep ratio
+ self.STATE[False, "icon"] = icons.getQIcon("shape-ellipse-solid")
+ self.STATE[False, "state"] = "Aspect ratio is not kept"
+ self.STATE[False, "action"] = "Do no keep data aspect ratio"
+ # keep ratio
+ self.STATE[True, "icon"] = icons.getQIcon("shape-circle-solid")
+ self.STATE[True, "state"] = "Aspect ratio is kept"
+ self.STATE[True, "action"] = "Keep data aspect ratio"
+
+ super(AspectToolButton, self).__init__(parent=parent, plot=plot)
+
+ keepAction = self._createAction(True)
+ keepAction.triggered.connect(self.keepDataAspectRatio)
+ keepAction.setIconVisibleInMenu(True)
+
+ dontKeepAction = self._createAction(False)
+ dontKeepAction.triggered.connect(self.dontKeepDataAspectRatio)
+ dontKeepAction.setIconVisibleInMenu(True)
+
+ menu = qt.QMenu(self)
+ menu.addAction(keepAction)
+ menu.addAction(dontKeepAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ def _createAction(self, keepAspectRatio):
+ icon = self.STATE[keepAspectRatio, "icon"]
+ text = self.STATE[keepAspectRatio, "action"]
+ return qt.QAction(icon, text, self)
+
+ def _connectPlot(self, plot):
+ plot.sigSetKeepDataAspectRatio.connect(self._keepDataAspectRatioChanged)
+ self._keepDataAspectRatioChanged(plot.isKeepDataAspectRatio())
+
+ def _disconnectPlot(self, plot):
+ plot.sigSetKeepDataAspectRatio.disconnect(self._keepDataAspectRatioChanged)
+
+ def keepDataAspectRatio(self):
+ """Configure the plot to keep the aspect ratio"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _keepDataAspectRatioChanged
+ plot.setKeepDataAspectRatio(True)
+
+ def dontKeepDataAspectRatio(self):
+ """Configure the plot to not keep the aspect ratio"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _keepDataAspectRatioChanged
+ plot.setKeepDataAspectRatio(False)
+
+ def _keepDataAspectRatioChanged(self, aspectRatio):
+ """Handle Plot set keep aspect ratio signal"""
+ icon, toolTip = (
+ self.STATE[aspectRatio, "icon"],
+ self.STATE[aspectRatio, "state"],
+ )
+ self.setIcon(icon)
+ self.setToolTip(toolTip)
+
+
+class YAxisOriginToolButton(PlotToolButton):
+ """Tool button to switch the Y axis orientation of a plot."""
+
+ STATE = None
+ """Lazy loaded states used to feed YAxisOriginToolButton"""
+
+ def __init__(self, parent=None, plot=None):
+ if self.STATE is None:
+ self.STATE = {}
+ # is down
+ self.STATE[False, "icon"] = icons.getQIcon("plot-ydown")
+ self.STATE[False, "state"] = "Y-axis is oriented downward"
+ self.STATE[False, "action"] = "Orient Y-axis downward"
+ # keep ration
+ self.STATE[True, "icon"] = icons.getQIcon("plot-yup")
+ self.STATE[True, "state"] = "Y-axis is oriented upward"
+ self.STATE[True, "action"] = "Orient Y-axis upward"
+
+ super(YAxisOriginToolButton, self).__init__(parent=parent, plot=plot)
+
+ upwardAction = self._createAction(True)
+ upwardAction.triggered.connect(self.setYAxisUpward)
+ upwardAction.setIconVisibleInMenu(True)
+
+ downwardAction = self._createAction(False)
+ downwardAction.triggered.connect(self.setYAxisDownward)
+ downwardAction.setIconVisibleInMenu(True)
+
+ menu = qt.QMenu(self)
+ menu.addAction(upwardAction)
+ menu.addAction(downwardAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ def _createAction(self, isUpward):
+ icon = self.STATE[isUpward, "icon"]
+ text = self.STATE[isUpward, "action"]
+ return qt.QAction(icon, text, self)
+
+ def _connectPlot(self, plot):
+ yAxis = plot.getYAxis()
+ yAxis.sigInvertedChanged.connect(self._yAxisInvertedChanged)
+ self._yAxisInvertedChanged(yAxis.isInverted())
+
+ def _disconnectPlot(self, plot):
+ plot.getYAxis().sigInvertedChanged.disconnect(self._yAxisInvertedChanged)
+
+ def setYAxisUpward(self):
+ """Configure the plot to use y-axis upward"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _yAxisInvertedChanged
+ plot.getYAxis().setInverted(False)
+
+ def setYAxisDownward(self):
+ """Configure the plot to use y-axis downward"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _yAxisInvertedChanged
+ plot.getYAxis().setInverted(True)
+
+ def _yAxisInvertedChanged(self, inverted):
+ """Handle Plot set y axis inverted signal"""
+ isUpward = not inverted
+ icon, toolTip = self.STATE[isUpward, "icon"], self.STATE[isUpward, "state"]
+ self.setIcon(icon)
+ self.setToolTip(toolTip)
+
+
+class ProfileOptionToolButton(PlotToolButton):
+ """Button to define option on the profile"""
+
+ sigMethodChanged = qt.Signal(str)
+
+ def __init__(self, parent=None, plot=None):
+ PlotToolButton.__init__(self, parent=parent, plot=plot)
+
+ self.STATE = {}
+ # is down
+ self.STATE["sum", "icon"] = icons.getQIcon("math-sigma")
+ self.STATE["sum", "state"] = "Compute profile sum"
+ self.STATE["sum", "action"] = "Compute profile sum"
+ # keep ration
+ self.STATE["mean", "icon"] = icons.getQIcon("math-mean")
+ self.STATE["mean", "state"] = "Compute profile mean"
+ self.STATE["mean", "action"] = "Compute profile mean"
+
+ self.sumAction = self._createAction("sum")
+ self.sumAction.triggered.connect(self.setSum)
+ self.sumAction.setIconVisibleInMenu(True)
+ self.sumAction.setCheckable(True)
+ self.sumAction.setChecked(True)
+
+ self.meanAction = self._createAction("mean")
+ self.meanAction.triggered.connect(self.setMean)
+ self.meanAction.setIconVisibleInMenu(True)
+ self.meanAction.setCheckable(True)
+
+ menu = qt.QMenu(self)
+ menu.addAction(self.sumAction)
+ menu.addAction(self.meanAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+ self._method = "mean"
+ self._update()
+
+ def _createAction(self, method):
+ icon = self.STATE[method, "icon"]
+ text = self.STATE[method, "action"]
+ return qt.QAction(icon, text, self)
+
+ def setSum(self):
+ self.setMethod("sum")
+
+ def _update(self):
+ icon = self.STATE[self._method, "icon"]
+ toolTip = self.STATE[self._method, "state"]
+ self.setIcon(icon)
+ self.setToolTip(toolTip)
+ self.sumAction.setChecked(self._method == "sum")
+ self.meanAction.setChecked(self._method == "mean")
+
+ def setMean(self):
+ self.setMethod("mean")
+
+ def setMethod(self, method):
+ """Set the method to use.
+
+ :param str method: Either 'sum' or 'mean'
+ """
+ if method != self._method:
+ if method in ("sum", "mean"):
+ self._method = method
+ self.sigMethodChanged.emit(self._method)
+ self._update()
+ else:
+ _logger.warning("Unsupported method '%s'. Setting ignored.", method)
+
+ def getMethod(self):
+ """Returns the current method in use (See :meth:`setMethod`).
+
+ :rtype: str
+ """
+ return self._method
+
+
+class ProfileToolButton(PlotToolButton):
+ """Button used in Profile3DToolbar to switch between 2D profile
+ and 1D profile."""
+
+ STATE = None
+ """Lazy loaded states used to feed ProfileToolButton"""
+
+ sigDimensionChanged = qt.Signal(int)
+
+ def __init__(self, parent=None, plot=None):
+ if self.STATE is None:
+ self.STATE = {
+ (1, "icon"): icons.getQIcon("profile1D"),
+ (1, "state"): "1D profile is computed on visible image",
+ (1, "action"): "1D profile on visible image",
+ (2, "icon"): icons.getQIcon("profile2D"),
+ (
+ 2,
+ "state",
+ ): "2D profile is computed, one 1D profile for each image in the stack",
+ (2, "action"): "2D profile on image stack",
+ }
+ # Compute 1D profile
+ # Compute 2D profile
+
+ super(ProfileToolButton, self).__init__(parent=parent, plot=plot)
+
+ self._dimension = 1
+
+ profile1DAction = self._createAction(1)
+ profile1DAction.triggered.connect(self.computeProfileIn1D)
+ profile1DAction.setIconVisibleInMenu(True)
+ profile1DAction.setCheckable(True)
+ profile1DAction.setChecked(True)
+ self._profile1DAction = profile1DAction
+
+ profile2DAction = self._createAction(2)
+ profile2DAction.triggered.connect(self.computeProfileIn2D)
+ profile2DAction.setIconVisibleInMenu(True)
+ profile2DAction.setCheckable(True)
+ self._profile2DAction = profile2DAction
+
+ menu = qt.QMenu(self)
+ menu.addAction(profile1DAction)
+ menu.addAction(profile2DAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+ menu.setTitle("Select profile dimension")
+ self.computeProfileIn1D()
+
+ def _createAction(self, profileDimension):
+ icon = self.STATE[profileDimension, "icon"]
+ text = self.STATE[profileDimension, "action"]
+ return qt.QAction(icon, text, self)
+
+ def _profileDimensionChanged(self, profileDimension):
+ """Update icon in toolbar, emit number of dimensions for profile"""
+ self.setIcon(self.STATE[profileDimension, "icon"])
+ self.setToolTip(self.STATE[profileDimension, "state"])
+ self._dimension = profileDimension
+ self.sigDimensionChanged.emit(profileDimension)
+ self._profile1DAction.setChecked(profileDimension == 1)
+ self._profile2DAction.setChecked(profileDimension == 2)
+
+ def computeProfileIn1D(self):
+ self._profileDimensionChanged(1)
+
+ def computeProfileIn2D(self):
+ self._profileDimensionChanged(2)
+
+ def setDimension(self, dimension):
+ """Set the selected dimension"""
+ assert dimension in [1, 2]
+ if self._dimension == dimension:
+ return
+ if dimension == 1:
+ self.computeProfileIn1D()
+ elif dimension == 2:
+ self.computeProfileIn2D()
+ else:
+ _logger.warning("Unsupported dimension '%s'. Setting ignored.", dimension)
+
+ def getDimension(self):
+ """Get the selected dimension.
+
+ :rtype: int (1 or 2)
+ """
+ return self._dimension
+
+
+class _SymbolToolButtonBase(PlotToolButton):
+ """Base class for PlotToolButton setting marker and size.
+
+ :param parent: See QWidget
+ :param plot: The `~silx.gui.plot.PlotWidget` to control
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(_SymbolToolButtonBase, self).__init__(parent=parent, plot=plot)
+
+ def _addSizeSliderToMenu(self, menu):
+ """Add a slider to set size to the given menu
+
+ :param QMenu menu:
+ """
+ slider = qt.QSlider(qt.Qt.Horizontal)
+ slider.setRange(1, 20)
+ slider.setValue(int(config.DEFAULT_PLOT_SYMBOL_SIZE))
+ slider.setTracking(False)
+ slider.valueChanged.connect(self._sizeChanged)
+ widgetAction = qt.QWidgetAction(menu)
+ widgetAction.setDefaultWidget(slider)
+ menu.addAction(widgetAction)
+
+ def _addSymbolsToMenu(self, menu):
+ """Add symbols to the given menu
+
+ :param QMenu menu:
+ """
+ for marker, name in zip(
+ SymbolMixIn.getSupportedSymbols(), SymbolMixIn.getSupportedSymbolNames()
+ ):
+ action = qt.QAction(name, menu)
+ action.setCheckable(False)
+ action.triggered.connect(functools.partial(self._markerChanged, marker))
+ menu.addAction(action)
+
+ def _sizeChanged(self, value):
+ """Manage slider value changed
+
+ :param int value: Marker size
+ """
+ plot = self.plot()
+ if plot is None:
+ return
+
+ for item in plot.getItems():
+ if isinstance(item, SymbolMixIn):
+ item.setSymbolSize(value)
+
+ def _markerChanged(self, marker):
+ """Manage change of marker.
+
+ :param str marker: Letter describing the marker
+ """
+ plot = self.plot()
+ if plot is None:
+ return
+
+ for item in plot.getItems():
+ if isinstance(item, SymbolMixIn):
+ item.setSymbol(marker)
+
+
+class SymbolToolButton(_SymbolToolButtonBase):
+ """A tool button with a drop-down menu to control symbol size and marker.
+
+ :param parent: See QWidget
+ :param plot: The `~silx.gui.plot.PlotWidget` to control
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(SymbolToolButton, self).__init__(parent=parent, plot=plot)
+
+ self.setToolTip("Set symbol size and marker")
+ self.setIcon(icons.getQIcon("plot-symbols"))
+
+ menu = qt.QMenu(self)
+ self._addSizeSliderToMenu(menu)
+ menu.addSeparator()
+ self._addSymbolsToMenu(menu)
+
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+
+class ScatterVisualizationToolButton(_SymbolToolButtonBase):
+ """QToolButton to select the visualization mode of scatter plot
+
+ :param parent: See QWidget
+ :param plot: The `~silx.gui.plot.PlotWidget` to control
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(ScatterVisualizationToolButton, self).__init__(parent=parent, plot=plot)
+
+ self.setToolTip("Set scatter visualization mode, symbol marker and size")
+ self.setIcon(icons.getQIcon("eye"))
+
+ menu = qt.QMenu(self)
+
+ # Add visualization modes
+
+ for mode in Scatter.supportedVisualizations():
+ if mode is not Scatter.Visualization.BINNED_STATISTIC:
+ name = mode.value.capitalize()
+ action = qt.QAction(name, menu)
+ action.setCheckable(False)
+ action.triggered.connect(
+ functools.partial(self._visualizationChanged, mode, None)
+ )
+ menu.addAction(action)
+
+ if Scatter.Visualization.BINNED_STATISTIC in Scatter.supportedVisualizations():
+ reductions = Scatter.supportedVisualizationParameterValues(
+ Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ )
+ if reductions:
+ submenu = menu.addMenu("Binned Statistic")
+ for reduction in reductions:
+ name = reduction.capitalize()
+ action = qt.QAction(name, menu)
+ action.setCheckable(False)
+ action.triggered.connect(
+ functools.partial(
+ self._visualizationChanged,
+ Scatter.Visualization.BINNED_STATISTIC,
+ {
+ Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION: reduction
+ },
+ )
+ )
+ submenu.addAction(action)
+
+ submenu.addSeparator()
+ binsmenu = submenu.addMenu("N Bins")
+
+ slider = qt.QSlider(qt.Qt.Horizontal)
+ slider.setRange(10, 1000)
+ slider.setValue(100)
+ slider.setTracking(False)
+ slider.valueChanged.connect(self._binningChanged)
+ widgetAction = qt.QWidgetAction(binsmenu)
+ widgetAction.setDefaultWidget(slider)
+ binsmenu.addAction(widgetAction)
+
+ menu.addSeparator()
+
+ submenu = menu.addMenu(icons.getQIcon("plot-symbols"), "Symbol")
+ self._addSymbolsToMenu(submenu)
+
+ submenu = menu.addMenu(icons.getQIcon("plot-symbols"), "Symbol Size")
+ self._addSizeSliderToMenu(submenu)
+
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ def _visualizationChanged(self, mode, parameters=None):
+ """Handle change of visualization mode.
+
+ :param ScatterVisualizationMixIn.Visualization mode:
+ The visualization mode to use for scatter
+ :param Union[dict,None] parameters:
+ Dict of VisualizationParameter: parameter_value to set
+ with the visualization.
+ """
+ plot = self.plot()
+ if plot is None:
+ return
+
+ for item in plot.getItems():
+ if isinstance(item, Scatter):
+ if parameters:
+ for parameter, value in parameters.items():
+ item.setVisualizationParameter(parameter, value)
+ item.setVisualization(mode)
+
+ def _binningChanged(self, value):
+ """Handle change of binning.
+
+ :param int value: The number of bin on each dimension.
+ """
+ plot = self.plot()
+ if plot is None:
+ return
+
+ for item in plot.getItems():
+ if isinstance(item, Scatter):
+ item.setVisualizationParameter(
+ Scatter.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ (value, value),
+ )
+ item.setVisualization(Scatter.Visualization.BINNED_STATISTIC)
diff --git a/src/silx/gui/plot/PlotWidget.py b/src/silx/gui/plot/PlotWidget.py
new file mode 100755
index 0000000..a01ca48
--- /dev/null
+++ b/src/silx/gui/plot/PlotWidget.py
@@ -0,0 +1,3789 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Qt widget providing plot API for 1D and 2D data.
+
+The :class:`PlotWidget` implements the plot API initially provided in PyMca.
+"""
+
+from __future__ import annotations
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+from collections import namedtuple
+from collections.abc import Sequence
+from contextlib import contextmanager
+from typing import Optional, Union
+import datetime as dt
+import itertools
+import numbers
+import warnings
+
+import numpy
+
+import silx
+from silx.utils.weakref import WeakMethodProxy
+
+try:
+ # Import matplotlib now to init matplotlib our way
+ import silx.gui.utils.matplotlib # noqa
+except ImportError:
+ _logger.debug("matplotlib not available")
+
+from ..colors import Colormap
+from .. import colors
+from . import PlotInteraction
+from . import PlotEvents
+from .LimitsHistory import LimitsHistory
+from . import _utils
+
+from . import items
+from .items.curve import CurveStyle
+from .items.axis import TickMode # noqa
+
+from .. import qt
+from ._utils.panzoom import ViewConstraints
+from ...gui.plot._utils.dtime_ticklayout import timestamp
+from ...utils.deprecation import deprecated_warning
+
+
+"""
+Object returned when requesting the data range.
+"""
+_PlotDataRange = namedtuple("PlotDataRange", ["x", "y", "yright"])
+
+
+class _PlotWidgetSelection(qt.QObject):
+ """Object managing a :class:`PlotWidget` selection.
+
+ It is a wrapper over :class:`PlotWidget`'s active items API.
+
+ :param PlotWidget parent:
+ """
+
+ sigCurrentItemChanged = qt.Signal(object, object)
+ """This signal is emitted whenever the current item changes.
+
+ It provides the current and previous items.
+ """
+
+ sigSelectedItemsChanged = qt.Signal()
+ """Signal emitted whenever the list of selected items changes."""
+
+ def __init__(self, parent):
+ assert isinstance(parent, PlotWidget)
+ super(_PlotWidgetSelection, self).__init__(parent=parent)
+
+ # Init history
+ self.__history = [ # Store active items from most recent to oldest
+ item
+ for item in (
+ parent.getActiveCurve(),
+ parent.getActiveImage(),
+ parent.getActiveScatter(),
+ )
+ if item is not None
+ ]
+
+ self.__current = self.__mostRecentActiveItem()
+
+ parent.sigActiveImageChanged.connect(self._activeImageChanged)
+ parent.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ parent.sigActiveScatterChanged.connect(self._activeScatterChanged)
+
+ def __mostRecentActiveItem(self) -> Optional[items.Item]:
+ """Returns most recent active item."""
+ return self.__history[0] if len(self.__history) >= 1 else None
+
+ def getSelectedItems(self) -> tuple[items.Item]:
+ """Returns the list of currently selected items in the :class:`PlotWidget`.
+
+ The list is given from most recently current item to oldest one."""
+ plot = self.parent()
+ if plot is None:
+ return ()
+
+ active = tuple(self.__history)
+
+ current = self.getCurrentItem()
+ if current is not None and current not in active:
+ # Current might not be an active item, if so add it
+ active = (current,) + active
+
+ return active
+
+ def getCurrentItem(self) -> Optional[items.Item]:
+ """Returns the current item in the :class:`PlotWidget` or None."""
+ return self.__current
+
+ def setCurrentItem(self, item: Optional[items.Item]):
+ """Set the current item in the :class:`PlotWidget`.
+
+ :param item:
+ The new item to select or None to clear the selection.
+ :raise ValueError: If the item is not the :class:`PlotWidget`
+ """
+ previous = self.getCurrentItem()
+ if previous is item:
+ return
+
+ previousSelected = self.getSelectedItems()
+
+ if item is None:
+ self.__current = None
+
+ # Reset all PlotWidget active items
+ plot = self.parent()
+ if plot is not None:
+ for kind in PlotWidget._ACTIVE_ITEM_KINDS:
+ if plot._getActiveItem(kind) is not None:
+ plot._setActiveItem(kind, None)
+
+ elif isinstance(item, items.Item):
+ plot = self.parent()
+ if plot is None or item.getPlot() is not plot:
+ raise ValueError("Item is not in the PlotWidget: %s" % str(item))
+ self.__current = item
+
+ kind = plot._itemKind(item)
+
+ # Clean-up history to be safe
+ self.__history = [
+ item for item in self.__history if PlotWidget._itemKind(item) != kind
+ ]
+
+ # Sync active item if needed
+ if kind in plot._ACTIVE_ITEM_KINDS and item is not plot._getActiveItem(
+ kind
+ ):
+ plot._setActiveItem(kind, item)
+ else:
+ raise ValueError("Not an Item: %s" % str(item))
+
+ self.sigCurrentItemChanged.emit(previous, item)
+
+ if previousSelected != self.getSelectedItems():
+ self.sigSelectedItemsChanged.emit()
+
+ def __activeItemChanged(
+ self, kind: str, previous: Optional[str], legend: Optional[str]
+ ):
+ """Set current item from kind and legend"""
+ if previous == legend:
+ return # No-op for update of item
+
+ plot = self.parent()
+ if plot is None:
+ return
+
+ previousSelected = self.getSelectedItems()
+
+ # Remove items of this kind from the history
+ self.__history = [
+ item for item in self.__history if PlotWidget._itemKind(item) != kind
+ ]
+
+ # Retrieve current item
+ if legend is None: # Use most recent active item
+ currentItem = self.__mostRecentActiveItem()
+ else:
+ currentItem = plot._getItem(kind=kind, legend=legend)
+ if currentItem is None: # Fallback in case something went wrong
+ currentItem = self.__mostRecentActiveItem()
+
+ # Update history
+ if currentItem is not None:
+ while currentItem in self.__history:
+ self.__history.remove(currentItem)
+ self.__history.insert(0, currentItem)
+
+ if currentItem != self.__current:
+ previousItem = self.__current
+ self.__current = currentItem
+ self.sigCurrentItemChanged.emit(previousItem, currentItem)
+
+ if previousSelected != self.getSelectedItems():
+ self.sigSelectedItemsChanged.emit()
+
+ def _activeImageChanged(self, previous, current):
+ """Handle active image change"""
+ self.__activeItemChanged("image", previous, current)
+
+ def _activeCurveChanged(self, previous, current):
+ """Handle active curve change"""
+ self.__activeItemChanged("curve", previous, current)
+
+ def _activeScatterChanged(self, previous, current):
+ """Handle active scatter change"""
+ self.__activeItemChanged("scatter", previous, current)
+
+
+class PlotWidget(qt.QMainWindow):
+ """Qt Widget providing a 1D/2D plot.
+
+ This widget is a QMainWindow.
+ This class implements the plot API initially provided in PyMca.
+
+ Supported backends:
+
+ - 'matplotlib' and 'mpl': Matplotlib with Qt.
+ - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
+ - 'none': No backend, to run headless for testing purpose.
+
+ :param parent: The parent of this widget or None (default).
+ :param backend: The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ # The following 2 class attributes are no longer used
+ # but there is no way to warn about deprecation
+ colorList = silx.config.DEFAULT_PLOT_CURVE_COLORS
+ colorDict = colors.COLORDICT
+
+ sigPlotSignal = qt.Signal(object)
+ """Signal for all events of the plot.
+
+ The signal information is provided as a dict.
+ See the :ref:`plot signal documentation page <plot_signal>` for
+ information about the content of the dict
+ """
+
+ sigSetKeepDataAspectRatio = qt.Signal(bool)
+ """Signal emitted when plot keep aspect ratio has changed"""
+
+ sigSetGraphGrid = qt.Signal(str)
+ """Signal emitted when plot grid has changed"""
+
+ sigSetGraphCursor = qt.Signal(bool)
+ """Signal emitted when plot crosshair cursor has changed"""
+
+ sigSetPanWithArrowKeys = qt.Signal(bool)
+ """Signal emitted when pan with arrow keys has changed"""
+
+ _sigAxesVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the axes visibility changed"""
+
+ sigContentChanged = qt.Signal(str, str, str)
+ """Signal emitted when the content of the plot is changed.
+
+ It provides the following information:
+
+ - action: The change of the plot: 'add' or 'remove'
+ - kind: The kind of primitive changed:
+ 'curve', 'image', 'scatter', 'histogram', 'item' or 'marker'
+ - legend: The legend of the primitive changed.
+ """
+
+ sigActiveCurveChanged = qt.Signal(object, object)
+ """Signal emitted when the active curve has changed.
+
+ It provides the following information:
+
+ - previous: The legend of the previous active curve or None
+ - legend: The legend of the new active curve or None if no curve is active
+ """
+
+ sigActiveImageChanged = qt.Signal(object, object)
+ """Signal emitted when the active image has changed.
+
+ It provides the following information:
+
+ - previous: The legend of the previous active image or None
+ - legend: The legend of the new active image or None if no image is active
+ """
+
+ sigActiveScatterChanged = qt.Signal(object, object)
+ """Signal emitted when the active Scatter has changed.
+
+ It provides the following information:
+
+ - previous: The legend of the previous active scatter or None
+ - legend: The legend of the new active image or None if no image is active
+ """
+
+ sigInteractiveModeChanged = qt.Signal(object)
+ """Signal emitted when the interactive mode has changed
+
+ It provides the source as passed to :meth:`setInteractiveMode`.
+ """
+
+ sigItemAdded = qt.Signal(items.Item)
+ """Signal emitted when an item was just added to the plot
+
+ It provides the added item.
+ """
+
+ sigItemAboutToBeRemoved = qt.Signal(items.Item)
+ """Signal emitted right before an item is removed from the plot.
+
+ It provides the item that will be removed.
+ """
+
+ sigItemRemoved = qt.Signal(items.Item)
+ """Signal emitted right after an item was removed from the plot.
+
+ It provides the item that was removed.
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the widget becomes visible (or invisible).
+ This happens when the widget is hidden or shown.
+
+ It provides the visible state.
+ """
+
+ _sigDefaultContextMenu = qt.Signal(qt.QMenu)
+ """Signal emitted when the default context menu of the plot is feed.
+
+ It provides the menu which will be displayed.
+ """
+
+ sigBackendChanged = qt.Signal()
+ """Signal emitted when the backend have changed."""
+
+ def __init__(self, parent=None, backend=None):
+ self._autoreplot = False
+ self._dirty = False
+ self._cursorInPlot = False
+ self.__muteActiveItemChanged = False
+
+ self._panWithArrowKeys = True
+ self._viewConstrains = None
+
+ super(PlotWidget, self).__init__(parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle("PlotWidget")
+
+ # Init the backend
+ self._backend = self.__getBackendClass(backend)(self, self)
+
+ self.setCallback() # set _callback
+
+ # Items handling
+ self.__items = []
+ self.__itemsToUpdate = [] # Used as an OrderedSet
+ self.__activeItems = {"curve": None, "image": None, "scatter": None}
+
+ self._dataRange = None
+
+ # line types
+ self._defaultColors = None
+ self._styleList = ["-", "--", "-.", ":"]
+ self._colorIndex = 0
+ self._styleIndex = 0
+
+ self._activeCurveSelectionMode = "atmostone"
+ self._activeCurveStyle = CurveStyle(
+ color=silx.config.DEFAULT_PLOT_ACTIVE_CURVE_COLOR,
+ linewidth=silx.config.DEFAULT_PLOT_ACTIVE_CURVE_LINEWIDTH,
+ )
+
+ # plot colors (updated later to sync backend)
+ self._foregroundColor = 0.0, 0.0, 0.0, 1.0
+ self._gridColor = 0.7, 0.7, 0.7, 1.0
+ self._backgroundColor = 1.0, 1.0, 1.0, 1.0
+ self._dataBackgroundColor = None
+
+ # default properties
+ self._cursorConfiguration = None
+
+ self._xAxis = items.XAxis(self)
+ self._yAxis = items.YAxis(self)
+ self._yRightAxis = items.YRightAxis(self, self._yAxis)
+
+ self._grid = None
+ self._graphTitle = ""
+ self.__graphCursorShape = "default"
+
+ # Set axes margins
+ self.__axesDisplayed = True
+ self.__axesMargins = 0.0, 0.0, 0.0, 0.0
+ self.setAxesMargins(0.15, 0.1, 0.1, 0.15)
+
+ self.setGraphTitle()
+ self.setGraphXLabel()
+ self.setGraphYLabel()
+ self.setGraphYLabel("", axis="right")
+
+ self.setDefaultColormap() # Init default colormap
+
+ self.setDefaultPlotPoints(silx.config.DEFAULT_PLOT_CURVE_SYMBOL_MODE)
+ self.setDefaultPlotLines(True)
+
+ self._limitsHistory = LimitsHistory(self)
+
+ self._eventHandler = PlotInteraction.PlotInteraction(self)
+ self._eventHandler._setInteractiveMode("zoom", color=(0.0, 0.0, 0.0, 1.0))
+ self._eventHandler.sigChanged.connect(self.__interactionChanged)
+ self.__isInteractionSignalForwarded = True
+ self._previousDefaultMode = "zoom", True
+
+ self._pressedButtons = [] # Currently pressed mouse buttons
+
+ self._defaultDataMargins = (0.0, 0.0, 0.0, 0.0)
+
+ # Only activate autoreplot at the end
+ # This avoids errors when loaded in Qt designer
+ self._dirty = False
+ self._autoreplot = True
+
+ widget = self.getWidgetHandle()
+ if widget is not None:
+ self.setCentralWidget(widget)
+ else:
+ _logger.info("PlotWidget backend does not support widget")
+
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self.setFocus(qt.Qt.OtherFocusReason)
+
+ # Set default limits
+ self.setGraphXLimits(0.0, 100.0)
+ self.setGraphYLimits(0.0, 100.0, axis="right")
+ self.setGraphYLimits(0.0, 100.0, axis="left")
+
+ # Sync backend colors with default ones
+ self._foregroundColorsUpdated()
+ self._backgroundColorsUpdated()
+
+ # selection handling
+ self.__selection = None
+
+ def __getBackendClass(self, backend):
+ """Returns backend class corresponding to backend.
+
+ If multiple backends are provided, the first available one is used.
+
+ :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend:
+ The name of the backend or its class or an iterable of those.
+ :rtype: BackendBase
+ :raise ValueError: In case the backend is not supported
+ :raise RuntimeError: If a backend is not available
+ """
+ if backend is None:
+ backend = silx.config.DEFAULT_PLOT_BACKEND
+
+ if callable(backend):
+ return backend
+
+ elif isinstance(backend, str):
+ backend = backend.lower()
+ if backend in ("matplotlib", "mpl"):
+ try:
+ from .backends.BackendMatplotlib import (
+ BackendMatplotlibQt as backendClass,
+ )
+ except ImportError:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("matplotlib backend is not available")
+
+ elif backend in ("gl", "opengl"):
+ from ..utils.glutils import isOpenGLAvailable
+
+ checkOpenGL = isOpenGLAvailable(version=(2, 1), runtimeCheck=False)
+ if not checkOpenGL:
+ _logger.debug("OpenGL check failed")
+ raise RuntimeError(
+ "OpenGL backend is not available: %s" % checkOpenGL.error
+ )
+
+ try:
+ from .backends.BackendOpenGL import BackendOpenGL as backendClass
+ except ImportError:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("OpenGL backend is not available")
+
+ elif backend == "none":
+ from .backends.BackendBase import BackendBase as backendClass
+
+ else:
+ raise ValueError("Backend not supported %s" % backend)
+
+ return backendClass
+
+ elif isinstance(backend, (tuple, list)):
+ for b in backend:
+ try:
+ return self.__getBackendClass(b)
+ except RuntimeError:
+ pass
+ else: # No backend was found
+ raise RuntimeError("None of the request backends are available")
+
+ raise ValueError("Backend not supported %s" % str(backend))
+
+ def selection(self):
+ """Returns the selection hander"""
+ if self.__selection is None: # Lazy initialization
+ self.__selection = _PlotWidgetSelection(parent=self)
+ return self.__selection
+
+ def setBackend(self, backend):
+ """Set the backend to use for rendering.
+
+ Supported backends:
+
+ - 'matplotlib' and 'mpl': Matplotlib with Qt.
+ - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
+ - 'none': No backend, to run headless for testing purpose.
+
+ :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend:
+ The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none',
+ a :class:`BackendBase.BackendBase` class.
+ If multiple backends are provided, the first available one is used.
+ :raises ValueError: Unsupported backend descriptor
+ :raises RuntimeError: Error while loading a backend
+ """
+ backend = self.__getBackendClass(backend)(self, self)
+
+ # First save state that is stored in the backend
+ xaxis = self.getXAxis()
+ xmin, xmax = xaxis.getLimits()
+ ymin, ymax = self.getYAxis(axis="left").getLimits()
+ y2min, y2max = self.getYAxis(axis="right").getLimits()
+ isKeepDataAspectRatio = self.isKeepDataAspectRatio()
+ xTimeZone = xaxis.getTimeZone()
+ isXAxisTimeSeries = xaxis.getTickMode() == TickMode.TIME_SERIES
+
+ isYAxisInverted = self.getYAxis().isInverted()
+
+ # Remove all items from previous backend
+ for item in self.getItems():
+ item._removeBackendRenderer(self._backend)
+
+ # Switch backend
+ self._backend = backend
+ widget = self._backend.getWidgetHandle()
+ self.setCentralWidget(widget)
+ if widget is None:
+ _logger.info("PlotWidget backend does not support widget")
+
+ # Mark as newly dirty
+ self._dirty = False
+ self._setDirtyPlot()
+
+ # Synchronize/restore state
+ self._foregroundColorsUpdated()
+ self._backgroundColorsUpdated()
+
+ self._backend.setGraphCursorShape(self.getGraphCursorShape())
+ crosshairConfig = self.getGraphCursor()
+ if crosshairConfig is None:
+ self._backend.setGraphCursor(False, "black", 1, "-")
+ else:
+ self._backend.setGraphCursor(True, *crosshairConfig)
+
+ self._backend.setGraphTitle(self.getGraphTitle())
+ self._backend.setGraphGrid(self.getGraphGrid())
+ if self.isAxesDisplayed():
+ self._backend.setAxesMargins(*self.getAxesMargins())
+ else:
+ self._backend.setAxesMargins(0.0, 0.0, 0.0, 0.0)
+
+ # Set axes
+ xaxis = self.getXAxis()
+ self._backend.setGraphXLabel(xaxis.getLabel())
+ self._backend.setXAxisTimeZone(xTimeZone)
+ self._backend.setXAxisTimeSeries(isXAxisTimeSeries)
+ self._backend.setXAxisLogarithmic(xaxis.getScale() == items.Axis.LOGARITHMIC)
+
+ for axis in ("left", "right"):
+ self._backend.setGraphYLabel(self.getYAxis(axis).getLabel(), axis)
+ self._backend.setYAxisInverted(isYAxisInverted)
+ self._backend.setYAxisLogarithmic(
+ self.getYAxis().getScale() == items.Axis.LOGARITHMIC
+ )
+
+ # Finally restore aspect ratio and limits
+ self._backend.setKeepDataAspectRatio(isKeepDataAspectRatio)
+ self.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
+
+ # Mark all items for update with new backend
+ for item in self.getItems():
+ item._updated()
+
+ self.sigBackendChanged.emit()
+
+ def getBackend(self):
+ """Returns the backend currently used by :class:`PlotWidget`.
+
+ :rtype: ~silx.gui.plot.backend.BackendBase.BackendBase
+ """
+ return self._backend
+
+ def _getDirtyPlot(self):
+ """Return the plot dirty flag.
+
+ If False, the plot has not changed since last replot.
+ If True, the full plot need to be redrawn.
+ If 'overlay', only the overlay has changed since last replot.
+
+ It can be accessed by backend to check the dirty state.
+
+ :return: False, True, 'overlay'
+ """
+ return self._dirty
+
+ # Default Qt context menu
+
+ def contextMenuEvent(self, event):
+ """Override QWidget.contextMenuEvent to implement the context menu"""
+ menu = qt.QMenu(self)
+ from .actions.control import ZoomBackAction # Avoid cyclic import
+
+ zoomBackAction = ZoomBackAction(plot=self, parent=menu)
+ menu.addAction(zoomBackAction)
+
+ mode = self.getInteractiveMode()
+ if "shape" in mode and mode["shape"] == "polygon":
+ from .actions.control import (
+ ClosePolygonInteractionAction,
+ ) # Avoid cyclic import
+
+ action = ClosePolygonInteractionAction(plot=self, parent=menu)
+ menu.addAction(action)
+
+ self._sigDefaultContextMenu.emit(menu)
+
+ # Make sure the plot is updated, especially when the plot is in
+ # draw interaction mode
+ menu.aboutToHide.connect(self.__simulateMouseMove)
+
+ menu.exec(event.globalPos())
+
+ def _setDirtyPlot(self, overlayOnly=False):
+ """Mark the plot as needing redraw
+
+ :param bool overlayOnly: True to redraw only the overlay,
+ False to redraw everything
+ """
+ wasDirty = self._dirty
+
+ if not self._dirty and overlayOnly:
+ self._dirty = "overlay"
+ else:
+ self._dirty = True
+
+ if self._autoreplot and not wasDirty and self.isVisible():
+ self._backend.postRedisplay()
+
+ def _foregroundColorsUpdated(self):
+ """Handle change of foreground/grid color"""
+ if self._gridColor is None:
+ gridColor = self._foregroundColor
+ else:
+ gridColor = self._gridColor
+ self._backend.setForegroundColors(self._foregroundColor, gridColor)
+ self._setDirtyPlot()
+
+ def getForegroundColor(self):
+ """Returns the RGBA colors used to display the foreground of this widget
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._foregroundColor)
+
+ def setForegroundColor(self, color):
+ """Set the foreground color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._foregroundColorsUpdated()
+
+ def getGridColor(self):
+ """Returns the RGBA colors used to display the grid lines
+
+ An invalid QColor is returned if there is no grid color,
+ in which case the foreground color is used.
+
+ :rtype: qt.QColor
+ """
+ if self._gridColor is None:
+ return qt.QColor() # An invalid color
+ else:
+ return qt.QColor.fromRgbF(*self._gridColor)
+
+ def setGridColor(self, color):
+ """Set the grid lines color
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._gridColor != color:
+ self._gridColor = color
+ self._foregroundColorsUpdated()
+
+ def _backgroundColorsUpdated(self):
+ """Handle change of background/data background color"""
+ if self._dataBackgroundColor is None:
+ dataBGColor = self._backgroundColor
+ else:
+ dataBGColor = self._dataBackgroundColor
+ self._backend.setBackgroundColors(self._backgroundColor, dataBGColor)
+ self._setDirtyPlot()
+
+ def getBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of this widget.
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._backgroundColor)
+
+ def setBackgroundColor(self, color):
+ """Set the background color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._backgroundColor != color:
+ self._backgroundColor = color
+ self._backgroundColorsUpdated()
+
+ def getDataBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of the plot
+ view displaying the data.
+
+ An invalid QColor is returned if there is no data background color.
+
+ :rtype: qt.QColor
+ """
+ if self._dataBackgroundColor is None:
+ # An invalid color
+ return qt.QColor()
+ else:
+ return qt.QColor.fromRgbF(*self._dataBackgroundColor)
+
+ def setDataBackgroundColor(self, color):
+ """Set the background color of the plot area.
+
+ Set to None or an invalid QColor to use the background color.
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._dataBackgroundColor != color:
+ self._dataBackgroundColor = color
+ self._backgroundColorsUpdated()
+
+ dataBackgroundColor = qt.Property(
+ qt.QColor, getDataBackgroundColor, setDataBackgroundColor
+ )
+
+ backgroundColor = qt.Property(qt.QColor, getBackgroundColor, setBackgroundColor)
+
+ foregroundColor = qt.Property(qt.QColor, getForegroundColor, setForegroundColor)
+
+ gridColor = qt.Property(qt.QColor, getGridColor, setGridColor)
+
+ def showEvent(self, event):
+ if self._autoreplot and self._dirty:
+ self._backend.postRedisplay()
+ super(PlotWidget, self).showEvent(event)
+ self.sigVisibilityChanged.emit(True)
+
+ def hideEvent(self, event):
+ super(PlotWidget, self).hideEvent(event)
+ if qt.BINDING == "PySide6":
+ # Workaround RuntimeError: The SignalInstance object was already deleted
+ try:
+ self.sigVisibilityChanged.emit(False)
+ except RuntimeError as e:
+ _logger.error(f"Exception occured: {e}")
+ else:
+ self.sigVisibilityChanged.emit(False)
+
+ def _invalidateDataRange(self):
+ """
+ Notifies this PlotWidget instance that the range has changed
+ and will have to be recomputed.
+ """
+ self._dataRange = None
+
+ def _updateDataRange(self):
+ """
+ Recomputes the range of the data displayed on this PlotWidget.
+ """
+ xMin = yMinLeft = yMinRight = float("nan")
+ xMax = yMaxLeft = yMaxRight = float("nan")
+
+ for item in self.getItems():
+ if item.isVisible():
+ bounds = item.getBounds()
+ if bounds is not None:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ # Ignore All-NaN slice encountered
+ xMin = numpy.nanmin([xMin, bounds[0]])
+ xMax = numpy.nanmax([xMax, bounds[1]])
+ # Take care of right axis
+ if (
+ isinstance(item, items.YAxisMixIn)
+ and item.getYAxis() == "right"
+ ):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ # Ignore All-NaN slice encountered
+ yMinRight = numpy.nanmin([yMinRight, bounds[2]])
+ yMaxRight = numpy.nanmax([yMaxRight, bounds[3]])
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ # Ignore All-NaN slice encountered
+ yMinLeft = numpy.nanmin([yMinLeft, bounds[2]])
+ yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]])
+
+ def lGetRange(x, y):
+ return None if numpy.isnan(x) and numpy.isnan(y) else (x, y)
+
+ xRange = lGetRange(xMin, xMax)
+ yLeftRange = lGetRange(yMinLeft, yMaxLeft)
+ yRightRange = lGetRange(yMinRight, yMaxRight)
+
+ self._dataRange = _PlotDataRange(x=xRange, y=yLeftRange, yright=yRightRange)
+
+ def getDataRange(self):
+ """
+ Returns this PlotWidget's data range.
+
+ :return: a namedtuple with the following members :
+ x, y (left y axis), yright. Each member is a tuple (min, max)
+ or None if no data is associated with the axis.
+ :rtype: namedtuple
+ """
+ if self._dataRange is None:
+ self._updateDataRange()
+ return self._dataRange
+
+ # Content management
+
+ _KIND_TO_CLASSES = {
+ "curve": (items.Curve,),
+ "image": (items.ImageBase,),
+ "scatter": (items.Scatter,),
+ "marker": (items.MarkerBase,),
+ "item": (
+ items.Line,
+ items.Shape,
+ items.BoundingRect,
+ items.XAxisExtent,
+ items.YAxisExtent,
+ ),
+ "histogram": (items.Histogram,),
+ }
+ """Mapping kind to item classes of this kind"""
+
+ @classmethod
+ def _itemKind(cls, item):
+ """Returns the "kind" of a given item
+
+ :param Item item: The item get the kind
+ :rtype: str
+ """
+ for kind, itemClasses in cls._KIND_TO_CLASSES.items():
+ if isinstance(item, itemClasses):
+ return kind
+ return "other"
+
+ def _notifyContentChanged(self, item):
+ self.notify(
+ "contentChanged",
+ action="add",
+ kind=self._itemKind(item),
+ legend=item.getName(),
+ )
+
+ def _itemRequiresUpdate(self, item):
+ """Called by items in the plot for asynchronous update
+
+ :param Item item: The item that required update
+ """
+ assert item.getPlot() == self
+ # Put item at the end of the list
+ if item in self.__itemsToUpdate:
+ self.__itemsToUpdate.remove(item)
+ self.__itemsToUpdate.append(item)
+ self._setDirtyPlot(overlayOnly=item.isOverlay())
+
+ def addItem(self, item):
+ """Add an item to the plot content.
+
+ :param ~silx.gui.plot.items.Item item: The item to add.
+ :raises ValueError: If item is already in the plot.
+ """
+ if not isinstance(item, items.Item):
+ raise ValueError(f"argument must be a subclass of Item")
+
+ if item in self.getItems():
+ raise ValueError("Item already in the plot")
+
+ # Add item to plot
+ self.__items.append(item)
+ item._setPlot(self)
+ self._itemRequiresUpdate(item)
+ if isinstance(item, items.DATA_ITEMS):
+ self._invalidateDataRange() # TODO handle this automatically
+
+ self._notifyContentChanged(item)
+ self.sigItemAdded.emit(item)
+
+ def removeItem(self, item):
+ """Remove the item from the plot.
+
+ :param ~silx.gui.plot.items.Item item: Item to remove from the plot.
+ :raises ValueError: If item is not in the plot.
+ """
+ if not isinstance(item, items.Item):
+ raise ValueError("argument must be an Item")
+
+ if item not in self.getItems():
+ raise ValueError("Item not in the plot")
+
+ self.sigItemAboutToBeRemoved.emit(item)
+
+ kind = self._itemKind(item)
+
+ if kind in self._ACTIVE_ITEM_KINDS:
+ if self._getActiveItem(kind) == item:
+ # Reset active item
+ self._setActiveItem(kind, None)
+
+ # Remove item from plot
+ self.__items.remove(item)
+ if item in self.__itemsToUpdate:
+ self.__itemsToUpdate.remove(item)
+ if item.isVisible():
+ self._setDirtyPlot(overlayOnly=item.isOverlay())
+ if item.getBounds() is not None:
+ self._invalidateDataRange()
+ item._removeBackendRenderer(self._backend)
+ item._setPlot(None)
+
+ if kind == "curve" and not self.getAllCurves(just_legend=True, withhidden=True):
+ self._resetColorAndStyle()
+
+ self.sigItemRemoved.emit(item)
+
+ self.notify("contentChanged", action="remove", kind=kind, legend=item.getName())
+
+ def discardItem(self, item) -> bool:
+ """Remove the item from the plot.
+
+ Same as :meth:`removeItem` but do not raise an exception.
+
+ :param ~silx.gui.plot.items.Item item: Item to remove from the plot.
+ :returns: True if the item was present, False otherwise.
+ """
+ try:
+ self.removeItem(item)
+ except ValueError:
+ return False
+ else:
+ return True
+
+ def getItems(self):
+ """Returns the list of items in the plot
+
+ :rtype: List[silx.gui.plot.items.Item]
+ """
+ return tuple(self.__items)
+
+ @contextmanager
+ def _muteActiveItemChangedSignal(self):
+ self.__muteActiveItemChanged = True
+ yield
+ self.__muteActiveItemChanged = False
+
+ # Add
+
+ # add * input arguments management:
+ # If an arg is set, then use it.
+ # Else:
+ # If a curve with the same legend exists, then use its arg value
+ # Else, use a default value.
+ # Store used value.
+ # This value is used when curve is updated either internally or by user.
+
+ def addCurve(
+ self,
+ x,
+ y,
+ legend=None,
+ info=None,
+ replace=False,
+ color=None,
+ symbol=None,
+ linewidth=None,
+ linestyle=None,
+ xlabel=None,
+ ylabel=None,
+ yaxis=None,
+ xerror=None,
+ yerror=None,
+ z=None,
+ selectable=None,
+ fill=None,
+ resetzoom=True,
+ histogram=None,
+ copy=True,
+ baseline=None,
+ ):
+ """Add a 1D curve given by x an y to the graph.
+
+ Curves are uniquely identified by their legend.
+ To add multiple curves, call :meth:`addCurve` multiple times with
+ different legend argument.
+ To replace an existing curve, call :meth:`addCurve` with the
+ existing curve legend.
+ If you want to display the curve values as an histogram see the
+ histogram parameter or :meth:`addHistogram`.
+
+ When curve parameters are not provided, if a curve with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ If you attempt to plot an histogram you can set edges values in x.
+ In this case len(x) = len(y) + 1.
+ If x contains datetime objects the XAxis tickMode is set to
+ TickMode.TIME_SERIES.
+ :param numpy.ndarray y: The data corresponding to the y coordinates
+ :param str legend: The legend to be associated to the curve (or None)
+ :param info: User-defined information associated to the curve
+ :param bool replace: True to delete already existing curves
+ (the default is False)
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ - None (the default) to use default symbol
+
+ :param float linewidth: The width of the curve in pixels (Default: 1).
+ :param str linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - None (the default) to use default line style
+
+ :param str xlabel: Label to show on the X axis when the curve is active
+ or None to keep default axis label.
+ :param str ylabel: Label to show on the Y axis when the curve is active
+ or None to keep default axis label.
+ :param str yaxis: The Y axis this curve is attached to.
+ Either 'left' (the default) or 'right'
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for lower errors,
+ row 1 for upper errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param int z: Layer on which to draw the curve (default: 1)
+ This allows to control the overlay.
+ :param bool selectable: Indicate if the curve can be selected.
+ (Default: True)
+ :param bool fill: True to fill the curve, False otherwise (default).
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param str histogram: if not None then the curve will be draw as an
+ histogram. The step for each values of the curve can be set to the
+ left, center or right of the original x curve values.
+ If histogram is not None and len(x) == len(y)+1 then x is directly
+ take as edges of the histogram.
+ Type of histogram::
+
+ - None (default)
+ - 'left'
+ - 'right'
+ - 'center'
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :param baseline: curve baseline
+ :type: Union[None,float,numpy.ndarray]
+ :returns: The curve item
+ """
+ # This is an histogram, use addHistogram
+ if histogram is not None:
+ histo = self.addHistogram(
+ histogram=y,
+ edges=x,
+ legend=legend,
+ color=color,
+ fill=fill,
+ align=histogram,
+ copy=copy,
+ )
+
+ histo.setInfo(info)
+ if linewidth is not None:
+ histo.setLineWidth(linewidth)
+ if linestyle is not None:
+ histo.setLineStyle(linestyle)
+ if xlabel is not None:
+ _logger.warning("addCurve: Histogram does not support xlabel argument")
+ if ylabel is not None:
+ _logger.warning("addCurve: Histogram does not support ylabel argument")
+ if yaxis is not None:
+ histo.setYAxis(yaxis)
+ if z is not None:
+ histo.setZValue(z)
+ if selectable is not None:
+ _logger.warning(
+ "addCurve: Histogram does not support selectable argument"
+ )
+
+ return histo
+
+ legend = "Unnamed curve 1.1" if legend is None else str(legend)
+
+ if replace:
+ self._resetColorAndStyle()
+
+ # Create/Update curve object
+ curve = self.getCurve(legend)
+ mustBeAdded = curve is None
+ if curve is None:
+ # No previous curve, create a default one and add it to the plot
+ curve = items.Curve() if histogram is None else items.Histogram()
+ curve.setName(legend)
+ # Set default color, linestyle and symbol
+ default_color, default_linestyle = self._getColorAndStyle()
+ curve.setColor(default_color)
+ curve.setLineStyle(default_linestyle)
+ curve.setSymbol(self._defaultPlotPoints)
+ curve._setBaseline(baseline=baseline)
+
+ # Do not emit sigActiveCurveChanged,
+ # it will be sent once with _setActiveItem
+ with self._muteActiveItemChangedSignal():
+ # Override previous/default values with provided ones
+ curve.setInfo(info)
+ if color is not None:
+ curve.setColor(
+ colors.rgba(color, colors=self.getDefaultColors())
+ if isinstance(color, str)
+ else color
+ )
+ if symbol is not None:
+ curve.setSymbol(symbol)
+ if linewidth is not None:
+ curve.setLineWidth(linewidth)
+ if linestyle is not None:
+ curve.setLineStyle(linestyle)
+ if xlabel is not None:
+ curve._setXLabel(xlabel)
+ if ylabel is not None:
+ curve._setYLabel(ylabel)
+ if yaxis is not None:
+ curve.setYAxis(yaxis)
+ if z is not None:
+ curve.setZValue(z)
+ if selectable is not None:
+ curve._setSelectable(selectable)
+ if fill is not None:
+ curve.setFill(fill)
+
+ # Set curve data
+ # If errors not provided, reuse previous ones
+ # TODO: Issue if size of data change but not that of errors
+ if xerror is None:
+ xerror = curve.getXErrorData(copy=False)
+ if yerror is None:
+ yerror = curve.getYErrorData(copy=False)
+
+ # Convert x to timestamps so that the internal representation
+ # remains floating points. The user is expected to set the axis'
+ # tickMode to TickMode.TIME_SERIES and, if necessary, set the axis
+ # to the correct time zone.
+ if len(x) > 0 and isinstance(x[0], dt.datetime):
+ x = [timestamp(d) for d in x]
+
+ curve.setData(x, y, xerror, yerror, baseline=baseline, copy=copy)
+
+ if replace: # Then remove all other curves
+ for c in self.getAllCurves(withhidden=True):
+ if c is not curve:
+ self.removeItem(c)
+
+ if mustBeAdded:
+ self.addItem(curve)
+ else:
+ self._notifyContentChanged(curve)
+
+ if curve is self.getActiveCurve() or (
+ self.getActiveCurveSelectionMode() == "legacy"
+ and self.getActiveCurve() is None
+ and len(self.getAllCurves(just_legend=True, withhidden=False)) == 1
+ and curve.isVisible()
+ ):
+ self.setActiveCurve(curve)
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return curve
+
+ def addHistogram(
+ self,
+ histogram,
+ edges,
+ legend=None,
+ color=None,
+ fill=None,
+ align="center",
+ resetzoom=True,
+ copy=True,
+ z=None,
+ baseline=None,
+ ):
+ """Add an histogram to the graph.
+
+ This is NOT computing the histogram, this method takes as parameter
+ already computed histogram values.
+
+ Histogram are uniquely identified by their legend.
+ To add multiple histograms, call :meth:`addHistogram` multiple times
+ with different legend argument.
+
+ When histogram parameters are not provided, if an histogram with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray histogram: The values of the histogram.
+ :param numpy.ndarray edges:
+ The bin edges of the histogram.
+ If histogram and edges have the same length, the bin edges
+ are computed according to the align parameter.
+ :param str legend:
+ The legend to be associated to the histogram (or None)
+ :param color: color to be used
+ :type color: str ("#RRGGBB") or RGB unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool fill: True to fill the curve, False otherwise (default).
+ :param str align:
+ In case histogram values and edges have the same length N,
+ the N+1 bin edges are computed according to the alignment in:
+ 'center' (default), 'left', 'right'.
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :param int z: Layer on which to draw the histogram
+ :param baseline: histogram baseline
+ :type: Union[None,float,numpy.ndarray]
+ :returns: The histogram item
+ """
+ legend = "Unnamed histogram" if legend is None else str(legend)
+
+ # Create/Update histogram object
+ histo = self.getHistogram(legend)
+ mustBeAdded = histo is None
+ if histo is None:
+ # No previous histogram, create a default one and
+ # add it to the plot
+ histo = items.Histogram()
+ histo.setName(legend)
+ histo.setColor(self._getColorAndStyle()[0])
+
+ # Override previous/default values with provided ones
+ if color is not None:
+ histo.setColor(
+ colors.rgba(color, colors=self.getDefaultColors())
+ if isinstance(color, str)
+ else color
+ )
+ if fill is not None:
+ histo.setFill(fill)
+ if z is not None:
+ histo.setZValue(z=z)
+
+ # Set histogram data
+ histo.setData(
+ histogram=histogram, edges=edges, baseline=baseline, align=align, copy=copy
+ )
+
+ if mustBeAdded:
+ self.addItem(histo)
+ else:
+ self._notifyContentChanged(histo)
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return histo
+
+ def addImage(
+ self,
+ data,
+ legend=None,
+ info=None,
+ replace=False,
+ z=None,
+ selectable=None,
+ draggable=None,
+ colormap=None,
+ pixmap=None,
+ xlabel=None,
+ ylabel=None,
+ origin=None,
+ scale=None,
+ resetzoom=True,
+ copy=True,
+ ):
+ """Add a 2D dataset or an image to the plot.
+
+ It displays either an array of data using a colormap or a RGB(A) image.
+
+ Images are uniquely identified by their legend.
+ To add multiple images, call :meth:`addImage` multiple times with
+ different legend argument.
+ To replace/update an existing image, call :meth:`addImage` with the
+ existing image legend.
+
+ When image parameters are not provided, if an image with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray data:
+ (nrows, ncolumns) data or
+ (nrows, ncolumns, RGBA) ubyte array
+ Note: boolean values are converted to int8.
+ :param str legend: The legend to be associated to the image (or None)
+ :param info: User-defined information associated to the image
+ :param bool replace:
+ True to delete already existing images (Default: False).
+ :param int z: Layer on which to draw the image (default: 0)
+ This allows to control the overlay.
+ :param bool selectable: Indicate if the image can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the image can be moved.
+ (default: False)
+ :param colormap: Colormap object to use (or None).
+ This is ignored if data is a RGB(A) image.
+ :type colormap: Union[~silx.gui.colors.Colormap, dict]
+ :param pixmap: Pixmap representation of the data (if any)
+ :type pixmap: (nrows, ncolumns, RGBA) ubyte array or None (default)
+ :param str xlabel: X axis label to show when this curve is active,
+ or None to keep default axis label.
+ :param str ylabel: Y axis label to show when this curve is active,
+ or None to keep default axis label.
+ :param origin: (origin X, origin Y) of the data.
+ It is possible to pass a single float if both
+ coordinates are equal.
+ Default: (0., 0.)
+ :type origin: float or 2-tuple of float
+ :param scale: (scale X, scale Y) of the data.
+ It is possible to pass a single float if both
+ coordinates are equal.
+ Default: (1., 1.)
+ :type scale: float or 2-tuple of float
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The image item
+ """
+ legend = "Unnamed Image 1.1" if legend is None else str(legend)
+
+ data = numpy.array(data, copy=False)
+ assert data.ndim in (2, 3)
+
+ image = self.getImage(legend)
+ if image is not None and image.getData(copy=False).ndim != data.ndim:
+ # Update a data image with RGBA image or the other way around:
+ # Remove previous image
+ # In this case, we don't retrieve defaults from the previous image
+ self.removeItem(image)
+ image = None
+
+ mustBeAdded = image is None
+ if image is None:
+ # No previous image, create a default one and add it to the plot
+ if data.ndim == 2:
+ image = items.ImageData()
+ image.setColormap(self.getDefaultColormap())
+ else:
+ image = items.ImageRgba()
+ image.setName(legend)
+
+ # Do not emit sigActiveImageChanged,
+ # it will be sent once with _setActiveItem
+ with self._muteActiveItemChangedSignal():
+ # Override previous/default values with provided ones
+ image.setInfo(info)
+ if origin is not None:
+ image.setOrigin(origin)
+ if scale is not None:
+ image.setScale(scale)
+ if z is not None:
+ image.setZValue(z)
+ if selectable is not None:
+ image._setSelectable(selectable)
+ if draggable is not None:
+ image._setDraggable(draggable)
+ if colormap is not None and isinstance(image, items.ColormapMixIn):
+ if isinstance(colormap, dict):
+ image.setColormap(Colormap._fromDict(colormap))
+ else:
+ assert isinstance(colormap, Colormap)
+ image.setColormap(colormap)
+ if xlabel is not None:
+ image._setXLabel(xlabel)
+ if ylabel is not None:
+ image._setYLabel(ylabel)
+
+ if data.ndim == 2:
+ image.setData(data, alternative=pixmap, copy=copy)
+ else: # RGB(A) image
+ if pixmap is not None:
+ _logger.warning(
+ "addImage: pixmap argument ignored when data is RGB(A)"
+ )
+ image.setData(data, copy=copy)
+
+ if replace:
+ for img in self.getAllImages():
+ if img is not image:
+ self.removeItem(img)
+
+ if mustBeAdded:
+ self.addItem(image)
+ else:
+ self._notifyContentChanged(image)
+
+ if len(self.getAllImages()) == 1 or image is self.getActiveImage():
+ self.setActiveImage(image)
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return image
+
+ def addScatter(
+ self,
+ x,
+ y,
+ value,
+ legend=None,
+ colormap=None,
+ info=None,
+ symbol=None,
+ xerror=None,
+ yerror=None,
+ z=None,
+ copy=True,
+ ):
+ """Add a (x, y, value) scatter to the graph.
+
+ Scatters are uniquely identified by their legend.
+ To add multiple scatters, call :meth:`addScatter` multiple times with
+ different legend argument.
+ To replace/update an existing scatter, call :meth:`addScatter` with the
+ existing scatter legend.
+
+ When scatter parameters are not provided, if a scatter with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates
+ :param numpy.ndarray value: The data value associated with each point
+ :param str legend: The legend to be associated to the scatter (or None)
+ :param ~silx.gui.colors.Colormap colormap:
+ Colormap object to be used for the scatter (or None)
+ :param info: User-defined information associated to the curve
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ - None (the default) to use default symbol
+
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for lower errors,
+ row 1 for upper errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param int z: Layer on which to draw the scatter (default: 1)
+ This allows to control the overlay.
+
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The scatter item
+ """
+ legend = "Unnamed scatter 1.1" if legend is None else str(legend)
+
+ # Create/Update curve object
+ scatter = self._getItem(kind="scatter", legend=legend)
+ mustBeAdded = scatter is None
+ if scatter is None:
+ # No previous scatter, create a default one and add it to the plot
+ scatter = items.Scatter()
+ scatter.setName(legend)
+ scatter.setColormap(self.getDefaultColormap())
+
+ # Do not emit sigActiveScatterChanged,
+ # it will be sent once with _setActiveItem
+ with self._muteActiveItemChangedSignal():
+ # Override previous/default values with provided ones
+ scatter.setInfo(info)
+ if symbol is not None:
+ scatter.setSymbol(symbol)
+ if z is not None:
+ scatter.setZValue(z)
+ if colormap is not None:
+ if isinstance(colormap, dict):
+ scatter.setColormap(Colormap._fromDict(colormap))
+ else:
+ assert isinstance(colormap, Colormap)
+ scatter.setColormap(colormap)
+
+ # Set scatter data
+ # If errors not provided, reuse previous ones
+ if xerror is None:
+ xerror = scatter.getXErrorData(copy=False)
+ if xerror is not None and len(xerror) != len(x):
+ xerror = None
+ if yerror is None:
+ yerror = scatter.getYErrorData(copy=False)
+ if yerror is not None and len(yerror) != len(y):
+ yerror = None
+
+ scatter.setData(x, y, value, xerror, yerror, copy=copy)
+
+ if mustBeAdded:
+ self.addItem(scatter)
+ else:
+ self._notifyContentChanged(scatter)
+
+ scatters = [
+ item
+ for item in self.getItems()
+ if isinstance(item, items.Scatter) and item.isVisible()
+ ]
+ if len(scatters) == 1 or scatter is self.getActiveScatter():
+ self.setActiveScatter(scatter)
+
+ return scatter
+
+ def addShape(
+ self,
+ xdata,
+ ydata,
+ legend=None,
+ info=None,
+ replace=False,
+ shape="polygon",
+ color="black",
+ fill=True,
+ overlay=False,
+ z=None,
+ linestyle="-",
+ linewidth=1.0,
+ linebgcolor="deprecated",
+ gapcolor=None,
+ ):
+ """Add an item (i.e. a shape) to the plot.
+
+ Items are uniquely identified by their legend.
+ To add multiple items, call :meth:`addItem` multiple times with
+ different legend argument.
+ To replace/update an existing item, call :meth:`addItem` with the
+ existing item legend.
+
+ :param numpy.ndarray xdata: The X coords of the points of the shape
+ :param numpy.ndarray ydata: The Y coords of the points of the shape
+ :param str legend: The legend to be associated to the item
+ :param info: User-defined information associated to the item
+ :param bool replace: True to delete already existing items
+ (the default is False)
+ :param str shape: Type of item to be drawn in
+ hline, polygon (the default), rectangle, vline,
+ polylines
+ :param str color: Color of the item, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool fill: True (the default) to fill the shape
+ :param bool overlay: True if item is an overlay (Default: False).
+ This allows for rendering optimization if this
+ item is changed often.
+ :param int z: Layer on which to draw the item (default: 2)
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str gapcolor: Gap color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
+ :returns: The shape item
+ """
+ # expected to receive the same parameters as the signal
+
+ legend = "Unnamed Item 1.1" if legend is None else str(legend)
+
+ z = int(z) if z is not None else 2
+
+ if replace:
+ self.remove(kind="item")
+ else:
+ self.remove(legend, kind="item")
+
+ item = items.Shape(shape)
+ item.setName(legend)
+ item.setInfo(info)
+ item.setColor(color)
+ item.setFill(fill)
+ item.setOverlay(overlay)
+ item.setZValue(z)
+ item.setPoints(numpy.array((xdata, ydata)).T)
+ item.setLineStyle(linestyle)
+ item.setLineWidth(linewidth)
+ if linebgcolor != "deprecated":
+ deprecated_warning(
+ type_="Argument",
+ name="linebgcolor",
+ replacement="gapcolor",
+ since_version="2.0.0",
+ )
+ gapcolor = linebgcolor if gapcolor is None else gapcolor
+ item.setLineGapColor(gapcolor)
+
+ self.addItem(item)
+
+ return item
+
+ def addXMarker(
+ self,
+ x,
+ legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ yaxis="left",
+ ):
+ """Add a vertical line marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addXMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param x: Position of the marker on the X axis in data coordinates
+ :type x: Union[None, float]
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display on the marker.
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: The marker item
+ """
+ return self._addMarker(
+ x=x,
+ y=None,
+ legend=legend,
+ text=text,
+ color=color,
+ selectable=selectable,
+ draggable=draggable,
+ symbol=None,
+ constraint=constraint,
+ yaxis=yaxis,
+ )
+
+ def addYMarker(
+ self,
+ y,
+ legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ yaxis="left",
+ ):
+ """Add a horizontal line marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addYMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param float y: Position of the marker on the Y axis in data
+ coordinates
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display next to the marker.
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: The marker item
+ """
+ return self._addMarker(
+ x=None,
+ y=y,
+ legend=legend,
+ text=text,
+ color=color,
+ selectable=selectable,
+ draggable=draggable,
+ symbol=None,
+ constraint=constraint,
+ yaxis=yaxis,
+ )
+
+ def addMarker(
+ self,
+ x,
+ y,
+ legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ symbol="+",
+ constraint=None,
+ yaxis="left",
+ ):
+ """Add a point marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param float x: Position of the marker on the X axis in data
+ coordinates
+ :param float y: Position of the marker on the Y axis in data
+ coordinates
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display next to the marker
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param str symbol: Symbol representing the marker in::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross (the default)
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: The marker item
+ """
+ if x is None:
+ xmin, xmax = self._xAxis.getLimits()
+ x = 0.5 * (xmax + xmin)
+
+ if y is None:
+ ymin, ymax = self._yAxis.getLimits()
+ y = 0.5 * (ymax + ymin)
+
+ return self._addMarker(
+ x=x,
+ y=y,
+ legend=legend,
+ text=text,
+ color=color,
+ selectable=selectable,
+ draggable=draggable,
+ symbol=symbol,
+ constraint=constraint,
+ yaxis=yaxis,
+ )
+
+ def _addMarker(
+ self,
+ x,
+ y,
+ legend,
+ text,
+ color,
+ selectable,
+ draggable,
+ symbol,
+ constraint,
+ yaxis=None,
+ ):
+ """Common method for adding point, vline and hline marker.
+
+ See :meth:`addMarker` for argument documentation.
+ """
+ assert (x, y) != (None, None)
+
+ if legend is None: # Find an unused legend
+ markerLegends = [
+ item.getName()
+ for item in self.getItems()
+ if isinstance(item, items.MarkerBase)
+ ]
+ for index in itertools.count():
+ legend = "Unnamed Marker %d" % index
+ if legend not in markerLegends:
+ break # Keep this legend
+ legend = str(legend)
+
+ if x is None:
+ markerClass = items.YMarker
+ elif y is None:
+ markerClass = items.XMarker
+ else:
+ markerClass = items.Marker
+
+ # Create/Update marker object
+ marker = self._getMarker(legend)
+ if marker is not None and not isinstance(marker, markerClass):
+ _logger.warning(
+ "Adding marker with same legend" " but different type replaces it"
+ )
+ self.removeItem(marker)
+ marker = None
+
+ mustBeAdded = marker is None
+ if marker is None:
+ # No previous marker, create one
+ marker = markerClass()
+ marker.setName(legend)
+
+ if text is not None:
+ marker.setText(text)
+ if color is not None:
+ marker.setColor(color)
+ if selectable is not None:
+ marker._setSelectable(selectable)
+ if draggable is not None:
+ marker._setDraggable(draggable)
+ if symbol is not None:
+ marker.setSymbol(symbol)
+ marker.setYAxis(yaxis)
+
+ # TODO to improve, but this ensure constraint is applied
+ marker.setPosition(x, y)
+ if constraint is not None:
+ marker._setConstraint(constraint)
+ marker.setPosition(x, y)
+
+ if mustBeAdded:
+ self.addItem(marker)
+ else:
+ self._notifyContentChanged(marker)
+
+ return marker
+
+ # Hide
+
+ def isCurveHidden(self, legend):
+ """Returns True if the curve associated to legend is hidden, else False
+
+ :param str legend: The legend key identifying the curve
+ :return: True if the associated curve is hidden, False otherwise
+ """
+ curve = self._getItem("curve", legend)
+ return curve is not None and not curve.isVisible()
+
+ def hideCurve(self, legend, flag=True):
+ """Show/Hide the curve associated to legend.
+
+ Even when hidden, the curve is kept in the list of curves.
+
+ :param str legend: The legend associated to the curve to be hidden
+ :param bool flag: True (default) to hide the curve, False to show it
+ """
+ curve = self._getItem("curve", legend)
+ if curve is None:
+ _logger.warning("Curve not in plot: %s", legend)
+ return
+
+ isVisible = not flag
+ if isVisible != curve.isVisible():
+ curve.setVisible(isVisible)
+
+ # Remove
+
+ ITEM_KINDS = "curve", "image", "scatter", "item", "marker", "histogram"
+ """List of supported kind of items in the plot."""
+
+ _ACTIVE_ITEM_KINDS = "curve", "scatter", "image"
+ """List of item's kind which have a active item."""
+
+ def remove(
+ self,
+ legend: str | items.Item | None = None,
+ kind: str | Sequence[str] = ITEM_KINDS,
+ ):
+ """Remove one or all element(s) of the given legend and kind.
+
+ Examples:
+
+ - ``remove()`` clears the plot
+ - ``remove(kind='curve')`` removes all curves from the plot
+ - ``remove('myCurve', kind='curve')`` removes the curve with
+ legend 'myCurve' from the plot.
+ - ``remove('myImage, kind='image')`` removes the image with
+ legend 'myImage' from the plot.
+ - ``remove('myImage')`` removes elements (for instance curve, image,
+ item and marker) with legend 'myImage'.
+
+ :param legend:
+ The legend of the item to remove or the item itself.
+ If None all items of given kind are removed.
+ :param kind: The kind of items to remove from the plot.
+ See :attr:`ITEM_KINDS`.
+ By default, it removes all kind of elements.
+ """
+ if isinstance(legend, items.Item):
+ return self.removeItem(legend)
+
+ if kind == "all": # Replace all by tuple of all kinds
+ kind = self.ITEM_KINDS
+
+ if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
+ kind = (kind,)
+
+ for aKind in kind:
+ assert aKind in self.ITEM_KINDS
+
+ if legend is None: # This is a clear
+ # Clear each given kind
+ for aKind in kind:
+ for item in self.getItems():
+ if (
+ isinstance(item, self._KIND_TO_CLASSES[aKind])
+ and item.getPlot() is self
+ ): # Make sure item is still in the plot
+ self.removeItem(item)
+
+ else: # This is removing a single element
+ # Remove each given kind
+ for aKind in kind:
+ item = self._getItem(aKind, legend)
+ if item is not None:
+ self.removeItem(item)
+
+ def removeCurve(self, legend: str | items.Curve | None):
+ """Remove the curve associated to legend from the graph.
+
+ :param legend:
+ The legend of the curve to be deleted or the curve item
+ """
+ if legend is None:
+ return
+ if isinstance(legend, items.Item):
+ return self.removeItem(legend)
+ self.remove(legend, kind="curve")
+
+ def removeImage(self, legend: str | items.ImageBase | None):
+ """Remove the image associated to legend from the graph.
+
+ :param legend:
+ The legend of the image to be deleted or the image item
+ """
+ if legend is None:
+ return
+ if isinstance(legend, items.Item):
+ return self.removeItem(legend)
+ self.remove(legend, kind="image")
+
+ def removeMarker(self, legend: str | items.Marker | None):
+ """Remove the marker associated to legend from the graph.
+
+ :param legend:
+ The legend of the marker to be deleted or the marker item
+ """
+ if legend is None:
+ return
+ if isinstance(legend, items.Item):
+ return self.removeItem(legend)
+ self.remove(legend, kind="marker")
+
+ # Clear
+
+ def clear(self):
+ """Remove everything from the plot."""
+ for item in self.getItems():
+ if item.getPlot() is self: # Make sure item is still in the plot
+ self.removeItem(item)
+
+ def clearCurves(self):
+ """Remove all the curves from the plot."""
+ self.remove(kind="curve")
+
+ def clearImages(self):
+ """Remove all the images from the plot."""
+ self.remove(kind="image")
+
+ def clearItems(self):
+ """Remove all the items from the plot."""
+ self.remove(kind="item")
+
+ def clearMarkers(self):
+ """Remove all the markers from the plot."""
+ self.remove(kind="marker")
+
+ # Interaction
+
+ def getGraphCursor(self):
+ """Returns the state of the crosshair cursor.
+
+ See :meth:`setGraphCursor`.
+
+ :return: None if the crosshair cursor is not active,
+ else a tuple (color, linewidth, linestyle).
+ """
+ return self._cursorConfiguration
+
+ def setGraphCursor(self, flag=False, color="black", linewidth=1, linestyle="-"):
+ """Toggle the display of a crosshair cursor and set its attributes.
+
+ :param bool flag: Toggle the display of a crosshair cursor.
+ The crosshair cursor is hidden by default.
+ :param color: The color to use for the crosshair.
+ :type color: A string (either a predefined color name in colors.py
+ or "#RRGGBB")) or a 4 columns unsigned byte array
+ (Default: black).
+ :param int linewidth: The width of the lines of the crosshair
+ (Default: 1).
+ :param str linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line (the default)
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ """
+ if flag:
+ self._cursorConfiguration = color, linewidth, linestyle
+ else:
+ self._cursorConfiguration = None
+
+ self._backend.setGraphCursor(
+ flag=flag, color=color, linewidth=linewidth, linestyle=linestyle
+ )
+ self._setDirtyPlot()
+ self.notify("setGraphCursor", state=self._cursorConfiguration is not None)
+
+ def pan(self, direction, factor=0.1):
+ """Pan the graph in the given direction by the given factor.
+
+ Warning: Pan of right Y axis not implemented!
+
+ :param str direction: One of 'up', 'down', 'left', 'right'.
+ :param float factor: Proportion of the range used to pan the graph.
+ Must be strictly positive.
+ """
+ assert direction in ("up", "down", "left", "right")
+ assert factor > 0.0
+
+ if direction in ("left", "right"):
+ xFactor = factor if direction == "right" else -factor
+ xMin, xMax = self._xAxis.getLimits()
+
+ xMin, xMax = _utils.applyPan(
+ xMin, xMax, xFactor, self._xAxis.getScale() == self._xAxis.LOGARITHMIC
+ )
+ self._xAxis.setLimits(xMin, xMax)
+
+ else: # direction in ('up', 'down')
+ sign = -1.0 if self._yAxis.isInverted() else 1.0
+ yFactor = sign * (factor if direction == "up" else -factor)
+ yMin, yMax = self._yAxis.getLimits()
+ yIsLog = self._yAxis.getScale() == self._yAxis.LOGARITHMIC
+
+ yMin, yMax = _utils.applyPan(yMin, yMax, yFactor, yIsLog)
+ self._yAxis.setLimits(yMin, yMax)
+
+ y2Min, y2Max = self._yRightAxis.getLimits()
+
+ y2Min, y2Max = _utils.applyPan(y2Min, y2Max, yFactor, yIsLog)
+ self._yRightAxis.setLimits(y2Min, y2Max)
+
+ # Active Curve/Image
+
+ def isActiveCurveHandling(self):
+ """Returns True if active curve selection is enabled.
+
+ :rtype: bool
+ """
+ return self.getActiveCurveSelectionMode() != "none"
+
+ def setActiveCurveHandling(self, flag=True):
+ """Enable/Disable active curve selection.
+
+ :param bool flag: True to enable 'atmostone' active curve selection,
+ False to disable active curve selection.
+ """
+ self.setActiveCurveSelectionMode("atmostone" if flag else "none")
+
+ def getActiveCurveStyle(self):
+ """Returns the current style applied to active curve
+
+ :rtype: CurveStyle
+ """
+ return self._activeCurveStyle
+
+ def setActiveCurveStyle(
+ self, color=None, linewidth=None, linestyle=None, symbol=None, symbolsize=None
+ ):
+ """Set the style of active curve
+
+ :param color: Color
+ :param Union[str,None] linestyle: Style of the line
+ :param Union[float,None] linewidth: Width of the line
+ :param Union[str,None] symbol: Symbol of the markers
+ :param Union[float,None] symbolsize: Size of the symbols
+ """
+ self._activeCurveStyle = CurveStyle(
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle,
+ symbol=symbol,
+ symbolsize=symbolsize,
+ )
+ curve = self.getActiveCurve()
+ if curve is not None:
+ curve.setHighlightedStyle(self.getActiveCurveStyle())
+
+ def getActiveCurve(self, just_legend=False):
+ """Return the currently active curve.
+
+ It returns None in case of not having an active curve.
+
+ :param bool just_legend: True to get the legend of the curve,
+ False (the default) to get the curve data
+ and info.
+ :return: Active curve's legend or corresponding
+ :class:`.items.Curve`
+ :rtype: str or :class:`.items.Curve` or None
+ """
+ if not self.isActiveCurveHandling():
+ return None
+
+ return self._getActiveItem(kind="curve", just_legend=just_legend)
+
+ def setActiveCurve(self, legend):
+ """Make the curve associated to legend the active curve.
+
+ :param legend: The legend associated to the curve
+ or None to have no active curve.
+ :type legend: str or None
+ """
+ if not self.isActiveCurveHandling():
+ return
+ if legend is None and self.getActiveCurveSelectionMode() == "legacy":
+ _logger.info(
+ "setActiveCurve(None) ignored due to active curve selection mode"
+ )
+ return
+
+ return self._setActiveItem(kind="curve", item=legend)
+
+ def setActiveCurveSelectionMode(self, mode):
+ """Sets the current selection mode.
+
+ :param str mode: The active curve selection mode to use.
+ It can be: 'legacy', 'atmostone' or 'none'.
+ """
+ assert mode in ("legacy", "atmostone", "none")
+
+ if mode != self._activeCurveSelectionMode:
+ self._activeCurveSelectionMode = mode
+ if mode == "none": # reset active curve
+ self._setActiveItem(kind="curve", item=None)
+
+ elif mode == "legacy" and self.getActiveCurve() is None:
+ # Select an active curve
+ curves = self.getAllCurves(just_legend=False, withhidden=False)
+ if len(curves) == 1:
+ if curves[0].isVisible():
+ self.setActiveCurve(curves[0].getName())
+
+ def getActiveCurveSelectionMode(self):
+ """Returns the current selection mode.
+
+ It can be "atmostone", "legacy" or "none".
+
+ :rtype: str
+ """
+ return self._activeCurveSelectionMode
+
+ def getActiveImage(self, just_legend=False):
+ """Returns the currently active image.
+
+ It returns None in case of not having an active image.
+
+ :param bool just_legend: True to get the legend of the image,
+ False (the default) to get the image data
+ and info.
+ :return: Active image's legend or corresponding image object
+ :rtype: str, :class:`.items.ImageData`, :class:`.items.ImageRgba`
+ or None
+ """
+ return self._getActiveItem(kind="image", just_legend=just_legend)
+
+ def setActiveImage(self, legend):
+ """Make the image associated to legend the active image.
+
+ :param str legend: The legend associated to the image
+ or None to have no active image.
+ """
+ return self._setActiveItem(kind="image", item=legend)
+
+ def getActiveScatter(self, just_legend=False):
+ """Returns the currently active scatter.
+
+ It returns None in case of not having an active scatter.
+
+ :param bool just_legend: True to get the legend of the scatter,
+ False (the default) to get the scatter data
+ and info.
+ :return: Active scatter's legend or corresponding scatter object
+ :rtype: str, :class:`.items.Scatter` or None
+ """
+ return self._getActiveItem(kind="scatter", just_legend=just_legend)
+
+ def setActiveScatter(self, legend):
+ """Make the scatter associated to legend the active scatter.
+
+ :param str legend: The legend associated to the scatter
+ or None to have no active scatter.
+ """
+ return self._setActiveItem(kind="scatter", item=legend)
+
+ def _getActiveItem(
+ self,
+ kind: str | None,
+ just_legend: bool = False,
+ ) -> items.Curve | items.Scatter | items.ImageBase | None:
+ """Return the currently active item of given kind if any.
+
+ :param kind: Type of item: 'curve', 'scatter' or 'image'
+ :param just_legend:
+ True to get the item's legend, False (the default) to get the item
+ """
+ assert kind in self._ACTIVE_ITEM_KINDS
+ item = self.__activeItems[kind]
+ if item is not None and just_legend:
+ return item.getName()
+ return item
+
+ def _setActiveItem(
+ self,
+ kind: str,
+ item: items.Curve | items.ImageBase | items.Scatter | str | None,
+ ) -> str | None:
+ """Make the given item active.
+
+ Note: There is one active item per "kind" of item.
+ """
+ assert kind in self._ACTIVE_ITEM_KINDS
+
+ if item is None:
+ legend = None
+ elif isinstance(item, items.Item):
+ legend = item.getName()
+ else:
+ legend = str(item)
+ item = self._getItem(kind, legend)
+ if item is None:
+ _logger.warning("This %s does not exist: %s", kind, legend)
+
+ oldActiveItem = self._getActiveItem(kind=kind)
+
+ if oldActiveItem is None and item is None:
+ return None
+
+ if oldActiveItem is not None:
+ # Stop listening previous active item
+ oldActiveItem.sigItemChanged.disconnect(self._activeItemChanged)
+ # Curve specific: Reset highlight of previous active curve
+ if kind == "curve":
+ oldActiveItem.setHighlighted(False)
+
+ self.__activeItems[kind] = item
+
+ xLabel = None
+ yLabel = None
+ yRightLabel = None
+
+ if item is not None:
+ # Curve specific: handle highlight
+ if kind == "curve":
+ item.setHighlightedStyle(self.getActiveCurveStyle())
+ item.setHighlighted(True)
+
+ if isinstance(item, items.LabelsMixIn):
+ xLabel = item.getXLabel()
+ if isinstance(item, items.YAxisMixIn) and item.getYAxis() == "right":
+ yRightLabel = item.getYLabel()
+ else:
+ yLabel = item.getYLabel()
+
+ # Start listening new active item
+ item.sigItemChanged.connect(self._activeItemChanged)
+
+ # Store current labels and update plot
+ self._xAxis._setCurrentLabel(xLabel)
+ self._yAxis._setCurrentLabel(yLabel)
+ self._yRightAxis._setCurrentLabel(yRightLabel)
+
+ self._setDirtyPlot()
+
+ self.notify(
+ f"active{kind.capitalize()}Changed",
+ updated=oldActiveItem is not item,
+ previous=None if oldActiveItem is None else oldActiveItem.getName(),
+ legend=legend,
+ )
+ return legend
+
+ def _activeItemChanged(self, type_):
+ """Listen for active item changed signal and broadcast signal
+
+ :param item.ItemChangedType type_: The type of item change
+ """
+ if not self.__muteActiveItemChanged:
+ item = self.sender()
+ if item is not None:
+ kind = self._itemKind(item)
+ self.notify(
+ "active" + kind[0].upper() + kind[1:] + "Changed",
+ updated=False,
+ previous=item.getName(),
+ legend=item.getName(),
+ )
+
+ # Getters
+
+ def getAllCurves(self, just_legend=False, withhidden=False):
+ """Returns all curves legend or info and data.
+
+ It returns an empty list in case of not having any curve.
+
+ If just_legend is False, it returns a list of :class:`items.Curve`
+ objects describing the curves.
+ If just_legend is True, it returns a list of curves' legend.
+
+ :param bool just_legend: True to get the legend of the curves,
+ False (the default) to get the curves' data
+ and info.
+ :param bool withhidden: False (default) to skip hidden curves.
+ :return: list of curves' legend or :class:`.items.Curve`
+ :rtype: list of str or list of :class:`.items.Curve`
+ """
+ curves = [
+ item
+ for item in self.getItems()
+ if isinstance(item, items.Curve) and (withhidden or item.isVisible())
+ ]
+ return [curve.getName() for curve in curves] if just_legend else curves
+
+ def getCurve(self, legend: str | items.Curve | None = None) -> items.Curve:
+ """Get the object describing a specific curve.
+
+ It returns None in case no matching curve is found.
+
+ :param legend:
+ The legend identifying the curve.
+ If not provided or None (the default), the active curve is returned
+ or if there is no active curve, the latest updated curve that is
+ not hidden is returned if there are curves in the plot.
+ :return: None or :class:`.items.Curve` object
+ """
+ if isinstance(legend, items.Curve):
+ _logger.warning("getCurve call not needed: legend is already an item")
+ return legend
+ return self._getItem(kind="curve", legend=legend)
+
+ def getAllImages(self, just_legend=False):
+ """Returns all images legend or objects.
+
+ It returns an empty list in case of not having any image.
+
+ If just_legend is False, it returns a list of :class:`items.ImageBase`
+ objects describing the images.
+ If just_legend is True, it returns a list of legends.
+
+ :param bool just_legend: True to get the legend of the images,
+ False (the default) to get the images'
+ object.
+ :return: list of images' legend or :class:`.items.ImageBase`
+ :rtype: list of str or list of :class:`.items.ImageBase`
+ """
+ images = [item for item in self.getItems() if isinstance(item, items.ImageBase)]
+ return [image.getName() for image in images] if just_legend else images
+
+ def getImage(self, legend: str | items.ImageBase | None = None) -> items.ImageBase:
+ """Get the object describing a specific image.
+
+ It returns None in case no matching image is found.
+
+ :param legend:
+ The legend identifying the image.
+ If not provided or None (the default), the active image is returned
+ or if there is no active image, the latest updated image
+ is returned if there are images in the plot.
+ :return: None or :class:`.items.ImageBase` object
+ """
+ if isinstance(legend, items.ImageBase):
+ _logger.warning("getImage call not needed: legend is already an item")
+ return legend
+ return self._getItem(kind="image", legend=legend)
+
+ def getScatter(self, legend: str | items.Scatter | None = None) -> items.Scatter:
+ """Get the object describing a specific scatter.
+
+ It returns None in case no matching scatter is found.
+
+ :param legend:
+ The legend identifying the scatter.
+ If not provided or None (the default), the active scatter is
+ returned or if there is no active scatter, the latest updated
+ scatter is returned if there are scatters in the plot.
+ :return: None or :class:`.items.Scatter` object
+ """
+ if isinstance(legend, items.Scatter):
+ _logger.warning("getScatter call not needed: legend is already an item")
+ return legend
+ return self._getItem(kind="scatter", legend=legend)
+
+ def getHistogram(
+ self, legend: str | items.Histogram | None = None
+ ) -> items.Histogram:
+ """Get the object describing a specific histogram.
+
+ It returns None in case no matching histogram is found.
+
+ :param legend:
+ The legend identifying the histogram.
+ If not provided or None (the default), the latest updated scatter
+ is returned if there are histograms in the plot.
+ :return: None or :class:`.items.Histogram` object
+ """
+ if isinstance(legend, items.Histogram):
+ _logger.warning("getHistogram call not needed: legend is already an item")
+ return legend
+ return self._getItem(kind="histogram", legend=legend)
+
+ def _getItem(self, kind, legend=None) -> items.Item:
+ """Get an item from the plot: either an image or a curve.
+
+ Returns None if no match found.
+
+ :param str kind: Type of item to retrieve,
+ see :attr:`ITEM_KINDS`.
+ :param str legend: Legend of the item or
+ None to get active or last item
+ :return: Object describing the item or None
+ """
+ if isinstance(legend, items.Item):
+ _logger.warning("_getItem call not needed: legend is already an item")
+ return legend
+
+ assert kind in self.ITEM_KINDS
+
+ if legend is not None:
+ for item in self.getItems():
+ if item.getName() == legend and kind == self._itemKind(item):
+ return item
+ return None # No item found
+
+ if kind in self._ACTIVE_ITEM_KINDS:
+ item = self._getActiveItem(kind=kind)
+ if item is not None: # Return active item if available
+ return item
+ # Return last visible item if any
+ itemClasses = self._KIND_TO_CLASSES[kind]
+ allItems = [
+ item
+ for item in self.getItems()
+ if isinstance(item, itemClasses) and item.isVisible()
+ ]
+ return allItems[-1] if allItems else None
+
+ # Limits
+
+ def _notifyLimitsChanged(self, emitSignal=True):
+ """Send an event when plot area limits are changed."""
+ xRange = self._xAxis.getLimits()
+ yRange = self._yAxis.getLimits()
+ y2Range = self._yRightAxis.getLimits()
+ if emitSignal:
+ axes = self.getXAxis(), self.getYAxis(), self.getYAxis(axis="right")
+ ranges = xRange, yRange, y2Range
+ for axis, limits in zip(axes, ranges):
+ axis.sigLimitsChanged.emit(*limits)
+ event = PlotEvents.prepareLimitsChangedSignal(
+ id(self.getWidgetHandle()), xRange, yRange, y2Range
+ )
+ self.notify(**event)
+
+ def getLimitsHistory(self):
+ """Returns the object handling the history of limits of the plot"""
+ return self._limitsHistory
+
+ def getGraphXLimits(self):
+ """Get the graph X (bottom) limits.
+
+ :return: Minimum and maximum values of the X axis
+ """
+ return self._backend.getGraphXLimits()
+
+ def setGraphXLimits(self, xmin, xmax):
+ """Set the graph X (bottom) limits.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ """
+ self._xAxis.setLimits(xmin, xmax)
+
+ def getGraphYLimits(self, axis="left"):
+ """Get the graph Y limits.
+
+ :param str axis: The axis for which to get the limits:
+ Either 'left' or 'right'
+ :return: Minimum and maximum values of the X axis
+ """
+ assert axis in ("left", "right")
+ yAxis = self._yAxis if axis == "left" else self._yRightAxis
+ return yAxis.getLimits()
+
+ def setGraphYLimits(self, ymin, ymax, axis="left"):
+ """Set the graph Y limits.
+
+ :param float ymin: minimum bottom axis value
+ :param float ymax: maximum bottom axis value
+ :param str axis: The axis for which to get the limits:
+ Either 'left' or 'right'
+ """
+ assert axis in ("left", "right")
+ yAxis = self._yAxis if axis == "left" else self._yRightAxis
+ return yAxis.setLimits(ymin, ymax)
+
+ def setLimits(
+ self,
+ xmin: float,
+ xmax: float,
+ ymin: float,
+ ymax: float,
+ y2min: Optional[float] = None,
+ y2max: Optional[float] = None,
+ margins: Union[bool, tuple[float, float, float, float]] = False,
+ ):
+ """Set the limits of the X and Y axes at once.
+
+ If y2min or y2max is None, the right Y axis limits are not updated.
+
+ :param xmin: minimum bottom axis value
+ :param xmax: maximum bottom axis value
+ :param ymin: minimum left axis value
+ :param ymax: maximum left axis value
+ :param y2min: minimum right axis value or None (the default)
+ :param y2max: maximum right axis value or None (the default)
+ :param margins:
+ Data margins to add to the limits or a boolean telling
+ whether or not to add margins from :meth:`getDataMargins`.
+ """
+ limits = [
+ *self.getXAxis()._checkLimits(xmin, xmax),
+ *self.getYAxis()._checkLimits(ymin, ymax),
+ ]
+
+ # Only consider y2 axis if both limits are not None
+ if None not in (y2min, y2max):
+ limits.extend(self.getYAxis(axis="right")._checkLimits(y2min, y2max))
+
+ if margins: # Add margins around limits inside the plot area
+ limits = list(
+ _utils.addMarginsToLimits(
+ self.getDataMargins() if margins is True else margins,
+ self.getXAxis()._isLogarithmic(),
+ self.getYAxis()._isLogarithmic(),
+ *limits,
+ )
+ )
+
+ if self.isKeepDataAspectRatio():
+ # Use limits with margins to keep ratio
+ xmin, xmax, ymin, ymax = limits[:4]
+
+ # Compute bbox wth figure aspect ratio
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ if plotWidth > 0 and plotHeight > 0:
+ plotRatio = plotHeight / plotWidth
+ dataRatio = (ymax - ymin) / (xmax - xmin)
+ if dataRatio < plotRatio:
+ # Increase y range
+ ycenter = 0.5 * (ymax + ymin)
+ yrange = (xmax - xmin) * plotRatio
+ limits[2] = ycenter - 0.5 * yrange
+ limits[3] = ycenter + 0.5 * yrange
+
+ elif dataRatio > plotRatio:
+ # Increase x range
+ xcenter = 0.5 * (xmax + xmin)
+ xrange_ = (ymax - ymin) / plotRatio
+ limits[0] = xcenter - 0.5 * xrange_
+ limits[1] = xcenter + 0.5 * xrange_
+
+ if self._viewConstrains:
+ limits[:4] = self._viewConstrains.normalize(*limits[:4])
+
+ self._backend.setLimits(*limits)
+ self._setDirtyPlot()
+ self._notifyLimitsChanged()
+
+ def _getViewConstraints(self):
+ """Return the plot object managing constaints on the plot view.
+
+ :rtype: ViewConstraints
+ """
+ if self._viewConstrains is None:
+ self._viewConstrains = ViewConstraints()
+ return self._viewConstrains
+
+ # Title and labels
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str."""
+ return self._graphTitle
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ self._graphTitle = str(title)
+ self._backend.setGraphTitle(title)
+ self._setDirtyPlot()
+
+ def getGraphXLabel(self):
+ """Return the current X axis label as a str."""
+ return self._xAxis.getLabel()
+
+ def setGraphXLabel(self, label="X"):
+ """Set the plot X axis label.
+
+ The provided label can be temporarily replaced by the X label of the
+ active curve if any.
+
+ :param str label: The X axis label (default: 'X')
+ """
+ self._xAxis.setLabel(label)
+
+ def getGraphYLabel(self, axis="left"):
+ """Return the current Y axis label as a str.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ """
+ assert axis in ("left", "right")
+ yAxis = self._yAxis if axis == "left" else self._yRightAxis
+ return yAxis.getLabel()
+
+ def setGraphYLabel(self, label="Y", axis="left"):
+ """Set the plot Y axis label.
+
+ The provided label can be temporarily replaced by the Y label of the
+ active curve if any.
+
+ :param str label: The Y axis label (default: 'Y')
+ :param str axis: The Y axis for which to set the label (left or right)
+ """
+ assert axis in ("left", "right")
+ yAxis = self._yAxis if axis == "left" else self._yRightAxis
+ return yAxis.setLabel(label)
+
+ # Axes
+
+ def getXAxis(self):
+ """Returns the X axis
+
+ .. versionadded:: 0.6
+
+ :rtype: :class:`.items.Axis`
+ """
+ return self._xAxis
+
+ def getYAxis(self, axis="left"):
+ """Returns an Y axis
+
+ .. versionadded:: 0.6
+
+ :param str axis: The Y axis to return
+ ('left' or 'right').
+ :rtype: :class:`.items.Axis`
+ """
+ assert axis in ["left", "right"]
+ return self._yAxis if axis == "left" else self._yRightAxis
+
+ def setAxesDisplayed(self, displayed: bool):
+ """Display or not the axes.
+
+ :param bool displayed: If `True` axes are displayed. If `False` axes
+ are not anymore visible and the margin used for them is removed.
+ """
+ if displayed != self.__axesDisplayed:
+ self.__axesDisplayed = displayed
+ if displayed:
+ self._backend.setAxesMargins(*self.__axesMargins)
+ else:
+ self._backend.setAxesMargins(0.0, 0.0, 0.0, 0.0)
+ self._setDirtyPlot()
+ self._sigAxesVisibilityChanged.emit(displayed)
+
+ def isAxesDisplayed(self) -> bool:
+ """Returns whether or not axes are currently displayed
+
+ :rtype: bool
+ """
+ return self.__axesDisplayed
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ """Set ratios of margins surrounding data plot area.
+
+ All ratios must be within [0., 1.].
+ Sums of ratios of opposed side must be < 1.
+
+ :param float left: Left-side margin ratio.
+ :param float top: Top margin ratio
+ :param float right: Right-side margin ratio
+ :param float bottom: Bottom margin ratio
+ :raises ValueError:
+ """
+ for value in (left, top, right, bottom):
+ if value < 0.0 or value > 1.0:
+ raise ValueError("Margin ratios must be within [0., 1.]")
+ if left + right >= 1.0 or top + bottom >= 1.0:
+ raise ValueError("Sum of ratios of opposed sides >= 1")
+ margins = left, top, right, bottom
+
+ if margins != self.__axesMargins:
+ self.__axesMargins = margins
+ if self.isAxesDisplayed(): # Only apply if axes are displayed
+ self._backend.setAxesMargins(*margins)
+ self._setDirtyPlot()
+
+ def getAxesMargins(self):
+ """Returns ratio of margins surrounding data plot area.
+
+ :return: (left, top, right, bottom)
+ :rtype: List[float]
+ """
+ return self.__axesMargins
+
+ def setYAxisInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ self._yAxis.setInverted(flag)
+
+ def isYAxisInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise."""
+ return self._yAxis.isInverted()
+
+ def isXAxisLogarithmic(self):
+ """Return True if X axis scale is logarithmic, False if linear."""
+ return self._xAxis._isLogarithmic()
+
+ def setXAxisLogarithmic(self, flag):
+ """Set the bottom X axis scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ self._xAxis._setLogarithmic(flag)
+
+ def isYAxisLogarithmic(self):
+ """Return True if Y axis scale is logarithmic, False if linear."""
+ return self._yAxis._isLogarithmic()
+
+ def setYAxisLogarithmic(self, flag):
+ """Set the Y axes scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ self._yAxis._setLogarithmic(flag)
+
+ def isXAxisAutoScale(self):
+ """Return True if X axis is automatically adjusting its limits."""
+ return self._xAxis.isAutoScale()
+
+ def setXAxisAutoScale(self, flag=True):
+ """Set the X axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._xAxis.setAutoScale(flag)
+
+ def isYAxisAutoScale(self):
+ """Return True if Y axes are automatically adjusting its limits."""
+ return self._yAxis.isAutoScale()
+
+ def setYAxisAutoScale(self, flag=True):
+ """Set the Y axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._yAxis.setAutoScale(flag)
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self._backend.isKeepDataAspectRatio()
+
+ def setKeepDataAspectRatio(self, flag=True):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ flag = bool(flag)
+ if flag == self.isKeepDataAspectRatio():
+ return
+ self._backend.setKeepDataAspectRatio(flag=flag)
+ self._setDirtyPlot()
+ self._forceResetZoom()
+ self.notify("setKeepDataAspectRatio", state=flag)
+
+ def getGraphGrid(self):
+ """Return the current grid mode, either None, 'major' or 'both'.
+
+ See :meth:`setGraphGrid`.
+ """
+ return self._grid
+
+ def setGraphGrid(self, which=True):
+ """Set the type of grid to display.
+
+ :param which: None or False to disable the grid,
+ 'major' or True for grid on major ticks (the default),
+ 'both' for grid on both major and minor ticks.
+ :type which: str of bool
+ """
+ assert which in (None, True, False, "both", "major")
+ if not which:
+ which = None
+ elif which is True:
+ which = "major"
+ self._grid = which
+ self._backend.setGraphGrid(which)
+ self._setDirtyPlot()
+ self.notify("setGraphGrid", which=str(which))
+
+ # Defaults
+
+ def isDefaultPlotPoints(self):
+ """Return True if the default Curve symbol is set and False if not."""
+ return self._defaultPlotPoints == silx.config.DEFAULT_PLOT_SYMBOL
+
+ def setDefaultPlotPoints(self, flag):
+ """Set the default symbol of all curves.
+
+ When called, this reset the symbol of all existing curves.
+
+ :param bool flag: True to use 'o' as the default curve symbol,
+ False to use no symbol.
+ """
+ self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else ""
+
+ # Reset symbol of all curves
+ curves = self.getAllCurves(just_legend=False, withhidden=True)
+
+ if curves:
+ for curve in curves:
+ curve.setSymbol(self._defaultPlotPoints)
+
+ def isDefaultPlotLines(self):
+ """Return True for line as default line style, False for no line."""
+ return self._plotLines
+
+ def setDefaultPlotLines(self, flag):
+ """Toggle the use of lines as the default curve line style.
+
+ :param bool flag: True to use a line as the default line style,
+ False to use no line as the default line style.
+ """
+ self._plotLines = bool(flag)
+
+ linestyle = "-" if self._plotLines else " "
+
+ # Reset linestyle of all curves
+ curves = self.getAllCurves(withhidden=True)
+
+ if curves:
+ for curve in curves:
+ curve.setLineStyle(linestyle)
+
+ def getDefaultColormap(self):
+ """Return the default colormap used by :meth:`addImage`.
+
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self._defaultColormap
+
+ def setDefaultColormap(self, colormap=None):
+ """Set the default colormap used by :meth:`addImage`.
+
+ Setting the default colormap do not change any currently displayed
+ image.
+ It only affects future calls to :meth:`addImage` without the colormap
+ parameter.
+
+ :param ~silx.gui.colors.Colormap colormap:
+ The description of the default colormap, or
+ None to set the colormap to a linear
+ autoscale gray colormap.
+ """
+ if colormap is None:
+ colormap = Colormap(
+ name=silx.config.DEFAULT_COLORMAP_NAME,
+ normalization="linear",
+ vmin=None,
+ vmax=None,
+ )
+ if isinstance(colormap, dict):
+ self._defaultColormap = Colormap._fromDict(colormap)
+ else:
+ assert isinstance(colormap, Colormap)
+ self._defaultColormap = colormap
+ self.notify("defaultColormapChanged")
+
+ @staticmethod
+ def getSupportedColormaps():
+ """Get the supported colormap names as a tuple of str.
+
+ The list contains at least:
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue',
+ 'magma', 'inferno', 'plasma', 'viridis')
+ """
+ return Colormap.getSupportedColormaps()
+
+ def setDefaultColors(self, colors: Optional[Tuple[str, ...]]):
+ """Set the list of colors to use as default for curves and histograms.
+
+ Set to None to use `silx.config.DEFAULT_PLOT_CURVE_COLORS`.
+ """
+ self._defaultColors = None if colors is None else tuple(colors)
+ self._resetColorAndStyle()
+
+ def getDefaultColors(self) -> Tuple[str, ...]:
+ """Returns the list of default colors for curves and histograms"""
+ if self._defaultColors is None:
+ return tuple(silx.config.DEFAULT_PLOT_CURVE_COLORS)
+ return self._defaultColors
+
+ def _resetColorAndStyle(self):
+ self._colorIndex = 0
+ self._styleIndex = 0
+
+ def _getColorAndStyle(self) -> Tuple[str, str]:
+ defaultColors = self.getDefaultColors()
+ if self._colorIndex >= len(defaultColors): # Handle list length updated
+ self._colorIndex = 0
+
+ color = defaultColors[self._colorIndex]
+ style = self._styleList[self._styleIndex]
+
+ # Loop over color and then styles
+ self._colorIndex += 1
+ if self._colorIndex >= len(defaultColors):
+ self._colorIndex = 0
+ self._styleIndex = (self._styleIndex + 1) % len(self._styleList)
+
+ # If color is the one of active curve, take the next one
+ if colors.rgba(color) == self.getActiveCurveStyle().getColor():
+ color, style = self._getColorAndStyle()
+
+ if not self._plotLines:
+ style = " "
+
+ return color, style
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ """Return the widget the plot is displayed in.
+
+ This widget is owned by the backend.
+ """
+ return self._backend.getWidgetHandle()
+
+ def notify(self, event, **kwargs):
+ """Send an event to the listeners and send signals.
+
+ Event are passed to the registered callback as a dict with an 'event'
+ key for backward compatibility with PyMca.
+
+ :param str event: The type of event
+ :param kwargs: The information of the event.
+ """
+ eventDict = kwargs.copy()
+ eventDict["event"] = event
+ self.sigPlotSignal.emit(eventDict)
+
+ if event == "setKeepDataAspectRatio":
+ self.sigSetKeepDataAspectRatio.emit(kwargs["state"])
+ elif event == "setGraphGrid":
+ self.sigSetGraphGrid.emit(kwargs["which"])
+ elif event == "setGraphCursor":
+ self.sigSetGraphCursor.emit(kwargs["state"])
+ elif event == "contentChanged":
+ self.sigContentChanged.emit(
+ kwargs["action"], kwargs["kind"], kwargs["legend"]
+ )
+ elif event == "activeCurveChanged":
+ self.sigActiveCurveChanged.emit(kwargs["previous"], kwargs["legend"])
+ elif event == "activeImageChanged":
+ self.sigActiveImageChanged.emit(kwargs["previous"], kwargs["legend"])
+ elif event == "activeScatterChanged":
+ self.sigActiveScatterChanged.emit(kwargs["previous"], kwargs["legend"])
+ elif event == "interactiveModeChanged":
+ self.sigInteractiveModeChanged.emit(kwargs["source"])
+
+ eventDict = kwargs.copy()
+ eventDict["event"] = event
+ self._callback(eventDict)
+
+ def setCallback(self, callbackFunction=None):
+ """Attach a listener to the backend.
+
+ Limitation: Only one listener at a time.
+
+ :param callbackFunction: function accepting a dictionary as input
+ to handle the graph events
+ If None (default), use a default listener.
+ """
+ # TODO allow multiple listeners
+ # allow register listener by event type
+ if callbackFunction is None:
+ callbackFunction = WeakMethodProxy(self.graphCallback)
+ self._callback = callbackFunction
+
+ def graphCallback(self, ddict=None):
+ """This callback is going to receive all the events from the plot.
+
+ Those events will consist on a dictionary and among the dictionary
+ keys the key 'event' is mandatory to describe the type of event.
+ This default implementation only handles setting the active curve.
+ """
+
+ if ddict is None:
+ ddict = {}
+ _logger.debug("Received dict keys = %s", str(ddict.keys()))
+ _logger.debug(str(ddict))
+ if ddict["event"] == "curveClicked":
+ if ddict["button"] == "left":
+ self.setActiveCurve(ddict["item"])
+ qt.QToolTip.showText(self.cursor().pos(), ddict["label"])
+ elif ddict["event"] == "mouseClicked" and ddict["button"] == "left":
+ self.setActiveCurve(None)
+
+ def saveGraph(self, filename, fileFormat=None, dpi=None):
+ """Save a snapshot of the plot.
+
+ Supported file formats depends on the backend in use.
+ The following file formats are always supported: "png", "svg".
+ The matplotlib backend supports more formats:
+ "pdf", "ps", "eps", "tiff", "jpeg", "jpg".
+
+ :param filename: Destination
+ :type filename: str, StringIO or BytesIO
+ :param str fileFormat: String specifying the format
+ :return: False if cannot save the plot, True otherwise
+ """
+ if fileFormat is None:
+ if not hasattr(filename, "lower"):
+ _logger.warning("saveGraph cancelled, cannot define file format.")
+ return False
+ else:
+ fileFormat = (filename.split(".")[-1]).lower()
+
+ supportedFormats = (
+ "png",
+ "svg",
+ "pdf",
+ "ps",
+ "eps",
+ "tif",
+ "tiff",
+ "jpeg",
+ "jpg",
+ )
+
+ if fileFormat not in supportedFormats:
+ _logger.warning("Unsupported format %s", fileFormat)
+ return False
+ else:
+ self._backend.saveGraph(filename, fileFormat=fileFormat, dpi=dpi)
+ return True
+
+ def getDataMargins(self) -> tuple[float, float, float, float]:
+ """Get the default data margin ratios, see :meth:`setDataMargins`.
+
+ :return: The margin ratios for each side (xMin, xMax, yMin, yMax).
+ """
+ return self._defaultDataMargins
+
+ def setDataMargins(
+ self,
+ xMinMargin: float = 0.0,
+ xMaxMargin: float = 0.0,
+ yMinMargin: float = 0.0,
+ yMaxMargin: float = 0.0,
+ ):
+ """Set the default data margins to use in :meth:`resetZoom`.
+
+ Set the default ratios of margins to add around the data
+ inside the plot area for each side.
+ """
+ self._defaultDataMargins = (xMinMargin, xMaxMargin, yMinMargin, yMaxMargin)
+
+ def getAutoReplot(self):
+ """Return True if replot is automatically handled, False otherwise.
+
+ See :meth`setAutoReplot`.
+ """
+ return self._autoreplot
+
+ def setAutoReplot(self, autoreplot=True):
+ """Set automatic replot mode.
+
+ When enabled, the plot is redrawn automatically when changed.
+ When disabled, the plot is not redrawn when its content change.
+ Instead, it :meth:`replot` must be called.
+
+ :param bool autoreplot: True to enable it (default),
+ False to disable it.
+ """
+ self._autoreplot = bool(autoreplot)
+
+ # If the plot is dirty before enabling autoreplot,
+ # then _backend.postRedisplay will never be called from _setDirtyPlot
+ if self._autoreplot and self._getDirtyPlot():
+ self._backend.postRedisplay()
+
+ @contextmanager
+ def _paintContext(self):
+ """This context MUST surround backend rendering.
+
+ It is in charge of performing required PlotWidget operations
+ """
+ for item in self.__itemsToUpdate:
+ item._update(self._backend)
+
+ self.__itemsToUpdate = []
+ yield
+ self._dirty = False # reset dirty flag
+
+ def replot(self):
+ """Request to draw the plot."""
+ self._backend.replot()
+
+ def _forceResetZoom(
+ self,
+ dataMargins: Optional[tuple[float, float, float, float]] = None,
+ ):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ This method forces a reset zoom and does not check axis autoscale.
+
+ Extra margins can be added around the data inside the plot area
+ (see :meth:`setDataMargins`).
+ Margins are given as one ratio of the data range per limit of the
+ data (xMin, xMax, yMin and yMax limits).
+ For log scale, extra margins are applied in log10 of the data.
+
+ :param dataMargins:
+ Ratios of margins to add around the data inside the plot area for each side.
+ If None (the default), use margins from :meth:`getDataMargins`.
+ """
+ # Get data range
+ ranges = self.getDataRange()
+ xmin, xmax = (1.0, 100.0) if ranges.x is None else ranges.x
+ ymin, ymax = (1.0, 100.0) if ranges.y is None else ranges.y
+ if ranges.yright is None:
+ y2min, y2max = ymin, ymax
+ else:
+ y2min, y2max = ranges.yright
+ if ranges.y is None:
+ ymin, ymax = ranges.yright
+
+ self.setLimits(
+ xmin,
+ xmax,
+ ymin,
+ ymax,
+ y2min,
+ y2max,
+ margins=dataMargins if dataMargins is not None else True,
+ )
+
+ def resetZoom(self, dataMargins=None):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ It automatically scale limits of axes that are in autoscale mode
+ (see :meth:`getXAxis`, :meth:`getYAxis` and :meth:`Axis.setAutoScale`).
+ It keeps current limits on axes that are not in autoscale mode.
+
+ Extra margins can be added around the data inside the plot area
+ (see :meth:`setDataMargins`).
+ Margins are given as one ratio of the data range per limit of the
+ data (xMin, xMax, yMin and yMax limits).
+ For log scale, extra margins are applied in log10 of the data.
+
+ :param dataMargins: Ratios of margins to add around the data inside
+ the plot area for each side (default: no margins).
+ :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax).
+ """
+ xLimits = self._xAxis.getLimits()
+ yLimits = self._yAxis.getLimits()
+ y2Limits = self._yRightAxis.getLimits()
+
+ xAuto = self._xAxis.isAutoScale()
+ yAuto = self._yAxis.isAutoScale()
+
+ # With log axes, autoscale if limits are <= 0
+ # This avoids issues with toggling log scale with matplotlib 2.1.0
+ if self._xAxis.getScale() == self._xAxis.LOGARITHMIC and xLimits[0] <= 0:
+ xAuto = True
+ if self._yAxis.getScale() == self._yAxis.LOGARITHMIC and (
+ yLimits[0] <= 0 or y2Limits[0] <= 0
+ ):
+ yAuto = True
+
+ if not xAuto and not yAuto:
+ _logger.debug("Nothing to autoscale")
+ else: # Some axes to autoscale
+ self._forceResetZoom(dataMargins=dataMargins)
+
+ # Restore limits for axis not in autoscale
+ if not xAuto and yAuto:
+ self.setGraphXLimits(*xLimits)
+ elif xAuto and not yAuto:
+ if y2Limits is not None:
+ self.setGraphYLimits(y2Limits[0], y2Limits[1], axis="right")
+ if yLimits is not None:
+ self.setGraphYLimits(yLimits[0], yLimits[1], axis="left")
+
+ if (
+ xLimits != self._xAxis.getLimits()
+ or yLimits != self._yAxis.getLimits()
+ or y2Limits != self._yRightAxis.getLimits()
+ ):
+ self._notifyLimitsChanged()
+
+ # Coord conversion
+
+ def dataToPixel(self, x=None, y=None, axis="left", check=True):
+ """Convert a position in data coordinates to a position in pixels.
+
+ :param x: The X coordinate in data space. If None (default)
+ the middle position of the displayed data is used.
+ :type x: float or 1D numpy array of float
+ :param y: The Y coordinate in data space. If None (default)
+ the middle position of the displayed data is used.
+ :type y: float or 1D numpy array of float
+
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :param bool check: True to return None if outside displayed area,
+ False to convert to pixels anyway
+ :returns: The corresponding position in pixels or
+ None if the data position is not in the displayed area and
+ check is True.
+ :rtype: A tuple of 2 floats or 2 arrays of float: (xPixel, yPixel) or None.
+ """
+ assert axis in ("left", "right")
+
+ xmin, xmax = self._xAxis.getLimits()
+ yAxis = self.getYAxis(axis=axis)
+ ymin, ymax = yAxis.getLimits()
+
+ if x is None:
+ x = 0.5 * (xmax + xmin)
+ if y is None:
+ y = 0.5 * (ymax + ymin)
+
+ if isinstance(x, numbers.Real) != isinstance(y, numbers.Real):
+ raise ValueError("x and y must be of the same type")
+ if not isinstance(x, numbers.Real) and (x.shape != y.shape or x.ndim != 1):
+ raise ValueError("x and y must be 1D arrays of the same length")
+
+ if check:
+ isOutside = numpy.logical_or(
+ numpy.logical_or(x > xmax, x < xmin),
+ numpy.logical_or(y > ymax, y < ymin),
+ )
+
+ if numpy.any(isOutside):
+ if isinstance(x, numbers.Real):
+ return None
+ else: # Filter-out points that are outside
+ x = numpy.array(x, copy=True, dtype=numpy.float64)
+ x[isOutside] = numpy.nan
+
+ y = numpy.array(y, copy=True, dtype=numpy.float64)
+ y[isOutside] = numpy.nan
+
+ return self._backend.dataToPixel(x, y, axis=axis)
+
+ def pixelToData(self, x, y, axis="left", check=False):
+ """Convert a position in pixels to a position in data coordinates.
+
+ :param float x: The X coordinate in pixels. If None (default)
+ the center of the widget is used.
+ :param float y: The Y coordinate in pixels. If None (default)
+ the center of the widget is used.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :param bool check: Toggle checking if pixel is in plot area.
+ If False, this method never returns None.
+ :returns: The corresponding position in data space or
+ None if the pixel position is not in the plot area.
+ :rtype: A tuple of 2 floats: (xData, yData) or None.
+ """
+ assert axis in ("left", "right")
+
+ if x is None:
+ x = self.width() // 2
+ if y is None:
+ y = self.height() // 2
+
+ if check:
+ left, top, width, height = self.getPlotBoundsInPixels()
+ isOutside = numpy.logical_or(
+ numpy.logical_or(x < left, x > left + width),
+ numpy.logical_or(y < top, y > top + height),
+ )
+ if numpy.any(isOutside):
+ return None
+
+ return self._backend.pixelToData(x, y, axis)
+
+ def getPlotBoundsInPixels(self):
+ """Plot area bounds in widget coordinates in pixels.
+
+ :return: bounds as a 4-tuple of int: (left, top, width, height)
+ """
+ return self._backend.getPlotBoundsInPixels()
+
+ # Interaction support
+
+ def getGraphCursorShape(self):
+ """Returns the current cursor shape.
+
+ :rtype: str
+ """
+ return self.__graphCursorShape
+
+ def setGraphCursorShape(self, cursor=None):
+ """Set the cursor shape.
+
+ :param str cursor: Name of the cursor shape
+ """
+ self.__graphCursorShape = cursor
+ self._backend.setGraphCursorShape(cursor)
+
+ def _getMarkerAt(self, x, y):
+ """Return the most interactive marker at a location, else None
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :rtype: None of marker object
+ """
+
+ def checkDraggable(item):
+ return isinstance(item, items.MarkerBase) and item.isDraggable()
+
+ def checkSelectable(item):
+ return isinstance(item, items.MarkerBase) and item.isSelectable()
+
+ def check(item):
+ return isinstance(item, items.MarkerBase)
+
+ result = self._pickTopMost(x, y, checkDraggable)
+ if not result:
+ result = self._pickTopMost(x, y, checkSelectable)
+ if not result:
+ result = self._pickTopMost(x, y, check)
+ marker = result.getItem() if result is not None else None
+ return marker
+
+ def _getMarker(self, legend=None):
+ """Get the object describing a specific marker.
+
+ It returns None in case no matching marker is found
+
+ :param str legend: The legend of the marker to retrieve
+ :rtype: None of marker object
+ """
+ return self._getItem(kind="marker", legend=legend)
+
+ def pickItems(self, x, y, condition=None):
+ """Generator of picked items in the plot at given position.
+
+ Items are returned from front to back.
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :param callable condition:
+ Callable taking an item as input and returning False for items
+ to skip during picking. If None (default) no item is skipped.
+ :return: Iterable of :class:`PickingResult` objects at picked position.
+ Items are ordered from front to back.
+ """
+ for item in reversed(
+ self._backend.getItemsFromBackToFront(condition=condition)
+ ):
+ result = item.pick(x, y)
+ if result is not None:
+ yield result
+
+ def _pickTopMost(self, x, y, condition=None):
+ """Returns top-most picked item in the plot at given position.
+
+ Items are checked from front to back.
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :param callable condition:
+ Callable taking an item as input and returning False for items
+ to skip during picking. If None (default) no item is skipped.
+ :return: :class:`PickingResult` object at picked position.
+ If no item is picked, it returns None
+ :rtype: Union[None,PickingResult]
+ """
+ for result in self.pickItems(x, y, condition):
+ return result
+ return None
+
+ # User event handling #
+
+ def _isPositionInPlotArea(self, x, y):
+ """Project position in pixel to the closest point in the plot area
+
+ :param float x: X coordinate in widget coordinate (in pixel)
+ :param float y: Y coordinate in widget coordinate (in pixel)
+ :return: (x, y) in widget coord (in pixel) in the plot area
+ """
+ left, top, width, height = self.getPlotBoundsInPixels()
+ xPlot = numpy.clip(x, left, left + width)
+ yPlot = numpy.clip(y, top, top + height)
+ return xPlot, yPlot
+
+ def onMousePress(self, xPixel, yPixel, btn):
+ """Handle mouse press event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param str btn: Mouse button in 'left', 'middle', 'right'
+ """
+ if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
+ self._pressedButtons.append(btn)
+ self._eventHandler.handleEvent("press", xPixel, yPixel, btn)
+
+ def onMouseMove(self, xPixel, yPixel):
+ """Handle mouse move event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ """
+ inXPixel, inYPixel = self._isPositionInPlotArea(xPixel, yPixel)
+ isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel
+
+ if self._cursorInPlot != isCursorInPlot:
+ self._cursorInPlot = isCursorInPlot
+ self._eventHandler.handleEvent("enter" if self._cursorInPlot else "leave")
+
+ if isCursorInPlot:
+ # Signal mouse move event
+ dataPos = self.pixelToData(inXPixel, inYPixel)
+ assert dataPos is not None
+
+ btn = self._pressedButtons[-1] if self._pressedButtons else None
+ event = PlotEvents.prepareMouseSignal(
+ "mouseMoved", btn, dataPos[0], dataPos[1], xPixel, yPixel
+ )
+ self.notify(**event)
+
+ # Either button was pressed in the plot or cursor is in the plot
+ if isCursorInPlot or self._pressedButtons:
+ self._eventHandler.handleEvent("move", inXPixel, inYPixel)
+
+ def onMouseRelease(self, xPixel, yPixel, btn):
+ """Handle mouse release event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param str btn: Mouse button in 'left', 'middle', 'right'
+ """
+ try:
+ self._pressedButtons.remove(btn)
+ except ValueError:
+ pass
+ else:
+ xPixel, yPixel = self._isPositionInPlotArea(xPixel, yPixel)
+ self._eventHandler.handleEvent("release", xPixel, yPixel, btn)
+
+ def onMouseWheel(self, xPixel, yPixel, angleInDegrees):
+ """Handle mouse wheel event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param float angleInDegrees: Angle corresponding to wheel motion.
+ Positive for movement away from the user,
+ negative for movement toward the user.
+ """
+ if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
+ self._eventHandler.handleEvent("wheel", xPixel, yPixel, angleInDegrees)
+
+ def onMouseLeaveWidget(self):
+ """Handle mouse leave widget event."""
+ if self._cursorInPlot:
+ self._cursorInPlot = False
+ self._eventHandler.handleEvent("leave")
+
+ # Interaction modes #
+
+ def interaction(self) -> PlotInteraction:
+ """Returns the interaction handler for this PlotWidget"""
+ return self._eventHandler
+
+ def __interactionChanged(self):
+ """Handle PlotInteraction updates"""
+ if self.__isInteractionSignalForwarded:
+ self.sigInteractiveModeChanged.emit(None)
+
+ def getInteractiveMode(self):
+ """Returns the current interactive mode as a dict.
+
+ The returned dict contains at least the key 'mode'.
+ Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ It can also contains extra keys (e.g., 'color') specific to a mode
+ as provided to :meth:`setInteractiveMode`.
+ """
+ return self.interaction()._getInteractiveMode()
+
+ def resetInteractiveMode(self):
+ """Reset the interactive mode to use the previous basic interactive
+ mode used.
+
+ It can be one of "zoom" or "pan".
+ """
+ mode, zoomOnWheel = self._previousDefaultMode
+ self.setInteractiveMode(mode=mode, zoomOnWheel=zoomOnWheel)
+
+ def setInteractiveMode(
+ self,
+ mode: str,
+ color: Union[str, Sequence[numbers.Real]] = "black",
+ shape: str = "polygon",
+ label: Optional[str] = None,
+ zoomOnWheel: bool = True,
+ source=None,
+ width: Optional[float] = None,
+ ):
+ """Switch the interactive mode.
+
+ :param mode: The name of the interactive mode.
+ In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ :param color: Only for 'draw' and 'zoom' modes.
+ Color to use for drawing selection area. Default black.
+ :type color: Color description: The name as a str or
+ a tuple of 4 floats.
+ :param shape: Only for 'draw' mode. The kind of shape to draw.
+ In 'polygon', 'rectangle', 'line', 'vline', 'hline',
+ 'freeline'.
+ Default is 'polygon'.
+ :param label: Only for 'draw' mode, sent in drawing events.
+ :param zoomOnWheel: Toggle zoom on wheel support
+ :param source: A user-defined object (typically the caller object)
+ that will be send in the interactiveModeChanged event,
+ to identify which object required a mode change.
+ Default: None
+ :param width: Width of the pencil. Only for draw pencil mode.
+ """
+ self.__isInteractionSignalForwarded = False
+ try:
+ self._eventHandler._setInteractiveMode(mode, color, shape, label, width)
+ self._eventHandler.setZoomOnWheelEnabled(zoomOnWheel)
+ finally:
+ self.__isInteractionSignalForwarded = True
+
+ if mode in ["pan", "zoom"]:
+ self._previousDefaultMode = mode, zoomOnWheel
+
+ self.notify("interactiveModeChanged", source=source)
+
+ # Panning with arrow keys
+
+ def isPanWithArrowKeys(self):
+ """Returns whether or not panning the graph with arrow keys is enabled.
+
+ See :meth:`setPanWithArrowKeys`.
+ """
+ return self._panWithArrowKeys
+
+ def setPanWithArrowKeys(self, pan=False):
+ """Enable/Disable panning the graph with arrow keys.
+
+ This grabs the keyboard.
+
+ :param bool pan: True to enable panning, False to disable.
+ """
+ pan = bool(pan)
+ panHasChanged = self._panWithArrowKeys != pan
+
+ self._panWithArrowKeys = pan
+ if not self._panWithArrowKeys:
+ self.setFocusPolicy(qt.Qt.NoFocus)
+ else:
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self.setFocus(qt.Qt.OtherFocusReason)
+
+ if panHasChanged:
+ self.sigSetPanWithArrowKeys.emit(pan)
+
+ # Dict to convert Qt arrow key code to direction str.
+ _ARROWS_TO_PAN_DIRECTION = {
+ qt.Qt.Key_Left: "left",
+ qt.Qt.Key_Right: "right",
+ qt.Qt.Key_Up: "up",
+ qt.Qt.Key_Down: "down",
+ }
+
+ def __simulateMouseMove(self):
+ qapp = qt.QApplication.instance()
+ event = qt.QMouseEvent(
+ qt.QEvent.MouseMove,
+ qt.QPointF(self.getWidgetHandle().mapFromGlobal(qt.QCursor.pos())),
+ qt.Qt.NoButton,
+ qapp.mouseButtons(),
+ qapp.keyboardModifiers(),
+ )
+ qapp.sendEvent(self.getWidgetHandle(), event)
+
+ def keyPressEvent(self, event):
+ """Key event handler handling panning on arrow keys.
+
+ Overrides base class implementation.
+ """
+ key = event.key()
+ if self._panWithArrowKeys and key in self._ARROWS_TO_PAN_DIRECTION:
+ self.pan(self._ARROWS_TO_PAN_DIRECTION[key], factor=0.1)
+
+ # Send a mouse move event to the plot widget to take into account
+ # that even if mouse didn't move on the screen, it moved relative
+ # to the plotted data.
+ self.__simulateMouseMove()
+ else:
+ # Only call base class implementation when key is not handled.
+ # See QWidget.keyPressEvent for details.
+ super(PlotWidget, self).keyPressEvent(event)
diff --git a/src/silx/gui/plot/PlotWindow.py b/src/silx/gui/plot/PlotWindow.py
new file mode 100644
index 0000000..9aa8c78
--- /dev/null
+++ b/src/silx/gui/plot/PlotWindow.py
@@ -0,0 +1,1052 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A :class:`.PlotWidget` with additional toolbars.
+
+The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "12/04/2019"
+
+from collections import abc
+import logging
+import weakref
+
+import silx
+from silx.utils.weakref import WeakMethodProxy
+from silx.utils.proxy import docstring
+
+from . import PlotWidget
+from . import actions
+from . import items
+from .actions import medfilt as actions_medfilt
+from .actions import fit as actions_fit
+from .actions import control as actions_control
+from .actions import histogram as actions_histogram
+from . import PlotToolButtons
+from . import tools
+from .Profile import ProfileToolBar
+from .LegendSelector import LegendsDockWidget
+from .CurvesROIWidget import CurvesROIDockWidget
+from .MaskToolsWidget import MaskToolsDockWidget
+from .StatsWidget import BasicStatsWidget
+from .ColorBar import ColorBarWidget
+
+try:
+ from ..console import IPythonDockWidget
+except ImportError:
+ IPythonDockWidget = None
+
+from .. import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotWindow(PlotWidget):
+ """Qt Widget providing a 1D/2D plot area and additional tools.
+
+ This widgets inherits from :class:`.PlotWidget` and provides its plot API.
+
+ Initialiser parameters:
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ :param bool resetzoom: Toggle visibility of reset zoom action.
+ :param bool autoScale: Toggle visibility of axes autoscale actions.
+ :param bool logScale: Toggle visibility of axes log scale actions.
+ :param bool grid: Toggle visibility of grid mode action.
+ :param bool curveStyle: Toggle visibility of curve style action.
+ :param bool colormap: Toggle visibility of colormap action.
+ :param bool aspectRatio: Toggle visibility of aspect ratio button.
+ :param bool yInverted: Toggle visibility of Y axis direction button.
+ :param bool copy: Toggle visibility of copy action.
+ :param bool save: Toggle visibility of save action.
+ :param bool print_: Toggle visibility of print action.
+ :param bool control: True to display an Options button with a sub-menu
+ to show legends, toggle crosshair and pan with arrows.
+ (Default: False)
+ :param position: True to display widget with (x, y) mouse position
+ (Default: False).
+ It also supports a list of (name, funct(x, y)->value)
+ to customize the displayed values.
+ See :class:`~silx.gui.plot.tools.PositionInfo`.
+ :param bool roi: Toggle visibilty of ROI action.
+ :param bool mask: Toggle visibilty of mask action.
+ :param bool fit: Toggle visibilty of fit action.
+ """
+
+ def __init__(
+ self,
+ parent=None,
+ backend=None,
+ resetzoom=True,
+ autoScale=True,
+ logScale=True,
+ grid=True,
+ curveStyle=True,
+ colormap=True,
+ aspectRatio=True,
+ yInverted=True,
+ copy=True,
+ save=True,
+ print_=True,
+ control=False,
+ position=False,
+ roi=True,
+ mask=True,
+ fit=False,
+ ):
+ super(PlotWindow, self).__init__(parent=parent, backend=backend)
+ if parent is None:
+ self.setWindowTitle("PlotWindow")
+
+ self._dockWidgets = []
+
+ # lazy loaded dock widgets
+ self._legendsDockWidget = None
+ self._curvesROIDockWidget = None
+ self._maskToolsDockWidget = None
+ self._consoleDockWidget = None
+ self._statsDockWidget = None
+
+ # Create color bar, hidden by default for backward compatibility
+ self._colorbar = ColorBarWidget(parent=self, plot=self)
+
+ # Init actions
+ self.group = qt.QActionGroup(self)
+ self.group.setExclusive(False)
+
+ self.resetZoomAction = self.group.addAction(
+ actions.control.ResetZoomAction(self, parent=self)
+ )
+ self.resetZoomAction.setVisible(resetzoom)
+ self.addAction(self.resetZoomAction)
+
+ self.zoomInAction = self.group.addAction(
+ actions.control.ZoomInAction(self, parent=self)
+ )
+ self.zoomInAction.setVisible(False)
+ self.addAction(self.zoomInAction)
+
+ self.zoomOutAction = self.group.addAction(
+ actions.control.ZoomOutAction(self, parent=self)
+ )
+ self.zoomOutAction.setVisible(False)
+ self.addAction(self.zoomOutAction)
+
+ self.xAxisAutoScaleAction = self.group.addAction(
+ actions.control.XAxisAutoScaleAction(self, parent=self)
+ )
+ self.xAxisAutoScaleAction.setVisible(autoScale)
+ self.addAction(self.xAxisAutoScaleAction)
+
+ self.yAxisAutoScaleAction = self.group.addAction(
+ actions.control.YAxisAutoScaleAction(self, parent=self)
+ )
+ self.yAxisAutoScaleAction.setVisible(autoScale)
+ self.addAction(self.yAxisAutoScaleAction)
+
+ self.xAxisLogarithmicAction = self.group.addAction(
+ actions.control.XAxisLogarithmicAction(self, parent=self)
+ )
+ self.xAxisLogarithmicAction.setVisible(logScale)
+ self.addAction(self.xAxisLogarithmicAction)
+
+ self.yAxisLogarithmicAction = self.group.addAction(
+ actions.control.YAxisLogarithmicAction(self, parent=self)
+ )
+ self.yAxisLogarithmicAction.setVisible(logScale)
+ self.addAction(self.yAxisLogarithmicAction)
+
+ self.gridAction = self.group.addAction(
+ actions.control.GridAction(self, gridMode="both", parent=self)
+ )
+ self.gridAction.setVisible(grid)
+ self.addAction(self.gridAction)
+
+ self.curveStyleAction = self.group.addAction(
+ actions.control.CurveStyleAction(self, parent=self)
+ )
+ self.curveStyleAction.setVisible(curveStyle)
+ self.addAction(self.curveStyleAction)
+
+ self.colormapAction = self.group.addAction(
+ actions.control.ColormapAction(self, parent=self)
+ )
+ self.colormapAction.setVisible(colormap)
+ self.addAction(self.colormapAction)
+
+ self.colorbarAction = self.group.addAction(
+ actions_control.ColorBarAction(self, parent=self)
+ )
+ self.colorbarAction.setVisible(False)
+ self.addAction(self.colorbarAction)
+ self._colorbar.setVisible(False)
+
+ self.keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
+ parent=self, plot=self
+ )
+ self.keepDataAspectRatioButton.setVisible(aspectRatio)
+
+ self.yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton(
+ parent=self, plot=self
+ )
+ self.yAxisInvertedButton.setVisible(yInverted)
+
+ self.group.addAction(self.getRoiAction())
+ self.getRoiAction().setVisible(roi)
+
+ self.group.addAction(self.getMaskAction())
+ self.getMaskAction().setVisible(mask)
+
+ self._intensityHistoAction = self.group.addAction(
+ actions_histogram.PixelIntensitiesHistoAction(self, parent=self)
+ )
+ self._intensityHistoAction.setVisible(False)
+
+ self._medianFilter2DAction = self.group.addAction(
+ actions_medfilt.MedianFilter2DAction(self, parent=self)
+ )
+ self._medianFilter2DAction.setVisible(False)
+
+ self._medianFilter1DAction = self.group.addAction(
+ actions_medfilt.MedianFilter1DAction(self, parent=self)
+ )
+ self._medianFilter1DAction.setVisible(False)
+
+ self.fitAction = self.group.addAction(actions_fit.FitAction(self, parent=self))
+ self.fitAction.setVisible(fit)
+ self.addAction(self.fitAction)
+
+ # lazy loaded actions needed by the controlButton menu
+ self._consoleAction = None
+ self._panWithArrowKeysAction = None
+ self._crosshairAction = None
+
+ # Make colorbar background white
+ self._colorbar.setAutoFillBackground(True)
+ self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground)
+ self._updateColorBarBackground()
+
+ if control: # Create control button only if requested
+ self.controlButton = qt.QToolButton()
+ self.controlButton.setText("Options")
+ self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon)
+ self.controlButton.setAutoRaise(True)
+ self.controlButton.setPopupMode(qt.QToolButton.InstantPopup)
+ menu = qt.QMenu(self)
+ menu.aboutToShow.connect(self._customControlButtonMenu)
+ self.controlButton.setMenu(menu)
+
+ self._positionWidget = None
+ if position: # Add PositionInfo widget to the bottom of the plot
+ if isinstance(position, abc.Iterable):
+ # Use position as a set of converters
+ converters = position
+ else:
+ converters = None
+ self._positionWidget = tools.PositionInfo(plot=self, converters=converters)
+ # Set a snapping mode that is consistent with legacy one
+ self._positionWidget.setSnappingMode(
+ tools.PositionInfo.SNAPPING_CROSSHAIR
+ | tools.PositionInfo.SNAPPING_ACTIVE_ONLY
+ | tools.PositionInfo.SNAPPING_SYMBOLS_ONLY
+ | tools.PositionInfo.SNAPPING_CURVE
+ | tools.PositionInfo.SNAPPING_SCATTER
+ )
+
+ self.__setCentralWidget()
+
+ # Creating the toolbar also create actions for toolbuttons
+ self._interactiveModeToolBar = tools.InteractiveModeToolBar(
+ parent=self, plot=self
+ )
+ self.addToolBar(self._interactiveModeToolBar)
+
+ self._toolbar = self._createToolBar(title="Plot", parent=self)
+ self.addToolBar(self._toolbar)
+
+ self._outputToolBar = tools.OutputToolBar(parent=self, plot=self)
+ self._outputToolBar.getCopyAction().setVisible(copy)
+ self._outputToolBar.getSaveAction().setVisible(save)
+ self._outputToolBar.getPrintAction().setVisible(print_)
+ self.addToolBar(self._outputToolBar)
+
+ # Activate shortcuts in PlotWindow widget:
+ for toolbar in (self._interactiveModeToolBar, self._outputToolBar):
+ for action in toolbar.actions():
+ self.addAction(action)
+
+ def __setCentralWidget(self):
+ """Set central widget to host plot backend, colorbar, and bottom bar"""
+ gridLayout = qt.QGridLayout()
+ gridLayout.setSpacing(0)
+ gridLayout.setContentsMargins(0, 0, 0, 0)
+ gridLayout.addWidget(self.getWidgetHandle(), 0, 0)
+ gridLayout.addWidget(self._colorbar, 0, 1)
+ gridLayout.setRowStretch(0, 1)
+ gridLayout.setColumnStretch(0, 1)
+ centralWidget = qt.QWidget(self)
+ centralWidget.setLayout(gridLayout)
+
+ if hasattr(self, "controlButton") or self._positionWidget is not None:
+ hbox = qt.QHBoxLayout()
+ hbox.setContentsMargins(0, 0, 0, 0)
+
+ if hasattr(self, "controlButton"):
+ hbox.addWidget(self.controlButton)
+
+ if self._positionWidget is not None:
+ hbox.addWidget(self._positionWidget)
+
+ hbox.addStretch(1)
+ bottomBar = qt.QWidget(centralWidget)
+ bottomBar.setLayout(hbox)
+
+ gridLayout.addWidget(bottomBar, 1, 0, 1, -1)
+
+ self.setCentralWidget(centralWidget)
+
+ @docstring(PlotWidget)
+ def setBackend(self, backend):
+ super(PlotWindow, self).setBackend(backend)
+ self.__setCentralWidget() # Recreate PlotWindow's central widget
+
+ @docstring(PlotWidget)
+ def setBackgroundColor(self, color):
+ super(PlotWindow, self).setBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ @docstring(PlotWidget)
+ def setDataBackgroundColor(self, color):
+ super(PlotWindow, self).setDataBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ @docstring(PlotWidget)
+ def setForegroundColor(self, color):
+ super(PlotWindow, self).setForegroundColor(color)
+ self._updateColorBarBackground()
+
+ def _updateColorBarBackground(self):
+ """Update the colorbar background according to the state of the plot"""
+ if self.isAxesDisplayed():
+ color = self.getBackgroundColor()
+ else:
+ color = self.getDataBackgroundColor()
+ if not color.isValid():
+ # If no color defined, use the background one
+ color = self.getBackgroundColor()
+
+ foreground = self.getForegroundColor()
+
+ palette = self._colorbar.palette()
+ palette.setColor(qt.QPalette.Window, color)
+ palette.setColor(qt.QPalette.WindowText, foreground)
+ palette.setColor(qt.QPalette.Text, foreground)
+ self._colorbar.setPalette(palette)
+
+ def getInteractiveModeToolBar(self):
+ """Returns QToolBar controlling interactive mode.
+
+ :rtype: QToolBar
+ """
+ return self._interactiveModeToolBar
+
+ def getOutputToolBar(self):
+ """Returns QToolBar containing save, copy and print actions
+
+ :rtype: QToolBar
+ """
+ return self._outputToolBar
+
+ def getPositionInfoWidget(self):
+ """Returns the widget displaying current cursor position information
+
+ :rtype: ~silx.gui.plot.tools.PositionInfo
+ """
+ return self._positionWidget
+
+ def getSelectionMask(self):
+ """Return the current mask handled by :attr:`maskToolsDockWidget`.
+
+ :return: The array of the mask with dimension of the 'active' image.
+ If there is no active image, an empty array is returned.
+ :rtype: 2D numpy.ndarray of uint8
+ """
+ return self.getMaskToolsDockWidget().getSelectionMask()
+
+ def setSelectionMask(self, mask):
+ """Set the mask handled by :attr:`maskToolsDockWidget`.
+
+ If the provided mask has not the same dimension as the 'active'
+ image, it will by cropped or padded.
+
+ :param mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :return: True if success, False if failed
+ """
+ return bool(self.getMaskToolsDockWidget().setSelectionMask(mask))
+
+ def _toggleConsoleVisibility(self, isChecked=False):
+ """Create IPythonDockWidget if needed,
+ show it or hide it."""
+ # create widget if needed (first call)
+ if self._consoleDockWidget is None:
+ available_vars = {"plt": weakref.proxy(self)}
+ banner = "The variable 'plt' is available. Use the 'whos' "
+ banner += "and 'help(plt)' commands for more information.\n\n"
+ self._consoleDockWidget = IPythonDockWidget(
+ available_vars=available_vars, custom_banner=banner, parent=self
+ )
+ self.addTabbedDockWidget(self._consoleDockWidget)
+ self._consoleDockWidget.toggleViewAction().toggled.connect(
+ self._consoleDockWidgetToggled
+ )
+
+ self._consoleDockWidget.setVisible(isChecked)
+
+ def _consoleVisibilityTriggered(self, isChecked):
+ if isChecked and self.isVisible():
+ self._consoleDockWidget.show()
+ self._consoleDockWidget.raise_()
+
+ def _consoleDockWidgetToggled(self, isChecked):
+ if self.isVisible():
+ self.getConsoleAction().setChecked(isChecked)
+
+ def _createToolBar(self, title, parent):
+ """Create a QToolBar from the QAction of the PlotWindow.
+
+ :param str title: The title of the QMenu
+ :param qt.QWidget parent: See :class:`QToolBar`
+ """
+ toolbar = qt.QToolBar(title, parent)
+
+ # Order widgets with actions
+ objects = self.group.actions()
+
+ # Add push buttons to list
+ index = objects.index(self.colormapAction)
+ objects.insert(index + 1, self.keepDataAspectRatioButton)
+ objects.insert(index + 2, self.yAxisInvertedButton)
+
+ for obj in objects:
+ if isinstance(obj, qt.QAction):
+ toolbar.addAction(obj)
+ else:
+ # Add action for toolbutton in order to allow changing
+ # visibility (see doc QToolBar.addWidget doc)
+ if obj is self.keepDataAspectRatioButton:
+ self.keepDataAspectRatioAction = toolbar.addWidget(obj)
+ elif obj is self.yAxisInvertedButton:
+ self.yAxisInvertedAction = toolbar.addWidget(obj)
+ else:
+ raise RuntimeError("unknow action to be defined")
+ return toolbar
+
+ def toolBar(self):
+ """Return a QToolBar from the QAction of the PlotWindow."""
+ return self._toolbar
+
+ def menu(self, title="Plot", parent=None):
+ """Return a QMenu from the QAction of the PlotWindow.
+
+ :param str title: The title of the QMenu
+ :param parent: See :class:`QMenu`
+ """
+ menu = qt.QMenu(title, parent)
+ for action in self.group.actions():
+ menu.addAction(action)
+ return menu
+
+ def _customControlButtonMenu(self):
+ """Display Options button sub-menu."""
+ controlMenu = self.controlButton.menu()
+ controlMenu.clear()
+ controlMenu.addAction(self.getLegendsDockWidget().toggleViewAction())
+ controlMenu.addAction(self.getRoiAction())
+ controlMenu.addAction(self.getStatsAction())
+ controlMenu.addAction(self.getMaskAction())
+ controlMenu.addAction(self.getConsoleAction())
+
+ controlMenu.addSeparator()
+ controlMenu.addAction(self.getCrosshairAction())
+ controlMenu.addAction(self.getPanWithArrowKeysAction())
+
+ def addTabbedDockWidget(self, dock_widget):
+ """Add a dock widget as a new tab if there are already dock widgets
+ in the plot. When the first tab is added, the area is chosen
+ depending on the plot geometry:
+ if the window is much wider than it is high, the right dock area
+ is used, else the bottom dock area is used.
+
+ :param dock_widget: Instance of :class:`QDockWidget` to be added.
+ """
+ if dock_widget not in self._dockWidgets:
+ self._dockWidgets.append(dock_widget)
+ if len(self._dockWidgets) == 1:
+ # The first created dock widget must be added to a Widget area
+ width = self.centralWidget().width()
+ height = self.centralWidget().height()
+ if width > (1.25 * height):
+ area = qt.Qt.RightDockWidgetArea
+ else:
+ area = qt.Qt.BottomDockWidgetArea
+ self.addDockWidget(area, dock_widget)
+ else:
+ # Other dock widgets are added as tabs to the same widget area
+ self.tabifyDockWidget(self._dockWidgets[0], dock_widget)
+
+ def removeDockWidget(self, dockwidget):
+ """Removes the *dockwidget* from the main window layout and hides it.
+
+ Note that the *dockwidget* is *not* deleted.
+
+ :param QDockWidget dockwidget:
+ """
+ if dockwidget in self._dockWidgets:
+ self._dockWidgets.remove(dockwidget)
+ super(PlotWindow, self).removeDockWidget(dockwidget)
+
+ def _handleFirstDockWidgetShow(self, visible):
+ """Handle QDockWidget.visibilityChanged
+
+ It calls :meth:`addTabbedDockWidget` for the `sender` widget.
+ This allows to call `addTabbedDockWidget` lazily.
+
+ It disconnect itself from the signal once done.
+
+ :param bool visible:
+ """
+ if visible:
+ dockWidget = self.sender()
+ dockWidget.visibilityChanged.disconnect(self._handleFirstDockWidgetShow)
+ self.addTabbedDockWidget(dockWidget)
+
+ def _handleDockWidgetViewActionTriggered(self, checked):
+ if checked:
+ action = self.sender()
+ if action is None:
+ return
+ dockWidget = action.parent()
+ if dockWidget is None:
+ return
+ dockWidget.show() # Show needed here for raise to have an effect
+ dockWidget.raise_()
+
+ def getColorBarWidget(self):
+ """Returns the embedded :class:`ColorBarWidget` widget.
+
+ :rtype: ColorBarWidget
+ """
+ return self._colorbar
+
+ # getters for dock widgets
+
+ def getLegendsDockWidget(self):
+ """DockWidget with Legend panel"""
+ if self._legendsDockWidget is None:
+ self._legendsDockWidget = LegendsDockWidget(plot=self)
+ self._legendsDockWidget.hide()
+ self._legendsDockWidget.toggleViewAction().triggered.connect(
+ self._handleDockWidgetViewActionTriggered
+ )
+ self._legendsDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow
+ )
+ return self._legendsDockWidget
+
+ def getCurvesRoiDockWidget(self):
+ # Undocumented for a "soft deprecation" in version 0.7.0
+ # (still used internally for lazy loading)
+ if self._curvesROIDockWidget is None:
+ self._curvesROIDockWidget = CurvesROIDockWidget(
+ plot=self, name="Regions Of Interest"
+ )
+ self._curvesROIDockWidget.hide()
+ self._curvesROIDockWidget.toggleViewAction().triggered.connect(
+ self._handleDockWidgetViewActionTriggered
+ )
+ self._curvesROIDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow
+ )
+ return self._curvesROIDockWidget
+
+ def getCurvesRoiWidget(self):
+ """Return the :class:`CurvesROIWidget`.
+
+ :class:`silx.gui.plot.CurvesROIWidget.CurvesROIWidget` offers a getter
+ and a setter for the ROI data:
+
+ - :meth:`CurvesROIWidget.getRois`
+ - :meth:`CurvesROIWidget.setRois`
+ """
+ return self.getCurvesRoiDockWidget().roiWidget
+
+ def getMaskToolsDockWidget(self):
+ """DockWidget with image mask panel (lazy-loaded)."""
+ if self._maskToolsDockWidget is None:
+ self._maskToolsDockWidget = MaskToolsDockWidget(plot=self, name="Mask")
+ self._maskToolsDockWidget.hide()
+ self._maskToolsDockWidget.toggleViewAction().triggered.connect(
+ self._handleDockWidgetViewActionTriggered
+ )
+ self._maskToolsDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow
+ )
+ return self._maskToolsDockWidget
+
+ def getStatsWidget(self):
+ """Returns a BasicStatsWidget connected to this plot
+
+ :rtype: BasicStatsWidget
+ """
+ if self._statsDockWidget is None:
+ self._statsDockWidget = qt.QDockWidget()
+ self._statsDockWidget.setWindowTitle("Curves stats")
+ self._statsDockWidget.layout().setContentsMargins(0, 0, 0, 0)
+ statsWidget = BasicStatsWidget(parent=self, plot=self)
+ self._statsDockWidget.setWidget(statsWidget)
+ self._statsDockWidget.hide()
+ self._statsDockWidget.toggleViewAction().triggered.connect(
+ self._handleDockWidgetViewActionTriggered
+ )
+ self._statsDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow
+ )
+ return self._statsDockWidget.widget()
+
+ # getters for actions
+
+ def getConsoleAction(self):
+ """QAction handling the IPython console activation.
+
+ By default, it is connected to a method that initializes the
+ console widget the first time the user clicks the "Console" menu
+ button. The following clicks, after initialization is done,
+ will toggle the visibility of the console widget.
+
+ :rtype: QAction
+ """
+ if self._consoleAction is None:
+ self._consoleAction = qt.QAction("Console", self)
+ self._consoleAction.setCheckable(True)
+ if IPythonDockWidget is not None:
+ self._consoleAction.toggled.connect(self._toggleConsoleVisibility)
+ self._consoleAction.triggered.connect(self._consoleVisibilityTriggered)
+
+ else:
+ self._consoleAction.setEnabled(False)
+ return self._consoleAction
+
+ def getCrosshairAction(self):
+ """Action toggling crosshair cursor mode.
+
+ :rtype: actions.PlotAction
+ """
+ if self._crosshairAction is None:
+ self._crosshairAction = actions.control.CrosshairAction(self, color="red")
+ return self._crosshairAction
+
+ def getMaskAction(self):
+ """QAction toggling image mask dock widget
+
+ :rtype: QAction
+ """
+ return self.getMaskToolsDockWidget().toggleViewAction()
+
+ def getPanWithArrowKeysAction(self):
+ """Action toggling pan with arrow keys.
+
+ :rtype: actions.PlotAction
+ """
+ if self._panWithArrowKeysAction is None:
+ self._panWithArrowKeysAction = actions.control.PanWithArrowKeysAction(self)
+ return self._panWithArrowKeysAction
+
+ def getStatsAction(self):
+ return self.getStatsWidget().parent().toggleViewAction()
+
+ def getRoiAction(self):
+ """QAction toggling curve ROI dock widget
+
+ :rtype: QAction
+ """
+ return self.getCurvesRoiDockWidget().toggleViewAction()
+
+ def getResetZoomAction(self):
+ """Action resetting the zoom
+
+ :rtype: actions.PlotAction
+ """
+ return self.resetZoomAction
+
+ def getZoomInAction(self):
+ """Action to zoom in
+
+ :rtype: actions.PlotAction
+ """
+ return self.zoomInAction
+
+ def getZoomOutAction(self):
+ """Action to zoom out
+
+ :rtype: actions.PlotAction
+ """
+ return self.zoomOutAction
+
+ def getXAxisAutoScaleAction(self):
+ """Action to toggle the X axis autoscale on zoom reset
+
+ :rtype: actions.PlotAction
+ """
+ return self.xAxisAutoScaleAction
+
+ def getYAxisAutoScaleAction(self):
+ """Action to toggle the Y axis autoscale on zoom reset
+
+ :rtype: actions.PlotAction
+ """
+ return self.yAxisAutoScaleAction
+
+ def getXAxisLogarithmicAction(self):
+ """Action to toggle logarithmic X axis
+
+ :rtype: actions.PlotAction
+ """
+ return self.xAxisLogarithmicAction
+
+ def getYAxisLogarithmicAction(self):
+ """Action to toggle logarithmic Y axis
+
+ :rtype: actions.PlotAction
+ """
+ return self.yAxisLogarithmicAction
+
+ def getGridAction(self):
+ """Action to toggle the grid visibility in the plot
+
+ :rtype: actions.PlotAction
+ """
+ return self.gridAction
+
+ def getCurveStyleAction(self):
+ """Action to change curve line and markers styles
+
+ :rtype: actions.PlotAction
+ """
+ return self.curveStyleAction
+
+ def getColormapAction(self):
+ """Action open a colormap dialog to change active image
+ and default colormap.
+
+ :rtype: actions.PlotAction
+ """
+ return self.colormapAction
+
+ def getKeepDataAspectRatioButton(self):
+ """Button to toggle aspect ratio preservation
+
+ :rtype: PlotToolButtons.AspectToolButton
+ """
+ return self.keepDataAspectRatioButton
+
+ def getKeepDataAspectRatioAction(self):
+ """Action associated to keepDataAspectRatioButton.
+ Use this to change the visibility of keepDataAspectRatioButton in the
+ toolbar (See :meth:`QToolBar.addWidget` documentation).
+
+ :rtype: actions.PlotAction
+ """
+ return self.keepDataAspectRatioAction
+
+ def getYAxisInvertedButton(self):
+ """Button to switch the Y axis orientation
+
+ :rtype: PlotToolButtons.YAxisOriginToolButton
+ """
+ return self.yAxisInvertedButton
+
+ def getYAxisInvertedAction(self):
+ """Action associated to yAxisInvertedButton.
+ Use this to change the visibility yAxisInvertedButton in the toolbar.
+ (See :meth:`QToolBar.addWidget` documentation).
+
+ :rtype: actions.PlotAction
+ """
+ return self.yAxisInvertedAction
+
+ def getIntensityHistogramAction(self):
+ """Action toggling the histogram intensity Plot widget
+
+ :rtype: actions.PlotAction
+ """
+ return self._intensityHistoAction
+
+ def getCopyAction(self):
+ """Action to copy plot snapshot to clipboard
+
+ :rtype: actions.PlotAction
+ """
+ return self.getOutputToolBar().getCopyAction()
+
+ def getSaveAction(self):
+ """Action to save plot
+
+ :rtype: actions.PlotAction
+ """
+ return self.getOutputToolBar().getSaveAction()
+
+ def getPrintAction(self):
+ """Action to print plot
+
+ :rtype: actions.PlotAction
+ """
+ return self.getOutputToolBar().getPrintAction()
+
+ def getFitAction(self):
+ """Action to fit selected curve
+
+ :rtype: actions.PlotAction
+ """
+ return self.fitAction
+
+ def getMedianFilter1DAction(self):
+ """Action toggling the 1D median filter
+
+ :rtype: actions.PlotAction
+ """
+ return self._medianFilter1DAction
+
+ def getMedianFilter2DAction(self):
+ """Action toggling the 2D median filter
+
+ :rtype: actions.PlotAction
+ """
+ return self._medianFilter2DAction
+
+ def getColorBarAction(self):
+ """Action toggling the colorbar show/hide action
+
+ .. warning:: to show/hide the plot colorbar call directly the ColorBar
+ widget using getColorBarWidget()
+
+ :rtype: actions.PlotAction
+ """
+ return self.colorbarAction
+
+
+class Plot1D(PlotWindow):
+ """PlotWindow with tools specific for curves.
+
+ This widgets provides the plot API of :class:`.PlotWidget`.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ def __init__(self, parent=None, backend=None):
+ super(Plot1D, self).__init__(
+ parent=parent,
+ backend=backend,
+ resetzoom=True,
+ autoScale=True,
+ logScale=True,
+ grid=True,
+ curveStyle=True,
+ colormap=False,
+ aspectRatio=False,
+ yInverted=False,
+ copy=True,
+ save=True,
+ print_=True,
+ control=True,
+ position=True,
+ roi=True,
+ mask=False,
+ fit=True,
+ )
+ if parent is None:
+ self.setWindowTitle("Plot1D")
+ self.getXAxis().setLabel("X")
+ self.getYAxis().setLabel("Y")
+ action = self.getFitAction()
+ action.setXRangeUpdatedOnZoom(True)
+ action.setFittedItemUpdatedFromActiveCurve(True)
+
+ self.getInteractiveModeToolBar().getZoomModeAction().setAxesMenuEnabled(True)
+
+
+class Plot2D(PlotWindow):
+ """PlotWindow with a toolbar specific for images.
+
+ This widgets provides the plot API of :~:`.PlotWidget`.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ def __init__(self, parent=None, backend=None):
+ # List of information to display at the bottom of the plot
+ posInfo = [
+ ("X", lambda x, y: x),
+ ("Y", lambda x, y: y),
+ ("Data", WeakMethodProxy(self._getImageValue)),
+ ("Dims", WeakMethodProxy(self._getImageDims)),
+ ]
+
+ super(Plot2D, self).__init__(
+ parent=parent,
+ backend=backend,
+ resetzoom=True,
+ autoScale=False,
+ logScale=False,
+ grid=False,
+ curveStyle=False,
+ colormap=True,
+ aspectRatio=True,
+ yInverted=True,
+ copy=True,
+ save=True,
+ print_=True,
+ control=False,
+ position=posInfo,
+ roi=False,
+ mask=True,
+ )
+ if parent is None:
+ self.setWindowTitle("Plot2D")
+ self.getXAxis().setLabel("Columns")
+ self.getYAxis().setLabel("Rows")
+
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
+ self.getYAxis().setInverted(True)
+
+ self.profile = ProfileToolBar(plot=self)
+ self.addToolBar(self.profile)
+
+ self.colorbarAction.setVisible(True)
+ self.getColorBarWidget().setVisible(True)
+
+ # Put colorbar action after colormap action
+ actions = self.toolBar().actions()
+ for action in actions:
+ if action is self.getColormapAction():
+ break
+
+ self.sigActiveImageChanged.connect(self.__activeImageChanged)
+
+ def __activeImageChanged(self, previous, legend):
+ """Handle change of active image
+
+ :param Union[str,None] previous: Legend of previous active image
+ :param Union[str,None] legend: Legend of current active image
+ """
+ if previous is not None:
+ item = self.getImage(previous)
+ if item is not None:
+ item.sigItemChanged.disconnect(self.__imageChanged)
+
+ if legend is not None:
+ item = self.getImage(legend)
+ item.sigItemChanged.connect(self.__imageChanged)
+
+ positionInfo = self.getPositionInfoWidget()
+ if positionInfo is not None:
+ positionInfo.updateInfo()
+
+ def __imageChanged(self, event):
+ """Handle update of active image item
+
+ :param event: Type of changed event
+ """
+ if event == items.ItemChangedType.DATA:
+ positionInfo = self.getPositionInfoWidget()
+ if positionInfo is not None:
+ positionInfo.updateInfo()
+
+ def _getImageValue(self, x, y):
+ """Get status bar value of top most image at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The value at that point or '-'
+ """
+ pickedMask = None
+ for picked in self.pickItems(
+ *self.dataToPixel(x, y, check=False),
+ lambda item: isinstance(item, items.ImageBase),
+ ):
+ if isinstance(picked.getItem(), items.MaskImageData):
+ if pickedMask is None: # Use top-most if many masks
+ pickedMask = picked
+ else:
+ image = picked.getItem()
+
+ indices = picked.getIndices(copy=False)
+ if indices is not None:
+ row, col = indices[0][0], indices[1][0]
+ value = image.getData(copy=False)[row, col]
+
+ if pickedMask is not None: # Check if masked
+ maskItem = pickedMask.getItem()
+ indices = pickedMask.getIndices()
+ row, col = indices[0][0], indices[1][0]
+ if maskItem.getData(copy=False)[row, col] != 0:
+ return value, "Masked"
+ return value
+
+ return "-" # No image picked
+
+ def _getImageDims(self, *args):
+ activeImage = self.getActiveImage()
+ if activeImage is not None and activeImage.getData(copy=False) is not None:
+ dims = activeImage.getData(copy=False).shape[1::-1]
+ return "x".join(str(dim) for dim in dims)
+ else:
+ return "-"
+
+ def getProfileToolbar(self):
+ """Profile tools attached to this plot
+
+ See :class:`silx.gui.plot.Profile.ProfileToolBar`
+ """
+ return self.profile
+
+ def getProfilePlot(self):
+ """Return plot window used to display profile curve.
+
+ :return: :class:`Plot1D`
+ """
+ return self.profile.getProfilePlot()
diff --git a/src/silx/gui/plot/PrintPreviewToolButton.py b/src/silx/gui/plot/PrintPreviewToolButton.py
new file mode 100644
index 0000000..0812420
--- /dev/null
+++ b/src/silx/gui/plot/PrintPreviewToolButton.py
@@ -0,0 +1,388 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This modules provides tool buttons to send the content of a plot to a
+print preview page.
+The plot content can then be moved on the page and resized prior to printing.
+
+Classes
+-------
+
+- :class:`PrintPreviewToolButton`
+- :class:`SingletonPrintPreviewToolButton`
+
+Examples
+--------
+
+Simple example
+++++++++++++++
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.PrintPreviewToolButton import PrintPreviewToolButton
+ import numpy
+
+ app = qt.QApplication([])
+
+ pw = PlotWidget()
+ toolbar = qt.QToolBar(pw)
+ toolbutton = PrintPreviewToolButton(parent=toolbar, plot=pw)
+ pw.addToolBar(toolbar)
+ toolbar.addWidget(toolbutton)
+ pw.show()
+
+ x = numpy.arange(1000)
+ y = x / numpy.sin(x)
+ pw.addCurve(x, y)
+
+ app.exec()
+
+Singleton example
++++++++++++++++++
+
+This example illustrates how to print the content of several different
+plots on the same page. The plots all instantiate a
+:class:`SingletonPrintPreviewToolButton`, which relies on a singleton widget
+(:class:`silx.gui.widgets.PrintPreview.SingletonPrintPreviewDialog`).
+
+.. image:: img/printPreviewMultiPlot.png
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.PrintPreviewToolButton import SingletonPrintPreviewToolButton
+ import numpy
+
+ app = qt.QApplication([])
+
+ plot_widgets = []
+
+ for i in range(3):
+ pw = PlotWidget()
+ toolbar = qt.QToolBar(pw)
+ toolbutton = SingletonPrintPreviewToolButton(parent=toolbar,
+ plot=pw)
+ pw.addToolBar(toolbar)
+ toolbar.addWidget(toolbutton)
+ pw.show()
+ plot_widgets.append(pw)
+
+ x = numpy.arange(1000)
+
+ plot_widgets[0].addCurve(x, numpy.sin(x * 2 * numpy.pi / 1000))
+ plot_widgets[1].addCurve(x, numpy.cos(x * 2 * numpy.pi / 1000))
+ plot_widgets[2].addCurve(x, numpy.tan(x * 2 * numpy.pi / 1000))
+
+ app.exec()
+
+"""
+
+import logging
+from io import StringIO
+
+from .. import qt
+from .. import icons
+from . import PlotWidget
+from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog
+from ..widgets.PrintGeometryDialog import PrintGeometryDialog
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/12/2018"
+
+_logger = logging.getLogger(__name__)
+# _logger.setLevel(logging.DEBUG)
+
+
+class PrintPreviewToolButton(qt.QToolButton):
+ """QToolButton to open a :class:`PrintPreviewDialog` (if not already open)
+ and add the current plot to its page to be printed.
+
+ :param parent: See :class:`QAction`
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(PrintPreviewToolButton, self).__init__(parent)
+
+ if not isinstance(plot, PlotWidget):
+ raise TypeError("plot parameter must be a PlotWidget")
+ self._plot = plot
+
+ self.setIcon(icons.getQIcon("document-print"))
+
+ printGeomAction = qt.QAction("Print geometry", self)
+ printGeomAction.setToolTip(
+ "Define a print geometry prior to sending "
+ "the plot to the print preview dialog"
+ )
+ printGeomAction.setIcon(icons.getQIcon("shape-rectangle"))
+ printGeomAction.triggered.connect(self._setPrintConfiguration)
+
+ printPreviewAction = qt.QAction("Print preview", self)
+ printPreviewAction.setToolTip("Send plot to the print preview dialog")
+ printPreviewAction.setIcon(icons.getQIcon("document-print"))
+ printPreviewAction.triggered.connect(self._plotToPrintPreview)
+
+ menu = qt.QMenu(self)
+ menu.addAction(printGeomAction)
+ menu.addAction(printPreviewAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ self._printPreviewDialog = None
+ self._printConfigurationDialog = None
+
+ self._printGeometry = {
+ "xOffset": 0.1,
+ "yOffset": 0.1,
+ "width": 0.9,
+ "height": 0.9,
+ "units": "page",
+ "keepAspectRatio": True,
+ }
+
+ @property
+ def printPreviewDialog(self):
+ """Lazy loaded :class:`PrintPreviewDialog`"""
+ # if changes are made here, don't forget making them in
+ # SingletonPrintPreviewToolButton.printPreviewDialog as well
+ if self._printPreviewDialog is None:
+ self._printPreviewDialog = PrintPreviewDialog(self.parent())
+ return self._printPreviewDialog
+
+ def getTitle(self):
+ """Implement this method to fetch the title in the plot.
+
+ :return: Title to be printed above the plot, or None (no title added)
+ :rtype: str or None
+ """
+ return None
+
+ def getCommentAndPosition(self):
+ """Implement this method to fetch the legend to be printed below the
+ figure and its position.
+
+ :return: Legend to be printed below the figure and its position:
+ "CENTER", "LEFT" or "RIGHT"
+ :rtype: (str, str) or (None, None)
+ """
+ return None, None
+
+ def getPlot(self):
+ """Return the :class:`.PlotWidget` associated with this tool button.
+
+ :rtype: :class:`.PlotWidget`
+ """
+ return self._plot
+
+ def _plotToPrintPreview(self):
+ """Grab the plot widget and send it to the print preview dialog.
+ Make sure the print preview dialog is shown and raised."""
+ if not self.printPreviewDialog.ensurePrinterIsSet():
+ return
+
+ comment, commentPosition = self.getCommentAndPosition()
+
+ if qt.HAS_SVG:
+ svgRenderer, viewBox = self._getSvgRendererAndViewbox()
+ self.printPreviewDialog.addSvgItem(
+ svgRenderer,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition,
+ viewBox=viewBox,
+ keepRatio=self._printGeometry["keepAspectRatio"],
+ )
+ else:
+ _logger.warning("Missing QtSvg library, using a raster image")
+ pixmap = self._plot.centralWidget().grab()
+ self.printPreviewDialog.addPixmap(
+ pixmap,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition,
+ )
+ self.printPreviewDialog.show()
+ self.printPreviewDialog.raise_()
+
+ def _getSvgRendererAndViewbox(self):
+ """Return a SVG renderer displaying the plot and its viewbox
+ (interactively specified by the user the first time this is called).
+
+ The size of the renderer is adjusted to the printer configuration
+ and to the geometry configuration (width, height, ratio) specified
+ by the user."""
+ imgData = StringIO()
+ assert self._plot.saveGraph(imgData, fileFormat="svg"), "Unable to save graph"
+ imgData.flush()
+ imgData.seek(0)
+ svgData = imgData.read()
+
+ svgRenderer = qt.QSvgRenderer()
+
+ viewbox = self._getViewBox()
+
+ svgRenderer.setViewBox(viewbox)
+
+ xml_stream = qt.QXmlStreamReader(svgData.encode(errors="replace"))
+
+ # This is for PyMca compatibility, to share a print preview with PyMca plots
+ svgRenderer._viewBox = viewbox
+ svgRenderer._svgRawData = svgData.encode(errors="replace")
+ svgRenderer._svgRendererData = xml_stream
+
+ if not svgRenderer.load(xml_stream):
+ raise RuntimeError("Cannot interpret svg data")
+
+ return svgRenderer, viewbox
+
+ def _getViewBox(self):
+ """ """
+ printer = self.printPreviewDialog.printer
+ dpix = printer.logicalDpiX()
+ dpiy = printer.logicalDpiY()
+ availableWidth = printer.width()
+ availableHeight = printer.height()
+
+ config = self._printGeometry
+ width = config["width"]
+ height = config["height"]
+ xOffset = config["xOffset"]
+ yOffset = config["yOffset"]
+ units = config["units"]
+ keepAspectRatio = config["keepAspectRatio"]
+ aspectRatio = self._getPlotAspectRatio()
+
+ # convert the offsets to dots
+ if units.lower() in ["inch", "inches"]:
+ xOffset = xOffset * dpix
+ yOffset = yOffset * dpiy
+ if width is not None:
+ width = width * dpix
+ if height is not None:
+ height = height * dpiy
+ elif units.lower() in ["cm", "centimeters"]:
+ xOffset = (xOffset / 2.54) * dpix
+ yOffset = (yOffset / 2.54) * dpiy
+ if width is not None:
+ width = (width / 2.54) * dpix
+ if height is not None:
+ height = (height / 2.54) * dpiy
+ else:
+ # page units
+ xOffset = availableWidth * xOffset
+ yOffset = availableHeight * yOffset
+ if width is not None:
+ width = availableWidth * width
+ if height is not None:
+ height = availableHeight * height
+
+ availableWidth -= xOffset
+ availableHeight -= yOffset
+
+ if width is not None:
+ if (availableWidth + 0.1) < width:
+ txt = "Available width %f is less than requested width %f" % (
+ availableWidth,
+ width,
+ )
+ raise ValueError(txt)
+ if height is not None:
+ if (availableHeight + 0.1) < height:
+ txt = "Available height %f is less than requested height %f" % (
+ availableHeight,
+ height,
+ )
+ raise ValueError(txt)
+
+ if keepAspectRatio:
+ bodyWidth = width or availableWidth
+ bodyHeight = bodyWidth * aspectRatio
+
+ if bodyHeight > availableHeight:
+ bodyHeight = availableHeight
+ bodyWidth = bodyHeight / aspectRatio
+
+ else:
+ bodyWidth = width or availableWidth
+ bodyHeight = height or availableHeight
+
+ return qt.QRectF(xOffset, yOffset, bodyWidth, bodyHeight)
+
+ def _setPrintConfiguration(self):
+ """Open a dialog to prompt the user to adjust print
+ geometry parameters."""
+ self.printPreviewDialog.ensurePrinterIsSet()
+ if self._printConfigurationDialog is None:
+ self._printConfigurationDialog = PrintGeometryDialog(self.parent())
+
+ self._printConfigurationDialog.setPrintGeometry(self._printGeometry)
+ if self._printConfigurationDialog.exec():
+ self._printGeometry = self._printConfigurationDialog.getPrintGeometry()
+
+ def _getPlotAspectRatio(self):
+ widget = self._plot.centralWidget()
+ graphWidth = float(widget.width())
+ graphHeight = float(widget.height())
+ return graphHeight / graphWidth
+
+
+class SingletonPrintPreviewToolButton(PrintPreviewToolButton):
+ """This class is similar to its parent class :class:`PrintPreviewToolButton`
+ but it uses a singleton print preview widget.
+
+ This allows for several plots to send their content to the
+ same print page, and for users to arrange them."""
+
+ def __init__(self, parent=None, plot=None):
+ PrintPreviewToolButton.__init__(self, parent, plot)
+
+ @property
+ def printPreviewDialog(self):
+ if self._printPreviewDialog is None:
+ self._printPreviewDialog = SingletonPrintPreviewDialog(self.parent())
+ return self._printPreviewDialog
+
+
+if __name__ == "__main__":
+ import numpy
+
+ app = qt.QApplication([])
+
+ pw = PlotWidget()
+ toolbar = qt.QToolBar(pw)
+ toolbutton = PrintPreviewToolButton(parent=toolbar, plot=pw)
+ pw.addToolBar(toolbar)
+ toolbar.addWidget(toolbutton)
+ pw.show()
+
+ x = numpy.arange(1000)
+ y = x / numpy.sin(x)
+ pw.addCurve(x, y)
+
+ app.exec()
diff --git a/src/silx/gui/plot/Profile.py b/src/silx/gui/plot/Profile.py
new file mode 100644
index 0000000..f89f780
--- /dev/null
+++ b/src/silx/gui/plot/Profile.py
@@ -0,0 +1,254 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Utility functions, toolbars and actions to create profile on images
+and stacks of images"""
+
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "12/04/2019"
+
+
+import weakref
+
+from .. import qt
+from . import actions
+from .tools.profile import manager
+from .tools.profile import rois
+from silx.gui.widgets.MultiModeAction import MultiModeAction
+
+from .tools import roi as roi_mdl
+from silx.gui.plot import items
+
+
+class _CustomProfileManager(manager.ProfileManager):
+ """This custom profile manager uses a single predefined profile window
+ if it is specified. Else the behavior is the same as the default
+ ProfileManager"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__profileWindow = None
+ self.__specializedProfileWindows = {}
+
+ def setSpecializedProfileWindow(self, roiClass, profileWindow):
+ """Set a profile window for a given class or ROI.
+
+ Setting profileWindow to None removes the roiClass from the list.
+
+ :param roiClass:
+ :param profileWindow:
+ """
+ if profileWindow is None:
+ self.__specializedProfileWindows.pop(roiClass, None)
+ else:
+ self.__specializedProfileWindows[roiClass] = profileWindow
+
+ def setProfileWindow(self, profileWindow):
+ self.__profileWindow = profileWindow
+
+ def createProfileWindow(self, plot, roi):
+ for (
+ roiClass,
+ specializedProfileWindow,
+ ) in self.__specializedProfileWindows.items():
+ if isinstance(roi, roiClass):
+ return specializedProfileWindow
+
+ if self.__profileWindow is not None:
+ return self.__profileWindow
+ else:
+ return super(_CustomProfileManager, self).createProfileWindow(plot, roi)
+
+ def clearProfileWindow(self, profileWindow):
+ for specializedProfileWindow in self.__specializedProfileWindows.values():
+ if profileWindow is specializedProfileWindow:
+ profileWindow.setProfile(None)
+ return
+
+ if self.__profileWindow is not None:
+ self.__profileWindow.setProfile(None)
+ else:
+ return super(_CustomProfileManager, self).clearProfileWindow(profileWindow)
+
+
+class ProfileToolBar(qt.QToolBar):
+ """QToolBar providing profile tools operating on a :class:`PlotWindow`.
+
+ Attributes:
+
+ - plot: Associated :class:`PlotWindow` on which the profile line is drawn.
+ - actionGroup: :class:`QActionGroup` of available actions.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow and add a :class:`ProfileToolBar`.
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> from silx.gui.plot.Profile import ProfileToolBar
+
+ >>> plot = PlotWindow() # Create a PlotWindow
+ >>> toolBar = ProfileToolBar(plot=plot) # Create a profile toolbar
+ >>> plot.addToolBar(toolBar) # Add it to plot
+ >>> plot.show() # To display the PlotWindow with the profile toolbar
+
+ :param plot: :class:`PlotWindow` instance on which to operate.
+ :param profileWindow: Plot widget instance where to
+ display the profile curve or None to create one.
+ :param parent: See :class:`QToolBar`.
+ """
+
+ def __init__(self, parent=None, plot=None, profileWindow=None):
+ super(ProfileToolBar, self).__init__(parent)
+ assert plot is not None
+
+ self._plotRef = weakref.ref(plot)
+
+ # If a profileWindow is defined,
+ # It will be used to display all the profiles
+ self._manager = self.createProfileManager(self, plot)
+ self._manager.setProfileWindow(profileWindow)
+ self._manager.setDefaultColorFromCursorColor(True)
+ self._manager.setItemType(image=True)
+ self._manager.setActiveItemTracking(True)
+
+ # Actions
+ self._browseAction = actions.mode.ZoomModeAction(plot, parent=self)
+ self._browseAction.setVisible(False)
+ self.freeLineAction = None
+ self._createProfileActions()
+ self._editor = self._manager.createEditorAction(self)
+
+ # ActionGroup
+ self.actionGroup = qt.QActionGroup(self)
+ self.actionGroup.addAction(self._browseAction)
+ self.actionGroup.addAction(self.hLineAction)
+ self.actionGroup.addAction(self.vLineAction)
+ self.actionGroup.addAction(self.lineAction)
+ self.actionGroup.addAction(self._editor)
+
+ modes = MultiModeAction(self)
+ modes.addAction(self.hLineAction)
+ modes.addAction(self.vLineAction)
+ modes.addAction(self.lineAction)
+ if self.freeLineAction is not None:
+ modes.addAction(self.freeLineAction)
+ modes.addAction(self.crossAction)
+ self.__multiAction = modes
+
+ # Add actions to ToolBar
+ self.addAction(self._browseAction)
+ self.addAction(modes)
+ self.addAction(self._editor)
+ self.addAction(self.clearAction)
+
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ self._activeImageChanged()
+
+ def createProfileManager(self, parent, plot):
+ return _CustomProfileManager(parent, plot)
+
+ def _createProfileActions(self):
+ self.hLineAction = self._manager.createProfileAction(
+ rois.ProfileImageHorizontalLineROI, self
+ )
+ self.vLineAction = self._manager.createProfileAction(
+ rois.ProfileImageVerticalLineROI, self
+ )
+ self.lineAction = self._manager.createProfileAction(
+ rois.ProfileImageLineROI, self
+ )
+ self.freeLineAction = self._manager.createProfileAction(
+ rois.ProfileImageDirectedLineROI, self
+ )
+ self.crossAction = self._manager.createProfileAction(
+ rois.ProfileImageCrossROI, self
+ )
+ self.clearAction = self._manager.createClearAction(self)
+
+ def getPlotWidget(self):
+ """The :class:`.PlotWidget` associated to the toolbar."""
+ return self._plotRef()
+
+ def _setRoiActionEnabled(self, itemKind, enabled):
+ for action in self.__multiAction.getMenu().actions():
+ if not isinstance(action, roi_mdl.CreateRoiModeAction):
+ continue
+ roiClass = action.getRoiClass()
+ if issubclass(itemKind, roiClass.ITEM_KIND):
+ action.setEnabled(enabled)
+
+ def _activeImageChanged(self, previous=None, legend=None):
+ """Handle active image change to toggle actions"""
+ if legend is None:
+ self._setRoiActionEnabled(items.ImageStack, False)
+ self._setRoiActionEnabled(items.ImageBase, False)
+ else:
+ plot = self.getPlotWidget()
+ image = plot.getActiveImage()
+ # Disable for empty image
+ enabled = image.getData(copy=False).size > 0
+ self._setRoiActionEnabled(type(image), enabled)
+
+ def getProfileManager(self):
+ """Return the manager of the profiles.
+
+ :rtype: ProfileManager
+ """
+ return self._manager
+
+ def clearProfile(self):
+ """Remove profile curve and profile area."""
+ self._manager.clearProfile()
+
+
+class Profile3DToolBar(ProfileToolBar):
+ def __init__(self, parent=None, stackview=None):
+ """QToolBar providing profile tools for an image or a stack of images.
+
+ :param parent: the parent QWidget
+ :param stackview: :class:`StackView` instance on which to operate.
+ :param parent: See :class:`QToolBar`.
+ """
+ # TODO: add param profileWindow (specify the plot used for profiles)
+ super(Profile3DToolBar, self).__init__(
+ parent=parent, plot=stackview.getPlotWidget()
+ )
+
+ self.stackView = stackview
+ """:class:`StackView` instance"""
+
+ def _createProfileActions(self):
+ self.hLineAction = self._manager.createProfileAction(
+ rois.ProfileImageStackHorizontalLineROI, self
+ )
+ self.vLineAction = self._manager.createProfileAction(
+ rois.ProfileImageStackVerticalLineROI, self
+ )
+ self.lineAction = self._manager.createProfileAction(
+ rois.ProfileImageStackLineROI, self
+ )
+ self.crossAction = self._manager.createProfileAction(
+ rois.ProfileImageStackCrossROI, self
+ )
+ self.clearAction = self._manager.createClearAction(self)
diff --git a/src/silx/gui/plot/ROIStatsWidget.py b/src/silx/gui/plot/ROIStatsWidget.py
new file mode 100644
index 0000000..36f3391
--- /dev/null
+++ b/src/silx/gui/plot/ROIStatsWidget.py
@@ -0,0 +1,807 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides widget for displaying statistics relative to a
+Region of interest and an item
+"""
+
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "22/07/2019"
+
+
+from contextlib import contextmanager
+from silx.gui import qt
+from silx.gui import icons
+from silx.gui.plot.StatsWidget import _StatsWidgetBase, _Container
+from silx.gui.plot.StatsWidget import UpdateMode
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.gui.plot.items.roi import RegionOfInterest
+from silx.gui.plot import items as plotitems
+from silx.gui.plot.items.core import ItemChangedType
+from silx.gui.plot3d import items as plot3ditems
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.gui.plot import stats as statsmdl
+from silx.utils.proxy import docstring
+import silx.gui.plot.items.marker
+import silx.gui.plot.items.shape
+import functools
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+class _GetROIItemCoupleDialog(qt.QDialog):
+ """
+ Dialog used to know which plot item and which roi he wants
+ """
+
+ _COMPATIBLE_KINDS = ("curve", "image", "scatter", "histogram")
+
+ def __init__(self, parent=None, plot=None, rois=None):
+ qt.QDialog.__init__(self, parent=parent)
+ assert plot is not None
+ assert rois is not None
+ self._plot = plot
+ self._rois = rois
+
+ self.setLayout(qt.QVBoxLayout())
+
+ # define the selection widget
+ self._selection_widget = qt.QWidget()
+ self._selection_widget.setLayout(qt.QHBoxLayout())
+ self._kindCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._kindCB)
+ self._itemCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._itemCB)
+ self._roiCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._roiCB)
+ self.layout().addWidget(self._selection_widget)
+
+ # define modal buttons
+ types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
+ self._buttonsModal = qt.QDialogButtonBox(parent=self)
+ self._buttonsModal.setStandardButtons(types)
+ self.layout().addWidget(self._buttonsModal)
+ self._buttonsModal.accepted.connect(self.accept)
+ self._buttonsModal.rejected.connect(self.reject)
+
+ # connect signal / slot
+ self._kindCB.currentIndexChanged.connect(self._updateValidItemAndRoi)
+
+ def _getCompatibleRois(self, kind):
+ """Return compatible rois for the given item kind"""
+
+ def is_compatible(roi, kind):
+ if isinstance(roi, RegionOfInterest):
+ return kind in ("image", "scatter")
+ elif isinstance(roi, ROI):
+ return kind in ("curve", "histogram")
+ else:
+ raise ValueError("kind not managed")
+
+ return list(filter(lambda x: is_compatible(x, kind), self._rois))
+
+ def exec(self):
+ self._kindCB.clear()
+ self._itemCB.clear()
+ # filter kind without any items
+ self._valid_kinds = {}
+ # key is item type, value kinds
+ self._valid_rois = {}
+ # key is item type, value rois
+ self._kind_name_to_roi = {}
+ # key is (kind, roi name) value is roi
+ self._kind_name_to_item = {}
+ # key is (kind, legend name) value is item
+ for kind in _GetROIItemCoupleDialog._COMPATIBLE_KINDS:
+
+ def getItems(kind):
+ output = []
+ for item in self._plot.getItems():
+ type_ = self._plot._itemKind(item)
+ if type_ in kind and item.isVisible():
+ output.append(item)
+ return output
+
+ items = getItems(kind=kind)
+ rois = self._getCompatibleRois(kind=kind)
+ if len(items) > 0 and len(rois) > 0:
+ self._valid_kinds[kind] = items
+ self._valid_rois[kind] = rois
+ for roi in rois:
+ name = roi.getName()
+ self._kind_name_to_roi[(kind, name)] = roi
+ for item in items:
+ self._kind_name_to_item[(kind, item.getLegend())] = item
+
+ # filter roi according to kinds
+ if len(self._valid_kinds) == 0:
+ _logger.warning("no couple item/roi detected for displaying stats")
+ return self.reject()
+
+ for kind in self._valid_kinds:
+ self._kindCB.addItem(kind)
+ self._updateValidItemAndRoi()
+
+ return qt.QDialog.exec(self)
+
+ def exec_(self): # Qt5 compatibility
+ return self.exec()
+
+ def _updateValidItemAndRoi(self, *args, **kwargs):
+ self._itemCB.clear()
+ self._roiCB.clear()
+ kind = self._kindCB.currentText()
+ for roi in self._valid_rois[kind]:
+ self._roiCB.addItem(roi.getName())
+ for item in self._valid_kinds[kind]:
+ self._itemCB.addItem(item.getLegend())
+
+ def getROI(self):
+ kind = self._kindCB.currentText()
+ roi_name = self._roiCB.currentText()
+ return self._kind_name_to_roi[(kind, roi_name)]
+
+ def getItem(self):
+ kind = self._kindCB.currentText()
+ item_name = self._itemCB.currentText()
+ return self._kind_name_to_item[(kind, item_name)]
+
+
+class ROIStatsItemHelper(object):
+ """Item utils to associate a plot item and a roi
+
+ Display on one row statistics regarding the couple
+ (Item (plot item) / roi).
+
+ :param Item plot_item: item for which we want statistics
+ :param Union[ROI,RegionOfInterest]: region of interest to use for
+ statistics.
+ """
+
+ def __init__(self, plot_item, roi):
+ self._plot_item = plot_item
+ self._roi = roi
+
+ @property
+ def roi(self):
+ """roi"""
+ return self._roi
+
+ def roi_name(self):
+ if isinstance(self._roi, ROI):
+ return self._roi.getName()
+ elif isinstance(self._roi, RegionOfInterest):
+ return self._roi.getName()
+ else:
+ raise TypeError("Unmanaged roi type")
+
+ @property
+ def roi_kind(self):
+ """roi class"""
+ return self._roi.__class__
+
+ # TODO: should call a util function from the wrapper ?
+ def item_kind(self):
+ """item kind"""
+ if isinstance(self._plot_item, plotitems.Curve):
+ return "curve"
+ elif isinstance(self._plot_item, plotitems.ImageData):
+ return "image"
+ elif isinstance(self._plot_item, plotitems.Scatter):
+ return "scatter"
+ elif isinstance(self._plot_item, plotitems.Histogram):
+ return "histogram"
+ elif isinstance(
+ self._plot_item, (plot3ditems.ImageData, plot3ditems.ScalarField3D)
+ ):
+ return "image"
+ elif isinstance(
+ self._plot_item, (plot3ditems.Scatter2D, plot3ditems.Scatter3D)
+ ):
+ return "scatter"
+
+ @property
+ def item_legend(self):
+ """legend of the plot Item"""
+ return self._plot_item.getLegend()
+
+ def id_key(self):
+ """unique key to represent the couple (item, roi)"""
+ return (self.item_kind(), self.item_legend, self.roi_kind, self.roi_name())
+
+
+class _StatsROITable(_StatsWidgetBase, TableWidget):
+ """
+ Table sued to display some statistics regarding a couple (item/roi)
+ """
+
+ _LEGEND_HEADER_DATA = "legend"
+
+ _KIND_HEADER_DATA = "kind"
+
+ _ROI_HEADER_DATA = "roi"
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(self, parent, plot):
+ TableWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(
+ self, statsOnVisibleData=False, displayOnlyActItem=False
+ )
+ self.__region_edition_callback = {}
+ """We need to keep trace of the roi signals connection because
+ the roi emits the sigChanged during roi edition"""
+ self._items = {}
+ self.setRowCount(0)
+ self.setColumnCount(3)
+
+ # Init headers
+ headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
+ self.setHorizontalHeaderItem(0, headerItem)
+ headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
+ self.setHorizontalHeaderItem(1, headerItem)
+ headerItem = qt.QTableWidgetItem(self._ROI_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._ROI_HEADER_DATA)
+ self.setHorizontalHeaderItem(2, headerItem)
+
+ self.setSortingEnabled(True)
+ self.setPlot(plot)
+
+ self.__plotItemToItems = {}
+ """Key is plotItem, values is list of __RoiStatsItemWidget"""
+ self.__roiToItems = {}
+ """Key is roi, values is list of __RoiStatsItemWidget"""
+ self.__roisKeyToRoi = {}
+
+ def add(self, item):
+ assert isinstance(item, ROIStatsItemHelper)
+ if item.id_key() in self._items:
+ _logger.warning("Item %s is already present", item.id_key())
+ return None
+ self._items[item.id_key()] = item
+ self._addItem(item)
+ return item
+
+ def _addItem(self, item):
+ """
+ Add a _RoiStatsItemWidget item to the table.
+
+ :param item:
+ :return: True if successfully added.
+ """
+ if not isinstance(item, ROIStatsItemHelper):
+ # skipped because also receive all new plot item (Marker...) that
+ # we don't want to manage in this case.
+ return
+ # plotItem = item.getItem()
+ # roi = item.getROI()
+ kind = item.item_kind()
+ if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.info("Item has not a supported type: %s", item)
+ return False
+
+ # register the roi and the kind
+ self._registerPlotItem(item)
+ self._registerROI(item)
+
+ # Prepare table items
+ tableItems = [
+ qt.QTableWidgetItem(), # Legend
+ qt.QTableWidgetItem(), # Kind
+ qt.QTableWidgetItem(),
+ ] # roi
+
+ for column in range(3, self.columnCount()):
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+
+ formatter = self._statsHandler.formatters[name]
+ if formatter:
+ tableItem = formatter.tabWidgetItemClass()
+ else:
+ tableItem = qt.QTableWidgetItem()
+
+ tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
+ if tooltip is not None:
+ tableItem.setToolTip(tooltip)
+
+ tableItems.append(tableItem)
+
+ # Disable sorting while adding table items
+ with self._disableSorting():
+ # Add a row to the table
+ self.setRowCount(self.rowCount() + 1)
+
+ # Add table items to the last row
+ row = self.rowCount() - 1
+ for column, tableItem in enumerate(tableItems):
+ tableItem.setData(qt.Qt.UserRole, _Container(item))
+ tableItem.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, column, tableItem)
+
+ # Update table items content
+ self._updateStats(item, data_changed=True)
+
+ # Listen for item changes
+ # Using queued connection to avoid issue with sender
+ # being that of the signal calling the signal
+ item._plot_item.sigItemChanged.connect(
+ self._plotItemChanged, qt.Qt.QueuedConnection
+ )
+ return True
+
+ def _removeAllItems(self):
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ # item = self._tableItemToItem(tableItem)
+ # item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.clearContents()
+ self.setRowCount(0)
+
+ def clear(self):
+ self._removeAllItems()
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ self._removeAllItems()
+ _StatsWidgetBase.setStats(self, statsHandler)
+
+ self.setRowCount(0)
+ self.setColumnCount(
+ len(self._statsHandler.stats) + 3
+ ) # + legend, kind and roi # noqa
+
+ for index, stat in enumerate(self._statsHandler.stats.values()):
+ headerItem = qt.QTableWidgetItem(stat.name.capitalize())
+ headerItem.setData(qt.Qt.UserRole, stat.name)
+ if stat.description is not None:
+ headerItem.setToolTip(stat.description)
+ self.setHorizontalHeaderItem(3 + index, headerItem)
+
+ horizontalHeader = self.horizontalHeader()
+ horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ self._updateItemObserve()
+
+ def _updateItemObserve(self, *args):
+ pass
+
+ def _dataChanged(self, item):
+ pass
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ assert isinstance(item, ROIStatsItemHelper)
+ plotItem = item._plot_item
+ roi = item._roi
+ if item is None:
+ return
+ plot = self.getPlot()
+ if plot is None:
+ _logger.info("Plot not available")
+ return
+
+ row = self._itemToRow(item)
+ if row is None:
+ _logger.error("This item is not in the table: %s", str(item))
+ return
+
+ statsHandler = self.getStatsHandler()
+ if statsHandler is not None:
+ stats = statsHandler.calculate(
+ plotItem,
+ plot,
+ onlimits=self._statsOnVisibleData,
+ roi=roi,
+ data_changed=data_changed,
+ roi_changed=roi_changed,
+ )
+ else:
+ stats = {}
+
+ with self._disableSorting():
+ for name, tableItem in self._itemToTableItems(item).items():
+ if name == self._LEGEND_HEADER_DATA:
+ text = self._plotWrapper.getLabel(plotItem)
+ tableItem.setText(text)
+ elif name == self._KIND_HEADER_DATA:
+ tableItem.setText(self._plotWrapper.getKind(plotItem))
+ elif name == self._ROI_HEADER_DATA:
+ name = roi.getName()
+ tableItem.setText(name)
+ else:
+ value = stats.get(name)
+ if value is None:
+ _logger.error("Value not found for: %s", name)
+ tableItem.setText("-")
+ else:
+ tableItem.setText(str(value))
+
+ @contextmanager
+ def _disableSorting(self):
+ """Context manager that disables table sorting
+
+ Previous state is restored when leaving
+ """
+ sorting = self.isSortingEnabled()
+ if sorting:
+ self.setSortingEnabled(False)
+ yield
+ if sorting:
+ self.setSortingEnabled(sorting)
+
+ def _itemToRow(self, item):
+ """Find the row corresponding to a plot item
+
+ :param item: The plot item
+ :return: The corresponding row index
+ :rtype: Union[int,None]
+ """
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ if self._tableItemToItem(tableItem) == item:
+ return row
+ return None
+
+ def _tableItemToItem(self, tableItem):
+ """Find the plot item corresponding to a table item
+
+ :param QTableWidgetItem tableItem:
+ :rtype: QObject
+ """
+ container = tableItem.data(qt.Qt.UserRole)
+ return container()
+
+ def _itemToTableItems(self, item):
+ """Find all table items corresponding to a plot item
+
+ :param item: The plot item
+ :return: An ordered dict of column name to QTableWidgetItem mapping
+ for the given plot item.
+ :rtype: dict
+ """
+ result = {}
+ row = self._itemToRow(item)
+ if row is not None:
+ for column in range(self.columnCount()):
+ tableItem = self.item(row, column)
+ if self._tableItemToItem(tableItem) != item:
+ _logger.error("Table item/plot item mismatch")
+ else:
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+ result[name] = tableItem
+ return result
+
+ def _plotItemToItems(self, plotItem):
+ """Return all _RoiStatsItemWidget associated to the plotItem
+ Needed for updating on itemChanged signal
+ """
+ if plotItem in self.__plotItemToItems:
+ return []
+ else:
+ return self.__plotItemToItems[plotItem]
+
+ def _registerPlotItem(self, item):
+ if item._plot_item not in self.__plotItemToItems:
+ self.__plotItemToItems[item._plot_item] = set()
+ self.__plotItemToItems[item._plot_item].add(item)
+
+ def _roiToItems(self, roi):
+ """Return all _RoiStatsItemWidget associated to the roi
+ Needed for updating on roiChanged signal
+ """
+ if roi in self.__roiToItems:
+ return []
+ else:
+ return self.__roiToItems[roi]
+
+ def _registerROI(self, item):
+ if item._roi not in self.__roiToItems:
+ self.__roiToItems[item._roi] = set()
+ # TODO: normalize also sig name
+ if isinstance(item._roi, RegionOfInterest):
+ # item connection within sigRegionChanged should only be
+ # stopped during the region edition
+ self.__region_edition_callback[item._roi] = functools.partial(
+ self._updateAllStats, False, True
+ )
+ item._roi.sigRegionChanged.connect(
+ self.__region_edition_callback[item._roi]
+ )
+ item._roi.sigEditingStarted.connect(
+ functools.partial(self._startFiltering, item._roi)
+ )
+ item._roi.sigEditingFinished.connect(
+ functools.partial(self._endFiltering, item._roi)
+ )
+ else:
+ item._roi.sigChanged.connect(
+ functools.partial(self._updateAllStats, False, True)
+ )
+ self.__roiToItems[item._roi].add(item)
+
+ def _startFiltering(self, roi):
+ roi.sigRegionChanged.disconnect(self.__region_edition_callback[roi])
+
+ def _endFiltering(self, roi):
+ roi.sigRegionChanged.connect(self.__region_edition_callback[roi])
+ self._updateAllStats(roi_changed=True)
+
+ def unregisterROI(self, roi):
+ if roi in self.__roiToItems:
+ del self.__roiToItems[roi]
+ if isinstance(roi, RegionOfInterest):
+ roi.sigRegionEditionStarted.disconnect(
+ functools.partial(self._startFiltering, roi)
+ )
+ roi.sigRegionEditionFinished.disconnect(
+ functools.partial(self._startFiltering, roi)
+ )
+ try:
+ roi.sigRegionChanged.disconnect(self._updateAllStats)
+ except:
+ pass
+ else:
+ roi.sigChanged.disconnect(self._updateAllStats)
+
+ def _plotItemChanged(self, event):
+ """Handle modifications of the items.
+
+ :param event:
+ """
+ if event is ItemChangedType.DATA:
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if self._skipPlotItemChangedEvent(event) is True:
+ return
+ else:
+ sender = self.sender()
+ for item in self.__plotItemToItems[sender]:
+ # TODO: get all concerned items
+ self._updateStats(item, data_changed=True)
+ # deal with stat items visibility
+ if event is ItemChangedType.VISIBLE:
+ if len(self._itemToTableItems(item).items()) > 0:
+ item_0 = list(self._itemToTableItems(item).values())[0]
+ row_index = item_0.row()
+ self.setRowHidden(row_index, not item.isVisible())
+
+ def _removeItem(self, itemKey):
+ if isinstance(
+ itemKey,
+ (silx.gui.plot.items.marker.Marker, silx.gui.plot.items.shape.Shape),
+ ):
+ return
+ if itemKey not in self._items:
+ _logger.warning("key not recognized. Won't remove any item")
+ return
+ item = self._items[itemKey]
+ row = self._itemToRow(item)
+ if row is None:
+ kind = self._plotWrapper.getKind(item)
+ if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.error("Removing item that is not in table: %s", str(item))
+ return
+ item._plot_item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.removeRow(row)
+ del self._items[itemKey]
+
+ def _updateAllStats(self, is_request=False, roi_changed=False):
+ """Update stats for all rows in the table
+
+ :param bool is_request: True if come from a manual request
+ """
+ if (
+ self.getUpdateMode() is UpdateMode.MANUAL
+ and not is_request
+ and not roi_changed
+ ):
+ return
+
+ with self._disableSorting():
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ self._updateStats(
+ item, roi_changed=roi_changed, data_changed=is_request
+ )
+
+ def _plotCurrentChanged(self, *args):
+ pass
+
+ def _getRoi(self, kind, name):
+ """return the roi fitting the requirement kind, name. This information
+ is enough to be sure it is unique (in the widget)"""
+ for roi in self.__roiToItems:
+ roiName = roi.getName()
+ if isinstance(roi, kind) and name == roiName:
+ return roi
+ return None
+
+ def _getPlotItem(self, kind, legend):
+ """return the plotItem fitting the requirement kind, legend.
+ This information is enough to be sure it is unique (in the widget)"""
+ for plotItem in self.__plotItemToItems:
+ if (
+ legend == plotItem.getLegend()
+ and self._plotWrapper.getKind(plotItem) == kind
+ ):
+ return plotItem
+ return None
+
+
+class ROIStatsWidget(qt.QMainWindow):
+ """
+ Widget used to define stats item for a couple(roi, plotItem).
+ Stats will be computing on a given item (curve, image...) in the given
+ region of interest.
+
+ It also provide an interface for adding and removing items.
+
+ .. snapshotqt:: img/ROIStatsWidget.png
+ :width: 300px
+ :align: center
+
+ from silx.gui import qt
+ from silx.gui.plot import Plot2D
+ from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
+ from silx.gui.plot.items.roi import RectangleROI
+ import numpy
+ plot = Plot2D()
+ plot.addImage(numpy.arange(10000).reshape(100, 100), legend='img')
+ plot.show()
+ rectangleROI = RectangleROI()
+ rectangleROI.setGeometry(origin=(0, 100), size=(20, 20))
+ rectangleROI.setName('Initial ROI')
+ widget = ROIStatsWidget(plot=plot)
+ widget.setStats([('sum', numpy.sum), ('mean', numpy.mean)])
+ widget.registerROI(rectangleROI)
+ widget.addItem(roi=rectangleROI, plotItem=plot.getImage('img'))
+ widget.show()
+
+ :param Union[qt.QWidget,None] parent: parent qWidget
+ :param PlotWindow plot: plot widget containing the items
+ :param stats: stats to display
+ :param tuple rois: tuple of rois to manage
+ """
+
+ def __init__(self, parent=None, plot=None, stats=None, rois=None):
+ qt.QMainWindow.__init__(self, parent)
+
+ toolbar = qt.QToolBar(self)
+ icon = icons.getQIcon("add")
+ self._rois = list(rois) if rois is not None else []
+ self._addAction = qt.QAction(icon, "add item/roi", toolbar)
+ self._addAction.triggered.connect(self._addRoiStatsItem)
+ icon = icons.getQIcon("rm")
+ self._removeAction = qt.QAction(icon, "remove item/roi", toolbar)
+ self._removeAction.triggered.connect(self._removeCurrentRow)
+
+ toolbar.addAction(self._addAction)
+ toolbar.addAction(self._removeAction)
+ self.addToolBar(toolbar)
+
+ self._plot = plot
+ self._statsROITable = _StatsROITable(parent=self, plot=self._plot)
+ self.setStats(stats=stats)
+ self.setCentralWidget(self._statsROITable)
+ self.setWindowFlags(qt.Qt.Widget)
+
+ # expose API
+ self._setUpdateMode = self._statsROITable.setUpdateMode
+ self._updateAllStats = self._statsROITable._updateAllStats
+
+ # setup
+ self._statsROITable.setSelectionBehavior(qt.QTableWidget.SelectRows)
+
+ def registerROI(self, roi):
+ """For now there is no direct link between roi and plot. That is why
+ we need to add/register them to be able to associate them"""
+ self._rois.append(roi)
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ self._plot = plot
+
+ def getPlot(self):
+ return self._plot
+
+ @docstring(_StatsROITable)
+ def setStats(self, stats):
+ if stats is not None:
+ self._statsROITable.setStats(statsHandler=stats)
+
+ @docstring(_StatsROITable)
+ def getStatsHandler(self):
+ """
+
+ :return:
+ """
+ return self._statsROITable.getStatsHandler()
+
+ def _addRoiStatsItem(self):
+ """Ask the user what couple ROI / item he want to display"""
+ dialog = _GetROIItemCoupleDialog(parent=self, plot=self._plot, rois=self._rois)
+ if dialog.exec():
+ self.addItem(roi=dialog.getROI(), plotItem=dialog.getItem())
+
+ def addItem(self, plotItem, roi):
+ """
+ Add a row of statitstic regarding the couple (plotItem, roi)
+
+ :param Item plotItem: item to use for statistics
+ :param roi: region of interest to limit the statistic.
+ :type: Union[ROI, RegionOfInterest]
+ :return: None of failed to add the item
+ :rtype: Union[None,ROIStatsItemHelper]
+ """
+ statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
+ return self._statsROITable.add(item=statsItem)
+
+ def removeItem(self, plotItem, roi):
+ """
+ Remove the row associated to the couple (plotItem, roi)
+
+ :param Item plotItem: item to use for statistics
+ :param roi: region of interest to limit the statistic.
+ :type: Union[ROI,RegionOfInterest]
+ """
+ statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
+ self._statsROITable._removeItem(itemKey=statsItem.id_key())
+
+ def _removeCurrentRow(self):
+ def is1DKind(kind):
+ if kind in ("curve", "histogram", "scatter"):
+ return True
+ else:
+ return False
+
+ currentRow = self._statsROITable.currentRow()
+ item_kind = self._statsROITable.item(currentRow, 1).text()
+ item_legend = self._statsROITable.item(currentRow, 0).text()
+
+ roi_name = self._statsROITable.item(currentRow, 2).text()
+ roi_kind = ROI if is1DKind(item_kind) else RegionOfInterest
+ roi = self._statsROITable._getRoi(kind=roi_kind, name=roi_name)
+ if roi is None:
+ _logger.warning("failed to retrieve the roi you want to remove")
+ return False
+ plot_item = self._statsROITable._getPlotItem(kind=item_kind, legend=item_legend)
+ if plot_item is None:
+ _logger.warning("failed to retrieve the plot item you want to" "remove")
+ return False
+ return self.removeItem(plotItem=plot_item, roi=roi)
diff --git a/src/silx/gui/plot/ScatterMaskToolsWidget.py b/src/silx/gui/plot/ScatterMaskToolsWidget.py
new file mode 100644
index 0000000..300f3a6
--- /dev/null
+++ b/src/silx/gui/plot/ScatterMaskToolsWidget.py
@@ -0,0 +1,643 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget providing a set of tools to draw masks on a PlotWidget.
+
+This widget is meant to work with a modified :class:`silx.gui.plot.PlotWidget`
+
+- :class:`ScatterMask`: Handle scatter mask update and history
+- :class:`ScatterMaskToolsWidget`: GUI for :class:`ScatterMask`
+- :class:`ScatterMaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "15/02/2019"
+
+
+import math
+import logging
+import os
+import numpy
+import sys
+
+from .. import qt
+from ...math.combo import min_max
+from ...image import shapes
+
+from .items import ItemChangedType, Scatter
+from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
+from ..colors import cursorColorForColormap, rgba
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ScatterMask(BaseMask):
+ """A 1D mask for scatter data."""
+
+ def __init__(self, scatter=None):
+ """
+
+ :param scatter: :class:`silx.gui.plot.items.Scatter` instance
+ """
+ BaseMask.__init__(self, scatter)
+
+ def _getXY(self):
+ x = self._dataItem.getXData(copy=False)
+ y = self._dataItem.getYData(copy=False)
+ return x, y
+
+ def getDataValues(self):
+ """Return scatter data values as a 1D array.
+
+ :rtype: 1D numpy.ndarray
+ """
+ return self._dataItem.getValueData(copy=False)
+
+ def save(self, filename, kind):
+ if kind == "npy":
+ try:
+ numpy.save(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+ elif kind in ["csv", "txt"]:
+ try:
+ numpy.savetxt(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+
+ def updatePoints(self, level, indices, mask=True):
+ """Mask/Unmask points with given indices.
+
+ :param int level: Mask level to update.
+ :param indices: Sequence or 1D array of indices of points to be
+ updated
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ if mask:
+ self._mask[indices] = level
+ else:
+ # unmask only where mask level is the specified value
+ indices_stencil = numpy.zeros_like(self._mask, dtype=bool)
+ indices_stencil[indices] = True
+ self._mask[numpy.logical_and(self._mask == level, indices_stencil)] = 0
+ self._notify()
+
+ # update shapes
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask a polygon of the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (y, x) or (row, col)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ polygon = shapes.Polygon(vertices)
+ x, y = self._getXY()
+
+ # TODO: this could be optimized if necessary
+ indices_in_polygon = [
+ idx for idx in range(len(x)) if polygon.is_inside(y[idx], x[idx])
+ ]
+
+ self.updatePoints(level, indices_in_polygon, mask)
+
+ def updateRectangle(self, level, y, x, height, width, mask=True):
+ """Mask/Unmask data inside a rectangle
+
+ :param int level: Mask level to update.
+ :param float y: Y coordinate of bottom left corner of the rectangle
+ :param float x: X coordinate of bottom left corner of the rectangle
+ :param float height:
+ :param float width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ vertices = [(y, x), (y + height, x), (y + height, x + width), (y, x + width)]
+ self.updatePolygon(level, vertices, mask)
+
+ def updateDisk(self, level, cy, cx, radius, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param float cy: Disk center (y).
+ :param float cx: Disk center (x).
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ x, y = self._getXY()
+ stencil = (y - cy) ** 2 + (x - cx) ** 2 < radius**2
+ self.updateStencil(level, stencil, mask)
+
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+
+ def is_inside(px, py):
+ return (px - ccol) ** 2 / radius_c**2 + (
+ py - crow
+ ) ** 2 / radius_r**2 <= 1.0
+
+ x, y = self._getXY()
+ indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])]
+ self.updatePoints(level, indices_inside, mask)
+
+ def updateLine(self, level, y0, x0, y1, x1, width, mask=True):
+ """Mask/Unmask points inside a rectangle defined by a line (two
+ end points) and a width.
+
+ :param int level: Mask level to update.
+ :param float y0: Row of the starting point.
+ :param float x0: Column of the starting point.
+ :param float row1: Row of the end point.
+ :param float col1: Column of the end point.
+ :param float width: Width of the line.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ # theta is the angle between the horizontal and the line
+ theta = math.atan((y1 - y0) / (x1 - x0)) if x1 - x0 else 0
+ w_over_2_sin_theta = width / 2.0 * math.sin(theta)
+ w_over_2_cos_theta = width / 2.0 * math.cos(theta)
+
+ vertices = [
+ (y0 - w_over_2_cos_theta, x0 + w_over_2_sin_theta),
+ (y0 + w_over_2_cos_theta, x0 - w_over_2_sin_theta),
+ (y1 + w_over_2_cos_theta, x1 - w_over_2_sin_theta),
+ (y1 - w_over_2_cos_theta, x1 + w_over_2_sin_theta),
+ ]
+
+ self.updatePolygon(level, vertices, mask)
+
+
+class ScatterMaskToolsWidget(BaseMaskToolsWidget):
+ """Widget with tools for masking data points on a scatter in a
+ :class:`PlotWidget`."""
+
+ def __init__(self, parent=None, plot=None):
+ super(ScatterMaskToolsWidget, self).__init__(parent, plot, mask=ScatterMask())
+ self._z = 2 # Mask layer in plot
+ self._data_scatter = None
+ """plot Scatter item for data"""
+
+ self._data_extent = None
+ """Maximum extent of the data i.e., max(xMax-xMin, yMax-yMin)"""
+
+ self._mask_scatter = None
+ """plot Scatter item for representing the mask"""
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask:
+ The array to use for the mask or None to reset the mask.
+ :type mask: numpy.ndarray of uint8, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 1-tuple if successful.
+ The mask can be cropped or padded to fit active scatter,
+ the returned shape is that of the scatter data.
+ """
+ if self._data_scatter is None:
+ # this can happen if the mask tools widget has never been shown
+ self._data_scatter = self.plot.getActiveScatter()
+ if self._data_scatter is None:
+ return None
+ self._adjustColorAndBrushSize(self._data_scatter)
+
+ if mask is None:
+ self.resetSelectionMask()
+ return self._data_scatter.getXData(copy=False).shape
+
+ mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
+
+ if (
+ self._data_scatter.getXData(copy=False).shape == (0,)
+ or mask.shape == self._data_scatter.getXData(copy=False).shape
+ ):
+ self._mask.setMask(mask, copy=copy)
+ self._mask.commit()
+ return mask.shape
+ else:
+ raise ValueError("Mask does not have the same shape as the data")
+
+ # Handle mask refresh on the plot
+
+ def _updatePlotMask(self):
+ """Update mask image in plot"""
+ mask = self.getSelectionMask(copy=False)
+ if mask is not None:
+ self.plot.addScatter(
+ self._data_scatter.getXData(),
+ self._data_scatter.getYData(),
+ mask,
+ legend=self._maskName,
+ colormap=self._colormap,
+ z=self._z,
+ )
+ self._mask_scatter = self.plot._getItem(
+ kind="scatter", legend=self._maskName
+ )
+ self._mask_scatter.setSymbolSize(self._data_scatter.getSymbolSize() + 2.0)
+ self._mask_scatter.sigItemChanged.connect(self.__maskScatterChanged)
+ elif self.plot._getItem(kind="scatter", legend=self._maskName) is not None:
+ self.plot.remove(self._maskName, kind="scatter")
+
+ def __maskScatterChanged(self, event):
+ """Handles update of mask scatter"""
+ if (
+ event is ItemChangedType.VISUALIZATION_MODE
+ and self._mask_scatter is not None
+ ):
+ self._mask_scatter.setVisualization(Scatter.Visualization.POINTS)
+
+ # track widget visibility and plot active image changes
+
+ def showEvent(self, event):
+ try:
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare
+ )
+ except (RuntimeError, TypeError):
+ pass
+ self._activeScatterChanged(None, None) # Init mask + enable/disable widget
+ self.plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+
+ def hideEvent(self, event):
+ try:
+ # if the method is not connected this raises a TypeError and there is no way
+ # to know the connected slots
+ self.plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
+ except (RuntimeError, TypeError):
+ _logger.info(sys.exc_info()[1])
+
+ if self.isMaskInteractionActivated():
+ # Disable drawing tool
+ self.plot.resetInteractiveMode()
+
+ if self.getSelectionMask(copy=False) is not None:
+ self.plot.sigActiveScatterChanged.connect(
+ self._activeScatterChangedAfterCare
+ )
+
+ def _adjustColorAndBrushSize(self, activeScatter):
+ colormap = activeScatter.getColormap()
+ self._defaultOverlayColor = rgba(cursorColorForColormap(colormap["name"]))
+ self._setMaskColors(
+ self.levelSpinBox.value(),
+ self.transparencySlider.value() / self.transparencySlider.maximum(),
+ )
+ self._z = activeScatter.getZValue() + 1
+ self._data_scatter = activeScatter
+
+ # Adjust brush size to data range
+ xData = self._data_scatter.getXData(copy=False)
+ yData = self._data_scatter.getYData(copy=False)
+ # Adjust brush size to data range
+ if xData.size > 0 and yData.size > 0:
+ xMin, xMax = min_max(xData)
+ yMin, yMax = min_max(yData)
+ self._data_extent = max(xMax - xMin, yMax - yMin)
+ else:
+ self._data_extent = None
+
+ def _activeScatterChangedAfterCare(self, previous, next):
+ """Check synchro of active scatter and mask when mask widget is hidden.
+
+ If active image has no more the same size as the mask, the mask is
+ removed, otherwise it is adjusted to z.
+ """
+ # check that content changed was the active scatter
+ activeScatter = self.plot.getActiveScatter()
+
+ if activeScatter is None or activeScatter.getName() == self._maskName:
+ # No active scatter or active scatter is the mask...
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare
+ )
+ self._data_extent = None
+ self._data_scatter = None
+
+ else:
+ self._adjustColorAndBrushSize(activeScatter)
+
+ if (
+ self._data_scatter.getXData(copy=False).shape
+ != self._mask.getMask(copy=False).shape
+ ):
+ # scatter has not the same size, remove mask and stop listening
+ if self.plot._getItem(kind="scatter", legend=self._maskName):
+ self.plot.remove(self._maskName, kind="scatter")
+
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare
+ )
+ self._data_extent = None
+ self._data_scatter = None
+
+ else:
+ # Refresh in case z changed
+ self._mask.setDataItem(self._data_scatter)
+ self._updatePlotMask()
+
+ def _activeScatterChanged(self, previous, next):
+ """Update widget and mask according to active scatter changes"""
+ activeScatter = self.plot.getActiveScatter()
+
+ if activeScatter is None or activeScatter.getName() == self._maskName:
+ # No active scatter or active scatter is the mask...
+ self.setEnabled(False)
+
+ self._data_scatter = None
+ self._data_extent = None
+ self._mask.reset()
+ self._mask.commit()
+
+ else: # There is an active scatter
+ self.setEnabled(True)
+ self._adjustColorAndBrushSize(activeScatter)
+
+ self._mask.setDataItem(self._data_scatter)
+ if (
+ self._data_scatter.getXData(copy=False).shape
+ != self._mask.getMask(copy=False).shape
+ ):
+ self._mask.reset(self._data_scatter.getXData(copy=False).shape)
+ self._mask.commit()
+ else:
+ # Refresh in case z changed
+ self._updatePlotMask()
+
+ self._updateInteractiveMode()
+
+ # Handle whole mask operations
+
+ def load(self, filename):
+ """Load a mask from an image file.
+
+ :param str filename: File name from which to load the mask
+ :raise Exception: An exception in case of failure
+ :raise RuntimeWarning: In case the mask was applied but with some
+ import changes to notice
+ """
+ _, extension = os.path.splitext(filename)
+ extension = extension.lower()[1:]
+ if extension == "npy":
+ try:
+ mask = numpy.load(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy file.', filename)
+ elif extension in ["txt", "csv"]:
+ try:
+ mask = numpy.loadtxt(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy txt file.', filename)
+ else:
+ msg = "Extension '%s' is not supported."
+ raise RuntimeError(msg % extension)
+
+ self.setSelectionMask(mask, copy=False)
+
+ def _loadMask(self):
+ """Open load mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Load Mask")
+ dialog.setModal(1)
+ filters = [
+ "NumPy binary file (*.npy)",
+ "CSV text file (*.csv)",
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.load(filename)
+ # except RuntimeWarning as e:
+ # message = e.args[0]
+ # msg = qt.QMessageBox(self)
+ # msg.setIcon(qt.QMessageBox.Warning)
+ # msg.setText("Mask loaded but an operation was applied.\n" + message)
+ # msg.exec()
+ except Exception as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot load mask from file. " + message)
+ msg.exec()
+
+ def _saveMask(self):
+ """Open Save mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Save Mask")
+ dialog.setModal(1)
+ filters = [
+ "NumPy binary file (*.npy)",
+ "CSV text file (*.csv)",
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ # convert filter name to extension name with the .
+ extension = dialog.selectedNameFilter().split()[-1][2:-1]
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if not filename.lower().endswith(extension):
+ filename += extension
+
+ if os.path.exists(filename):
+ try:
+ os.remove(filename)
+ except IOError as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Removing existing file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save.\n" "Input Output Error: %s" % strerror)
+ msg.exec()
+ return
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.save(filename, extension[1:])
+ except Exception as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Saving mask file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save file %s\n%s" % (filename, strerror))
+ msg.exec()
+
+ def resetSelectionMask(self):
+ """Reset the mask"""
+ self._mask.reset(shape=self._data_scatter.getXData(copy=False).shape)
+ self._mask.commit()
+
+ def _getPencilWidth(self):
+ """Returns the width of the pencil to use in data coordinates`
+
+ :rtype: float
+ """
+ width = super(ScatterMaskToolsWidget, self)._getPencilWidth()
+ if self._data_extent is not None:
+ width *= 0.01 * self._data_extent
+ return width
+
+ def _plotDrawEvent(self, event):
+ """Handle draw events from the plot"""
+ if self._drawingMode is None or event["event"] not in (
+ "drawingProgress",
+ "drawingFinished",
+ ):
+ return
+
+ if not len(self._data_scatter.getXData(copy=False)):
+ return
+
+ level = self.levelSpinBox.value()
+
+ if self._drawingMode == "rectangle":
+ if event["event"] == "drawingFinished":
+ doMask = self._isMasking()
+
+ self._mask.updateRectangle(
+ level,
+ y=event["y"],
+ x=event["x"],
+ height=abs(event["height"]),
+ width=abs(event["width"]),
+ mask=doMask,
+ )
+ self._mask.commit()
+
+ elif self._drawingMode == "ellipse":
+ if event["event"] == "drawingFinished":
+ doMask = self._isMasking()
+ center = event["points"][0]
+ size = event["points"][1]
+ self._mask.updateEllipse(
+ level, center[1], center[0], size[1], size[0], doMask
+ )
+ self._mask.commit()
+
+ elif self._drawingMode == "polygon":
+ if event["event"] == "drawingFinished":
+ doMask = self._isMasking()
+ vertices = event["points"]
+ vertices = vertices[:, (1, 0)] # (y, x)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == "pencil":
+ doMask = self._isMasking()
+ # convert from plot to array coords
+ x, y = event["points"][-1]
+
+ brushSize = self._getPencilWidth()
+
+ if self._lastPencilPos != (y, x):
+ if self._lastPencilPos is not None:
+ # Draw the line
+ self._mask.updateLine(
+ level,
+ self._lastPencilPos[0],
+ self._lastPencilPos[1],
+ y,
+ x,
+ brushSize,
+ doMask,
+ )
+
+ # Draw the very first, or last point
+ self._mask.updateDisk(level, y, x, brushSize / 2.0, doMask)
+
+ if event["event"] == "drawingFinished":
+ self._mask.commit()
+ self._lastPencilPos = None
+ else:
+ self._lastPencilPos = y, x
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
+
+ def _loadRangeFromColormapTriggered(self):
+ """Set range from active scatter colormap range"""
+ if self._data_scatter is not None:
+ # Update thresholds according to colormap
+ colormap = self._data_scatter.getColormap()
+ if colormap["autoscale"]:
+ min_ = numpy.nanmin(self._data_scatter.getValueData(copy=False))
+ max_ = numpy.nanmax(self._data_scatter.getValueData(copy=False))
+ else:
+ min_, max_ = colormap["vmin"], colormap["vmax"]
+ self.minLineEdit.setText(str(min_))
+ self.maxLineEdit.setText(str(max_))
+
+
+class ScatterMaskToolsDockWidget(BaseMaskToolsDockWidget):
+ """:class:`ScatterMaskToolsWidget` embedded in a QDockWidget.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: The PlotWidget this widget is operating on
+ :paran str name: The title of this widget
+ """
+
+ def __init__(self, parent=None, plot=None, name="Mask"):
+ widget = ScatterMaskToolsWidget(plot=plot)
+ super(ScatterMaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/src/silx/gui/plot/ScatterView.py b/src/silx/gui/plot/ScatterView.py
new file mode 100644
index 0000000..06475e3
--- /dev/null
+++ b/src/silx/gui/plot/ScatterView.py
@@ -0,0 +1,418 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A widget dedicated to display scatter plots
+
+It is based on a :class:`~silx.gui.plot.PlotWidget` with additional tools
+for scatter plots.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "14/06/2018"
+
+
+import logging
+import weakref
+
+import numpy
+
+from . import items
+from . import PlotWidget
+from . import tools
+from .actions import histogram as actions_histogram
+from .tools.profile import ScatterProfileToolBar
+from .ColorBar import ColorBarWidget
+from .ScatterMaskToolsWidget import ScatterMaskToolsWidget
+
+from ..widgets.BoxLayoutDockWidget import BoxLayoutDockWidget
+from .. import qt, icons
+from ...utils.proxy import docstring
+from ...utils.weakref import WeakMethodProxy
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ScatterView(qt.QMainWindow):
+ """Main window with a PlotWidget and tools specific for scatter plots.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`~silx.gui.plot.PlotWidget` for the list of supported backend.
+ :type backend: Union[str,~silx.gui.plot.backends.BackendBase.BackendBase]
+ """
+
+ _SCATTER_LEGEND = " "
+ """Legend used for the scatter item"""
+
+ def __init__(self, parent=None, backend=None):
+ super(ScatterView, self).__init__(parent=parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle("ScatterView")
+
+ # Create plot widget
+ plot = PlotWidget(parent=self, backend=backend)
+ self._plot = weakref.ref(plot)
+
+ # Add an empty scatter
+ self.__createEmptyScatter()
+
+ # Create colorbar widget with white background
+ self._colorbar = ColorBarWidget(parent=self, plot=plot)
+ self._colorbar.setAutoFillBackground(True)
+ palette = self._colorbar.palette()
+ palette.setColor(qt.QPalette.Window, qt.Qt.white)
+ self._colorbar.setPalette(palette)
+
+ # Create PositionInfo widget
+ self.__lastPickingPos = None
+ self.__pickingCache = None
+ self._positionInfo = tools.PositionInfo(
+ plot=plot,
+ converters=(
+ ("X", WeakMethodProxy(self._getPickedX)),
+ ("Y", WeakMethodProxy(self._getPickedY)),
+ ("Data", WeakMethodProxy(self._getPickedValue)),
+ ("Index", WeakMethodProxy(self._getPickedIndex)),
+ ),
+ )
+
+ # Combine plot, position info and colorbar into central widget
+ gridLayout = qt.QGridLayout()
+ gridLayout.setSpacing(0)
+ gridLayout.setContentsMargins(0, 0, 0, 0)
+ gridLayout.addWidget(plot, 0, 0)
+ gridLayout.addWidget(self._colorbar, 0, 1)
+ gridLayout.addWidget(self._positionInfo, 1, 0, 1, -1)
+ gridLayout.setRowStretch(0, 1)
+ gridLayout.setColumnStretch(0, 1)
+ centralWidget = qt.QWidget(self)
+ centralWidget.setLayout(gridLayout)
+ self.setCentralWidget(centralWidget)
+
+ # Create mask tool dock widget
+ self._maskToolsWidget = ScatterMaskToolsWidget(parent=self, plot=plot)
+ self._maskDock = BoxLayoutDockWidget()
+ self._maskDock.setWindowTitle("Scatter Mask")
+ self._maskDock.setWidget(self._maskToolsWidget)
+ self._maskDock.setVisible(False)
+ self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._maskDock)
+
+ self._maskAction = self._maskDock.toggleViewAction()
+ self._maskAction.setIcon(icons.getQIcon("image-mask"))
+ self._maskAction.setToolTip("Display/hide mask tools")
+
+ self._intensityHistoAction = actions_histogram.PixelIntensitiesHistoAction(
+ plot=plot, parent=self
+ )
+
+ # Create toolbars
+ self._interactiveModeToolBar = tools.InteractiveModeToolBar(
+ parent=self, plot=plot
+ )
+
+ self._scatterToolBar = tools.ScatterToolBar(parent=self, plot=plot)
+ self._scatterToolBar.addAction(self._maskAction)
+ self._scatterToolBar.addAction(self._intensityHistoAction)
+
+ self._profileToolBar = ScatterProfileToolBar(parent=self, plot=plot)
+
+ self._outputToolBar = tools.OutputToolBar(parent=self, plot=plot)
+
+ # Activate shortcuts in PlotWindow widget:
+ for toolbar in (
+ self._interactiveModeToolBar,
+ self._scatterToolBar,
+ self._profileToolBar,
+ self._outputToolBar,
+ ):
+ self.addToolBar(toolbar)
+ for action in toolbar.actions():
+ self.addAction(action)
+
+ def __createEmptyScatter(self):
+ """Create an empty scatter item that is used to display the data
+
+ :rtype: ~silx.gui.plot.items.Scatter
+ """
+ plot = self.getPlotWidget()
+ plot.addScatter(x=(), y=(), value=(), legend=self._SCATTER_LEGEND)
+ scatter = plot._getItem(kind="scatter", legend=self._SCATTER_LEGEND)
+ # Profile is not selectable,
+ # so it does not interfere with profile interaction
+ scatter._setSelectable(False)
+ return scatter
+
+ def _pickScatterData(self, x, y):
+ """Get data and index and value of top most scatter plot at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The data index and value at that point or None
+ """
+ pickingPos = x, y
+ if self.__lastPickingPos != pickingPos:
+ self.__pickingCache = None
+ self.__lastPickingPos = pickingPos
+
+ plot = self.getPlotWidget()
+ if plot is not None:
+ pixelPos = plot.dataToPixel(x, y)
+ if pixelPos is not None:
+ # Start from top-most item
+ result = plot._pickTopMost(
+ pixelPos[0],
+ pixelPos[1],
+ lambda item: isinstance(item, items.Scatter),
+ )
+ if result is not None:
+ item = result.getItem()
+ if (
+ item.getVisualization()
+ is items.Scatter.Visualization.BINNED_STATISTIC
+ ):
+ # Get highest index of closest points
+ selected = result.getIndices(copy=False)[::-1]
+ dataIndex = selected[
+ numpy.argmin(
+ (item.getXData(copy=False)[selected] - x) ** 2
+ + (item.getYData(copy=False)[selected] - y) ** 2
+ )
+ ]
+ else:
+ # Get last index
+ # with matplotlib it should be the top-most point
+ dataIndex = result.getIndices(copy=False)[-1]
+ self.__pickingCache = (
+ dataIndex,
+ item.getXData(copy=False)[dataIndex],
+ item.getYData(copy=False)[dataIndex],
+ item.getValueData(copy=False)[dataIndex],
+ )
+
+ return self.__pickingCache
+
+ def _getPickedIndex(self, x, y):
+ """Get data index of top most scatter plot at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The data index at that point or '-'
+ """
+ picking = self._pickScatterData(x, y)
+ return "-" if picking is None else picking[0]
+
+ def _getPickedX(self, x, y):
+ """Returns X position snapped to scatter plot when close enough
+
+ :param float x:
+ :param float y:
+ :rtype: float
+ """
+ picking = self._pickScatterData(x, y)
+ return x if picking is None else picking[1]
+
+ def _getPickedY(self, x, y):
+ """Returns Y position snapped to scatter plot when close enough
+
+ :param float x:
+ :param float y:
+ :rtype: float
+ """
+ picking = self._pickScatterData(x, y)
+ return y if picking is None else picking[2]
+
+ def _getPickedValue(self, x, y):
+ """Get data value of top most scatter plot at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The data value at that point or '-'
+ """
+ picking = self._pickScatterData(x, y)
+ return "-" if picking is None else picking[3]
+
+ def _mouseInPlotArea(self, x, y):
+ """Clip mouse coordinates to plot area coordinates
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :return: (x, y) in data coordinates
+ """
+ plot = self.getPlotWidget()
+ left, top, width, height = plot.getPlotBoundsInPixels()
+ xPlot = numpy.clip(x, left, left + width - 1)
+ yPlot = numpy.clip(y, top, top + height - 1)
+ return xPlot, yPlot
+
+ def getPlotWidget(self):
+ """Returns the :class:`~silx.gui.plot.PlotWidget` this window is based on.
+
+ :rtype: ~silx.gui.plot.PlotWidget
+ """
+ return self._plot()
+
+ def getPositionInfoWidget(self):
+ """Returns the widget display mouse coordinates information.
+
+ :rtype: ~silx.gui.plot.tools.PositionInfo
+ """
+ return self._positionInfo
+
+ def getMaskToolsWidget(self):
+ """Returns the widget controlling mask drawing
+
+ :rtype: ~silx.gui.plot.ScatterMaskToolsWidget
+ """
+ return self._maskToolsWidget
+
+ def getInteractiveModeToolBar(self):
+ """Returns QToolBar controlling interactive mode.
+
+ :rtype: ~silx.gui.plot.tools.InteractiveModeToolBar
+ """
+ return self._interactiveModeToolBar
+
+ def getScatterToolBar(self):
+ """Returns QToolBar providing scatter plot tools.
+
+ :rtype: ~silx.gui.plot.tools.ScatterToolBar
+ """
+ return self._scatterToolBar
+
+ def getScatterProfileToolBar(self):
+ """Returns QToolBar providing scatter profile tools.
+
+ :rtype: ~silx.gui.plot.tools.profile.ScatterProfileToolBar
+ """
+ return self._profileToolBar
+
+ def getOutputToolBar(self):
+ """Returns QToolBar containing save, copy and print actions
+
+ :rtype: ~silx.gui.plot.tools.OutputToolBar
+ """
+ return self._outputToolBar
+
+ def setColormap(self, colormap=None):
+ """Set the colormap for the displayed scatter and the
+ default plot colormap.
+
+ :param ~silx.gui.colors.Colormap colormap:
+ The description of the colormap.
+ """
+ self.getScatterItem().setColormap(colormap)
+ # Resilient to call to PlotWidget API (e.g., clear)
+ self.getPlotWidget().setDefaultColormap(colormap)
+
+ def getColormap(self):
+ """Return the colormap object in use.
+
+ :return: Colormap currently in use
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self.getScatterItem().getColormap()
+
+ # Control displayed scatter plot
+
+ def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
+ """Set the data of the scatter plot.
+
+ To reset the scatter plot, set x, y and value to None.
+
+ :param Union[numpy.ndarray,None] x: X coordinates.
+ :param Union[numpy.ndarray,None] y: Y coordinates.
+ :param Union[numpy.ndarray,None] value:
+ The data corresponding to the value of the data points.
+ :param xerror: Values with the uncertainties on the x values.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :type xerror: A float, or a numpy.ndarray of float32.
+
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param alpha: Values with the transparency (between 0 and 1)
+ :type alpha: A float, or a numpy.ndarray of float32
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ x = () if x is None else x
+ y = () if y is None else y
+ value = () if value is None else value
+
+ self.getScatterItem().setData(
+ x=x, y=y, value=value, xerror=xerror, yerror=yerror, alpha=alpha, copy=copy
+ )
+
+ @docstring(items.Scatter)
+ def getData(self, *args, **kwargs):
+ return self.getScatterItem().getData(*args, **kwargs)
+
+ def getScatterItem(self):
+ """Returns the plot item displaying the scatter data.
+
+ This allows to set the style of the displayed scatter.
+
+ :rtype: ~silx.gui.plot.items.Scatter
+ """
+ plot = self.getPlotWidget()
+ scatter = plot._getItem(kind="scatter", legend=self._SCATTER_LEGEND)
+ if scatter is None: # Resilient to call to PlotWidget API (e.g., clear)
+ scatter = self.__createEmptyScatter()
+ return scatter
+
+ # Convenient proxies
+
+ @docstring(PlotWidget)
+ def getXAxis(self, *args, **kwargs):
+ return self.getPlotWidget().getXAxis(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def getYAxis(self, *args, **kwargs):
+ return self.getPlotWidget().getYAxis(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def setGraphTitle(self, *args, **kwargs):
+ return self.getPlotWidget().setGraphTitle(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def getGraphTitle(self, *args, **kwargs):
+ return self.getPlotWidget().getGraphTitle(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def resetZoom(self, *args, **kwargs):
+ return self.getPlotWidget().resetZoom(*args, **kwargs)
+
+ @docstring(ScatterMaskToolsWidget)
+ def getSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs)
+
+ @docstring(ScatterMaskToolsWidget)
+ def setSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs)
diff --git a/src/silx/gui/plot/StackView.py b/src/silx/gui/plot/StackView.py
new file mode 100644
index 0000000..36560fd
--- /dev/null
+++ b/src/silx/gui/plot/StackView.py
@@ -0,0 +1,1248 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying a 3D volume as a stack of 2D images.
+
+The :class:`StackView` class implements this widget.
+
+Basic usage of :class:`StackView` is through the following methods:
+
+- :meth:`StackView.getColormap`, :meth:`StackView.setColormap` to update the
+ default colormap to use and update the currently displayed image.
+- :meth:`StackView.setStack` to update the displayed image.
+
+The :class:`StackView` uses :class:`PlotWindow` and also
+exposes a subset of the :class:`silx.gui.plot.Plot` API for further control
+(plot title, axes labels, ...).
+
+The :class:`StackViewMainWindow` class implements a widget that adds a status
+bar displaying the 3D index and the value under the mouse cursor.
+
+Example::
+
+ import numpy
+ import sys
+ from silx.gui import qt
+ from silx.gui.plot.StackView import StackViewMainWindow
+
+
+ app = qt.QApplication(sys.argv[1:])
+
+ # synthetic data, stack of 100 images of size 200x300
+ mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (100, 200, 300)
+ )
+
+
+ sv = StackViewMainWindow()
+ sv.setColormap("viridis", vmin=-4, vmax=4)
+ sv.setStack(mystack)
+ sv.setLabels(["1st dim (0-99)", "2nd dim (0-199)",
+ "3rd dim (0-299)"])
+ sv.show()
+
+ app.exec()
+
+"""
+
+__authors__ = ["P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "10/10/2018"
+
+import numpy
+import logging
+
+import silx
+from silx.gui import qt
+from .. import icons
+from . import items, PlotWindow, actions
+from .items.image import ImageStack
+from ..colors import Colormap
+from ..colors import cursorColorForColormap
+from .tools import LimitsToolBar
+from .Profile import Profile3DToolBar
+from ..widgets.FrameBrowser import HorizontalSliderWithBrowser
+
+from silx.gui.plot.actions import io as silx_io
+from silx.io.nxdata import save_NXdata
+from silx.utils.array_like import DatasetView, ListOfImages
+from silx.math import calibration
+
+from silx.io.utils import is_dataset
+
+_logger = logging.getLogger(__name__)
+
+
+class StackView(qt.QMainWindow):
+ """Stack view widget, to display and browse through stack of
+ images.
+
+ The profile tool can be switched to "3D" mode, to compute the profile
+ on each image of the stack (not only the active image currently displayed)
+ and display the result as a slice.
+
+ :param QWidget parent: the Qt parent, or None
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ :param bool resetzoom: Toggle visibility of reset zoom action.
+ :param bool autoScale: Toggle visibility of axes autoscale actions.
+ :param bool logScale: Toggle visibility of axes log scale actions.
+ :param bool grid: Toggle visibility of grid mode action.
+ :param bool colormap: Toggle visibility of colormap action.
+ :param bool aspectRatio: Toggle visibility of aspect ratio button.
+ :param bool yInverted: Toggle visibility of Y axis direction button.
+ :param bool copy: Toggle visibility of copy action.
+ :param bool save: Toggle visibility of save action.
+ :param bool print_: Toggle visibility of print action.
+ :param bool control: True to display an Options button with a sub-menu
+ to show legends, toggle crosshair and pan with arrows.
+ (Default: False)
+ :param position: True to display widget with (x, y) mouse position
+ (Default: False).
+ It also supports a list of (name, funct(x, y)->value)
+ to customize the displayed values.
+ See :class:`silx.gui.plot.PlotTools.PositionInfo`.
+ :param bool mask: Toggle visibilty of mask action.
+ """
+
+ # Qt signals
+ valueChanged = qt.Signal(object, object, object)
+ """Signals that the data value under the cursor has changed.
+
+ It provides: row, column, data value.
+ """
+
+ sigPlaneSelectionChanged = qt.Signal(int)
+ """Signal emitted when there is a change is perspective/displayed axes.
+
+ It provides the perspective as an integer, with the following meaning:
+
+ - 0: axis Y is the 2nd dimension, axis X is the 3rd dimension
+ - 1: axis Y is the 1st dimension, axis X is the 3rd dimension
+ - 2: axis Y is the 1st dimension, axis X is the 2nd dimension
+ """
+
+ sigStackChanged = qt.Signal(int)
+ """Signal emitted when the stack is changed.
+ This happens when a new volume is loaded, or when the current volume
+ is transposed (change in perspective).
+
+ The signal provides the size (number of pixels) of the stack.
+ This will be 0 if the stack is cleared, else it will be a positive
+ integer.
+ """
+
+ sigFrameChanged = qt.Signal(int)
+ """Signal emitter when the frame number has changed.
+
+ This signal provides the current frame number.
+ """
+
+ IMAGE_STACK_FILTER_NXDATA = (
+ "Stack of images as NXdata (%s)" % silx_io._NEXUS_HDF5_EXT_STR
+ )
+
+ def __init__(
+ self,
+ parent=None,
+ resetzoom=True,
+ backend=None,
+ autoScale=False,
+ logScale=False,
+ grid=False,
+ colormap=True,
+ aspectRatio=True,
+ yinverted=True,
+ copy=True,
+ save=True,
+ print_=True,
+ control=False,
+ position=None,
+ mask=True,
+ ):
+ qt.QMainWindow.__init__(self, parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle("StackView")
+
+ self._stack = None
+ """Loaded stack, as a 3D array, a 3D dataset or a list of 2D arrays."""
+ self.__transposed_view = None
+ """View on :attr:`_stack` with the axes sorted, to have
+ the orthogonal dimension first"""
+ self._perspective = 0
+ """Orthogonal dimension (depth) in :attr:`_stack`"""
+
+ self._stackItem = ImageStack()
+ """Hold the item displaying the stack"""
+ imageLegend = "__StackView__image" + str(id(self))
+ self._stackItem.setName(imageLegend)
+
+ self.__dimensionsLabels = ["Dimension 0", "Dimension 1", "Dimension 2"]
+ """These labels are displayed on the X and Y axes.
+ :meth:`setLabels` updates this attribute."""
+
+ self._first_stack_dimension = 0
+ """Used for dimension labels and combobox"""
+
+ self._titleCallback = self._defaultTitleCallback
+ """Function returning the plot title based on the frame index.
+ It can be set to a custom function using :meth:`setTitleCallback`"""
+
+ self.calibrations3D = (
+ calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ )
+
+ central_widget = qt.QWidget(self)
+
+ self._plot = PlotWindow(
+ parent=central_widget,
+ backend=backend,
+ resetzoom=resetzoom,
+ autoScale=autoScale,
+ logScale=logScale,
+ grid=grid,
+ curveStyle=False,
+ colormap=colormap,
+ aspectRatio=aspectRatio,
+ yInverted=yinverted,
+ copy=copy,
+ save=save,
+ print_=print_,
+ control=control,
+ position=position,
+ roi=False,
+ mask=mask,
+ )
+ self._plot.addItem(self._stackItem)
+ self._plot.getIntensityHistogramAction().setVisible(True)
+ self.sigInteractiveModeChanged = self._plot.sigInteractiveModeChanged
+ self.sigActiveImageChanged = self._plot.sigActiveImageChanged
+ self.sigPlotSignal = self._plot.sigPlotSignal
+
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
+ self._plot.getYAxis().setInverted(True)
+
+ self._plot.getColorBarAction().setVisible(True)
+ self._plot.getColorBarWidget().setVisible(True)
+
+ self._profileToolBar = Profile3DToolBar(parent=self._plot, stackview=self)
+ self._plot.addToolBar(self._profileToolBar)
+ self._plot.getXAxis().setLabel("Columns")
+ self._plot.getYAxis().setLabel("Rows")
+ self._plot.sigPlotSignal.connect(self._plotCallback)
+ self._plot.getSaveAction().setFileFilter(
+ "image",
+ self.IMAGE_STACK_FILTER_NXDATA,
+ func=self._saveImageStack,
+ appendToFile=True,
+ )
+
+ self.__planeSelection = PlanesWidget(self._plot)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
+
+ self._browser_label = qt.QLabel("Image index (Dim0):")
+
+ self._browser = HorizontalSliderWithBrowser(central_widget)
+ self._browser.setRange(0, 0)
+ self._browser.valueChanged[int].connect(self.__updateFrameNumber)
+ self._browser.setEnabled(False)
+
+ layout = qt.QGridLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot, 0, 0, 1, 3)
+ layout.addWidget(self.__planeSelection, 1, 0)
+ layout.addWidget(self._browser_label, 1, 1)
+ layout.addWidget(self._browser, 1, 2)
+
+ central_widget.setLayout(layout)
+ self.setCentralWidget(central_widget)
+
+ # clear profile lines when the perspective changes (plane browsed changed)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(
+ self._profileToolBar.clearProfile
+ )
+
+ def _saveImageStack(self, plot, filename, nameFilter):
+ """Save all images from the stack into a volume.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ :raises: ValueError if nameFilter is invalid
+ """
+ if not nameFilter == self.IMAGE_STACK_FILTER_NXDATA:
+ raise ValueError("Wrong callback")
+ entryPath = silx_io.SaveAction._selectWriteableOutputGroup(
+ filename, parent=self
+ )
+ if entryPath is None:
+ return False
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=self.getStack(copy=False, returnNumpyArray=True)[0],
+ signal_name="image_stack",
+ )
+
+ def _plotCallback(self, eventDict):
+ """Callback for plot events.
+
+ Emit :attr:`valueChanged` signal, with (x, y, value) tuple of the
+ cursor location in the plot."""
+ if eventDict["event"] == "mouseMoved":
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ data = activeImage.getData()
+ height, width = data.shape
+
+ # Get corresponding coordinate in image
+ origin = activeImage.getOrigin()
+ scale = activeImage.getScale()
+ x = int((eventDict["x"] - origin[0]) / scale[0])
+ y = int((eventDict["y"] - origin[1]) / scale[1])
+
+ if 0 <= x < width and 0 <= y < height:
+ self.valueChanged.emit(float(x), float(y), data[y][x])
+ else:
+ self.valueChanged.emit(float(x), float(y), None)
+
+ def getPerspective(self):
+ """Returns the index of the dimension the stack is browsed with
+
+ Possible values are: 0, 1, or 2.
+
+ :rtype: int
+ """
+ return self._perspective
+
+ def setPerspective(self, perspective):
+ """Set the index of the dimension the stack is browsed with:
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+
+ :param int perspective: Orthogonal dimension number (0, 1, or 2)
+ """
+ if perspective == self._perspective:
+ return
+ else:
+ if perspective > 2 or perspective < 0:
+ raise ValueError("Perspective must be 0, 1 or 2, not %s" % perspective)
+
+ self._perspective = int(perspective)
+ self.__createTransposedView()
+ self.__updateFrameNumber(self._browser.value())
+ self._plot.resetZoom()
+ self.__updatePlotLabels()
+ self._updateTitle()
+ self._browser_label.setText(
+ "Image index (Dim%d):" % (self._first_stack_dimension + perspective)
+ )
+
+ self.sigPlaneSelectionChanged.emit(perspective)
+ self.sigStackChanged.emit(
+ self._stack.size if self._stack is not None else 0
+ )
+ self.__planeSelection.sigPlaneSelectionChanged.disconnect(
+ self.setPerspective
+ )
+ self.__planeSelection.setPerspective(self._perspective)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
+
+ def __updatePlotLabels(self):
+ """Update plot axes labels depending on perspective"""
+ y, x = (
+ (1, 2)
+ if self._perspective == 0
+ else (0, 2)
+ if self._perspective == 1
+ else (0, 1)
+ )
+ self.setGraphXLabel(self.__dimensionsLabels[x])
+ self.setGraphYLabel(self.__dimensionsLabels[y])
+
+ def __createTransposedView(self):
+ """Create the new view on the stack depending on the perspective
+ (set orthogonal axis browsed on the viewer as first dimension)
+ """
+ assert self._stack is not None
+ assert 0 <= self._perspective < 3
+
+ # ensure we have the stack encapsulated in an array-like object
+ # having a transpose() method
+ if isinstance(self._stack, numpy.ndarray):
+ self.__transposed_view = self._stack
+
+ elif is_dataset(self._stack) or isinstance(self._stack, DatasetView):
+ self.__transposed_view = DatasetView(self._stack)
+
+ elif isinstance(self._stack, ListOfImages):
+ self.__transposed_view = ListOfImages(self._stack)
+
+ # transpose the array-like object if necessary
+ if self._perspective == 1:
+ self.__transposed_view = self.__transposed_view.transpose((1, 0, 2))
+ elif self._perspective == 2:
+ self.__transposed_view = self.__transposed_view.transpose((2, 0, 1))
+
+ self._browser.setRange(0, self.__transposed_view.shape[0] - 1)
+ self._browser.setValue(0)
+
+ # Update the item structure
+ self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
+ self._stackItem.setColormap(self.getColormap())
+ self._stackItem.setOrigin(self._getImageOrigin())
+ self._stackItem.setScale(self._getImageScale())
+
+ def __updateFrameNumber(self, index):
+ """Update the current image.
+
+ :param index: index of the frame to be displayed
+ """
+ if self.__transposed_view is None:
+ # no data set
+ return
+
+ self._stackItem.setStackPosition(index)
+
+ self._updateTitle()
+ self.sigFrameChanged.emit(index)
+
+ def _set3DScaleAndOrigin(self, calibrations):
+ """Set scale and origin for all 3 axes, to be used when plotting
+ an image.
+
+ See setStack for parameter documentation
+ """
+ if calibrations is None:
+ self.calibrations3D = (
+ calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ )
+ else:
+ self.calibrations3D = []
+ for i, calib in enumerate(calibrations):
+ if hasattr(calib, "__len__") and len(calib) == 2:
+ calib = calibration.LinearCalibration(calib[0], calib[1])
+ elif calib is None:
+ calib = calibration.NoCalibration()
+ elif not isinstance(calib, calibration.AbstractCalibration):
+ raise TypeError(
+ "calibration must be a 2-tuple, None or"
+ + " an instance of an AbstractCalibration "
+ + "subclass"
+ )
+ elif not calib.is_affine():
+ _logger.warning(
+ "Calibration for dimension %d is not linear, "
+ "it will be ignored for scaling the graph axes.",
+ i,
+ )
+ self.calibrations3D.append(calib)
+
+ def getCalibrations(self, order="array"):
+ """Returns currently used calibrations for each axis
+
+ Returned calibrations might differ from the ones that were set as
+ non-linear calibrations used for image axes are temporarily ignored.
+
+ :param str order:
+ 'array' to sort calibrations as data array (dim0, dim1, dim2),
+ 'axes' to sort calibrations as currently selected x, y and z axes.
+ :return: Calibrations ordered depending on order
+ :rtype: List[~silx.math.calibration.AbstractCalibration]
+ """
+ assert order in ("array", "axes")
+ calibs = []
+
+ # filter out non-linear calibration for graph axes
+ for index, calib in enumerate(self.calibrations3D):
+ if index != self._perspective and not calib.is_affine():
+ calib = calibration.NoCalibration()
+ calibs.append(calib)
+
+ if order == "axes": # Move 'z' axis to the end
+ xy_dims = [d for d in (0, 1, 2) if d != self._perspective]
+ calibs = [
+ calibs[max(xy_dims)],
+ calibs[min(xy_dims)],
+ calibs[self._perspective],
+ ]
+
+ return tuple(calibs)
+
+ def _getImageScale(self):
+ """
+ :return: 2-tuple (XScale, YScale) for current image view
+ """
+ xcalib, ycalib, _zcalib = self.getCalibrations(order="axes")
+ return xcalib.get_slope(), ycalib.get_slope()
+
+ def _getImageOrigin(self):
+ """
+ :return: 2-tuple (XOrigin, YOrigin) for current image view
+ """
+ xcalib, ycalib, _zcalib = self.getCalibrations(order="axes")
+ return xcalib(0), ycalib(0)
+
+ def _getImageZ(self, index):
+ """
+ :param idx: 0-based image index in the stack
+ :return: calibrated Z value corresponding to the image idx
+ """
+ _xcalib, _ycalib, zcalib = self.getCalibrations(order="axes")
+ return zcalib(index)
+
+ def _updateTitle(self):
+ frame_idx = self._browser.value()
+ self._plot.setGraphTitle(self._titleCallback(frame_idx))
+
+ def _defaultTitleCallback(self, index):
+ return "Image z=%g" % self._getImageZ(index)
+
+ # public API, stack specific methods
+ def setStack(self, stack, perspective=None, reset=True, calibrations=None):
+ """Set the 3D stack.
+
+ The perspective parameter is used to define which dimension of the 3D
+ array is to be used as frame index. The lowest remaining dimension
+ number is the row index of the displayed image (Y axis), and the highest
+ remaining dimension is the column index (X axis).
+
+ :param stack: 3D stack, or `None` to clear plot.
+ :type stack: 3D numpy.ndarray, or 3D h5py.Dataset, or list/tuple of 2D
+ numpy arrays, or None.
+ :param int perspective: Dimension for the frame index: 0, 1 or 2.
+ Use ``None`` to keep the current perspective (default).
+ :param bool reset: Whether to reset zoom or not.
+ :param calibrations: Sequence of 3 calibration objects for each axis.
+ These objects can be a subclass of :class:`AbstractCalibration`,
+ or 2-tuples *(a, b)* where *a* is the y-intercept and *b* is the
+ slope of a linear calibration (:math:`x \\mapsto a + b x`)
+ """
+ if stack is None:
+ self.clear()
+ self.sigStackChanged.emit(0)
+ return
+
+ self._set3DScaleAndOrigin(calibrations)
+
+ # stack as list of 2D arrays: must be converted into an array_like
+ if not isinstance(stack, numpy.ndarray):
+ if not is_dataset(stack):
+ try:
+ assert hasattr(stack, "__len__")
+ for img in stack:
+ assert hasattr(img, "shape")
+ assert len(img.shape) == 2
+ except AssertionError:
+ raise ValueError(
+ "Stack must be a 3D array/dataset or a list of " + "2D arrays."
+ )
+ stack = ListOfImages(stack)
+
+ assert len(stack.shape) == 3, "data must be 3D"
+
+ self._stack = stack
+ self.__createTransposedView()
+
+ perspective_changed = False
+ if perspective not in [None, self._perspective]:
+ perspective_changed = True
+ self.setPerspective(perspective)
+
+ # init plot
+ self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
+ self._stackItem.setColormap(self.getColormap())
+ self._stackItem.setOrigin(self._getImageOrigin())
+ self._stackItem.setScale(self._getImageScale())
+ self._stackItem.setVisible(True)
+
+ # Put back the item in the plot in case it was cleared
+ exists = self._plot.getImage(self._stackItem.getName())
+ if exists is None:
+ self._plot.addItem(self._stackItem)
+
+ self._plot.setActiveImage(self._stackItem)
+ self.__updatePlotLabels()
+ self._updateTitle()
+
+ if reset:
+ self._plot.resetZoom()
+
+ # enable and init browser
+ self._browser.setEnabled(True)
+
+ if not perspective_changed: # avoid double signal (see self.setPerspective)
+ self.sigStackChanged.emit(stack.size)
+
+ def getStack(self, copy=True, returnNumpyArray=False):
+ """Get the original stack, as a 3D array or dataset.
+
+ The output has the form: [data, params]
+ where params is a dictionary containing display parameters.
+
+ :param bool copy: If True (default), then the object is copied
+ and returned as a numpy array.
+ Else, a reference to original data is returned, if possible.
+ If the original data is not a numpy array and parameter
+ returnNumpyArray is True, a copy will be made anyway.
+ :param bool returnNumpyArray: If True, the returned object is
+ guaranteed to be a numpy array.
+ :return: 3D stack and parameters.
+ :rtype: (numpy.ndarray, dict)
+ """
+ if self._stack is None:
+ return None
+
+ image = self._stackItem
+ colormap = image.getColormap()
+
+ params = {
+ "info": image.getInfo(),
+ "origin": image.getOrigin(),
+ "scale": image.getScale(),
+ "z": image.getZValue(),
+ "selectable": image.isSelectable(),
+ "draggable": image.isDraggable(),
+ "colormap": colormap,
+ "xlabel": image.getXLabel(),
+ "ylabel": image.getYLabel(),
+ }
+ if returnNumpyArray or copy:
+ return numpy.array(self._stack, copy=copy), params
+
+ # if a list of 2D arrays was cast into a ListOfImages,
+ # return the original list
+ if isinstance(self._stack, ListOfImages):
+ return self._stack.images, params
+
+ return self._stack, params
+
+ def getCurrentView(self, copy=True, returnNumpyArray=False):
+ """Get the stack, as it is currently displayed.
+
+ The first index of the returned stack is always the frame
+ index. If the perspective has been changed in the widget since the
+ data was first loaded, this will be reflected in the order of the
+ dimensions of the returned object.
+
+ The output has the form: [data, params]
+ where params is a dictionary containing display parameters.
+
+ :param bool copy: If True (default), then the object is copied
+ and returned as a numpy array.
+ Else, a reference to original data is returned, if possible.
+ If the original data is not a numpy array and parameter
+ `returnNumpyArray` is `True`, a copy will be made anyway.
+ :param bool returnNumpyArray: If `True`, the returned object is
+ guaranteed to be a numpy array.
+ :return: 3D stack and parameters.
+ :rtype: (numpy.ndarray, dict)
+ """
+ image = self.getActiveImage()
+ if image is None:
+ return None
+
+ if isinstance(image, items.ColormapMixIn):
+ colormap = image.getColormap()
+ else:
+ colormap = None
+
+ params = {
+ "info": image.getInfo(),
+ "origin": image.getOrigin(),
+ "scale": image.getScale(),
+ "z": image.getZValue(),
+ "selectable": image.isSelectable(),
+ "draggable": image.isDraggable(),
+ "colormap": colormap,
+ "xlabel": image.getXLabel(),
+ "ylabel": image.getYLabel(),
+ }
+ if returnNumpyArray or copy:
+ return numpy.array(self.__transposed_view, copy=copy), params
+ return self.__transposed_view, params
+
+ def setFrameNumber(self, number):
+ """Set the frame selection to a specific value
+
+ :param int number: Number of the frame
+ """
+ self._browser.setValue(number)
+
+ def getFrameNumber(self):
+ """Set the frame selection to a specific value
+
+ :return: Index of currently displayed frame
+ :rtype: int
+ """
+ return self._browser.value()
+
+ def setFirstStackDimension(self, first_stack_dimension):
+ """When viewing the last 3 dimensions of an n-D array (n>3), you can
+ use this method to change the text in the combobox.
+
+ For instance, for a 7-D array, first stack dim is 4, so the default
+ "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
+ numbers are 0-based).
+
+ :param int first_stack_dim: First stack dimension (n-3) when viewing the
+ last 3 dimensions of an n-D array.
+ """
+ old_state = self.__planeSelection.blockSignals(True)
+ self.__planeSelection.setFirstStackDimension(first_stack_dimension)
+ self.__planeSelection.blockSignals(old_state)
+ self._first_stack_dimension = first_stack_dimension
+ self._browser_label.setText("Image index (Dim%d):" % first_stack_dimension)
+
+ def setTitleCallback(self, callback):
+ """Set a user defined function to generate the plot title based on the
+ image/frame index.
+
+ The callback function must accept an integer as a its first positional
+ parameter and must not require any other mandatory parameter.
+ It must return a string.
+
+ To switch back the default behavior, you can pass ``None``::
+
+ mystackview.setTitleCallback(None)
+
+ To have no title, pass a function that returns an empty string::
+
+ mystackview.setTitleCallback(lambda idx: "")
+
+ :param callback: Callback function generating the stack title based
+ on the frame number.
+ """
+
+ if callback is None:
+ self._titleCallback = self._defaultTitleCallback
+ elif callable(callback):
+ self._titleCallback = callback
+ else:
+ raise TypeError("Provided callback is not callable")
+ self._updateTitle()
+
+ def clear(self):
+ """Clear the widget:
+
+ - clear the plot
+ - clear the loaded data volume
+ """
+ self._stack = None
+ self.__transposed_view = None
+ self._perspective = 0
+ self._browser.setEnabled(False)
+ # reset browser range
+ self._browser.setRange(0, 0)
+ self._plot.clear()
+
+ def setLabels(self, labels=None):
+ """Set the labels to be displayed on the plot axes.
+
+ You must provide a sequence of 3 strings, corresponding to the 3
+ dimensions of the original data volume.
+ The proper label will automatically be selected for each plot axis
+ when the volume is rotated (when different axes are selected as the
+ X and Y axes).
+
+ :param List[str] labels: 3 labels corresponding to the 3 dimensions
+ of the data volumes.
+ """
+
+ default_labels = [
+ "Dimension %d" % self._first_stack_dimension,
+ "Dimension %d" % (self._first_stack_dimension + 1),
+ "Dimension %d" % (self._first_stack_dimension + 2),
+ ]
+ if labels is None:
+ new_labels = default_labels
+ else:
+ # filter-out None
+ new_labels = []
+ for i, label in enumerate(labels):
+ new_labels.append(label or default_labels[i])
+
+ self.__dimensionsLabels = new_labels
+ self.__updatePlotLabels()
+
+ def getLabels(self):
+ """Return dimension labels displayed on the plot axes
+
+ :return: List of three strings corresponding to the 3 dimensions
+ of the stack: (name_dim0, name_dim1, name_dim2)
+ """
+ return self.__dimensionsLabels
+
+ def getColormap(self):
+ """Get the current colormap description.
+
+ :return: A description of the current colormap.
+ See :meth:`setColormap` for details.
+ :rtype: dict
+ """
+ # "default" colormap used by addImage when image is added without
+ # specifying a special colormap
+ return self._plot.getDefaultColormap()
+
+ def scaleColormapRangeToStack(self):
+ """Scale colormap range according to current stack data.
+
+ If no stack has been set through :meth:`setStack`, this has no effect.
+
+ The range scaling mode is given by current :class:`Colormap`'s
+ :meth:`Colormap.getAutoscaleMode`.
+ """
+ stack = self.getStack(copy=False, returnNumpyArray=True)
+ if stack is None:
+ return # No-op
+
+ colormap = self.getColormap()
+ vmin, vmax = colormap.getColormapRange(data=stack[0])
+ colormap.setVRange(vmin=vmin, vmax=vmax)
+
+ def setColormap(
+ self, colormap=None, normalization=None, vmin=None, vmax=None, colors=None
+ ):
+ """Set the colormap and update active image.
+
+ Parameters that are not provided are taken from the current colormap.
+
+ The colormap parameter can also be a dict with the following keys:
+
+ - *name*: string. The colormap to use:
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ - *normalization*: string. The mapping to use for the colormap:
+ either 'linear' or 'log'.
+ - *autoscale*: bool. Whether to use autoscale (True) or range
+ provided by keys
+ 'vmin' and 'vmax' (False).
+ - *vmin*: float. The minimum value of the range to use if 'autoscale'
+ is False.
+ - *vmax*: float. The maximum value of the range to use if 'autoscale'
+ is False.
+ - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
+ List of RGB or RGBA colors to use (only if name is None)
+
+ :param colormap: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ Or a :class`.Colormap` object.
+ :type colormap: dict or str.
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param float vmin: The minimum value of the range to use.
+ :param float vmax: The maximum value of the range to use.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ """
+ # if is a colormap object or a dictionary
+ if isinstance(colormap, Colormap) or isinstance(colormap, dict):
+ # Support colormap parameter as a dict
+ errmsg = (
+ "If colormap is provided as a Colormap object, all other parameters"
+ )
+ errmsg += " must not be specified when calling setColormap"
+ assert normalization is None, errmsg
+ assert vmin is None, errmsg
+ assert vmax is None, errmsg
+ assert colors is None, errmsg
+
+ _colormap = colormap
+ else:
+ norm = normalization if normalization is not None else "linear"
+ name = colormap if colormap is not None else "gray"
+ _colormap = Colormap(
+ name=name, normalization=norm, vmin=vmin, vmax=vmax, colors=colors
+ )
+
+ cursorColor = cursorColorForColormap(_colormap.getName())
+ self._plot.setInteractiveMode("zoom", color=cursorColor)
+
+ self._plot.setDefaultColormap(_colormap)
+
+ # Update active image colormap
+ activeImage = self.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ activeImage.setColormap(self.getColormap())
+
+ def getPlotWidget(self):
+ """Return the :class:`PlotWidget`.
+
+ This gives access to advanced plot configuration options.
+ Be warned that modifying the plot can cause issues, and some changes
+ you make to the plot could be overwritten by the :class:`StackView`
+ widget's internal methods and callbacks.
+
+ :return: instance of :class:`PlotWidget` used in widget
+ """
+ return self._plot
+
+ def setOptionVisible(self, isVisible):
+ """
+ Set the visibility of the browsing options.
+
+ :param bool isVisible: True to have the options visible, else False
+ """
+ self._browser.setVisible(isVisible)
+ self.__planeSelection.setVisible(isVisible)
+
+ # proxies to PlotWidget or PlotWindow methods
+ def getProfileToolbar(self):
+ """Profile tools attached to this plot"""
+ return self._profileToolBar
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str."""
+ return self._plot.getGraphTitle()
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ return self._plot.setGraphTitle(title)
+
+ def getGraphXLabel(self):
+ """Return the current horizontal axis label as a str."""
+ return self._plot.getXAxis().getLabel()
+
+ def setGraphXLabel(self, label=None):
+ """Set the plot horizontal axis label.
+
+ :param str label: The horizontal axis label
+ """
+ if label is None:
+ label = self.__dimensionsLabels[1 if self._perspective == 2 else 2]
+ self._plot.getXAxis().setLabel(label)
+
+ def getGraphYLabel(self, axis="left"):
+ """Return the current vertical axis label as a str.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ """
+ return self._plot.getYAxis().getLabel(axis)
+
+ def setGraphYLabel(self, label=None, axis="left"):
+ """Set the vertical axis label on the plot.
+
+ :param str label: The Y axis label
+ :param str axis: The Y axis for which to set the label (left or right)
+ """
+ if label is None:
+ label = self.__dimensionsLabels[1 if self._perspective == 0 else 0]
+ self._plot.getYAxis(axis=axis).setLabel(label)
+
+ def resetZoom(self):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().resetZoom()
+ """
+ self._plot.resetZoom()
+
+ def setYAxisInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().setYAxisInverted(flag)
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ self._plot.setYAxisInverted(flag)
+
+ def isYAxisInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().isYAxisInverted()"""
+ return self._plot.isYAxisInverted()
+
+ def getSupportedColormaps(self):
+ """Get the supported colormap names as a tuple of str.
+
+ The list should at least contain and start by:
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().getSupportedColormaps()
+ """
+ return self._plot.getSupportedColormaps()
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().isKeepDataAspectRatio()"""
+ return self._plot.isKeepDataAspectRatio()
+
+ def setKeepDataAspectRatio(self, flag=True):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().setKeepDataAspectRatio(flag)
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ self._plot.setKeepDataAspectRatio(flag)
+
+ # kind of private methods, but needed by Profile
+ def getActiveImage(self, just_legend=False):
+ """Returns the stack image object."""
+ if just_legend:
+ return self._stackItem.getName()
+ return self._stackItem
+
+ def getColorBarAction(self):
+ """Returns the action managing the visibility of the colorbar.
+
+ .. warning:: to show/hide the plot colorbar call directly the ColorBar
+ widget using getColorBarWidget()
+
+ :rtype: QAction
+ """
+ return self._plot.getColorBarAction()
+
+ def remove(self, legend=None, kind=("curve", "image", "item", "marker")):
+ """See :meth:`Plot.Plot.remove`"""
+ self._plot.remove(legend, kind)
+
+ def setInteractiveMode(self, *args, **kwargs):
+ """
+ See :meth:`Plot.Plot.setInteractiveMode`
+ """
+ self._plot.setInteractiveMode(*args, **kwargs)
+
+ def addShape(self, *args, **kwargs):
+ """
+ See :meth:`Plot.Plot.addShape`
+ """
+ self._plot.addShape(*args, **kwargs)
+
+
+class PlanesWidget(qt.QWidget):
+ """Widget for the plane/perspective selection
+
+ :param parent: the parent QWidget
+ """
+
+ sigPlaneSelectionChanged = qt.Signal(int)
+
+ def __init__(self, parent):
+ super(PlanesWidget, self).__init__(parent)
+
+ self.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
+ layout0 = qt.QHBoxLayout()
+ self.setLayout(layout0)
+ layout0.setContentsMargins(0, 0, 0, 0)
+
+ layout0.addWidget(qt.QLabel("Axes selection:"))
+
+ # By default, the first dimension (dim0) is the frame index/depth/z,
+ # the second dimension is the image row number/y axis
+ # and the third dimension is the image column index/x axis
+
+ # 1
+ # | 0
+ # |/__2
+ self.qcbAxisSelection = qt.QComboBox(self)
+ self._setCBChoices(first_stack_dimension=0)
+ self.qcbAxisSelection.currentIndexChanged[int].connect(
+ self.__planeSelectionChanged
+ )
+
+ layout0.addWidget(self.qcbAxisSelection)
+
+ def __planeSelectionChanged(self, idx):
+ """Callback function when the combobox selection changes
+
+ idx is the dimension number orthogonal to the slice plane,
+ following the convention:
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+ """
+ self.sigPlaneSelectionChanged.emit(idx)
+
+ def _setCBChoices(self, first_stack_dimension):
+ self.qcbAxisSelection.clear()
+
+ dim1dim2 = "Dim%d-Dim%d" % (
+ first_stack_dimension + 1,
+ first_stack_dimension + 2,
+ )
+ dim0dim2 = "Dim%d-Dim%d" % (first_stack_dimension, first_stack_dimension + 2)
+ dim0dim1 = "Dim%d-Dim%d" % (first_stack_dimension, first_stack_dimension + 1)
+
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-front"), dim1dim2)
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-bottom"), dim0dim2)
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-left"), dim0dim1)
+
+ def setFirstStackDimension(self, first_stack_dim):
+ """When viewing the last 3 dimensions of an n-D array (n>3), you can
+ use this method to change the text in the combobox.
+
+ For instance, for a 7-D array, first stack dim is 4, so the default
+ "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
+ numbers are 0-based).
+
+ :param int first_stack_dim: First stack dimension (n-3) when viewing the
+ last 3 dimensions of an n-D array.
+ """
+ self._setCBChoices(first_stack_dim)
+
+ def setPerspective(self, perspective):
+ """Update the combobox selection.
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+
+ :param perspective: Orthogonal dimension number (0, 1, or 2)
+ """
+ self.qcbAxisSelection.setCurrentIndex(perspective)
+
+
+class StackViewMainWindow(StackView):
+ """This class is a :class:`StackView` with a menu, an additional toolbar
+ to set the plot limits, and a status bar to display the value and 3D
+ index of the data samples hovered by the mouse cursor.
+
+ :param QWidget parent: Parent widget, or None
+ """
+
+ def __init__(self, parent=None):
+ self._dataInfo = None
+ super(StackViewMainWindow, self).__init__(parent)
+ self.setWindowFlags(qt.Qt.Window)
+
+ # Add toolbars and status bar
+ self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self._plot))
+
+ self.statusBar()
+
+ menu = self.menuBar().addMenu("File")
+ menu.addAction(self._plot.getOutputToolBar().getSaveAction())
+ menu.addAction(self._plot.getOutputToolBar().getPrintAction())
+ menu.addSeparator()
+ action = menu.addAction("Quit")
+ action.triggered[bool].connect(qt.QApplication.instance().quit)
+
+ menu = self.menuBar().addMenu("Edit")
+ menu.addAction(self._plot.getOutputToolBar().getCopyAction())
+ menu.addSeparator()
+ menu.addAction(self._plot.getResetZoomAction())
+ menu.addAction(self._plot.getColormapAction())
+ menu.addAction(self.getColorBarAction())
+
+ menu.addAction(actions.control.KeepAspectRatioAction(self._plot, self))
+ menu.addAction(actions.control.YAxisInvertedAction(self._plot, self))
+
+ menu = self.menuBar().addMenu("Profile")
+ profileToolBar = self._profileToolBar
+ menu.addAction(profileToolBar.hLineAction)
+ menu.addAction(profileToolBar.vLineAction)
+ menu.addAction(profileToolBar.lineAction)
+ menu.addAction(profileToolBar.crossAction)
+ menu.addSeparator()
+ menu.addAction(profileToolBar._editor)
+ menu.addSeparator()
+ menu.addAction(profileToolBar.clearAction)
+
+ # Connect to StackView's signal
+ self.valueChanged.connect(self._statusBarSlot)
+
+ def _statusBarSlot(self, x, y, value):
+ """Update status bar with coordinates/value from plots."""
+ # todo (after implementing calibration):
+ # - use floats for (x, y, z)
+ # - display both indices (dim0, dim1, dim2) and (x, y, z)
+ msg = "Cursor out of range"
+ if x is not None and y is not None:
+ img_idx = self._browser.value()
+
+ if self._perspective == 0:
+ dim0, dim1, dim2 = img_idx, int(y), int(x)
+ elif self._perspective == 1:
+ dim0, dim1, dim2 = int(y), img_idx, int(x)
+ elif self._perspective == 2:
+ dim0, dim1, dim2 = int(y), int(x), img_idx
+
+ msg = "Position: (%d, %d, %d)" % (dim0, dim1, dim2)
+ if value is not None:
+ msg += ", Value: %g" % value
+ if self._dataInfo is not None:
+ msg = self._dataInfo + ", " + msg
+
+ self.statusBar().showMessage(msg)
+
+ def setStack(self, stack, *args, **kwargs):
+ """Set the displayed stack.
+
+ See :meth:`StackView.setStack` for details.
+ """
+ if hasattr(stack, "dtype") and hasattr(stack, "shape"):
+ assert len(stack.shape) == 3
+ nframes, height, width = stack.shape
+ self._dataInfo = "Data: %dx%dx%d (%s)" % (
+ nframes,
+ height,
+ width,
+ str(stack.dtype),
+ )
+ self.statusBar().showMessage(self._dataInfo)
+ else:
+ self._dataInfo = None
+
+ # Set the new stack in StackView widget
+ super(StackViewMainWindow, self).setStack(stack, *args, **kwargs)
+ self.setStatusBar(None)
diff --git a/src/silx/gui/plot/StatsWidget.py b/src/silx/gui/plot/StatsWidget.py
new file mode 100644
index 0000000..0c37f52
--- /dev/null
+++ b/src/silx/gui/plot/StatsWidget.py
@@ -0,0 +1,1720 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+Module containing widgets displaying stats from items of a plot.
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "24/07/2018"
+
+
+from contextlib import contextmanager
+import logging
+import weakref
+import functools
+import numpy
+import enum
+from silx.utils.proxy import docstring
+from silx.utils.enum import Enum as _Enum
+from silx.gui import qt
+from silx.gui import icons
+from silx.gui.plot import stats as statsmdl
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter
+from silx.gui.plot.items.core import ItemChangedType
+from silx.gui.widgets.FlowLayout import FlowLayout
+from . import PlotWidget
+from . import items as plotitems
+
+
+_logger = logging.getLogger(__name__)
+
+
+@enum.unique
+class UpdateMode(_Enum):
+ AUTO = "auto"
+ MANUAL = "manual"
+
+
+# Helper class to handle specific calls to PlotWidget and SceneWidget
+
+
+class _Wrapper(qt.QObject):
+ """Base class for connection with PlotWidget and SceneWidget.
+
+ This class is used when no PlotWidget or SceneWidget is connected.
+
+ :param plot: The plot to be used
+ """
+
+ sigItemAdded = qt.Signal(object)
+ """Signal emitted when a new item is added.
+
+ It provides the added item.
+ """
+
+ sigItemRemoved = qt.Signal(object)
+ """Signal emitted when an item is (about to be) removed.
+
+ It provides the removed item.
+ """
+
+ sigCurrentChanged = qt.Signal(object)
+ """Signal emitted when the current item has changed.
+
+ It provides the current item.
+ """
+
+ sigVisibleDataChanged = qt.Signal()
+ """Signal emitted when the visible data area has changed"""
+
+ def __init__(self, plot=None):
+ super(_Wrapper, self).__init__(parent=None)
+ self._plotRef = None if plot is None else weakref.ref(plot)
+
+ def getPlot(self):
+ """Returns the plot attached to this widget"""
+ return None if self._plotRef is None else self._plotRef()
+
+ def getItems(self):
+ """Returns the list of items in the plot
+
+ :rtype: List[object]
+ """
+ return ()
+
+ def getSelectedItems(self):
+ """Returns the list of selected items in the plot
+
+ :rtype: List[object]
+ """
+ return ()
+
+ def setCurrentItem(self, item):
+ """Set the current/active item in the plot
+
+ :param item: The plot item to set as active/current
+ """
+ pass
+
+ def getLabel(self, item):
+ """Returns the label of the given item.
+
+ :param item:
+ :rtype: str
+ """
+ return ""
+
+ def getKind(self, item):
+ """Returns the kind of an item or None if not supported
+
+ :param item:
+ :rtype: Union[str,None]
+ """
+ return None
+
+
+class _PlotWidgetWrapper(_Wrapper):
+ """Class handling PlotWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param PlotWidget plot:
+ """
+
+ def __init__(self, plot):
+ assert isinstance(plot, PlotWidget)
+ super(_PlotWidgetWrapper, self).__init__(plot)
+ plot.sigItemAdded.connect(self.sigItemAdded.emit)
+ plot.sigItemAboutToBeRemoved.connect(self.sigItemRemoved.emit)
+ plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ plot.sigPlotSignal.connect(self._limitsChanged)
+
+ def _activeChanged(self, kind):
+ """Handle change of active curve/image/scatter"""
+ plot = self.getPlot()
+ if plot is not None:
+ item = plot._getActiveItem(kind=kind)
+ if item is None or self.getKind(item) is not None:
+ self.sigCurrentChanged.emit(item)
+
+ def _activeCurveChanged(self, previous, current):
+ self._activeChanged(kind="curve")
+
+ def _activeImageChanged(self, previous, current):
+ self._activeChanged(kind="image")
+
+ def _activeScatterChanged(self, previous, current):
+ self._activeChanged(kind="scatter")
+
+ def _limitsChanged(self, event):
+ """Handle change of plot area limits."""
+ if event["event"] == "limitsChanged":
+ self.sigVisibleDataChanged.emit()
+
+ def getItems(self):
+ plot = self.getPlot()
+ if plot is None:
+ return ()
+ else:
+ return [item for item in plot.getItems() if item.isVisible()]
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ items = []
+ if plot is not None:
+ for kind in plot._ACTIVE_ITEM_KINDS:
+ item = plot._getActiveItem(kind=kind)
+ if item is not None:
+ items.append(item)
+ return tuple(items)
+
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ kind = self.getKind(item)
+ if kind in plot._ACTIVE_ITEM_KINDS:
+ if plot._getActiveItem(kind) != item:
+ plot._setActiveItem(kind, item)
+
+ def getLabel(self, item):
+ return item.getName()
+
+ def getKind(self, item):
+ if isinstance(item, plotitems.Curve):
+ return "curve"
+ elif isinstance(item, plotitems.ImageData):
+ return "image"
+ elif isinstance(item, plotitems.Scatter):
+ return "scatter"
+ elif isinstance(item, plotitems.Histogram):
+ return "histogram"
+ else:
+ return None
+
+
+class _SceneWidgetWrapper(_Wrapper):
+ """Class handling SceneWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
+ """
+
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.SceneWidget import SceneWidget
+
+ assert isinstance(plot, SceneWidget)
+ super(_SceneWidgetWrapper, self).__init__(plot)
+ plot.getSceneGroup().sigItemAdded.connect(self.sigItemAdded)
+ plot.getSceneGroup().sigItemRemoved.connect(self.sigItemRemoved)
+ plot.selection().sigCurrentChanged.connect(self._currentChanged)
+ # sigVisibleDataChanged is never emitted
+
+ def _currentChanged(self, current, previous):
+ self.sigCurrentChanged.emit(current)
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else tuple(plot.getSceneGroup().visit())
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (plot.selection().getCurrentItem(),)
+
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ plot.selection().setCurrentItem(item)
+
+ def getLabel(self, item):
+ return item.getLabel()
+
+ def getKind(self, item):
+ from ..plot3d import items as plot3ditems
+
+ if isinstance(item, (plot3ditems.ImageData, plot3ditems.ScalarField3D)):
+ return "image"
+ elif isinstance(item, (plot3ditems.Scatter2D, plot3ditems.Scatter3D)):
+ return "scatter"
+ else:
+ return None
+
+
+class _ScalarFieldViewWrapper(_Wrapper):
+ """Class handling ScalarFieldView specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
+ """
+
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.ScalarFieldView import ScalarFieldView
+ from ..plot3d.items import ScalarField3D
+
+ assert isinstance(plot, ScalarFieldView)
+ super(_ScalarFieldViewWrapper, self).__init__(plot)
+ self._item = ScalarField3D()
+ self._dataChanged()
+ plot.sigDataChanged.connect(self._dataChanged)
+ # sigItemAdded, sigItemRemoved, sigVisibleDataChanged are never emitted
+
+ def _dataChanged(self):
+ plot = self.getPlot()
+ if plot is not None:
+ self._item.setData(plot.getData(copy=False), copy=False)
+ self.sigCurrentChanged.emit(self._item)
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (self._item,)
+
+ def getSelectedItems(self):
+ return self.getItems()
+
+ def setCurrentItem(self, item):
+ pass
+
+ def getLabel(self, item):
+ return "Data"
+
+ def getKind(self, item):
+ return "image"
+
+
+class _Container(object):
+ """Class to contain a plot item.
+
+ This is apparently needed for compatibility with PySide2,
+
+ :param QObject obj:
+ """
+
+ def __init__(self, obj):
+ self._obj = obj
+
+ def __call__(self):
+ return self._obj
+
+
+class _StatsWidgetBase(object):
+ """
+ Base class for all widgets which want to display statistics
+ """
+
+ def __init__(self, statsOnVisibleData, displayOnlyActItem):
+ self._displayOnlyActItem = displayOnlyActItem
+ self._statsOnVisibleData = statsOnVisibleData
+ self._statsHandler = None
+ self._updateMode = UpdateMode.AUTO
+
+ self.__default_skipped_events = (
+ ItemChangedType.ALPHA,
+ ItemChangedType.COLOR,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL_SIZE,
+ ItemChangedType.LINE_WIDTH,
+ ItemChangedType.LINE_STYLE,
+ ItemChangedType.LINE_BG_COLOR,
+ ItemChangedType.FILL,
+ ItemChangedType.HIGHLIGHTED_COLOR,
+ ItemChangedType.HIGHLIGHTED_STYLE,
+ ItemChangedType.TEXT,
+ ItemChangedType.OVERLAY,
+ ItemChangedType.VISUALIZATION_MODE,
+ )
+
+ self._plotWrapper = _Wrapper()
+ self._dealWithPlotConnection(create=True)
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ try:
+ import OpenGL
+ except ImportError:
+ has_opengl = False
+ else:
+ has_opengl = True
+ from ..plot3d.SceneWidget import SceneWidget # Lazy import
+ self._dealWithPlotConnection(create=False)
+ self.clear()
+ if plot is None:
+ self._plotWrapper = _Wrapper()
+ elif isinstance(plot, PlotWidget):
+ self._plotWrapper = _PlotWidgetWrapper(plot)
+ else:
+ if has_opengl is True:
+ if isinstance(plot, SceneWidget):
+ self._plotWrapper = _SceneWidgetWrapper(plot)
+ else: # Expect a ScalarFieldView
+ self._plotWrapper = _ScalarFieldViewWrapper(plot)
+ else:
+ _logger.warning(
+ "OpenGL not installed, %s not managed"
+ % ("SceneWidget qnd ScalarFieldView")
+ )
+ self._dealWithPlotConnection(create=True)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ if statsHandler is None:
+ statsHandler = StatsHandler(statFormatters=())
+ elif isinstance(statsHandler, (list, tuple)):
+ statsHandler = StatsHandler(statsHandler)
+ assert isinstance(statsHandler, StatsHandler)
+
+ self._statsHandler = statsHandler
+
+ def getStatsHandler(self):
+ """Returns the :class:`StatsHandler` in use.
+
+ :rtype: StatsHandler
+ """
+ return self._statsHandler
+
+ def getPlot(self):
+ """Returns the plot attached to this widget
+
+ :rtype: Union[PlotWidget,SceneWidget,None]
+ """
+ return self._plotWrapper.getPlot()
+
+ def _dealWithPlotConnection(self, create=True):
+ """Manage connection to plot signals
+
+ Note: connection on Item are managed by _addItem and _removeItem methods
+ """
+ connections = [] # List of (signal, slot) to connect/disconnect
+ if self._statsOnVisibleData:
+ connections.append(
+ (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats)
+ )
+
+ if self._displayOnlyActItem:
+ connections.append(
+ (self._plotWrapper.sigCurrentChanged, self._updateCurrentItem)
+ )
+ else:
+ connections += [
+ (self._plotWrapper.sigItemAdded, self._addItem),
+ (self._plotWrapper.sigItemRemoved, self._removeItem),
+ (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged),
+ ]
+
+ for signal, slot in connections:
+ if create:
+ signal.connect(slot)
+ else:
+ signal.disconnect(slot)
+
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ raise NotImplementedError("Base class")
+
+ def _updateCurrentItem(self, *args):
+ """specific callback for the sigCurrentChanged and with the
+ _displayOnlyActItem option."""
+ raise NotImplementedError("Base class")
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ """Update displayed information for given plot item
+
+ :param item: The plot item
+ :param bool data_changed: is the item data changed.
+ :param bool roi_changed: is the associated roi changed.
+ """
+ raise NotImplementedError("Base class")
+
+ def _updateAllStats(self):
+ """Update stats for all rows in the table"""
+ raise NotImplementedError("Base class")
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
+ """
+ self._displayOnlyActItem = displayOnlyActItem
+
+ def setStatsOnVisibleData(self, b):
+ """Toggle computation of statistics on whole data or only visible ones.
+
+ .. warning:: When visible data is activated we will process to a simple
+ filtering of visible data by the user. The filtering is a
+ simple data sub-sampling. No interpolation is made to fit
+ data to boundaries.
+
+ :param bool b: True if we want to apply statistics only on visible data
+ """
+ if self._statsOnVisibleData != b:
+ self._dealWithPlotConnection(create=False)
+ self._statsOnVisibleData = b
+ self._dealWithPlotConnection(create=True)
+ self._updateAllStats()
+
+ def _addItem(self, item):
+ """Add a plot item to the table
+
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
+ """
+ raise NotImplementedError("Base class")
+
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
+ """
+ raise NotImplementedError("Base class")
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
+
+ :param current:
+ """
+ raise NotImplementedError("Base class")
+
+ def clear(self):
+ """clear GUI"""
+ pass
+
+ def _skipPlotItemChangedEvent(self, event):
+ """
+
+ :param ItemChangedtype event: event to filter or not
+ :return: True if we want to ignore this ItemChangedtype
+ :rtype: bool
+ """
+ return event in self.__default_skipped_events
+
+ def setUpdateMode(self, mode):
+ """Set the way to update the displayed statistics.
+
+ :param mode: mode requested for update
+ :type mode: Union[str,UpdateMode]
+ """
+ mode = UpdateMode.from_value(mode)
+ if mode != self._updateMode:
+ self._updateMode = mode
+ self._updateModeHasChanged()
+
+ def getUpdateMode(self):
+ """Returns update mode (See :meth:`setUpdateMode`).
+
+ :return: update mode
+ :rtype: UpdateMode
+ """
+ return self._updateMode
+
+ def _updateModeHasChanged(self):
+ """callback when the update mode has changed"""
+ pass
+
+
+class StatsTable(_StatsWidgetBase, TableWidget):
+ """
+ TableWidget displaying for each items contained by the Plot some
+ information:
+
+ * legend
+ * minimal value
+ * maximal value
+ * standard deviation (std)
+
+ :param QWidget parent: The widget's parent.
+ :param Union[PlotWidget,SceneWidget] plot:
+ :class:`PlotWidget` or :class:`SceneWidget` instance on which to operate
+ """
+
+ _LEGEND_HEADER_DATA = "legend"
+ _KIND_HEADER_DATA = "kind"
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(self, parent=None, plot=None):
+ TableWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(
+ self, statsOnVisibleData=False, displayOnlyActItem=False
+ )
+
+ # Init for _displayOnlyActItem == False
+ assert self._displayOnlyActItem is False
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ self.currentItemChanged.connect(self._currentItemChanged)
+
+ self.setRowCount(0)
+ self.setColumnCount(2)
+
+ # Init headers
+ headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
+ self.setHorizontalHeaderItem(0, headerItem)
+ headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
+ self.setHorizontalHeaderItem(1, headerItem)
+
+ self.setSortingEnabled(True)
+ self.setPlot(plot)
+
+ @contextmanager
+ def _disableSorting(self):
+ """Context manager that disables table sorting
+
+ Previous state is restored when leaving
+ """
+ sorting = self.isSortingEnabled()
+ if sorting:
+ self.setSortingEnabled(False)
+ yield
+ if sorting:
+ self.setSortingEnabled(sorting)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ self._removeAllItems()
+ _StatsWidgetBase.setStats(self, statsHandler)
+
+ self.setRowCount(0)
+ self.setColumnCount(len(self._statsHandler.stats) + 2) # + legend and kind
+
+ for index, stat in enumerate(self._statsHandler.stats.values()):
+ headerItem = qt.QTableWidgetItem(stat.name.capitalize())
+ headerItem.setData(qt.Qt.UserRole, stat.name)
+ if stat.description is not None:
+ headerItem.setToolTip(stat.description)
+ self.setHorizontalHeaderItem(2 + index, headerItem)
+
+ horizontalHeader = self.horizontalHeader()
+ horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ self._updateItemObserve()
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateItemObserve()
+
+ def clear(self):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ self._removeAllItems()
+
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ self._removeAllItems()
+
+ # Get selected or all items from the plot
+ if self._displayOnlyActItem: # Only selected
+ items = self._plotWrapper.getSelectedItems()
+ else: # All items
+ items = self._plotWrapper.getItems()
+
+ # Add items to the plot
+ for item in items:
+ self._addItem(item)
+
+ def _updateCurrentItem(self, *args):
+ """specific callback for the sigCurrentChanged and with the
+ _displayOnlyActItem option.
+
+ Behavior: create the tableItems if does not exists.
+ If exists, update it only when we are in 'auto' mode"""
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ # when sigCurrentChanged is giving the current item
+ if len(args) > 0 and isinstance(
+ args[0],
+ (
+ plotitems.Curve,
+ plotitems.Histogram,
+ plotitems.ImageData,
+ plotitems.Scatter,
+ ),
+ ):
+ item = args[0]
+ tableItems = self._itemToTableItems(item)
+ # if the table does not exists yet
+ if len(tableItems) == 0:
+ self._updateItemObserve()
+ else:
+ # in this case no current item
+ self._updateItemObserve(args)
+ else:
+ # auto mode
+ self._updateItemObserve(args)
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
+
+ :param current:
+ """
+ row = self._itemToRow(current)
+ if row is None:
+ if self.currentRow() >= 0:
+ self.setCurrentCell(-1, -1)
+ elif row != self.currentRow():
+ self.setCurrentCell(row, 0)
+
+ def _tableItemToItem(self, tableItem):
+ """Find the plot item corresponding to a table item
+
+ :param QTableWidgetItem tableItem:
+ :rtype: QObject
+ """
+ container = tableItem.data(qt.Qt.UserRole)
+ return container()
+
+ def _itemToRow(self, item):
+ """Find the row corresponding to a plot item
+
+ :param item: The plot item
+ :return: The corresponding row index
+ :rtype: Union[int,None]
+ """
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ if self._tableItemToItem(tableItem) == item:
+ return row
+ return None
+
+ def _itemToTableItems(self, item):
+ """Find all table items corresponding to a plot item
+
+ :param item: The plot item
+ :return: An ordered dict of column name to QTableWidgetItem mapping
+ for the given plot item.
+ :rtype: dict
+ """
+ result = {}
+ row = self._itemToRow(item)
+ if row is not None:
+ for column in range(self.columnCount()):
+ tableItem = self.item(row, column)
+ if self._tableItemToItem(tableItem) != item:
+ _logger.error("Table item/plot item mismatch")
+ else:
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+ result[name] = tableItem
+ return result
+
+ def _plotItemChanged(self, event):
+ """Handle modifications of the items.
+
+ :param event:
+ """
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if self._skipPlotItemChangedEvent(event) is True:
+ return
+ else:
+ item = self.sender()
+ self._updateStats(item, data_changed=True)
+ # deal with stat items visibility
+ if event is ItemChangedType.VISIBLE:
+ if len(self._itemToTableItems(item).items()) > 0:
+ item_0 = list(self._itemToTableItems(item).values())[0]
+ row_index = item_0.row()
+ self.setRowHidden(row_index, not item.isVisible())
+
+ def _addItem(self, item):
+ """Add a plot item to the table
+
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
+ """
+ if self._itemToRow(item) is not None:
+ _logger.info("Item already present in the table")
+ self._updateStats(item)
+ return True
+
+ kind = self._plotWrapper.getKind(item)
+ if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.info("Item has not a supported type: %s", item)
+ return False
+
+ # Prepare table items
+ tableItems = [qt.QTableWidgetItem(), qt.QTableWidgetItem()] # Legend # Kind
+
+ for column in range(2, self.columnCount()):
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+
+ formatter = self._statsHandler.formatters[name]
+ if formatter:
+ tableItem = formatter.tabWidgetItemClass()
+ else:
+ tableItem = qt.QTableWidgetItem()
+
+ tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
+ if tooltip is not None:
+ tableItem.setToolTip(tooltip)
+
+ tableItems.append(tableItem)
+
+ # Disable sorting while adding table items
+ with self._disableSorting():
+ # Add a row to the table
+ self.setRowCount(self.rowCount() + 1)
+
+ # Add table items to the last row
+ row = self.rowCount() - 1
+ for column, tableItem in enumerate(tableItems):
+ tableItem.setData(qt.Qt.UserRole, _Container(item))
+ tableItem.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, column, tableItem)
+
+ # Update table items content
+ self._updateStats(item, data_changed=True)
+
+ # Listen for item changes
+ # Using queued connection to avoid issue with sender
+ # being that of the signal calling the signal
+ item.sigItemChanged.connect(self._plotItemChanged, qt.Qt.QueuedConnection)
+
+ return True
+
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
+ """
+ row = self._itemToRow(item)
+ if row is None:
+ kind = self._plotWrapper.getKind(item)
+ if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.error("Removing item that is not in table: %s", str(item))
+ return
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.removeRow(row)
+
+ def _removeAllItems(self):
+ """Remove content of the table"""
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.clearContents()
+ self.setRowCount(0)
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ """Update displayed information for given plot item
+
+ :param item: The plot item
+ :param bool data_changed: is the item data changed.
+ :param bool roi_changed: is the associated roi changed.
+ """
+ if item is None:
+ return
+ plot = self.getPlot()
+ if plot is None:
+ _logger.info("Plot not available")
+ return
+
+ row = self._itemToRow(item)
+ if row is None:
+ _logger.error("This item is not in the table: %s", str(item))
+ return
+
+ statsHandler = self.getStatsHandler()
+ if statsHandler is not None:
+ # _updateStats is call when the plot visible area change.
+ # to force stats update we consider roi changed
+ if self._statsOnVisibleData:
+ roi_changed = True
+ else:
+ roi_changed = False
+ stats = statsHandler.calculate(
+ item,
+ plot,
+ self._statsOnVisibleData,
+ data_changed=data_changed,
+ roi_changed=roi_changed,
+ )
+ else:
+ stats = {}
+
+ with self._disableSorting():
+ for name, tableItem in self._itemToTableItems(item).items():
+ if name == self._LEGEND_HEADER_DATA:
+ text = self._plotWrapper.getLabel(item)
+ tableItem.setText(text)
+ elif name == self._KIND_HEADER_DATA:
+ tableItem.setText(self._plotWrapper.getKind(item))
+ else:
+ value = stats.get(name)
+ if value is None:
+ _logger.error("Value not found for: %s", name)
+ tableItem.setText("-")
+ else:
+ tableItem.setText(str(value))
+
+ def _updateAllStats(self, is_request=False):
+ """Update stats for all rows in the table
+
+ :param bool is_request: True if come from a manual request
+ """
+ if self.getUpdateMode() is UpdateMode.MANUAL and not is_request:
+ return
+ with self._disableSorting():
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ self._updateStats(item, data_changed=is_request)
+
+ def _currentItemChanged(self, current, previous):
+ """Handle change of selection in table and sync plot selection
+
+ :param QTableWidgetItem current:
+ :param QTableWidgetItem previous:
+ """
+ if current and current.row() >= 0:
+ item = self._tableItemToItem(current)
+ self._plotWrapper.setCurrentItem(item)
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
+ """
+ if self._displayOnlyActItem == displayOnlyActItem:
+ return
+ self._dealWithPlotConnection(create=False)
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.disconnect(self._currentItemChanged)
+
+ _StatsWidgetBase.setDisplayOnlyActiveItem(self, displayOnlyActItem)
+
+ self._updateItemObserve()
+ self._dealWithPlotConnection(create=True)
+
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.connect(self._currentItemChanged)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ else:
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+ def _updateModeHasChanged(self):
+ self.sigUpdateModeChanged.emit(self._updateMode)
+
+
+class UpdateModeWidget(qt.QWidget):
+ """Widget used to select the mode of update"""
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """signal emitted when the mode for update changed"""
+ sigUpdateRequested = qt.Signal()
+ """signal emitted when an manual request for example is activate"""
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QHBoxLayout())
+ self._buttonGrp = qt.QButtonGroup(parent=self)
+ self._buttonGrp.setExclusive(True)
+
+ spacer = qt.QSpacerItem(
+ 20, 20, qt.QSizePolicy.Expanding, qt.QSizePolicy.Minimum
+ )
+ self.layout().addItem(spacer)
+
+ self._autoRB = qt.QRadioButton("auto", parent=self)
+ self.layout().addWidget(self._autoRB)
+ self._buttonGrp.addButton(self._autoRB)
+
+ self._manualRB = qt.QRadioButton("manual", parent=self)
+ self.layout().addWidget(self._manualRB)
+ self._buttonGrp.addButton(self._manualRB)
+ self._manualRB.setChecked(True)
+
+ refresh_icon = icons.getQIcon("view-refresh")
+ self._updatePB = qt.QPushButton(refresh_icon, "", parent=self)
+ self.layout().addWidget(self._updatePB)
+
+ # connect signal / SLOT
+ self._updatePB.clicked.connect(self._updateRequested)
+ self._manualRB.toggled.connect(self._manualButtonToggled)
+ self._autoRB.toggled.connect(self._autoButtonToggled)
+
+ def _manualButtonToggled(self, checked):
+ if checked:
+ self.setUpdateMode(UpdateMode.MANUAL)
+ self.sigUpdateModeChanged.emit(self.getUpdateMode())
+
+ def _autoButtonToggled(self, checked):
+ if checked:
+ self.setUpdateMode(UpdateMode.AUTO)
+ self.sigUpdateModeChanged.emit(self.getUpdateMode())
+
+ def _updateRequested(self):
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ self.sigUpdateRequested.emit()
+
+ def setUpdateMode(self, mode):
+ """Set the way to update the displayed statistics.
+
+ :param mode: mode requested for update
+ :type mode: Union[str,UpdateMode]
+ """
+ mode = UpdateMode.from_value(mode)
+
+ if mode is UpdateMode.AUTO:
+ if not self._autoRB.isChecked():
+ self._autoRB.setChecked(True)
+ elif mode is UpdateMode.MANUAL:
+ if not self._manualRB.isChecked():
+ self._manualRB.setChecked(True)
+ else:
+ raise ValueError("mode", mode, "is not recognized")
+
+ def getUpdateMode(self):
+ """Returns update mode (See :meth:`setUpdateMode`).
+
+ :return: the active update mode
+ :rtype: UpdateMode
+ """
+ if self._manualRB.isChecked():
+ return UpdateMode.MANUAL
+ elif self._autoRB.isChecked():
+ return UpdateMode.AUTO
+ else:
+ raise RuntimeError("No mode selected")
+
+ def showRadioButtons(self, show):
+ """show / hide the QRadioButtons
+
+ :param bool show: if True make RadioButton visible
+ """
+ self._autoRB.setVisible(show)
+ self._manualRB.setVisible(show)
+
+
+class _OptionsWidget(qt.QToolBar):
+ def __init__(self, parent=None, updateMode=None, displayOnlyActItem=False):
+ assert updateMode is not None
+ qt.QToolBar.__init__(self, parent)
+ self.setIconSize(qt.QSize(16, 16))
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-active-items"))
+ action.setText("Active items only")
+ action.setToolTip("Display stats for active items only.")
+ action.setCheckable(True)
+ action.setChecked(displayOnlyActItem)
+ self.__displayActiveItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-items"))
+ action.setText("All items")
+ action.setToolTip("Display stats for all available items.")
+ action.setCheckable(True)
+ self.__displayWholeItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-visible-data"))
+ action.setText("Use the visible data range")
+ action.setToolTip(
+ "Use the visible data range.<br/>"
+ "If activated the data is filtered to only use"
+ "visible data of the plot."
+ "The filtering is a data sub-sampling."
+ "No interpolation is made to fit data to"
+ "boundaries."
+ )
+ action.setCheckable(True)
+ self.__useVisibleData = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-data"))
+ action.setText("Use the full data range")
+ action.setToolTip("Use the full data range.")
+ action.setCheckable(True)
+ action.setChecked(True)
+ self.__useWholeData = action
+
+ self.addAction(self.__displayWholeItems)
+ self.addAction(self.__displayActiveItems)
+ self.addSeparator()
+ self.addAction(self.__useVisibleData)
+ self.addAction(self.__useWholeData)
+
+ self.itemSelection = qt.QActionGroup(self)
+ self.itemSelection.setExclusive(True)
+ self.itemSelection.addAction(self.__displayActiveItems)
+ self.itemSelection.addAction(self.__displayWholeItems)
+
+ self.dataRangeSelection = qt.QActionGroup(self)
+ self.dataRangeSelection.setExclusive(True)
+ self.dataRangeSelection.addAction(self.__useWholeData)
+ self.dataRangeSelection.addAction(self.__useVisibleData)
+
+ self.__updateStatsAction = qt.QAction(self)
+ self.__updateStatsAction.setIcon(icons.getQIcon("view-refresh"))
+ self.__updateStatsAction.setText("update statistics")
+ self.__updateStatsAction.setToolTip("update statistics")
+ self.__updateStatsAction.setCheckable(False)
+ self._updateStatsSep = self.addSeparator()
+ self.addAction(self.__updateStatsAction)
+
+ self._setUpdateMode(mode=updateMode)
+
+ # expose API
+ self.sigUpdateStats = self.__updateStatsAction.triggered
+
+ def isActiveItemMode(self):
+ return self.itemSelection.checkedAction() is self.__displayActiveItems
+
+ def setDisplayActiveItems(self, only_active):
+ self.__displayActiveItems.setChecked(only_active)
+ self.__displayWholeItems.setChecked(not only_active)
+
+ def isVisibleDataRangeMode(self):
+ return self.dataRangeSelection.checkedAction() is self.__useVisibleData
+
+ def setVisibleDataRangeModeEnabled(self, enabled):
+ """Enable/Disable the visible data range mode
+
+ :param bool enabled: True to allow user to choose
+ stats on visible data
+ """
+ self.__useVisibleData.setEnabled(enabled)
+ if not enabled:
+ self.__useWholeData.setChecked(True)
+
+ def _setUpdateMode(self, mode):
+ self.__updateStatsAction.setVisible(mode == UpdateMode.MANUAL)
+ self._updateStatsSep.setVisible(mode == UpdateMode.MANUAL)
+
+ def getUpdateStatsAction(self):
+ """
+
+ :return: the action for the automatic mode
+ :rtype: QAction
+ """
+ return self.__updateStatsAction
+
+
+class StatsWidget(qt.QWidget):
+ """
+ Widget displaying a set of :class:`Stat` to be displayed on a
+ :class:`StatsTable` and to be apply on items contained in the :class:`Plot`
+ Also contains options to:
+
+ * compute statistics on all the data or on visible data only
+ * show statistics of all items or only the active one
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the visibility of this widget changes.
+
+ It Provides the visibility of the widget.
+ """
+
+ NUMBER_FORMAT = "{0:.3f}"
+
+ def __init__(self, parent=None, plot=None, stats=None):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._options = _OptionsWidget(parent=self, updateMode=UpdateMode.MANUAL)
+ self.layout().addWidget(self._options)
+ self._statsTable = StatsTable(parent=self, plot=plot)
+ self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode())
+ self._options._setUpdateMode(mode=self._statsTable.getUpdateMode())
+ self.setStats(stats)
+
+ self.layout().addWidget(self._statsTable)
+
+ old = self._statsTable.blockSignals(True)
+ self._options.itemSelection.triggered.connect(self._optSelectionChanged)
+ self._options.dataRangeSelection.triggered.connect(self._optDataRangeChanged)
+ self._optDataRangeChanged()
+ self._statsTable.blockSignals(old)
+
+ self._statsTable.sigUpdateModeChanged.connect(self._options._setUpdateMode)
+ callback = functools.partial(
+ self._getStatsTable()._updateAllStats, is_request=True
+ )
+ self._options.sigUpdateStats.connect(callback)
+
+ def _getStatsTable(self):
+ """Returns the :class:`StatsTable` used by this widget.
+
+ :rtype: StatsTable
+ """
+ return self._statsTable
+
+ def showEvent(self, event):
+ self.sigVisibilityChanged.emit(True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self.sigVisibilityChanged.emit(False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _optSelectionChanged(self, action=None):
+ self._getStatsTable().setDisplayOnlyActiveItem(self._options.isActiveItemMode())
+
+ def _optDataRangeChanged(self, action=None):
+ self._getStatsTable().setStatsOnVisibleData(
+ self._options.isVisibleDataRangeMode()
+ )
+
+ # Proxy methods
+
+ @docstring(StatsTable)
+ def setStats(self, statsHandler):
+ return self._getStatsTable().setStats(statsHandler=statsHandler)
+
+ @docstring(StatsTable)
+ def setPlot(self, plot):
+ self._options.setVisibleDataRangeModeEnabled(
+ plot is None or isinstance(plot, PlotWidget)
+ )
+ return self._getStatsTable().setPlot(plot=plot)
+
+ @docstring(StatsTable)
+ def getPlot(self):
+ return self._getStatsTable().getPlot()
+
+ @docstring(StatsTable)
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ old = self._options.blockSignals(True)
+ # update the options
+ self._options.setDisplayActiveItems(displayOnlyActItem)
+ self._options.blockSignals(old)
+ return self._getStatsTable().setDisplayOnlyActiveItem(
+ displayOnlyActItem=displayOnlyActItem
+ )
+
+ @docstring(StatsTable)
+ def setStatsOnVisibleData(self, b):
+ return self._getStatsTable().setStatsOnVisibleData(b=b)
+
+ @docstring(StatsTable)
+ def getUpdateMode(self):
+ return self._statsTable.getUpdateMode()
+
+ @docstring(StatsTable)
+ def setUpdateMode(self, mode):
+ self._statsTable.setUpdateMode(mode)
+
+
+DEFAULT_STATS = StatsHandler(
+ (
+ (statsmdl.StatMin(), StatFormatter()),
+ statsmdl.StatCoordMin(),
+ (statsmdl.StatMax(), StatFormatter()),
+ statsmdl.StatCoordMax(),
+ statsmdl.StatCOM(),
+ (("mean", numpy.mean), StatFormatter()),
+ (("std", numpy.std), StatFormatter()),
+ )
+)
+
+
+class BasicStatsWidget(StatsWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`StatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param PlotWidget plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+
+ .. snapshotqt:: img/BasicStatsWidget.png
+ :width: 300px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicStatsWidget(plot=plot)
+ widget.show()
+ """
+
+ def __init__(self, parent=None, plot=None):
+ StatsWidget.__init__(self, parent=parent, plot=plot, stats=DEFAULT_STATS)
+
+
+class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
+ """
+ Widget made to display stats into a QLayout with couple (QLabel, QLineEdit)
+ created for each stats.
+ The layout can be defined prior of adding any statistic.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(
+ self, parent=None, plot=None, kind="curve", stats=None, statsOnVisibleData=False
+ ):
+ self._item_kind = kind
+ """The item displayed"""
+ self._statQlineEdit = {}
+ """list of legends actually displayed"""
+ self._n_statistics_per_line = 4
+ """number of statistics displayed per line in the grid layout"""
+ qt.QWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(
+ self, statsOnVisibleData=statsOnVisibleData, displayOnlyActItem=True
+ )
+ self.setLayout(self._createLayout())
+ self.setPlot(plot)
+ if stats is not None:
+ self.setStats(stats)
+
+ def _addItemForStatistic(self, statistic):
+ assert isinstance(statistic, statsmdl.StatBase)
+ assert statistic.name in self._statsHandler.stats
+
+ self.layout().setSpacing(2)
+ self.layout().setContentsMargins(2, 2, 2, 2)
+
+ if isinstance(self.layout(), qt.QGridLayout):
+ parent = self
+ else:
+ widget = qt.QWidget(parent=self)
+ parent = widget
+
+ qLabel = qt.QLabel(statistic.name + ":", parent=parent)
+ qLineEdit = qt.QLineEdit("", parent=parent)
+ qLineEdit.setReadOnly(True)
+
+ self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit)
+ self._statQlineEdit[statistic.name] = qLineEdit
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateAllStats()
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ raise NotImplementedError("Base class")
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ _StatsWidgetBase.setStats(self, statsHandler)
+ for statName, stat in list(self._statsHandler.stats.items()):
+ self._addItemForStatistic(stat)
+ self._updateAllStats()
+
+ def _activeItemChanged(self, kind, previous, current):
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if kind == self._item_kind:
+ self._updateAllStats()
+
+ def _updateAllStats(self):
+ plot = self.getPlot()
+ if plot is not None:
+ _items = self._plotWrapper.getSelectedItems()
+
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ if len(items) == 1:
+ self._setItem(items[0])
+
+ def setKind(self, kind):
+ """Change the kind of active item to display
+ :param str kind: kind of item to display information for ('curve' ...)
+ """
+ if self._item_kind != kind:
+ self._item_kind = kind
+ self._updateItemObserve()
+
+ def getKind(self):
+ """
+ :return: kind of item we want to compute statistic for
+ :rtype: str
+ """
+ return self._item_kind
+
+ def _setItem(self, item, data_changed=True):
+ if item is None:
+ for stat_name, stat_widget in self._statQlineEdit.items():
+ stat_widget.setText("")
+ elif self._statsHandler is not None and len(self._statsHandler.stats) > 0:
+ plot = self.getPlot()
+ if plot is not None:
+ statsValDict = self._statsHandler.calculate(
+ item, plot, self._statsOnVisibleData, data_changed=data_changed
+ )
+ for statName, statVal in list(statsValDict.items()):
+ self._statQlineEdit[statName].setText(statVal)
+
+ def _updateItemObserve(self, *argv):
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ assert self._displayOnlyActItem
+ _items = self._plotWrapper.getSelectedItems()
+
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ _item = items[0] if len(items) == 1 else None
+ self._setItem(_item, data_changed=True)
+
+ def _updateCurrentItem(self):
+ self._updateItemObserve()
+
+ def _createLayout(self):
+ """create an instance of the main QLayout"""
+ raise NotImplementedError("Base class")
+
+ def _addItem(self, item):
+ raise NotImplementedError("Display only the active item")
+
+ def _removeItem(self, item):
+ raise NotImplementedError("Display only the active item")
+
+ def _plotCurrentChanged(self, current):
+ raise NotImplementedError("Display only the active item")
+
+ def _updateModeHasChanged(self):
+ self.sigUpdateModeChanged.emit(self._updateMode)
+
+
+class _BasicLineStatsWidget(_BaseLineStatsWidget):
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ kind="curve",
+ stats=DEFAULT_STATS,
+ statsOnVisibleData=False,
+ ):
+ _BaseLineStatsWidget.__init__(
+ self,
+ parent=parent,
+ kind=kind,
+ plot=plot,
+ stats=stats,
+ statsOnVisibleData=statsOnVisibleData,
+ )
+
+ def _createLayout(self):
+ return FlowLayout()
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ # create a mother widget to make sure both qLabel & qLineEdit will
+ # always be displayed side by side
+ widget = qt.QWidget(parent=self)
+ widget.setLayout(qt.QHBoxLayout())
+ widget.layout().setSpacing(0)
+ widget.layout().setContentsMargins(0, 0, 0, 0)
+
+ widget.layout().addWidget(qLabel)
+ widget.layout().addWidget(qLineEdit)
+
+ self.layout().addWidget(widget)
+
+ def _addOptionsWidget(self, widget):
+ self.layout().addWidget(widget)
+
+
+class BasicLineStatsWidget(qt.QWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`LineStatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ kind="curve",
+ stats=DEFAULT_STATS,
+ statsOnVisibleData=False,
+ ):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QHBoxLayout())
+ self.layout().setSpacing(0)
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._lineStatsWidget = _BasicLineStatsWidget(
+ parent=self,
+ plot=plot,
+ kind=kind,
+ stats=stats,
+ statsOnVisibleData=statsOnVisibleData,
+ )
+ self.layout().addWidget(self._lineStatsWidget)
+
+ self._options = UpdateModeWidget()
+ self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode())
+ self._options.showRadioButtons(False)
+ self.layout().addWidget(self._options)
+
+ # connect Signal ? SLOT
+ self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode)
+ self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode)
+ self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats)
+
+ def showControl(self, visible):
+ self._options.setVisible(visible)
+
+ # Proxy methods
+
+ @docstring(_BasicLineStatsWidget)
+ def setUpdateMode(self, mode):
+ self._lineStatsWidget.setUpdateMode(mode=mode)
+
+ @docstring(_BasicLineStatsWidget)
+ def getUpdateMode(self):
+ return self._lineStatsWidget.getUpdateMode()
+
+ @docstring(_BasicLineStatsWidget)
+ def setPlot(self, plot):
+ self._lineStatsWidget.setPlot(plot=plot)
+
+ @docstring(_BasicLineStatsWidget)
+ def setStats(self, statsHandler):
+ self._lineStatsWidget.setStats(statsHandler=statsHandler)
+
+ @docstring(_BasicLineStatsWidget)
+ def setKind(self, kind):
+ self._lineStatsWidget.setKind(kind=kind)
+
+ @docstring(_BasicLineStatsWidget)
+ def getKind(self):
+ return self._lineStatsWidget.getKind()
+
+ @docstring(_BasicLineStatsWidget)
+ def setStatsOnVisibleData(self, b):
+ self._lineStatsWidget.setStatsOnVisibleData(b)
+
+ @docstring(UpdateModeWidget)
+ def showRadioButtons(self, show):
+ self._options.showRadioButtons(show=show)
+
+
+class _BasicGridStatsWidget(_BaseLineStatsWidget):
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ kind="curve",
+ stats=DEFAULT_STATS,
+ statsOnVisibleData=False,
+ statsPerLine=4,
+ ):
+ _BaseLineStatsWidget.__init__(
+ self,
+ parent=parent,
+ kind=kind,
+ plot=plot,
+ stats=stats,
+ statsOnVisibleData=statsOnVisibleData,
+ )
+ self._n_statistics_per_line = statsPerLine
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ column = len(self._statQlineEdit) % self._n_statistics_per_line
+ row = len(self._statQlineEdit) // self._n_statistics_per_line
+ self.layout().addWidget(qLabel, row, column * 2)
+ self.layout().addWidget(qLineEdit, row, column * 2 + 1)
+
+ def _createLayout(self):
+ return qt.QGridLayout()
+
+
+class BasicGridStatsWidget(qt.QWidget):
+ """
+ pymca design like widget
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param str kind: the kind of plotitems we want to display
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ :param int statsPerLine: number of statistic to be displayed per line
+
+ .. snapshotqt:: img/BasicGridStatsWidget.png
+ :width: 600px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicGridStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicGridStatsWidget(plot=plot, kind='curve')
+ widget.show()
+ """
+
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ kind="curve",
+ stats=DEFAULT_STATS,
+ statsOnVisibleData=False,
+ ):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setSpacing(0)
+ self.layout().setContentsMargins(0, 0, 0, 0)
+
+ self._options = UpdateModeWidget()
+ self._options.showRadioButtons(False)
+ self.layout().addWidget(self._options)
+
+ self._lineStatsWidget = _BasicGridStatsWidget(
+ parent=self,
+ plot=plot,
+ kind=kind,
+ stats=stats,
+ statsOnVisibleData=statsOnVisibleData,
+ )
+ self.layout().addWidget(self._lineStatsWidget)
+
+ # tune options
+ self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode())
+
+ # connect Signal ? SLOT
+ self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode)
+ self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode)
+ self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats)
+
+ def showControl(self, visible):
+ self._options.setVisible(visible)
+
+ @docstring(_BasicGridStatsWidget)
+ def setUpdateMode(self, mode):
+ self._lineStatsWidget.setUpdateMode(mode=mode)
+
+ @docstring(_BasicGridStatsWidget)
+ def getUpdateMode(self):
+ return self._lineStatsWidget.getUpdateMode()
+
+ @docstring(_BasicGridStatsWidget)
+ def setPlot(self, plot):
+ self._lineStatsWidget.setPlot(plot=plot)
+
+ @docstring(_BasicGridStatsWidget)
+ def setStats(self, statsHandler):
+ self._lineStatsWidget.setStats(statsHandler=statsHandler)
+
+ @docstring(_BasicGridStatsWidget)
+ def setKind(self, kind):
+ self._lineStatsWidget.setKind(kind=kind)
+
+ @docstring(_BasicGridStatsWidget)
+ def getKind(self):
+ return self._lineStatsWidget.getKind()
+
+ @docstring(_BasicGridStatsWidget)
+ def setStatsOnVisibleData(self, b):
+ self._lineStatsWidget.setStatsOnVisibleData(b)
+
+ @docstring(UpdateModeWidget)
+ def showRadioButtons(self, show):
+ self._options.showRadioButtons(show=show)
diff --git a/src/silx/gui/plot/_BaseMaskToolsWidget.py b/src/silx/gui/plot/_BaseMaskToolsWidget.py
new file mode 100644
index 0000000..6b98289
--- /dev/null
+++ b/src/silx/gui/plot/_BaseMaskToolsWidget.py
@@ -0,0 +1,1311 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module is a collection of base classes used in modules
+:mod:`.MaskToolsWidget` (images) and :mod:`.ScatterMaskToolsWidget`
+"""
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+import os
+import weakref
+
+import numpy
+
+from silx.gui import qt, icons
+from silx.gui.widgets.FloatEdit import FloatEdit
+from silx.gui.colors import Colormap
+from silx.gui.colors import rgba
+from .actions.mode import PanModeAction
+
+
+class BaseMask(qt.QObject):
+ """Base class for :class:`ImageMask` and :class:`ScatterMask`
+
+ A mask field with update operations.
+
+ A mask is an array of the same shape as some underlying data. The mask
+ array stores integer values in the range 0-255, to allow for 254 levels
+ of mask (value 0 is reserved for unmasked data).
+
+ The mask is updated using spatial selection methods: data located inside
+ a selected area is masked with a specified mask level.
+
+ """
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the mask has changed"""
+
+ sigStateChanged = qt.Signal()
+ """Signal emitted for each mask commit/undo/redo operation"""
+
+ sigUndoable = qt.Signal(bool)
+ """Signal emitted when undo becomes possible/impossible"""
+
+ sigRedoable = qt.Signal(bool)
+ """Signal emitted when redo becomes possible/impossible"""
+
+ def __init__(self, dataItem=None):
+ self.historyDepth = 10
+ """Maximum number of operation stored in history list for undo"""
+ # Init lists for undo/redo
+ self._history = []
+ self._redo = []
+
+ # Store the mask
+ self._mask = numpy.array((), dtype=numpy.uint8)
+
+ # Store the plot item to be masked
+ self._dataItem = None
+ if dataItem is not None:
+ self.setDataItem(dataItem)
+ self.reset(self.getDataValues().shape)
+ super(BaseMask, self).__init__()
+
+ def setDataItem(self, item):
+ """Set a data item
+
+ :param item: A plot item, subclass of :class:`silx.gui.plot.items.Item`
+ :return:
+ """
+ self._dataItem = item
+
+ def getDataItem(self):
+ """Returns current plot item the mask is on.
+
+ :rtype: Union[~silx.gui.plot.items.Item,None]
+ """
+ return self._dataItem
+
+ def getDataValues(self):
+ """Return data values, as a numpy array with the same shape
+ as the mask.
+
+ This method must be implemented in a subclass, as the way of
+ accessing data depends on the data item passed to :meth:`setDataItem`
+
+ :return: Data values associated with the data item.
+ :rtype: numpy.ndarray
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def _notify(self):
+ """Notify of mask change."""
+ self.sigChanged.emit()
+
+ def getMask(self, copy=True):
+ """Get the current mask as a numpy array.
+
+ :param bool copy: True (default) to get a copy of the mask.
+ If False, the returned array MUST not be modified.
+ :return: The array of the mask with dimension of the data to be masked.
+ :rtype: numpy.ndarray of uint8
+ """
+ return numpy.array(self._mask, copy=copy)
+
+ def setMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ """
+ self._mask = numpy.array(mask, copy=copy, order="C", dtype=numpy.uint8)
+ self._notify()
+
+ # History control
+ def resetHistory(self):
+ """Reset history"""
+ self._history = [numpy.array(self._mask, copy=True)]
+ self._redo = []
+ self.sigUndoable.emit(False)
+ self.sigRedoable.emit(False)
+
+ def commit(self):
+ """Append the current mask to history if changed"""
+ if (
+ not self._history
+ or self._redo
+ or not numpy.array_equal(self._mask, self._history[-1])
+ ):
+ if self._redo:
+ self._redo = [] # Reset redo as a new action as been performed
+ self.sigRedoable[bool].emit(False)
+
+ while len(self._history) >= self.historyDepth:
+ self._history.pop(0)
+ self._history.append(numpy.array(self._mask, copy=True))
+
+ if len(self._history) == 2:
+ self.sigUndoable.emit(True)
+ self.sigStateChanged.emit()
+
+ def undo(self):
+ """Restore previous mask if any"""
+ if len(self._history) > 1:
+ self._redo.append(self._history.pop())
+ self._mask = numpy.array(self._history[-1], copy=True)
+ self._notify() # Do not store this change in history
+
+ if len(self._redo) == 1: # First redo
+ self.sigRedoable.emit(True)
+ if len(self._history) == 1: # Last value in history
+ self.sigUndoable.emit(False)
+ self.sigStateChanged.emit()
+
+ def redo(self):
+ """Restore previously undone modification if any"""
+ if self._redo:
+ self._mask = self._redo.pop()
+ self._history.append(numpy.array(self._mask, copy=True))
+ self._notify()
+
+ if not self._redo: # No more redo
+ self.sigRedoable.emit(False)
+ if len(self._history) == 2: # Something to undo
+ self.sigUndoable.emit(True)
+ self.sigStateChanged.emit()
+
+ # Whole mask operations
+
+ def clear(self, level):
+ """Set all values of the given mask level to 0.
+
+ :param int level: Value of the mask to set to 0.
+ """
+ assert 0 < level < 256
+ self._mask[self._mask == level] = 0
+ self._notify()
+
+ def invert(self, level):
+ """Invert mask of the given mask level.
+
+ 0 values become level and level values become 0.
+
+ :param int level: The level to invert.
+ """
+ assert 0 < level < 256
+ masked = self._mask == level
+ self._mask[self._mask == 0] = level
+ self._mask[masked] = 0
+ self._notify()
+
+ def reset(self, shape=None):
+ """Reset the mask to zero and change its shape.
+
+ :param shape: Shape of the new mask with the correct dimensionality
+ with regards to the data dimensionality,
+ or None to have an empty mask
+ :type shape: tuple of int
+ """
+ if shape is None:
+ # assume dimensionality never changes
+ shape = (0,) * len(self._mask.shape) # empty array
+ shapeChanged = shape != self._mask.shape
+ self._mask = numpy.zeros(shape, dtype=numpy.uint8)
+ if shapeChanged:
+ self.resetHistory()
+
+ self._notify()
+
+ # To be implemented
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save (e.g 'npy')
+ :raise Exception: Raised if the file writing fail
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ # update thresholds
+ def updateStencil(self, level, stencil, mask=True):
+ """Mask/Unmask points from boolean mask: all elements that are True
+ in the boolean mask are set to ``level`` (if ``mask=True``) or 0
+ (if ``mask=False``)
+
+ :param int level: Mask level to update.
+ :param stencil: Boolean mask.
+ :type stencil: numpy.array of same dimension as the mask
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ if mask:
+ self._mask[stencil] = level
+ else:
+ self._mask[numpy.logical_and(self._mask == level, stencil)] = 0
+ self._notify()
+
+ def updateBelowThreshold(self, level, threshold, mask=True):
+ """Mask/unmask all points whose values are below a threshold.
+
+ :param int level:
+ :param float threshold: Threshold
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ self.updateStencil(level, self.getDataValues() < threshold, mask)
+
+ def updateBetweenThresholds(self, level, min_, max_, mask=True):
+ """Mask/unmask all points whose values are in a range.
+
+ :param int level:
+ :param float min_: Lower threshold
+ :param float max_: Upper threshold
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ stencil = numpy.logical_and(
+ min_ <= self.getDataValues(), self.getDataValues() <= max_
+ )
+ self.updateStencil(level, stencil, mask)
+
+ def updateAboveThreshold(self, level, threshold, mask=True):
+ """Mask/unmask all points whose values are above a threshold.
+
+ :param int level: Mask level to update.
+ :param float threshold: Threshold.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ self.updateStencil(level, self.getDataValues() > threshold, mask)
+
+ def updateNotFinite(self, level, mask=True):
+ """Mask/unmask all points whose values are not finite.
+
+ :param int level: Mask level to update.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ self.updateStencil(
+ level, numpy.logical_not(numpy.isfinite(self.getDataValues())), mask
+ )
+
+ # Drawing operations:
+ def updateRectangle(self, level, row, col, height, width, mask=True):
+ """Mask/Unmask data inside a rectangle, with the given mask level.
+
+ :param int level: Mask level to update, in range 1-255.
+ :param row: Starting row/y of the rectangle
+ :param col: Starting column/x of the rectangle
+ :param height:
+ :param width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask data inside a polygon, with the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (row, col) / (y, x)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updatePoints(self, level, rows, cols, mask=True):
+ """Mask/Unmask points with given coordinates.
+
+ :param int level: Mask level to update.
+ :param rows: Rows/ordinates (y) of selected points
+ :type rows: 1D numpy.ndarray
+ :param cols: Columns/abscissa (x) of selected points
+ :type cols: 1D numpy.ndarray
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updateDisk(self, level, crow, ccol, radius, mask=True):
+ """Mask/Unmask data located inside a dick of the given mask level.
+
+ :param int level: Mask level to update.
+ :param crow: Disk center row/ordinate (y).
+ :param ccol: Disk center column/abscissa.
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
+ """Mask/Unmask a line of the given mask level.
+
+ :param int level: Mask level to update.
+ :param row0: Row/y of the starting point.
+ :param col0: Column/x of the starting point.
+ :param row1: Row/y of the end point.
+ :param col1: Column/x of the end point.
+ :param width: Width of the line in mask array unit.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+
+class BaseMaskToolsWidget(qt.QWidget):
+ """Base class for :class:`MaskToolsWidget` (image mask) and
+ :class:`scatterMaskToolsWidget`"""
+
+ sigMaskChanged = qt.Signal()
+ _maxLevelNumber = 255
+
+ def __init__(self, parent=None, plot=None, mask=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Plot widget on which to operate
+ :param mask: Instance of subclass of :class:`BaseMask`
+ (e.g. :class:`ImageMask`)
+ """
+ super(BaseMaskToolsWidget, self).__init__(parent)
+ # register if the user as force a color for the corresponding mask level
+ self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=bool)
+ # overlays colors set by the user
+ self._overlayColors = numpy.zeros(
+ (self._maxLevelNumber + 1, 3), dtype=numpy.float32
+ )
+
+ # as parent have to be the first argument of the widget to fit
+ # QtDesigner need but here plot can't be None by default.
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._maskName = "__MASK_TOOLS_%d" % id(self) # Legend of the mask
+
+ self._colormap = Colormap(
+ normalization="linear", vmin=0, vmax=self._maxLevelNumber
+ )
+ self._defaultOverlayColor = rgba("gray") # Color of the mask
+ self._setMaskColors(1, 0.5) # Set the colormap LUT
+
+ if not isinstance(mask, BaseMask):
+ raise TypeError("mask is not an instance of BaseMask")
+ self._mask = mask
+
+ self._mask.sigChanged.connect(self._updatePlotMask)
+ self._mask.sigChanged.connect(self._emitSigMaskChanged)
+
+ self._drawingMode = None # Store current drawing mode
+ self._lastPencilPos = None
+ self._multipleMasks = "exclusive"
+
+ self._maskFileDir = qt.QDir.current().absolutePath()
+ self.plot.sigInteractiveModeChanged.connect(self._interactiveModeChanged)
+
+ self._initWidgets()
+
+ def _emitSigMaskChanged(self):
+ """Notify mask changes"""
+ self.sigMaskChanged.emit()
+
+ def getMaskedItem(self):
+ """Returns the item that is currently being masked
+
+ :rtype: Union[~silx.gui.plot.items.Item,None]
+ """
+ return self._mask.getDataItem()
+
+ def getSelectionMask(self, copy=True):
+ """Get the current mask as a numpy array.
+
+ :param bool copy: True (default) to get a copy of the mask.
+ If False, the returned array MUST not be modified.
+ :return: The mask (as an array of uint8) with dimension of
+ the 'active' plot item.
+ If there is no active image or scatter, it returns None.
+ :rtype: Union[numpy.ndarray,None]
+ """
+ mask = self._mask.getMask(copy=copy)
+ return None if mask.size == 0 else mask
+
+ def setSelectionMask(self, mask):
+ """Set the mask: Must be implemented in subclass"""
+ raise NotImplementedError()
+
+ def resetSelectionMask(self):
+ """Reset the mask: Must be implemented in subclass"""
+ raise NotImplementedError()
+
+ def multipleMasks(self):
+ """Return the current mode of multiple masks support.
+
+ See :meth:`setMultipleMasks`
+ """
+ return self._multipleMasks
+
+ def setMultipleMasks(self, mode):
+ """Set the mode of multiple masks support.
+
+ Available modes:
+
+ - 'single': Edit a single level of mask
+ - 'exclusive': Supports to 256 levels of non overlapping masks
+
+ :param str mode: The mode to use
+ """
+ assert mode in ("exclusive", "single")
+ if mode != self._multipleMasks:
+ self._multipleMasks = mode
+ self._levelWidget.setVisible(self._multipleMasks != "single")
+ self._clearAllBtn.setVisible(self._multipleMasks != "single")
+
+ def setMaskFileDirectory(self, path):
+ """Set the default directory to use by load/save GUI tools
+
+ The directory is also updated by the user, if he change the location
+ of the dialog.
+ """
+ self.maskFileDir = path
+
+ def getMaskFileDirectory(self):
+ """Get the default directory used by load/save GUI tools"""
+ return self.maskFileDir
+
+ @property
+ def maskFileDir(self):
+ """The directory from which to load/save mask from/to files."""
+ if not os.path.isdir(self._maskFileDir):
+ self._maskFileDir = qt.QDir.current().absolutePath()
+ return self._maskFileDir
+
+ @maskFileDir.setter
+ def maskFileDir(self, maskFileDir):
+ self._maskFileDir = str(maskFileDir)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWindow` this widget is attached to."""
+ plot = self._plotRef()
+ if plot is None:
+ raise RuntimeError(
+ "Mask widget attached to a PlotWidget that no longer exists"
+ )
+ return plot
+
+ def setDirection(self, direction=qt.QBoxLayout.LeftToRight):
+ """Set the direction of the layout of the widget
+
+ :param direction: QBoxLayout direction
+ """
+ self.layout().setDirection(direction)
+
+ def _initWidgets(self):
+ """Create widgets"""
+ layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight)
+ layout.addWidget(self._initMaskGroupBox())
+ layout.addWidget(self._initDrawGroupBox())
+ layout.addWidget(self._initThresholdGroupBox())
+ layout.addWidget(self._initOtherToolsGroupBox())
+ layout.addStretch(1)
+ self.setLayout(layout)
+
+ @staticmethod
+ def _hboxWidget(*widgets, **kwargs):
+ """Place widgets in widget with horizontal layout
+
+ :param widgets: Widgets to position horizontally
+ :param bool stretch: True for trailing stretch (default),
+ False for no trailing stretch
+ :return: A QWidget with a QHBoxLayout
+ """
+ stretch = kwargs.get("stretch", True)
+
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ for widget in widgets:
+ layout.addWidget(widget)
+ if stretch:
+ layout.addStretch(1)
+ widget = qt.QWidget()
+ widget.setLayout(layout)
+ return widget
+
+ def _initTransparencyWidget(self):
+ """Init the mask transparency widget"""
+ transparencyWidget = qt.QWidget(parent=self)
+ grid = qt.QGridLayout()
+ grid.setContentsMargins(0, 0, 0, 0)
+ self.transparencySlider = qt.QSlider(
+ qt.Qt.Horizontal, parent=transparencyWidget
+ )
+ self.transparencySlider.setRange(3, 10)
+ self.transparencySlider.setValue(8)
+ self.transparencySlider.setToolTip("Set the transparency of the mask display")
+ self.transparencySlider.valueChanged.connect(self._updateColors)
+ grid.addWidget(qt.QLabel("Display:", parent=transparencyWidget), 0, 0)
+ grid.addWidget(self.transparencySlider, 0, 1, 1, 3)
+ grid.addWidget(
+ qt.QLabel("<small><b>Transparent</b></small>", parent=transparencyWidget),
+ 1,
+ 1,
+ )
+ grid.addWidget(
+ qt.QLabel("<small><b>Opaque</b></small>", parent=transparencyWidget), 1, 3
+ )
+ transparencyWidget.setLayout(grid)
+ return transparencyWidget
+
+ def _initMaskGroupBox(self):
+ """Init general mask operation widgets"""
+
+ # Mask level
+ self.levelSpinBox = qt.QSpinBox()
+ self.levelSpinBox.setRange(1, self._maxLevelNumber)
+ self.levelSpinBox.setToolTip(
+ "Choose which mask level is edited.\n"
+ "A mask can have up to 255 non-overlapping levels."
+ )
+ self.levelSpinBox.valueChanged[int].connect(self._updateColors)
+ self._levelWidget = self._hboxWidget(
+ qt.QLabel("Mask level:"), self.levelSpinBox
+ )
+ # Transparency
+ self._transparencyWidget = self._initTransparencyWidget()
+
+ style = qt.QApplication.style()
+
+ def getIcon(*identifiyers):
+ for i in identifiyers:
+ if isinstance(i, str):
+ if qt.QIcon.hasThemeIcon(i):
+ return qt.QIcon.fromTheme(i)
+ elif isinstance(i, qt.QIcon):
+ return i
+ else:
+ return style.standardIcon(i)
+ return qt.QIcon()
+
+ undoAction = qt.QAction(self)
+ undoAction.setText("Undo")
+ icon = getIcon("edit-undo", qt.QStyle.SP_ArrowBack)
+ undoAction.setIcon(icon)
+ undoAction.setShortcut(qt.QKeySequence.Undo)
+ undoAction.setToolTip(
+ "Undo last mask change <b>%s</b>" % undoAction.shortcut().toString()
+ )
+ self._mask.sigUndoable.connect(undoAction.setEnabled)
+ undoAction.triggered.connect(self._mask.undo)
+
+ redoAction = qt.QAction(self)
+ redoAction.setText("Redo")
+ icon = getIcon("edit-redo", qt.QStyle.SP_ArrowForward)
+ redoAction.setIcon(icon)
+ redoAction.setShortcut(qt.QKeySequence.Redo)
+ redoAction.setToolTip(
+ "Redo last undone mask change <b>%s</b>" % redoAction.shortcut().toString()
+ )
+ self._mask.sigRedoable.connect(redoAction.setEnabled)
+ redoAction.triggered.connect(self._mask.redo)
+
+ loadAction = qt.QAction(self)
+ loadAction.setText("Load...")
+ icon = icons.getQIcon("document-open")
+ loadAction.setIcon(icon)
+ loadAction.setToolTip("Load mask from file")
+ loadAction.triggered.connect(self._loadMask)
+
+ saveAction = qt.QAction(self)
+ saveAction.setText("Save...")
+ icon = icons.getQIcon("document-save")
+ saveAction.setIcon(icon)
+ saveAction.setToolTip("Save mask to file")
+ saveAction.triggered.connect(self._saveMask)
+
+ invertAction = qt.QAction(self)
+ invertAction.setText("Invert")
+ icon = icons.getQIcon("mask-invert")
+ invertAction.setIcon(icon)
+ invertAction.setShortcut(qt.QKeySequence(qt.Qt.CTRL | qt.Qt.Key_I))
+ invertAction.setToolTip(
+ "Invert current mask <b>%s</b>" % invertAction.shortcut().toString()
+ )
+ invertAction.triggered.connect(self._handleInvertMask)
+
+ clearAction = qt.QAction(self)
+ clearAction.setText("Clear")
+ icon = icons.getQIcon("mask-clear")
+ clearAction.setIcon(icon)
+ clearAction.setShortcut(qt.QKeySequence.Delete)
+ clearAction.setToolTip(
+ "Clear current mask level <b>%s</b>" % clearAction.shortcut().toString()
+ )
+ clearAction.triggered.connect(self._handleClearMask)
+
+ clearAllAction = qt.QAction(self)
+ clearAllAction.setText("Clear all")
+ icon = icons.getQIcon("mask-clear-all")
+ clearAllAction.setIcon(icon)
+ clearAllAction.setToolTip("Clear all mask levels")
+ clearAllAction.triggered.connect(self.resetSelectionMask)
+
+ # Buttons group
+ margin1 = qt.QWidget(self)
+ margin1.setMinimumWidth(6)
+ margin2 = qt.QWidget(self)
+ margin2.setMinimumWidth(6)
+
+ actions = (
+ loadAction,
+ saveAction,
+ margin1,
+ undoAction,
+ redoAction,
+ margin2,
+ invertAction,
+ clearAction,
+ clearAllAction,
+ )
+ widgets = []
+ for action in actions:
+ if isinstance(action, qt.QWidget):
+ widgets.append(action)
+ continue
+ btn = qt.QToolButton()
+ btn.setDefaultAction(action)
+ widgets.append(btn)
+ if action is clearAllAction:
+ self._clearAllBtn = btn
+ container = self._hboxWidget(*widgets)
+ container.layout().setSpacing(1)
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(container)
+ layout.addWidget(self._levelWidget)
+ layout.addWidget(self._transparencyWidget)
+ layout.addStretch(1)
+
+ maskGroup = qt.QGroupBox("Mask")
+ maskGroup.setLayout(layout)
+ return maskGroup
+
+ def isMaskInteractionActivated(self):
+ """Returns true if any mask interaction is activated"""
+ return self.drawActionGroup.checkedAction() is not None
+
+ def _initDrawGroupBox(self):
+ """Init drawing tools widgets"""
+ layout = qt.QVBoxLayout()
+
+ self.browseAction = PanModeAction(self.plot, self.plot)
+ self.addAction(self.browseAction)
+
+ # Draw tools
+ self.rectAction = qt.QAction(
+ icons.getQIcon("shape-rectangle"), "Rectangle selection", self
+ )
+ self.rectAction.setToolTip(
+ "Rectangle selection tool: (Un)Mask a rectangular region <b>R</b>"
+ )
+ self.rectAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R))
+ self.rectAction.setCheckable(True)
+ self.rectAction.triggered.connect(self._activeRectMode)
+ self.addAction(self.rectAction)
+
+ self.ellipseAction = qt.QAction(
+ icons.getQIcon("shape-ellipse"), "Circle selection", self
+ )
+ self.ellipseAction.setToolTip(
+ "Rectangle selection tool: (Un)Mask a circle region <b>R</b>"
+ )
+ self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R))
+ self.ellipseAction.setCheckable(True)
+ self.ellipseAction.triggered.connect(self._activeEllipseMode)
+ self.addAction(self.ellipseAction)
+
+ self.polygonAction = qt.QAction(
+ icons.getQIcon("shape-polygon"), "Polygon selection", self
+ )
+ self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S))
+ self.polygonAction.setToolTip(
+ "Polygon selection tool: (Un)Mask a polygonal region <b>S</b><br>"
+ "Left-click to place new polygon corners<br>"
+ "Left-click on first corner to close the polygon"
+ )
+ self.polygonAction.setCheckable(True)
+ self.polygonAction.triggered.connect(self._activePolygonMode)
+ self.addAction(self.polygonAction)
+
+ self.pencilAction = qt.QAction(
+ icons.getQIcon("draw-pencil"), "Pencil tool", self
+ )
+ self.pencilAction.setShortcut(qt.QKeySequence(qt.Qt.Key_P))
+ self.pencilAction.setToolTip("Pencil tool: (Un)Mask using a pencil <b>P</b>")
+ self.pencilAction.setCheckable(True)
+ self.pencilAction.triggered.connect(self._activePencilMode)
+ self.addAction(self.pencilAction)
+
+ self.drawActionGroup = qt.QActionGroup(self)
+ self.drawActionGroup.setExclusive(True)
+ self.drawActionGroup.addAction(self.rectAction)
+ self.drawActionGroup.addAction(self.ellipseAction)
+ self.drawActionGroup.addAction(self.polygonAction)
+ self.drawActionGroup.addAction(self.pencilAction)
+
+ actions = (
+ self.browseAction,
+ self.rectAction,
+ self.ellipseAction,
+ self.polygonAction,
+ self.pencilAction,
+ )
+ drawButtons = []
+ for action in actions:
+ btn = qt.QToolButton()
+ btn.setDefaultAction(action)
+ drawButtons.append(btn)
+ container = self._hboxWidget(*drawButtons)
+ layout.addWidget(container)
+
+ # Mask/Unmask radio buttons
+ maskRadioBtn = qt.QRadioButton("Mask")
+ maskRadioBtn.setToolTip(
+ "Drawing masks with current level. Press <b>Ctrl</b> to unmask"
+ )
+ maskRadioBtn.setChecked(True)
+
+ unmaskRadioBtn = qt.QRadioButton("Unmask")
+ unmaskRadioBtn.setToolTip(
+ "Drawing unmasks with current level. Press <b>Ctrl</b> to mask"
+ )
+
+ self.maskStateGroup = qt.QButtonGroup()
+ self.maskStateGroup.addButton(maskRadioBtn, 1)
+ self.maskStateGroup.addButton(unmaskRadioBtn, 0)
+
+ self.maskStateWidget = self._hboxWidget(maskRadioBtn, unmaskRadioBtn)
+ layout.addWidget(self.maskStateWidget)
+
+ self.maskStateWidget.setHidden(True)
+
+ # Pencil settings
+ self.pencilSetting = self._createPencilSettings(None)
+ self.pencilSetting.setVisible(False)
+ layout.addWidget(self.pencilSetting)
+
+ layout.addStretch(1)
+
+ drawGroup = qt.QGroupBox("Draw tools")
+ drawGroup.setLayout(layout)
+ return drawGroup
+
+ def _createPencilSettings(self, parent=None):
+ pencilSetting = qt.QWidget(parent)
+
+ self.pencilSpinBox = qt.QSpinBox(parent=pencilSetting)
+ self.pencilSpinBox.setRange(1, 1024)
+ pencilToolTip = """Set pencil drawing tool size in pixels of the image
+ on which to make the mask."""
+ self.pencilSpinBox.setToolTip(pencilToolTip)
+
+ self.pencilSlider = qt.QSlider(qt.Qt.Horizontal, parent=pencilSetting)
+ self.pencilSlider.setRange(1, 50)
+ self.pencilSlider.setToolTip(pencilToolTip)
+
+ pencilLabel = qt.QLabel("Pencil size:", parent=pencilSetting)
+
+ layout = qt.QGridLayout()
+ layout.addWidget(pencilLabel, 0, 0)
+ layout.addWidget(self.pencilSpinBox, 0, 1)
+ layout.addWidget(self.pencilSlider, 1, 1)
+ pencilSetting.setLayout(layout)
+
+ self.pencilSpinBox.valueChanged.connect(self._pencilWidthChanged)
+ self.pencilSlider.valueChanged.connect(self._pencilWidthChanged)
+
+ return pencilSetting
+
+ def _initThresholdGroupBox(self):
+ """Init thresholding widgets"""
+
+ self.belowThresholdAction = qt.QAction(
+ icons.getQIcon("plot-roi-below"), "Mask below threshold", self
+ )
+ self.belowThresholdAction.setToolTip(
+ "Mask image where values are below given threshold"
+ )
+ self.belowThresholdAction.setCheckable(True)
+ self.belowThresholdAction.setChecked(True)
+
+ self.betweenThresholdAction = qt.QAction(
+ icons.getQIcon("plot-roi-between"), "Mask within range", self
+ )
+ self.betweenThresholdAction.setToolTip(
+ "Mask image where values are within given range"
+ )
+ self.betweenThresholdAction.setCheckable(True)
+
+ self.aboveThresholdAction = qt.QAction(
+ icons.getQIcon("plot-roi-above"), "Mask above threshold", self
+ )
+ self.aboveThresholdAction.setToolTip(
+ "Mask image where values are above given threshold"
+ )
+ self.aboveThresholdAction.setCheckable(True)
+
+ self.thresholdActionGroup = qt.QActionGroup(self)
+ self.thresholdActionGroup.setExclusive(True)
+ self.thresholdActionGroup.addAction(self.belowThresholdAction)
+ self.thresholdActionGroup.addAction(self.betweenThresholdAction)
+ self.thresholdActionGroup.addAction(self.aboveThresholdAction)
+ self.thresholdActionGroup.triggered.connect(self._thresholdActionGroupTriggered)
+
+ self.loadColormapRangeAction = qt.QAction(
+ icons.getQIcon("view-refresh"), "Set min-max from colormap", self
+ )
+ self.loadColormapRangeAction.setToolTip(
+ "Set min and max values from current colormap range"
+ )
+ self.loadColormapRangeAction.setCheckable(False)
+ self.loadColormapRangeAction.triggered.connect(
+ self._loadRangeFromColormapTriggered
+ )
+
+ widgets = []
+ for action in self.thresholdActionGroup.actions():
+ btn = qt.QToolButton()
+ btn.setDefaultAction(action)
+ widgets.append(btn)
+
+ spacer = qt.QWidget(parent=self)
+ spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Preferred)
+ widgets.append(spacer)
+
+ loadColormapRangeBtn = qt.QToolButton()
+ loadColormapRangeBtn.setDefaultAction(self.loadColormapRangeAction)
+ widgets.append(loadColormapRangeBtn)
+
+ toolBar = self._hboxWidget(*widgets, stretch=False)
+
+ config = qt.QGridLayout()
+ config.setContentsMargins(0, 0, 0, 0)
+
+ self.minLineLabel = qt.QLabel("Min:", self)
+ self.minLineEdit = FloatEdit(self, value=0)
+ config.addWidget(self.minLineLabel, 0, 0)
+ config.addWidget(self.minLineEdit, 0, 1)
+
+ self.maxLineLabel = qt.QLabel("Max:", self)
+ self.maxLineEdit = FloatEdit(self, value=0)
+ config.addWidget(self.maxLineLabel, 1, 0)
+ config.addWidget(self.maxLineEdit, 1, 1)
+
+ self.applyMaskBtn = qt.QPushButton("Apply mask")
+ self.applyMaskBtn.clicked.connect(self._maskBtnClicked)
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(toolBar)
+ layout.addLayout(config)
+ layout.addWidget(self.applyMaskBtn)
+ layout.addStretch(1)
+
+ self.thresholdGroup = qt.QGroupBox("Threshold")
+ self.thresholdGroup.setLayout(layout)
+
+ # Init widget state
+ self._thresholdActionGroupTriggered(self.belowThresholdAction)
+ return self.thresholdGroup
+
+ # track widget visibility and plot active image changes
+
+ def _initOtherToolsGroupBox(self):
+ layout = qt.QVBoxLayout()
+
+ self.maskNanBtn = qt.QPushButton("Mask not finite values")
+ self.maskNanBtn.setToolTip("Mask Not a Number and infinite values")
+ self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked)
+ layout.addWidget(self.maskNanBtn)
+ layout.addStretch(1)
+
+ self.otherToolGroup = qt.QGroupBox("Other tools")
+ self.otherToolGroup.setLayout(layout)
+ return self.otherToolGroup
+
+ def changeEvent(self, event):
+ """Reset drawing action when disabling widget"""
+ if (
+ event.type() == qt.QEvent.EnabledChange
+ and not self.isEnabled()
+ and self.drawActionGroup.checkedAction()
+ ):
+ # Disable drawing tool by reseting interaction to pan or zoom
+ self.plot.resetInteractiveMode()
+
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save in 'edf', 'tif', 'npy'
+ :raise Exception: Raised if the process fails
+ """
+ self._mask.save(filename, kind)
+
+ def getCurrentMaskColor(self):
+ """Returns the color of the current selected level.
+
+ :rtype: A tuple or a python array
+ """
+ currentLevel = self.levelSpinBox.value()
+ if self._defaultColors[currentLevel]:
+ return self._defaultOverlayColor
+ else:
+ return self._overlayColors[currentLevel].tolist()
+
+ def _setMaskColors(self, level, alpha):
+ """Set-up the mask colormap to highlight current mask level.
+
+ :param int level: The mask level to highlight
+ :param float alpha: Alpha level of mask in [0., 1.]
+ """
+ assert 0 < level <= self._maxLevelNumber
+
+ colors = numpy.empty((self._maxLevelNumber + 1, 4), dtype=numpy.float32)
+
+ # Set color
+ colors[:, :3] = self._defaultOverlayColor[:3]
+
+ # check if some colors has been directly set by the user
+ mask = numpy.equal(self._defaultColors, False)
+ colors[mask, :3] = self._overlayColors[mask, :3]
+
+ # Set alpha
+ colors[:, -1] = alpha / 2.0
+
+ # Set highlighted level color
+ colors[level, 3] = alpha
+
+ # Set no mask level
+ colors[0] = (0.0, 0.0, 0.0, 0.0)
+
+ self._colormap.setColormapLUT(colors)
+
+ def resetMaskColors(self, level=None):
+ """Reset the mask color at the given level to be defaultColors
+
+ :param level:
+ The index of the mask for which we want to reset the color.
+ If none we will reset color for all masks.
+ """
+ if level is None:
+ self._defaultColors[level] = True
+ else:
+ self._defaultColors[:] = True
+
+ self._updateColors()
+
+ def setMaskColors(self, rgb, level=None):
+ """Set the masks color
+
+ :param rgb: The rgb color
+ :param level:
+ The index of the mask for which we want to change the color.
+ If none set this color for all the masks
+ """
+ rgb = rgba(rgb)[0:3]
+ if level is None:
+ self._overlayColors[:] = rgb
+ self._defaultColors[:] = False
+ else:
+ self._overlayColors[level] = rgb
+ self._defaultColors[level] = False
+
+ self._updateColors()
+
+ def getMaskColors(self):
+ """masks colors getter"""
+ return self._overlayColors
+
+ def _updateColors(self, *args):
+ """Rebuild mask colormap when selected level or transparency change"""
+ self._setMaskColors(
+ self.levelSpinBox.value(),
+ self.transparencySlider.value() / self.transparencySlider.maximum(),
+ )
+ self._updatePlotMask()
+ self._updateInteractiveMode()
+
+ def _pencilWidthChanged(self, width):
+ old = self.pencilSpinBox.blockSignals(True)
+ try:
+ self.pencilSpinBox.setValue(width)
+ finally:
+ self.pencilSpinBox.blockSignals(old)
+
+ old = self.pencilSlider.blockSignals(True)
+ try:
+ self.pencilSlider.setValue(width)
+ finally:
+ self.pencilSlider.blockSignals(old)
+ self._updateInteractiveMode()
+
+ def _updateInteractiveMode(self):
+ """Update the current mode to the same if some cached data have to be
+ updated. It is the case for the color for example.
+ """
+ if self._drawingMode == "rectangle":
+ self._activeRectMode()
+ elif self._drawingMode == "ellipse":
+ self._activeEllipseMode()
+ elif self._drawingMode == "polygon":
+ self._activePolygonMode()
+ elif self._drawingMode == "pencil":
+ self._activePencilMode()
+
+ def _handleClearMask(self):
+ """Handle clear button clicked: reset current level mask"""
+ self._mask.clear(self.levelSpinBox.value())
+ self._mask.commit()
+
+ def _handleInvertMask(self):
+ """Invert the current mask level selection."""
+ self._mask.invert(self.levelSpinBox.value())
+ self._mask.commit()
+
+ # Handle drawing tools UI events
+
+ def _interactiveModeChanged(self, source):
+ """Handle plot interactive mode changed:
+
+ If changed from elsewhere, disable drawing tool
+ """
+ if source is not self:
+ self.pencilAction.setChecked(False)
+ self.rectAction.setChecked(False)
+ self.polygonAction.setChecked(False)
+ self._releaseDrawingMode()
+ self._updateDrawingModeWidgets()
+
+ def _releaseDrawingMode(self):
+ """Release the drawing mode if is was used"""
+ if self._drawingMode is None:
+ return
+ self.plot.sigPlotSignal.disconnect(self._plotDrawEvent)
+ self._drawingMode = None
+
+ def _activeRectMode(self):
+ """Handle rect action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = "rectangle"
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode(
+ "draw", shape="rectangle", source=self, color=color
+ )
+ self._updateDrawingModeWidgets()
+
+ def _activeEllipseMode(self):
+ """Handle circle action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = "ellipse"
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode("draw", shape="ellipse", source=self, color=color)
+ self._updateDrawingModeWidgets()
+
+ def _activePolygonMode(self):
+ """Handle polygon action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = "polygon"
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode("draw", shape="polygon", source=self, color=color)
+ self._updateDrawingModeWidgets()
+
+ def _getPencilWidth(self):
+ """Returns the width of the pencil to use in data coordinates`
+
+ :rtype: float
+ """
+ return self.pencilSpinBox.value()
+
+ def _activePencilMode(self):
+ """Handle pencil action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = "pencil"
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ width = self._getPencilWidth()
+ self.plot.setInteractiveMode(
+ "draw", shape="pencil", source=self, color=color, width=width
+ )
+ self._updateDrawingModeWidgets()
+
+ def _updateDrawingModeWidgets(self):
+ self.maskStateWidget.setVisible(self._drawingMode is not None)
+ self.pencilSetting.setVisible(self._drawingMode == "pencil")
+
+ # Handle plot drawing events
+
+ def _isMasking(self):
+ """Returns true if the tool is used for masking, else it is used for
+ unmasking.
+
+ :rtype: bool"""
+ # First draw event, use current modifiers for all draw sequence
+ doMask = self.maskStateGroup.checkedId() == 1
+ if qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier:
+ doMask = not doMask
+ return doMask
+
+ # Handle threshold UI events
+
+ def _thresholdActionGroupTriggered(self, triggeredAction):
+ """Threshold action group listener."""
+ if triggeredAction is self.belowThresholdAction:
+ self.minLineLabel.setVisible(True)
+ self.maxLineLabel.setVisible(False)
+ self.minLineEdit.setVisible(True)
+ self.maxLineEdit.setVisible(False)
+ self.applyMaskBtn.setText("Mask below")
+ elif triggeredAction is self.betweenThresholdAction:
+ self.minLineLabel.setVisible(True)
+ self.maxLineLabel.setVisible(True)
+ self.minLineEdit.setVisible(True)
+ self.maxLineEdit.setVisible(True)
+ self.applyMaskBtn.setText("Mask between")
+ elif triggeredAction is self.aboveThresholdAction:
+ self.minLineLabel.setVisible(False)
+ self.maxLineLabel.setVisible(True)
+ self.minLineEdit.setVisible(False)
+ self.maxLineEdit.setVisible(True)
+ self.applyMaskBtn.setText("Mask above")
+ self.applyMaskBtn.setToolTip(triggeredAction.toolTip())
+
+ def _maskBtnClicked(self):
+ if self.belowThresholdAction.isChecked():
+ if self.minLineEdit.text():
+ self._mask.updateBelowThreshold(
+ self.levelSpinBox.value(), self.minLineEdit.value()
+ )
+ self._mask.commit()
+
+ elif self.betweenThresholdAction.isChecked():
+ if self.minLineEdit.text() and self.maxLineEdit.text():
+ min_ = self.minLineEdit.value()
+ max_ = self.maxLineEdit.value()
+ self._mask.updateBetweenThresholds(
+ self.levelSpinBox.value(), min_, max_
+ )
+ self._mask.commit()
+
+ elif self.aboveThresholdAction.isChecked():
+ if self.maxLineEdit.text():
+ max_ = float(self.maxLineEdit.value())
+ self._mask.updateAboveThreshold(self.levelSpinBox.value(), max_)
+ self._mask.commit()
+
+ def _maskNotFiniteBtnClicked(self):
+ """Handle not finite mask button clicked: mask NaNs and inf"""
+ self._mask.updateNotFinite(self.levelSpinBox.value())
+ self._mask.commit()
+
+
+class BaseMaskToolsDockWidget(qt.QDockWidget):
+ """Base class for :class:`MaskToolsWidget` and
+ :class:`ScatterMaskToolsWidget`.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :paran str name: The title of this widget
+ """
+
+ sigMaskChanged = qt.Signal()
+
+ def __init__(self, parent=None, name="Mask", widget=None):
+ super(BaseMaskToolsDockWidget, self).__init__(parent)
+ self.setWindowTitle(name)
+
+ if not isinstance(widget, BaseMaskToolsWidget):
+ raise TypeError("BaseMaskToolsDockWidget requires a MaskToolsWidget")
+ self.setWidget(widget)
+ self.widget().sigMaskChanged.connect(self._emitSigMaskChanged)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.dockLocationChanged.connect(self._dockLocationChanged)
+ self.topLevelChanged.connect(self._topLevelChanged)
+
+ def _emitSigMaskChanged(self):
+ """Notify mask changes"""
+ # must be connected to self.widget().sigMaskChanged in child class
+ self.sigMaskChanged.emit()
+
+ def getSelectionMask(self, copy=True):
+ """Get the current mask as a 2D array.
+
+ :param bool copy: True (default) to get a copy of the mask.
+ If False, the returned array MUST not be modified.
+ :return: The array of the mask with dimension of the 'active' image.
+ If there is no active image, an empty array is returned.
+ :rtype: 2D numpy.ndarray of uint8
+ """
+ return self.widget().getSelectionMask(copy=copy)
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 2-tuple if successful.
+ The mask can be cropped or padded to fit active image,
+ the returned shape is that of the active image.
+ """
+ return self.widget().setSelectionMask(mask, copy=copy)
+
+ def resetSelectionMask(self):
+ """Reset the mask to an array of zeros with the shape of the
+ current data."""
+ self.widget().resetSelectionMask()
+
+ def toggleViewAction(self):
+ """Returns a checkable action that shows or closes this widget.
+
+ See :class:`QMainWindow`.
+ """
+ action = super(BaseMaskToolsDockWidget, self).toggleViewAction()
+ action.setIcon(icons.getQIcon("image-mask"))
+ action.setToolTip("Display/hide mask tools")
+ return action
+
+ def _dockLocationChanged(self, area):
+ if area in (qt.Qt.LeftDockWidgetArea, qt.Qt.RightDockWidgetArea):
+ direction = qt.QBoxLayout.TopToBottom
+ else:
+ direction = qt.QBoxLayout.LeftToRight
+ self.widget().setDirection(direction)
+
+ def _topLevelChanged(self, topLevel):
+ if topLevel:
+ self.widget().setDirection(qt.QBoxLayout.LeftToRight)
+ self.resize(self.widget().minimumSize())
+ self.adjustSize()
diff --git a/src/silx/gui/plot/__init__.py b/src/silx/gui/plot/__init__.py
new file mode 100644
index 0000000..2a1587f
--- /dev/null
+++ b/src/silx/gui/plot/__init__.py
@@ -0,0 +1,78 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of Qt widgets for plotting curves and images.
+
+The plotting API is inherited from the `PyMca <http://pymca.sourceforge.net/>`_
+plot API and is mostly compatible with it.
+
+Those widgets supports interaction (e.g., zoom, pan, selections).
+
+List of Qt widgets:
+
+.. currentmodule:: silx.gui.plot
+
+- :mod:`.PlotWidget`: A widget displaying a single plot.
+- :mod:`.PlotWindow`: A :mod:`.PlotWidget` with a configurable set of tools.
+- :class:`.Plot1D`: A widget with tools for curves.
+- :class:`.Plot2D`: A widget with tools for images.
+- :class:`.ScatterView`: A widget with tools for scatter plot.
+- :class:`.ImageView`: A widget with tools for images and a side histogram.
+- :class:`.StackView`: A widget with tools for a stack of images.
+
+By default, those widget are using matplotlib_.
+They can optionally use a faster OpenGL-based rendering (beta feature),
+which is enabled by setting the ``backend`` argument to ``'gl'``
+when creating the widgets (See :class:`.PlotWidget`).
+
+.. note::
+
+ This package depends on matplotlib_.
+ The OpenGL backend further depends on
+ `PyOpenGL <http://pyopengl.sourceforge.net/>`_ and OpenGL >= 2.1.
+
+.. _matplotlib: http://matplotlib.org/
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/05/2017"
+
+
+from .PlotWidget import PlotWidget # noqa
+from .PlotWindow import PlotWindow, Plot1D, Plot2D # noqa
+from .items.axis import TickMode
+from .ImageView import ImageView # noqa
+from .StackView import StackView # noqa
+from .ScatterView import ScatterView # noqa
+
+__all__ = [
+ "ImageView",
+ "PlotWidget",
+ "PlotWindow",
+ "Plot1D",
+ "Plot2D",
+ "StackView",
+ "ScatterView",
+ "TickMode",
+]
diff --git a/src/silx/gui/plot/_utils/__init__.py b/src/silx/gui/plot/_utils/__init__.py
new file mode 100644
index 0000000..3075007
--- /dev/null
+++ b/src/silx/gui/plot/_utils/__init__.py
@@ -0,0 +1,92 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Miscellaneous utility functions for the Plot"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
+
+
+import numpy
+
+from .panzoom import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX
+from .panzoom import applyZoomToPlot, applyPan, checkAxisLimits, EnabledAxes
+
+
+def addMarginsToLimits(
+ margins, isXLog, isYLog, xMin, xMax, yMin, yMax, y2Min=None, y2Max=None
+):
+ """Returns updated limits by extending them with margins.
+
+ :param margins: The ratio of the margins to add or None for no margins.
+ :type margins: A 4-tuple of floats as
+ (xMinMargin, xMaxMargin, yMinMargin, yMaxMargin)
+
+ :return: The updated limits
+ :rtype: tuple of 4 or 6 floats: Either (xMin, xMax, yMin, yMax) or
+ (xMin, xMax, yMin, yMax, y2Min, y2Max) if y2Min and y2Max
+ are provided.
+ """
+ if margins is not None:
+ xMinMargin, xMaxMargin, yMinMargin, yMaxMargin = margins
+
+ if not isXLog:
+ xRange = xMax - xMin
+ xMin -= xMinMargin * xRange
+ xMax += xMaxMargin * xRange
+
+ elif xMin > 0.0 and xMax > 0.0: # Log scale
+ # Do not apply margins if limits < 0
+ xMinLog, xMaxLog = numpy.log10(xMin), numpy.log10(xMax)
+ xRangeLog = xMaxLog - xMinLog
+ xMin = pow(10.0, xMinLog - xMinMargin * xRangeLog)
+ xMax = pow(10.0, xMaxLog + xMaxMargin * xRangeLog)
+
+ if not isYLog:
+ yRange = yMax - yMin
+ yMin -= yMinMargin * yRange
+ yMax += yMaxMargin * yRange
+ elif yMin > 0.0 and yMax > 0.0: # Log scale
+ # Do not apply margins if limits < 0
+ yMinLog, yMaxLog = numpy.log10(yMin), numpy.log10(yMax)
+ yRangeLog = yMaxLog - yMinLog
+ yMin = pow(10.0, yMinLog - yMinMargin * yRangeLog)
+ yMax = pow(10.0, yMaxLog + yMaxMargin * yRangeLog)
+
+ if y2Min is not None and y2Max is not None:
+ if not isYLog:
+ yRange = y2Max - y2Min
+ y2Min -= yMinMargin * yRange
+ y2Max += yMaxMargin * yRange
+ elif y2Min > 0.0 and y2Max > 0.0: # Log scale
+ # Do not apply margins if limits < 0
+ yMinLog, yMaxLog = numpy.log10(y2Min), numpy.log10(y2Max)
+ yRangeLog = yMaxLog - yMinLog
+ y2Min = pow(10.0, yMinLog - yMinMargin * yRangeLog)
+ y2Max = pow(10.0, yMaxLog + yMaxMargin * yRangeLog)
+
+ if y2Min is None or y2Max is None:
+ return xMin, xMax, yMin, yMax
+ else:
+ return xMin, xMax, yMin, yMax, y2Min, y2Max
diff --git a/src/silx/gui/plot/_utils/dtime_ticklayout.py b/src/silx/gui/plot/_utils/dtime_ticklayout.py
new file mode 100644
index 0000000..ba0fda7
--- /dev/null
+++ b/src/silx/gui/plot/_utils/dtime_ticklayout.py
@@ -0,0 +1,480 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+from __future__ import annotations
+
+"""This module implements date-time labels layout on graph axes."""
+
+__authors__ = ["P. Kenter"]
+__license__ = "MIT"
+__date__ = "04/04/2018"
+
+
+from collections.abc import Sequence
+import datetime as dt
+import enum
+import logging
+import math
+import time
+
+import dateutil.tz
+
+from dateutil.relativedelta import relativedelta
+
+from .ticklayout import niceNumGeneric
+
+_logger = logging.getLogger(__name__)
+
+
+MICROSECONDS_PER_SECOND = 1000000
+SECONDS_PER_MINUTE = 60
+SECONDS_PER_HOUR = 60 * SECONDS_PER_MINUTE
+SECONDS_PER_DAY = 24 * SECONDS_PER_HOUR
+SECONDS_PER_YEAR = 365.25 * SECONDS_PER_DAY
+SECONDS_PER_MONTH_AVERAGE = SECONDS_PER_YEAR / 12 # Seconds per average month
+
+
+# No dt.timezone in Python 2.7 so we use dateutil.tz.tzutc
+_EPOCH = dt.datetime(1970, 1, 1, tzinfo=dateutil.tz.tzutc())
+
+
+def timestamp(dtObj):
+ """Returns POSIX timestamp of a datetime objects.
+
+ If the dtObj object has a timestamp() method (python 3.3), this is
+ used. Otherwise (e.g. python 2.7) it is calculated here.
+
+ The POSIX timestamp is a floating point value of the number of seconds
+ since the start of an epoch (typically 1970-01-01). For details see:
+ https://docs.python.org/3/library/datetime.html#datetime.datetime.timestamp
+
+ :param datetime.datetime dtObj: date-time representation.
+ :return: POSIX timestamp
+ :rtype: float
+ """
+ if hasattr(dtObj, "timestamp"):
+ return dtObj.timestamp()
+ else:
+ # Back ported from Python 3.5
+ if dtObj.tzinfo is None:
+ return (
+ time.mktime(
+ (
+ dtObj.year,
+ dtObj.month,
+ dtObj.day,
+ dtObj.hour,
+ dtObj.minute,
+ dtObj.second,
+ -1,
+ -1,
+ -1,
+ )
+ )
+ + dtObj.microsecond / 1e6
+ )
+ else:
+ return (dtObj - _EPOCH).total_seconds()
+
+
+@enum.unique
+class DtUnit(enum.Enum):
+ YEARS = 0
+ MONTHS = 1
+ DAYS = 2
+ HOURS = 3
+ MINUTES = 4
+ SECONDS = 5
+ MICRO_SECONDS = 6 # a fraction of a second
+
+
+def getDateElement(dateTime, unit):
+ """Picks the date element with the unit from the dateTime
+
+ E.g. getDateElement(datetime(1970, 5, 6), DtUnit.Day) will return 6
+
+ :param datetime dateTime: date/time to pick from
+ :param DtUnit unit: The unit describing the date element.
+ """
+ if unit == DtUnit.YEARS:
+ return dateTime.year
+ elif unit == DtUnit.MONTHS:
+ return dateTime.month
+ elif unit == DtUnit.DAYS:
+ return dateTime.day
+ elif unit == DtUnit.HOURS:
+ return dateTime.hour
+ elif unit == DtUnit.MINUTES:
+ return dateTime.minute
+ elif unit == DtUnit.SECONDS:
+ return dateTime.second
+ elif unit == DtUnit.MICRO_SECONDS:
+ return dateTime.microsecond
+ else:
+ raise ValueError("Unexpected DtUnit: {}".format(unit))
+
+
+def setDateElement(dateTime, value, unit):
+ """Returns a copy of dateTime with the tickStep unit set to value
+
+ :param datetime.datetime: date time object
+ :param int value: value to set
+ :param DtUnit unit: unit
+ :return: datetime.datetime
+ """
+ intValue = int(value)
+ _logger.debug(
+ "setDateElement({}, {} (int={}), {})".format(dateTime, value, intValue, unit)
+ )
+
+ year = dateTime.year
+ month = dateTime.month
+ day = dateTime.day
+ hour = dateTime.hour
+ minute = dateTime.minute
+ second = dateTime.second
+ microsecond = dateTime.microsecond
+
+ if unit == DtUnit.YEARS:
+ year = intValue
+ elif unit == DtUnit.MONTHS:
+ month = intValue
+ elif unit == DtUnit.DAYS:
+ day = intValue
+ elif unit == DtUnit.HOURS:
+ hour = intValue
+ elif unit == DtUnit.MINUTES:
+ minute = intValue
+ elif unit == DtUnit.SECONDS:
+ second = intValue
+ elif unit == DtUnit.MICRO_SECONDS:
+ microsecond = intValue
+ else:
+ raise ValueError("Unexpected DtUnit: {}".format(unit))
+
+ _logger.debug(
+ "creating date time {}".format(
+ (year, month, day, hour, minute, second, microsecond)
+ )
+ )
+
+ return dt.datetime(
+ year, month, day, hour, minute, second, microsecond, tzinfo=dateTime.tzinfo
+ )
+
+
+def roundToElement(dateTime, unit):
+ """Returns a copy of dateTime rounded to given unit
+
+ :param datetime.datetime: date time object
+ :param DtUnit unit: unit
+ :return: datetime.datetime
+ """
+ year = dateTime.year
+ month = dateTime.month
+ day = dateTime.day
+ hour = dateTime.hour
+ minute = dateTime.minute
+ second = dateTime.second
+ microsecond = dateTime.microsecond
+
+ if unit.value < DtUnit.YEARS.value:
+ pass # Never round years
+ if unit.value < DtUnit.MONTHS.value:
+ month = 1
+ if unit.value < DtUnit.DAYS.value:
+ day = 1
+ if unit.value < DtUnit.HOURS.value:
+ hour = 0
+ if unit.value < DtUnit.MINUTES.value:
+ minute = 0
+ if unit.value < DtUnit.SECONDS.value:
+ second = 0
+ if unit.value < DtUnit.MICRO_SECONDS.value:
+ microsecond = 0
+
+ result = dt.datetime(
+ year, month, day, hour, minute, second, microsecond, tzinfo=dateTime.tzinfo
+ )
+
+ return result
+
+
+def addValueToDate(dateTime, value, unit):
+ """Adds a value with unit to a dateTime.
+
+ Uses dateutil.relativedelta.relativedelta from the standard library to do
+ the actual math. This function doesn't allow for fractional month or years,
+ so month and year are truncated to integers before adding.
+
+ :param datetime dateTime: date time
+ :param float value: value to be added
+ :param DtUnit unit: of the value
+ :return:
+ :raises ValueError: unit is unsupported or result is out of datetime bounds
+ """
+ # logger.debug("addValueToDate({}, {}, {})".format(dateTime, value, unit))
+
+ if unit == DtUnit.YEARS:
+ intValue = int(value) # floats not implemented in relativeDelta(years)
+ return dateTime + relativedelta(years=intValue)
+ elif unit == DtUnit.MONTHS:
+ intValue = int(value) # floats not implemented in relativeDelta(mohths)
+ return dateTime + relativedelta(months=intValue)
+ elif unit == DtUnit.DAYS:
+ return dateTime + relativedelta(days=value)
+ elif unit == DtUnit.HOURS:
+ return dateTime + relativedelta(hours=value)
+ elif unit == DtUnit.MINUTES:
+ return dateTime + relativedelta(minutes=value)
+ elif unit == DtUnit.SECONDS:
+ return dateTime + relativedelta(seconds=value)
+ elif unit == DtUnit.MICRO_SECONDS:
+ return dateTime + relativedelta(microseconds=value)
+ else:
+ raise ValueError("Unexpected DtUnit: {}".format(unit))
+
+
+def bestUnit(durationInSeconds):
+ """Gets the best tick spacing given a duration in seconds.
+
+ :param durationInSeconds: time span duration in seconds
+ :return: DtUnit enumeration.
+ """
+
+ # Based on; https://stackoverflow.com/a/2144398/
+ # If the duration is longer than two years the tick spacing will be in
+ # years. Else, if the duration is longer than two months, the spacing will
+ # be in months, Etcetera.
+ #
+ # This factor differs per unit. As a baseline it is 2, but for instance,
+ # for Months this needs to be higher (3>), This because it is impossible to
+ # have partial months so the tick spacing is always at least 1 month. A
+ # duration of two months would result in two ticks, which is too few.
+ # months would then results
+
+ if durationInSeconds > SECONDS_PER_YEAR * 3:
+ return (durationInSeconds / SECONDS_PER_YEAR, DtUnit.YEARS)
+ elif durationInSeconds > SECONDS_PER_MONTH_AVERAGE * 3:
+ return (durationInSeconds / SECONDS_PER_MONTH_AVERAGE, DtUnit.MONTHS)
+ elif durationInSeconds > SECONDS_PER_DAY * 2:
+ return (durationInSeconds / SECONDS_PER_DAY, DtUnit.DAYS)
+ elif durationInSeconds > SECONDS_PER_HOUR * 2:
+ return (durationInSeconds / SECONDS_PER_HOUR, DtUnit.HOURS)
+ elif durationInSeconds > SECONDS_PER_MINUTE * 2:
+ return (durationInSeconds / SECONDS_PER_MINUTE, DtUnit.MINUTES)
+ elif durationInSeconds > 1 * 2:
+ return (durationInSeconds, DtUnit.SECONDS)
+ else:
+ return (durationInSeconds * MICROSECONDS_PER_SECOND, DtUnit.MICRO_SECONDS)
+
+
+NICE_DATE_VALUES = {
+ DtUnit.YEARS: [1, 2, 5, 10],
+ DtUnit.MONTHS: [1, 2, 3, 4, 6, 12],
+ DtUnit.DAYS: [1, 2, 3, 7, 14, 28],
+ DtUnit.HOURS: [1, 2, 3, 4, 6, 12],
+ DtUnit.MINUTES: [1, 2, 3, 5, 10, 15, 30],
+ DtUnit.SECONDS: [1, 2, 3, 5, 10, 15, 30],
+ DtUnit.MICRO_SECONDS: [1.0, 2.0, 3.0, 4.0, 5.0, 10.0], # floats for microsec
+}
+
+
+def bestFormatString(spacing, unit):
+ """Finds the best format string given the spacing and DtUnit.
+
+ If the spacing is a fractional number < 1 the format string will take this
+ into account
+
+ :param spacing: spacing between ticks
+ :param DtUnit unit:
+ :return: Format string for use in strftime
+ :rtype: str
+ """
+ isSmall = spacing < 1
+
+ if unit == DtUnit.YEARS:
+ return "%Y-m" if isSmall else "%Y"
+ elif unit == DtUnit.MONTHS:
+ return "%Y-%m-%d" if isSmall else "%Y-%m"
+ elif unit == DtUnit.DAYS:
+ return "%H:%M" if isSmall else "%Y-%m-%d"
+ elif unit == DtUnit.HOURS:
+ return "%H:%M" if isSmall else "%H:%M"
+ elif unit == DtUnit.MINUTES:
+ return "%H:%M:%S" if isSmall else "%H:%M"
+ elif unit == DtUnit.SECONDS:
+ return "%S.%f" if isSmall else "%H:%M:%S"
+ elif unit == DtUnit.MICRO_SECONDS:
+ return "%S.%f"
+ else:
+ raise ValueError("Unexpected DtUnit: {}".format(unit))
+
+
+def formatDatetimes(
+ datetimes: Sequence[dt.datetime], spacing: int | None, unit: DtUnit | None
+) -> dict[dt.datetime, str]:
+ """Returns formatted string for each datetime according to tick spacing and time unit"""
+ if spacing is None or unit is None:
+ # Locator has no spacing or units yet: Use elaborate fmtString
+ return {
+ datetime: datetime.strftime("Y-%m-%d %H:%M:%S") for datetime in datetimes
+ }
+
+ formatString = bestFormatString(spacing, unit)
+ if unit != DtUnit.MICRO_SECONDS:
+ return {datetime: datetime.strftime(formatString) for datetime in datetimes}
+
+ # For microseconds: Strip leading/trailing zeros
+ texts = tuple(datetime.strftime(formatString) for datetime in datetimes)
+ nzeros = min(len(text) - len(text.rstrip("0")) for text in texts)
+ return {
+ datetime: text[0 if text[0] != "0" else 1 : -min(nzeros, 5)]
+ for datetime, text in zip(datetimes, texts)
+ }
+
+
+def niceDateTimeElement(value, unit, isRound=False):
+ """Uses the Nice Numbers algorithm to determine a nice value.
+
+ The fractions are optimized for the unit of the date element.
+ """
+
+ niceValues = NICE_DATE_VALUES[unit]
+ elemValue = niceNumGeneric(value, niceValues, isRound=isRound)
+
+ if unit == DtUnit.YEARS or unit == DtUnit.MONTHS:
+ elemValue = max(1, int(elemValue))
+
+ return elemValue
+
+
+def findStartDate(dMin, dMax, nTicks):
+ """Rounds a date down to the nearest nice number of ticks"""
+ assert dMax >= dMin, "dMin ({}) should come before dMax ({})".format(dMin, dMax)
+
+ if dMin == dMax:
+ # Fallback when range is smaller than microsecond resolution
+ return dMin, 1, DtUnit.MICRO_SECONDS
+
+ delta = dMax - dMin
+ lengthSec = delta.total_seconds()
+ _logger.debug(
+ "findStartDate: {}, {} (duration = {} sec, {} days)".format(
+ dMin, dMax, lengthSec, lengthSec / SECONDS_PER_DAY
+ )
+ )
+
+ length, unit = bestUnit(lengthSec)
+ niceLength = niceDateTimeElement(length, unit)
+
+ _logger.debug(
+ "Length: {:8.3f} {} (nice = {})".format(length, unit.name, niceLength)
+ )
+
+ niceSpacing = niceDateTimeElement(niceLength / nTicks, unit, isRound=True)
+
+ _logger.debug(
+ "Spacing: {:8.3f} {} (nice = {})".format(
+ niceLength / nTicks, unit.name, niceSpacing
+ )
+ )
+
+ dVal = getDateElement(dMin, unit)
+
+ if unit == DtUnit.MONTHS: # TODO: better rounding?
+ niceVal = math.floor((dVal - 1) / niceSpacing) * niceSpacing + 1
+ elif unit == DtUnit.DAYS:
+ niceVal = math.floor((dVal - 1) / niceSpacing) * niceSpacing + 1
+ else:
+ niceVal = math.floor(dVal / niceSpacing) * niceSpacing
+
+ if unit == DtUnit.YEARS and niceVal <= dt.MINYEAR:
+ niceVal = max(1, niceSpacing)
+
+ _logger.debug(
+ "StartValue: dVal = {}, niceVal: {} ({})".format(dVal, niceVal, unit.name)
+ )
+
+ startDate = roundToElement(dMin, unit)
+ startDate = setDateElement(startDate, niceVal, unit)
+
+ return startDate, niceSpacing, unit
+
+
+def dateRange(dMin, dMax, step, unit, includeFirstBeyond=False):
+ """Generates a range of dates
+
+ :param datetime dMin: start date
+ :param datetime dMax: end date
+ :param int step: the step size
+ :param DtUnit unit: the unit of the step size
+ :param bool includeFirstBeyond: if True the first date later than dMax will
+ be included in the range. If False (the default), the last generated
+ datetime will always be smaller than dMax.
+ :return:
+ """
+ if unit == DtUnit.YEARS or unit == DtUnit.MONTHS or unit == DtUnit.MICRO_SECONDS:
+ # No support for fractional month or year and resolution is microsecond
+ # In those cases, make sure the step is at least 1
+ step = max(1, step)
+ else:
+ assert step > 0, "tickstep is 0"
+
+ dateTime = dMin
+ while dateTime < dMax:
+ yield dateTime
+ try:
+ dateTime = addValueToDate(dateTime, step, unit)
+ except ValueError:
+ return # current dateTime is out of datetime bounds
+
+ if includeFirstBeyond:
+ yield dateTime
+
+
+def calcTicks(dMin, dMax, nTicks):
+ """Returns tick positions.
+
+ :param datetime.datetime dMin: The min value on the axis
+ :param datetime.datetime dMax: The max value on the axis
+ :param int nTicks: The target number of ticks. The actual number of found
+ ticks may differ.
+ :returns: (list of datetimes, DtUnit) tuple
+ """
+ _logger.debug("Calc calcTicks({}, {}, nTicks={})".format(dMin, dMax, nTicks))
+
+ startDate, niceSpacing, unit = findStartDate(dMin, dMax, nTicks)
+
+ result = []
+ for d in dateRange(startDate, dMax, niceSpacing, unit, includeFirstBeyond=True):
+ result.append(d)
+
+ return result, niceSpacing, unit
+
+
+def calcTicksAdaptive(dMin, dMax, axisLength, tickDensity):
+ """Calls calcTicks with a variable number of ticks, depending on axisLength"""
+ # At least 2 ticks
+ nticks = max(2, int(round(tickDensity * axisLength)))
+ return calcTicks(dMin, dMax, nticks)
diff --git a/src/silx/gui/plot/_utils/panzoom.py b/src/silx/gui/plot/_utils/panzoom.py
new file mode 100644
index 0000000..cac591d
--- /dev/null
+++ b/src/silx/gui/plot/_utils/panzoom.py
@@ -0,0 +1,366 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Functions to apply pan and zoom on a Plot"""
+
+from __future__ import annotations
+
+__authors__ = ["T. Vincent", "V. Valls"]
+__license__ = "MIT"
+__date__ = "08/08/2017"
+
+
+import logging
+import math
+from typing import NamedTuple
+import numpy
+
+
+_logger = logging.getLogger(__name__)
+
+
+# Float 32 info ###############################################################
+# Using min/max value below limits of float32
+# so operation with such value (e.g., max - min) do not overflow
+
+FLOAT32_SAFE_MIN = -1e37
+FLOAT32_MINPOS = numpy.finfo(numpy.float32).tiny
+FLOAT32_SAFE_MAX = 1e37
+# TODO double support
+
+
+def checkAxisLimits(vmin: float, vmax: float, isLog: bool = False, name: str = ""):
+ """Makes sure axis range is not empty and within supported range.
+
+ :param vmin: Min axis value
+ :param vmax: Max axis value
+ :return: (min, max) making sure min < max
+ :rtype: 2-tuple of float
+ """
+ min_ = FLOAT32_MINPOS if isLog else FLOAT32_SAFE_MIN
+ vmax = numpy.clip(vmax, min_, FLOAT32_SAFE_MAX)
+ vmin = numpy.clip(vmin, min_, FLOAT32_SAFE_MAX)
+
+ if vmax < vmin:
+ _logger.debug("%s axis: max < min, inverting limits.", name)
+ vmin, vmax = vmax, vmin
+ elif vmax == vmin:
+ _logger.debug("%s axis: max == min, expanding limits.", name)
+ if vmin == 0.0:
+ vmin, vmax = -0.1, 0.1
+ elif vmin < 0:
+ vmax *= 0.9
+ vmin = max(vmin * 1.1, FLOAT32_SAFE_MIN) # Clip to range
+ else: # vmin > 0
+ vmax = min(vmin * 1.1, FLOAT32_SAFE_MAX) # Clip to range
+ vmin *= 0.9
+
+ return vmin, vmax
+
+
+def scale1DRange(
+ min_: float, max_: float, center: float, scale: float, isLog: bool
+) -> tuple[float, float]:
+ """Scale a 1D range given a scale factor and an center point.
+
+ Keeps the values in a smaller range than float32.
+
+ :param min_: The current min value of the range.
+ :param max_: The current max value of the range.
+ :param center: The center of the zoom (i.e., invariant point).
+ :param scale: The scale to use for zoom
+ :param isLog: Whether using log scale or not.
+ :return: The zoomed range (min, max)
+ """
+ if isLog:
+ # Min and center can be < 0 when
+ # autoscale is off and switch to log scale
+ # max_ < 0 should not happen
+ min_ = numpy.log10(min_) if min_ > 0.0 else FLOAT32_MINPOS
+ center = numpy.log10(center) if center > 0.0 else FLOAT32_MINPOS
+ max_ = numpy.log10(max_) if max_ > 0.0 else FLOAT32_MINPOS
+
+ if min_ == max_:
+ return min_, max_
+
+ offset = (center - min_) / (max_ - min_)
+ range_ = (max_ - min_) / scale
+ newMin = center - offset * range_
+ newMax = center + (1.0 - offset) * range_
+
+ if isLog:
+ # No overflow as exponent is log10 of a float32
+ newMin = pow(10.0, newMin)
+ newMax = pow(10.0, newMax)
+ newMin = numpy.clip(newMin, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
+ newMax = numpy.clip(newMax, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
+ else:
+ newMin = numpy.clip(newMin, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
+ newMax = numpy.clip(newMax, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
+ return newMin, newMax
+
+
+class EnabledAxes(NamedTuple):
+ """Toggle zoom for each axis"""
+
+ xaxis: bool = True
+ yaxis: bool = True
+ y2axis: bool = True
+
+ def isDisabled(self) -> bool:
+ """True only if all axes are disabled"""
+ return not (self.xaxis or self.yaxis or self.y2axis)
+
+
+def applyZoomToPlot(
+ plot,
+ scale: float,
+ center: tuple[float, float] = None,
+ enabled: EnabledAxes = EnabledAxes(),
+):
+ """Zoom in/out plot given a scale and a center point.
+
+ :param plot: The plot on which to apply zoom.
+ :param scale: Scale factor of zoom.
+ :param center: (x, y) coords in pixel coordinates of the zoom center.
+ :param enabled: Toggle zoom for each axis independently
+ """
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ y2Min, y2Max = plot.getYAxis(axis="right").getLimits()
+
+ if center is None:
+ left, top, width, height = plot.getPlotBoundsInPixels()
+ cx, cy = left + width // 2, top + height // 2
+ else:
+ cx, cy = center
+
+ dataCenterPos = plot.pixelToData(cx, cy)
+ assert dataCenterPos is not None
+
+ if enabled.xaxis:
+ xMin, xMax = scale1DRange(
+ xMin, xMax, dataCenterPos[0], scale, plot.getXAxis()._isLogarithmic()
+ )
+
+ if enabled.yaxis:
+ yMin, yMax = scale1DRange(
+ yMin, yMax, dataCenterPos[1], scale, plot.getYAxis()._isLogarithmic()
+ )
+
+ if enabled.y2axis:
+ dataPos = plot.pixelToData(cx, cy, axis="right")
+ assert dataPos is not None
+ y2Center = dataPos[1]
+ y2Min, y2Max = scale1DRange(
+ y2Min, y2Max, y2Center, scale, plot.getYAxis()._isLogarithmic()
+ )
+
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+
+def applyPan(min_, max_, panFactor, isLog10):
+ """Returns a new range with applied panning.
+
+ Moves the range according to panFactor.
+ If isLog10 is True, converts to log10 before moving.
+
+ :param float min_: Min value of the data range to pan.
+ :param float max_: Max value of the data range to pan.
+ Must be >= min.
+ :param float panFactor: Signed proportion of the range to use for pan.
+ :param bool isLog10: True if log10 scale, False if linear scale.
+ :return: New min and max value with pan applied.
+ :rtype: 2-tuple of float.
+ """
+ if isLog10 and min_ > 0.0:
+ # Negative range and log scale can happen with matplotlib
+ logMin, logMax = math.log10(min_), math.log10(max_)
+ logOffset = panFactor * (logMax - logMin)
+ newMin = pow(10.0, logMin + logOffset)
+ newMax = pow(10.0, logMax + logOffset)
+
+ # Takes care of out-of-range values
+ if newMin > 0.0 and newMax < float("inf"):
+ min_, max_ = newMin, newMax
+
+ else:
+ offset = panFactor * (max_ - min_)
+ newMin, newMax = min_ + offset, max_ + offset
+
+ # Takes care of out-of-range values
+ if newMin > -float("inf") and newMax < float("inf"):
+ min_, max_ = newMin, newMax
+ return min_, max_
+
+
+class _Unset(object):
+ """To be able to have distinction between None and unset"""
+
+ pass
+
+
+class ViewConstraints(object):
+ """
+ Store constraints applied on the view box and compute the resulting view box.
+ """
+
+ def __init__(self):
+ self._min = [None, None]
+ self._max = [None, None]
+ self._minRange = [None, None]
+ self._maxRange = [None, None]
+
+ def update(
+ self,
+ xMin=_Unset,
+ xMax=_Unset,
+ yMin=_Unset,
+ yMax=_Unset,
+ minXRange=_Unset,
+ maxXRange=_Unset,
+ minYRange=_Unset,
+ maxYRange=_Unset,
+ ):
+ """
+ Update the constraints managed by the object
+
+ The constraints are the same as the ones provided by PyQtGraph.
+
+ :param float xMin: Minimum allowed x-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float xMax: Maximum allowed x-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float yMin: Minimum allowed y-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float yMax: Maximum allowed y-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float minXRange: Minimum allowed left-to-right span across the
+ view (default do not change the stat, None remove the constraint)
+ :param float maxXRange: Maximum allowed left-to-right span across the
+ view (default do not change the stat, None remove the constraint)
+ :param float minYRange: Minimum allowed top-to-bottom span across the
+ view (default do not change the stat, None remove the constraint)
+ :param float maxYRange: Maximum allowed top-to-bottom span across the
+ view (default do not change the stat, None remove the constraint)
+ :return: True if the constraints was changed
+ """
+ updated = False
+
+ minRange = [minXRange, minYRange]
+ maxRange = [maxXRange, maxYRange]
+ minPos = [xMin, yMin]
+ maxPos = [xMax, yMax]
+
+ for axis in range(2):
+ value = minPos[axis]
+ if value is not _Unset and value != self._min[axis]:
+ self._min[axis] = value
+ updated = True
+
+ value = maxPos[axis]
+ if value is not _Unset and value != self._max[axis]:
+ self._max[axis] = value
+ updated = True
+
+ value = minRange[axis]
+ if value is not _Unset and value != self._minRange[axis]:
+ self._minRange[axis] = value
+ updated = True
+
+ value = maxRange[axis]
+ if value is not _Unset and value != self._maxRange[axis]:
+ self._maxRange[axis] = value
+ updated = True
+
+ # Sanity checks
+
+ for axis in range(2):
+ if (
+ self._maxRange[axis] is not None
+ and self._min[axis] is not None
+ and self._max[axis] is not None
+ ):
+ # max range cannot be larger than bounds
+ diff = self._max[axis] - self._min[axis]
+ self._maxRange[axis] = min(self._maxRange[axis], diff)
+ updated = True
+
+ return updated
+
+ def normalize(self, xMin, xMax, yMin, yMax, allow_scaling=True):
+ """Normalize a view range defined by x and y corners using predefined
+ containts.
+
+ :param float xMin: Min position of the x-axis
+ :param float xMax: Max position of the x-axis
+ :param float yMin: Min position of the y-axis
+ :param float yMax: Max position of the y-axis
+ :param bool allow_scaling: Allow or not to apply scaling for the
+ normalization. Used according to the interaction mode.
+ :return: A normalized tuple of (xMin, xMax, yMin, yMax)
+ """
+ viewRange = [[xMin, xMax], [yMin, yMax]]
+
+ for axis in range(2):
+ # clamp xRange and yRange
+ if allow_scaling:
+ diff = viewRange[axis][1] - viewRange[axis][0]
+ delta = None
+ if self._maxRange[axis] is not None and diff > self._maxRange[axis]:
+ delta = self._maxRange[axis] - diff
+ elif self._minRange[axis] is not None and diff < self._minRange[axis]:
+ delta = self._minRange[axis] - diff
+ if delta is not None:
+ viewRange[axis][0] -= delta * 0.5
+ viewRange[axis][1] += delta * 0.5
+
+ # clamp min and max positions
+ outMin = (
+ self._min[axis] is not None and viewRange[axis][0] < self._min[axis]
+ )
+ outMax = (
+ self._max[axis] is not None and viewRange[axis][1] > self._max[axis]
+ )
+
+ if outMin and outMax:
+ if allow_scaling:
+ # we can clamp both sides
+ viewRange[axis][0] = self._min[axis]
+ viewRange[axis][1] = self._max[axis]
+ else:
+ # center the result
+ delta = viewRange[axis][1] - viewRange[axis][0]
+ mid = self._min[axis] + self._max[axis] - self._min[axis]
+ viewRange[axis][0] = mid - delta
+ viewRange[axis][1] = mid + delta
+ elif outMin:
+ delta = self._min[axis] - viewRange[axis][0]
+ viewRange[axis][0] += delta
+ viewRange[axis][1] += delta
+ elif outMax:
+ delta = self._max[axis] - viewRange[axis][1]
+ viewRange[axis][0] += delta
+ viewRange[axis][1] += delta
+
+ return viewRange[0][0], viewRange[0][1], viewRange[1][0], viewRange[1][1]
diff --git a/src/silx/gui/plot/_utils/test/__init__.py b/src/silx/gui/plot/_utils/test/__init__.py
new file mode 100644
index 0000000..78821ec
--- /dev/null
+++ b/src/silx/gui/plot/_utils/test/__init__.py
@@ -0,0 +1,23 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
new file mode 100644
index 0000000..adcb9c9
--- /dev/null
+++ b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
@@ -0,0 +1,87 @@
+# /*##########################################################################
+#
+# Copyright (c) 2015-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["P. Kenter"]
+__license__ = "MIT"
+__date__ = "06/04/2018"
+
+
+import datetime as dt
+import pytest
+
+
+from silx.gui.plot._utils.dtime_ticklayout import calcTicks, DtUnit, SECONDS_PER_YEAR
+
+
+def testSmallMonthlySpacing():
+ """Tests a range that did result in a spacing of less than 1 month.
+ It is impossible to add fractional month so the unit must be in days
+ """
+ from dateutil import parser
+
+ d1 = parser.parse("2017-01-03 13:15:06.000044")
+ d2 = parser.parse("2017-03-08 09:16:16.307584")
+ _ticks, _units, spacing = calcTicks(d1, d2, nTicks=4)
+
+ assert spacing == DtUnit.DAYS
+
+
+def testNoCrash():
+ """Creates many combinations of and number-of-ticks and end-dates;
+ tests that it doesn't give an exception and returns a reasonable number
+ of ticks.
+ """
+ d1 = dt.datetime(2017, 1, 3, 13, 15, 6, 44)
+
+ value = 100e-6 # Start at 100 micro sec range.
+
+ while value <= 200 * SECONDS_PER_YEAR:
+ d2 = d1 + dt.timedelta(microseconds=value * 1e6) # end date range
+
+ for numTicks in range(2, 12):
+ ticks, _, _ = calcTicks(d1, d2, numTicks)
+
+ margin = 2.5
+ assert (
+ numTicks / margin <= len(ticks) <= numTicks * margin
+ ), "Condition {} <= {} <= {} failed for # ticks={} and d2={}:".format(
+ numTicks / margin, len(ticks), numTicks * margin, numTicks, d2
+ )
+
+ value = value * 1.5 # let date period grow exponentially
+
+
+@pytest.mark.parametrize(
+ "dMin, dMax",
+ [
+ (dt.datetime(1, 1, 1), dt.datetime(400, 1, 1)),
+ (dt.datetime(4000, 1, 1), dt.datetime(9999, 1, 1)),
+ (dt.datetime(1, 1, 1), dt.datetime(9999, 12, 23)),
+ ],
+)
+def testCalcTicksOutOfBoundTicks(dMin, dMax):
+ """Test tick generation with values leading to out-of-bound ticks"""
+ ticks, _, unit = calcTicks(dMin, dMax, nTicks=5)
+ assert len(ticks) != 0
+ assert unit == DtUnit.YEARS
diff --git a/src/silx/gui/plot/_utils/test/test_ticklayout.py b/src/silx/gui/plot/_utils/test/test_ticklayout.py
new file mode 100644
index 0000000..1413563
--- /dev/null
+++ b/src/silx/gui/plot/_utils/test/test_ticklayout.py
@@ -0,0 +1,77 @@
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+
+from silx.gui.plot._utils import ticklayout
+
+
+class TestTickLayout(ParametricTestCase):
+ """Test ticks layout algorithms"""
+
+ def testTicks(self):
+ """Test of :func:`ticks`"""
+ tests = { # (vmin, vmax): ref_ticks
+ (1.0, 1.0): (1.0,),
+ (0.5, 10.5): (2.0, 4.0, 6.0, 8.0, 10.0),
+ (0.001, 0.005): (0.001, 0.002, 0.003, 0.004, 0.005),
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks, labels = ticklayout.ticks(vmin, vmax)
+ self.assertTrue(numpy.allclose(ticks, ref_ticks))
+
+ def testNiceNumbers(self):
+ """Minimalistic tests of :func:`niceNumbers`"""
+ tests = { # (vmin, vmax): ref_ticks
+ (0.5, 10.5): (0.0, 12.0, 2.0, 0),
+ (10000.0, 10000.5): (10000.0, 10000.5, 0.1, 1),
+ (0.001, 0.005): (0.001, 0.005, 0.001, 3),
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks = ticklayout.niceNumbers(vmin, vmax)
+ self.assertEqual(ticks, ref_ticks)
+
+ def testNiceNumbersLog(self):
+ """Minimalistic tests of :func:`niceNumbersForLog10`"""
+ tests = { # (log10(min), log10(max): ref_ticks
+ (0.0, 3.0): (0, 3, 1, 0),
+ (-3.0, 3): (-3, 3, 1, 0),
+ (-32.0, 0.0): (-36, 0, 6, 0),
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks = ticklayout.niceNumbersForLog10(vmin, vmax)
+ self.assertEqual(ticks, ref_ticks)
diff --git a/src/silx/gui/plot/_utils/ticklayout.py b/src/silx/gui/plot/_utils/ticklayout.py
new file mode 100644
index 0000000..3678270
--- /dev/null
+++ b/src/silx/gui/plot/_utils/ticklayout.py
@@ -0,0 +1,266 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module implements labels layout on graph axes."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/10/2016"
+
+
+import math
+
+
+# utils #######################################################################
+
+
+def numberOfDigits(tickSpacing):
+ """Returns the number of digits to display for text label.
+
+ :param float tickSpacing: Step between ticks in data space.
+ :return: Number of digits to show for labels.
+ :rtype: int
+ """
+ nfrac = int(-math.floor(math.log10(tickSpacing)))
+ if nfrac < 0:
+ nfrac = 0
+ return nfrac
+
+
+# Nice Numbers ################################################################
+
+# This is the original niceNum implementation. For the date time ticks a more
+# generic implementation was needed.
+#
+# def _niceNum(value, isRound=False):
+# expvalue = math.floor(math.log10(value))
+# frac = value/pow(10., expvalue)
+# if isRound:
+# if frac < 1.5:
+# nicefrac = 1.
+# elif frac < 3.: # In niceNumGeneric this is (2+5)/2 = 3.5
+# nicefrac = 2.
+# elif frac < 7.:
+# nicefrac = 5. # In niceNumGeneric this is (5+10)/2 = 7.5
+# else:
+# nicefrac = 10.
+# else:
+# if frac <= 1.:
+# nicefrac = 1.
+# elif frac <= 2.:
+# nicefrac = 2.
+# elif frac <= 5.:
+# nicefrac = 5.
+# else:
+# nicefrac = 10.
+# return nicefrac * pow(10., expvalue)
+
+
+def niceNumGeneric(value, niceFractions=None, isRound=False):
+ """A more generic implementation of the _niceNum function
+
+ Allows the user to specify the fractions instead of using a hardcoded
+ list of [1, 2, 5, 10.0].
+ """
+ if value == 0:
+ return value
+
+ if niceFractions is None: # Use default values
+ niceFractions = 1.0, 2.0, 5.0, 10.0
+ roundFractions = (1.5, 3.0, 7.0, 10.0) if isRound else niceFractions
+
+ else:
+ roundFractions = list(niceFractions)
+ if isRound:
+ # Take the average with the next element. The last remains the same.
+ for i in range(len(roundFractions) - 1):
+ roundFractions[i] = (niceFractions[i] + niceFractions[i + 1]) / 2
+
+ highest = niceFractions[-1]
+ value = float(value)
+
+ expvalue = math.floor(math.log(value, highest))
+ frac = value / pow(highest, expvalue)
+
+ for niceFrac, roundFrac in zip(niceFractions, roundFractions):
+ if frac <= roundFrac:
+ return niceFrac * pow(highest, expvalue)
+
+ # should not come here
+ assert False, "should not come here"
+
+
+def niceNumbers(vMin, vMax, nTicks=5):
+ """Returns tick positions.
+
+ This function implements graph labels layout using nice numbers
+ by Paul Heckbert from "Graphics Gems", Academic Press, 1990.
+ See `C code <http://tog.acm.org/resources/GraphicsGems/gems/Label.c>`_.
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param int nTicks: The number of ticks to position
+ :returns: min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple
+ """
+ vrange = niceNumGeneric(vMax - vMin, isRound=False)
+ spacing = niceNumGeneric(vrange / nTicks, isRound=True)
+ graphmin = math.floor(vMin / spacing) * spacing
+ graphmax = math.ceil(vMax / spacing) * spacing
+ nfrac = numberOfDigits(spacing)
+ return graphmin, graphmax, spacing, nfrac
+
+
+def _frange(start, stop, step):
+ """range for float (including stop)."""
+ assert step >= 0.0
+ while start <= stop:
+ yield start
+ start += step
+
+
+def ticks(vMin, vMax, nbTicks=5):
+ """Returns tick positions and labels using nice numbers algorithm.
+
+ This enforces ticks to be within [vMin, vMax] range.
+ It returns at least 1 tick (when vMin == vMax).
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param int nbTicks: The number of ticks to position
+ :returns: tick positions and corresponding text labels
+ :rtype: 2-tuple: list of float, list of string
+ """
+ assert vMin <= vMax
+ if vMin == vMax:
+ positions = [vMin]
+ nfrac = 0
+
+ else:
+ start, end, step, nfrac = niceNumbers(vMin, vMax, nbTicks)
+ positions = [t for t in _frange(start, end, step) if vMin <= t <= vMax]
+
+ # Makes sure there is at least 2 ticks
+ if len(positions) < 2:
+ positions = [vMin, vMax]
+ nfrac = numberOfDigits(vMax - vMin)
+
+ # Generate labels
+ format_ = "%g" if nfrac == 0 else "%.{}f".format(nfrac)
+ labels = [format_ % tick for tick in positions]
+ return positions, labels
+
+
+def niceNumbersAdaptative(vMin, vMax, axisLength, tickDensity):
+ """Returns tick positions using :func:`niceNumbers` and a
+ density of ticks.
+
+ axisLength and tickDensity are based on the same unit (e.g., pixel).
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param float axisLength: The length of the axis.
+ :param float tickDensity: The density of ticks along the axis.
+ :returns: min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple
+ """
+ # At least 2 ticks
+ nticks = max(2, int(round(tickDensity * axisLength)))
+ tickmin, tickmax, step, nfrac = niceNumbers(vMin, vMax, nticks)
+
+ return tickmin, tickmax, step, nfrac
+
+
+# Nice Numbers for log scale ##################################################
+
+
+def niceNumbersForLog10(minLog, maxLog, nTicks=5):
+ """Return tick positions for logarithmic scale
+
+ :param float minLog: log10 of the min value on the axis
+ :param float maxLog: log10 of the max value on the axis
+ :param int nTicks: The number of ticks to position
+ :returns: log10 of min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple of int
+ """
+ graphminlog = math.floor(minLog)
+ graphmaxlog = math.ceil(maxLog)
+ rangelog = graphmaxlog - graphminlog
+
+ if rangelog <= nTicks:
+ spacing = 1.0
+ else:
+ spacing = math.floor(rangelog / nTicks)
+
+ graphminlog = math.floor(graphminlog / spacing) * spacing
+ graphmaxlog = math.ceil(graphmaxlog / spacing) * spacing
+
+ nfrac = numberOfDigits(spacing)
+
+ return int(graphminlog), int(graphmaxlog), int(spacing), nfrac
+
+
+def niceNumbersAdaptativeForLog10(vMin, vMax, axisLength, tickDensity):
+ """Returns tick positions using :func:`niceNumbers` and a
+ density of ticks.
+
+ axisLength and tickDensity are based on the same unit (e.g., pixel).
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param float axisLength: The length of the axis.
+ :param float tickDensity: The density of ticks along the axis.
+ :returns: log10 of min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple
+ """
+ # At least 2 ticks
+ nticks = max(2, int(round(tickDensity * axisLength)))
+ tickmin, tickmax, step, nfrac = niceNumbersForLog10(vMin, vMax, nticks)
+
+ return tickmin, tickmax, step, nfrac
+
+
+def computeLogSubTicks(ticks, lowBound, highBound):
+ """Return the sub ticks for the log scale for all given ticks if subtick
+ is in [lowBound, highBound]
+
+ :param ticks: log10 of the ticks
+ :param lowBound: the lower boundary of ticks
+ :param highBound: the higher boundary of ticks
+ :return: all the sub ticks contained in ticks (log10)
+ """
+ if len(ticks) < 1:
+ return []
+
+ res = []
+ for logPos in ticks:
+ dataOrigPos = logPos
+ for index in range(2, 10):
+ dataPos = dataOrigPos * index
+ if lowBound <= dataPos <= highBound:
+ res.append(dataPos)
+ return res
diff --git a/src/silx/gui/plot/actions/PlotAction.py b/src/silx/gui/plot/actions/PlotAction.py
new file mode 100644
index 0000000..9341bdd
--- /dev/null
+++ b/src/silx/gui/plot/actions/PlotAction.py
@@ -0,0 +1,84 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+The class :class:`.PlotAction` help the creation of a qt.QAction associated
+with a :class:`.PlotWidget`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "03/01/2018"
+
+
+from typing import Callable, Optional, Union
+import weakref
+from silx.gui import icons
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+
+
+class PlotAction(qt.QAction):
+ """Base class for QAction that operates on a PlotWidget.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param icon: QIcon or name of icon to use
+ :param text: The name of this action to be used for menu label
+ :param tooltip: The text of the tooltip
+ :param triggered: The callback to connect to the action's triggered
+ signal. None for no callback (default)
+ :param checkable: True for checkable action, False otherwise (default)
+ :param parent: See :class:`QAction`.
+ """
+
+ def __init__(
+ self,
+ plot: PlotWidget,
+ icon: Union[str, qt.QIcon],
+ text: str,
+ tooltip: Optional[str] = None,
+ triggered: Optional[Callable] = None,
+ checkable: bool = False,
+ parent: Optional[qt.QObject] = None,
+ ):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+
+ if not isinstance(icon, qt.QIcon):
+ # Try with icon as a string and load corresponding icon
+ icon = icons.getQIcon(icon)
+
+ super(PlotAction, self).__init__(icon, text, parent)
+
+ if tooltip is not None:
+ self.setToolTip(tooltip)
+
+ self.setCheckable(checkable)
+
+ if triggered is not None:
+ self.triggered[bool].connect(triggered)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWidget` this action group is controlling."""
+ return self._plotRef()
diff --git a/src/silx/gui/plot/actions/PlotToolAction.py b/src/silx/gui/plot/actions/PlotToolAction.py
new file mode 100644
index 0000000..479d7c2
--- /dev/null
+++ b/src/silx/gui/plot/actions/PlotToolAction.py
@@ -0,0 +1,156 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+The class :class:`.PlotToolAction` help the creation of a qt.QAction associating
+a tool window with a :class:`.PlotWidget`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "10/10/2018"
+
+
+import weakref
+
+from .PlotAction import PlotAction
+from silx.gui import qt
+
+
+class PlotToolAction(PlotAction):
+ """Base class for QAction that maintain a tool window operating on a
+ PlotWidget."""
+
+ def __init__(
+ self,
+ plot,
+ icon,
+ text,
+ tooltip=None,
+ triggered=None,
+ checkable=False,
+ parent=None,
+ ):
+ PlotAction.__init__(
+ self,
+ plot=plot,
+ icon=icon,
+ text=text,
+ tooltip=tooltip,
+ triggered=self._triggered,
+ parent=parent,
+ checkable=True,
+ )
+ self._previousGeometry = None
+ self._toolWindow = None
+
+ def _triggered(self, checked):
+ """Update the plot of the histogram visibility status
+
+ :param bool checked: status of the action button
+ """
+ self._setToolWindowVisible(checked)
+
+ def _setToolWindowVisible(self, visible):
+ """Set the tool window visible or hidden."""
+ tool = self._getToolWindow()
+ if tool.isVisible() == visible:
+ # Nothing to do
+ return
+
+ if visible:
+ self._connectPlot(tool)
+ tool.show()
+ if self._previousGeometry is not None:
+ # Restore the geometry
+ tool.setGeometry(self._previousGeometry)
+ else:
+ self._disconnectPlot(tool)
+ # Save the geometry
+ self._previousGeometry = tool.geometry()
+ tool.hide()
+
+ def _connectPlot(self, window):
+ """Called if the tool is visible and have to be updated according to
+ event of the plot.
+
+ :param qt.QWidget window: The tool window
+ """
+ pass
+
+ def _disconnectPlot(self, window):
+ """Called if the tool is not visible and dont have anymore to be updated
+ according to event of the plot.
+
+ :param qt.QWidget window: The tool window
+ """
+ pass
+
+ def _isWindowInUse(self):
+ """Returns true if the tool window is currently in use."""
+ if not self.isChecked():
+ return False
+ return self._toolWindow is not None
+
+ def _ownerVisibilityChanged(self, isVisible):
+ """Called when the visibility of the parent of the tool window changes
+
+ :param bool isVisible: True if the parent became visible
+ """
+ if self._isWindowInUse():
+ self._setToolWindowVisible(isVisible)
+
+ def eventFilter(self, qobject, event):
+ """Observe when the close event is emitted then
+ simply uncheck the action button
+
+ :param qobject: the object observe
+ :param event: the event received by qobject
+ """
+ if event.type() == qt.QEvent.Close:
+ if self._toolWindow is not None:
+ window = self._toolWindow()
+ self._previousGeometry = window.geometry()
+ window.hide()
+ self.setChecked(False)
+
+ return PlotAction.eventFilter(self, qobject, event)
+
+ def _getToolWindow(self):
+ """Returns the window containing the tool.
+
+ It uses lazy loading to create this tool..
+ """
+ if self._toolWindow is None:
+ window = self._createToolWindow()
+ if self._previousGeometry is not None:
+ window.setGeometry(self._previousGeometry)
+ window.installEventFilter(self)
+ plot = self.plot
+ plot.sigVisibilityChanged.connect(self._ownerVisibilityChanged)
+ self._toolWindow = weakref.ref(window)
+ return self._toolWindow()
+
+ def _createToolWindow(self):
+ """Create the tool window managing the plot."""
+ raise NotImplementedError()
diff --git a/src/silx/gui/plot/actions/__init__.py b/src/silx/gui/plot/actions/__init__.py
new file mode 100644
index 0000000..3e606c6
--- /dev/null
+++ b/src/silx/gui/plot/actions/__init__.py
@@ -0,0 +1,41 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of QAction to use with
+:class:`~silx.gui.plot.PlotWidget`
+
+Those actions are useful to add menu items or toolbar items
+that interact with a :class:`~silx.gui.plot.PlotWidget`.
+
+It provides a base class used to define new plot actions:
+:class:`~silx.gui.plot.actions.PlotAction`.
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "16/08/2017"
+
+from .PlotAction import PlotAction
+from . import control
+from . import mode
+from . import io
diff --git a/src/silx/gui/plot/actions/control.py b/src/silx/gui/plot/actions/control.py
new file mode 100755
index 0000000..c21d235
--- /dev/null
+++ b/src/silx/gui/plot/actions/control.py
@@ -0,0 +1,760 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.control` provides a set of QAction relative to control
+of a :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`ColormapAction`
+- :class:`CrosshairAction`
+- :class:`CurveStyleAction`
+- :class:`GridAction`
+- :class:`KeepAspectRatioAction`
+- :class:`PanWithArrowKeysAction`
+- :class:`ResetZoomAction`
+- :class:`ShowAxisAction`
+- :class:`XAxisLogarithmicAction`
+- :class:`XAxisAutoScaleAction`
+- :class:`YAxisInvertedAction`
+- :class:`YAxisLogarithmicAction`
+- :class:`YAxisAutoScaleAction`
+- :class:`ZoomBackAction`
+- :class:`ZoomInAction`
+- :class:`ZoomOutAction`
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "27/11/2020"
+
+from . import PlotAction
+import logging
+from silx.gui.plot import items
+from silx.gui.plot._utils import applyZoomToPlot as _applyZoomToPlot
+from silx.gui import qt
+from silx.gui import icons
+from silx.utils.deprecation import deprecated
+
+_logger = logging.getLogger(__name__)
+
+
+class ResetZoomAction(PlotAction):
+ """QAction controlling reset zoom on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ResetZoomAction, self).__init__(
+ plot,
+ icon="zoom-original",
+ text="Reset Zoom",
+ tooltip="Auto-scale the graph",
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent,
+ )
+ self._autoscaleChanged(True)
+ plot.getXAxis().sigAutoScaleChanged.connect(self._autoscaleChanged)
+ plot.getYAxis().sigAutoScaleChanged.connect(self._autoscaleChanged)
+
+ def _autoscaleChanged(self, enabled):
+ xAxis = self.plot.getXAxis()
+ yAxis = self.plot.getYAxis()
+ self.setEnabled(xAxis.isAutoScale() or yAxis.isAutoScale())
+
+ if xAxis.isAutoScale() and yAxis.isAutoScale():
+ tooltip = "Auto-scale the graph"
+ elif xAxis.isAutoScale(): # And not Y axis
+ tooltip = "Auto-scale the x-axis of the graph only"
+ elif yAxis.isAutoScale(): # And not X axis
+ tooltip = "Auto-scale the y-axis of the graph only"
+ else: # no axis in autoscale
+ tooltip = "Auto-scale the graph"
+ self.setToolTip(tooltip)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.resetZoom()
+
+
+class ZoomBackAction(PlotAction):
+ """QAction performing a zoom-back in :class:`.PlotWidget` limits history.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomBackAction, self).__init__(
+ plot,
+ icon="zoom-back",
+ text="Zoom Back",
+ tooltip="Zoom back the plot",
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent,
+ )
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.getLimitsHistory().pop()
+
+
+class ZoomInAction(PlotAction):
+ """QAction performing a zoom-in on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomInAction, self).__init__(
+ plot,
+ icon="zoom-in",
+ text="Zoom In",
+ tooltip="Zoom in the plot",
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent,
+ )
+ self.setShortcut(qt.QKeySequence.ZoomIn)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def _actionTriggered(self, checked=False):
+ _applyZoomToPlot(self.plot, 1.1)
+
+
+class ZoomOutAction(PlotAction):
+ """QAction performing a zoom-out on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomOutAction, self).__init__(
+ plot,
+ icon="zoom-out",
+ text="Zoom Out",
+ tooltip="Zoom out the plot",
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent,
+ )
+ self.setShortcut(qt.QKeySequence.ZoomOut)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def _actionTriggered(self, checked=False):
+ _applyZoomToPlot(self.plot, 1.0 / 1.1)
+
+
+class XAxisAutoScaleAction(PlotAction):
+ """QAction controlling X axis autoscale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(XAxisAutoScaleAction, self).__init__(
+ plot,
+ icon="plot-xauto",
+ text="X Autoscale",
+ tooltip="Enable x-axis auto-scale when checked.\n"
+ "If unchecked, x-axis does not change when reseting zoom.",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.setChecked(plot.getXAxis().isAutoScale())
+ plot.getXAxis().sigAutoScaleChanged.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.getXAxis().setAutoScale(checked)
+ if checked:
+ self.plot.resetZoom()
+
+
+class YAxisAutoScaleAction(PlotAction):
+ """QAction controlling Y axis autoscale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(YAxisAutoScaleAction, self).__init__(
+ plot,
+ icon="plot-yauto",
+ text="Y Autoscale",
+ tooltip="Enable y-axis auto-scale when checked.\n"
+ "If unchecked, y-axis does not change when reseting zoom.",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.setChecked(plot.getYAxis().isAutoScale())
+ plot.getYAxis().sigAutoScaleChanged.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.getYAxis().setAutoScale(checked)
+ if checked:
+ self.plot.resetZoom()
+
+
+class XAxisLogarithmicAction(PlotAction):
+ """QAction controlling X axis log scale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(XAxisLogarithmicAction, self).__init__(
+ plot,
+ icon="plot-xlog",
+ text="X Log. scale",
+ tooltip="Logarithmic x-axis when checked",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.axis = plot.getXAxis()
+ self.setChecked(self.axis.getScale() == self.axis.LOGARITHMIC)
+ self.axis.sigScaleChanged.connect(self._setCheckedIfLogScale)
+
+ def _setCheckedIfLogScale(self, scale):
+ self.setChecked(scale == self.axis.LOGARITHMIC)
+
+ def _actionTriggered(self, checked=False):
+ scale = self.axis.LOGARITHMIC if checked else self.axis.LINEAR
+ self.axis.setScale(scale)
+
+
+class YAxisLogarithmicAction(PlotAction):
+ """QAction controlling Y axis log scale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(YAxisLogarithmicAction, self).__init__(
+ plot,
+ icon="plot-ylog",
+ text="Y Log. scale",
+ tooltip="Logarithmic y-axis when checked",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.axis = plot.getYAxis()
+ self.setChecked(self.axis.getScale() == self.axis.LOGARITHMIC)
+ self.axis.sigScaleChanged.connect(self._setCheckedIfLogScale)
+
+ def _setCheckedIfLogScale(self, scale):
+ self.setChecked(scale == self.axis.LOGARITHMIC)
+
+ def _actionTriggered(self, checked=False):
+ scale = self.axis.LOGARITHMIC if checked else self.axis.LINEAR
+ self.axis.setScale(scale)
+
+
+class GridAction(PlotAction):
+ """QAction controlling grid mode on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param str gridMode: The grid mode to use in 'both', 'major'.
+ See :meth:`.PlotWidget.setGraphGrid`
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, gridMode="both", parent=None):
+ assert gridMode in ("both", "major")
+ self._gridMode = gridMode
+
+ super(GridAction, self).__init__(
+ plot,
+ icon="plot-grid",
+ text="Grid",
+ tooltip="Toggle grid (on/off)",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.setChecked(plot.getGraphGrid() is not None)
+ plot.sigSetGraphGrid.connect(self._gridChanged)
+
+ def _gridChanged(self, which):
+ """Slot listening for PlotWidget grid mode change."""
+ self.setChecked(which != "None")
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setGraphGrid(self._gridMode if checked else None)
+
+
+class CurveStyleAction(PlotAction):
+ """QAction controlling curve style on a :class:`.PlotWidget`.
+
+ It changes the default line and markers style which updates all
+ curves on the plot.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(CurveStyleAction, self).__init__(
+ plot,
+ icon="plot-toggle-points",
+ text="Curve style",
+ tooltip="Change curve line and markers style",
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent,
+ )
+
+ def _actionTriggered(self, checked=False):
+ currentState = (self.plot.isDefaultPlotLines(), self.plot.isDefaultPlotPoints())
+
+ if currentState == (False, False):
+ newState = True, False
+ else:
+ # line only, line and symbol, symbol only
+ states = (True, False), (True, True), (False, True)
+ newState = states[(states.index(currentState) + 1) % 3]
+
+ self.plot.setDefaultPlotLines(newState[0])
+ self.plot.setDefaultPlotPoints(newState[1])
+
+
+class ColormapAction(PlotAction):
+ """QAction opening a ColormapDialog to update the colormap.
+
+ Both the active image colormap and the default colormap are updated.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ self._dialog = None # To store an instance of ColormapDialog
+ super(ColormapAction, self).__init__(
+ plot,
+ icon="colormap",
+ text="Colormap",
+ tooltip="Change colormap",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.plot.sigActiveImageChanged.connect(self._updateColormap)
+ self.plot.sigActiveScatterChanged.connect(self._updateColormap)
+
+ def setColormapDialog(self, dialog):
+ """Set a specific colormap dialog instead of using the default one."""
+ assert dialog is not None
+ if self._dialog is not None:
+ self._dialog.visibleChanged.disconnect(self._dialogVisibleChanged)
+
+ self._dialog = dialog
+ self._dialog.visibleChanged.connect(
+ self._dialogVisibleChanged, qt.Qt.UniqueConnection
+ )
+ self.setChecked(self._dialog.isVisible())
+
+ @deprecated(replacement="setColormapDialog", since_version="2.0")
+ def setColorDialog(self, colorDialog):
+ self.setColormapDialog(colorDialog)
+
+ def getColormapDialog(self):
+ if self._dialog is None:
+ self._dialog = self._createDialog(self.plot)
+ self._dialog.visibleChanged.connect(self._dialogVisibleChanged)
+ return self._dialog
+
+ @staticmethod
+ def _createDialog(parent):
+ """Create the dialog if not already existing
+
+ :parent QWidget parent: Parent of the new colormap
+ :rtype: ColormapDialog
+ """
+ from silx.gui.dialog.ColormapDialog import ColormapDialog
+
+ dialog = ColormapDialog(parent=parent)
+ dialog.setModal(False)
+ return dialog
+
+ def _actionTriggered(self, checked=False):
+ """Create a cmap dialog and update active image and default cmap."""
+ dialog = self.getColormapDialog()
+ # Run the dialog listening to colormap change
+ if checked is True:
+ self._updateColormap()
+ dialog.show()
+ else:
+ dialog.hide()
+
+ def _dialogVisibleChanged(self, isVisible):
+ self.setChecked(isVisible)
+
+ def _updateColormap(self):
+ if self._dialog is None:
+ return
+ image = self.plot.getActiveImage()
+
+ if isinstance(image, items.ColormapMixIn):
+ # Set dialog from active image
+ colormap = image.getColormap()
+ # Set histogram and range if any
+ self._dialog.setItem(image)
+
+ else:
+ # No active image or active image is RGBA,
+ # Check for active scatter plot
+ scatter = self.plot.getActiveScatter()
+ if scatter is not None:
+ colormap = scatter.getColormap()
+ self._dialog.setItem(scatter)
+
+ else:
+ # No active data image nor scatter,
+ # set dialog from default info
+ colormap = self.plot.getDefaultColormap()
+ # Reset histogram and range if any
+ self._dialog.setData(None)
+
+ self._dialog.setColormap(colormap)
+
+
+class ColorBarAction(PlotAction):
+ """QAction opening the ColorBarWidget of the specified plot.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ self._dialog = None # To store an instance of ColorBar
+ super(ColorBarAction, self).__init__(
+ plot,
+ icon="colorbar",
+ text="Colorbar",
+ tooltip="Show/Hide the colorbar",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ colorBarWidget = self.plot.getColorBarWidget()
+ old = self.blockSignals(True)
+ self.setChecked(colorBarWidget.isVisibleTo(self.plot))
+ self.blockSignals(old)
+ colorBarWidget.sigVisibleChanged.connect(self._widgetVisibleChanged)
+
+ def _widgetVisibleChanged(self, isVisible):
+ """Callback when the colorbar `visible` property change."""
+ if self.isChecked() == isVisible:
+ return
+ self.setChecked(isVisible)
+
+ def _actionTriggered(self, checked=False):
+ """Create a cmap dialog and update active image and default cmap."""
+ colorBarWidget = self.plot.getColorBarWidget()
+ if not colorBarWidget.isHidden() == checked:
+ return
+ self.plot.getColorBarWidget().setVisible(checked)
+
+
+class KeepAspectRatioAction(PlotAction):
+ """QAction controlling aspect ratio on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ # Uses two images for checked/unchecked states
+ self._states = {
+ False: (icons.getQIcon("shape-circle-solid"), "Keep data aspect ratio"),
+ True: (
+ icons.getQIcon("shape-ellipse-solid"),
+ "Do no keep data aspect ratio",
+ ),
+ }
+
+ icon, tooltip = self._states[plot.isKeepDataAspectRatio()]
+ super(KeepAspectRatioAction, self).__init__(
+ plot,
+ icon=icon,
+ text="Toggle keep aspect ratio",
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent,
+ )
+ plot.sigSetKeepDataAspectRatio.connect(self._keepDataAspectRatioChanged)
+
+ def _keepDataAspectRatioChanged(self, aspectRatio):
+ """Handle Plot set keep aspect ratio signal"""
+ icon, tooltip = self._states[aspectRatio]
+ self.setIcon(icon)
+ self.setToolTip(tooltip)
+
+ def _actionTriggered(self, checked=False):
+ # This will trigger _keepDataAspectRatioChanged
+ self.plot.setKeepDataAspectRatio(not self.plot.isKeepDataAspectRatio())
+
+
+class YAxisInvertedAction(PlotAction):
+ """QAction controlling Y orientation on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ # Uses two images for checked/unchecked states
+ self._states = {
+ False: (icons.getQIcon("plot-ydown"), "Orient Y axis downward"),
+ True: (icons.getQIcon("plot-yup"), "Orient Y axis upward"),
+ }
+
+ icon, tooltip = self._states[plot.getYAxis().isInverted()]
+ super(YAxisInvertedAction, self).__init__(
+ plot,
+ icon=icon,
+ text="Invert Y Axis",
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent,
+ )
+ plot.getYAxis().sigInvertedChanged.connect(self._yAxisInvertedChanged)
+
+ def _yAxisInvertedChanged(self, inverted):
+ """Handle Plot set y axis inverted signal"""
+ icon, tooltip = self._states[inverted]
+ self.setIcon(icon)
+ self.setToolTip(tooltip)
+
+ def _actionTriggered(self, checked=False):
+ # This will trigger _yAxisInvertedChanged
+ yAxis = self.plot.getYAxis()
+ yAxis.setInverted(not yAxis.isInverted())
+
+
+class CrosshairAction(PlotAction):
+ """QAction toggling crosshair cursor on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param str color: Color to use to draw the crosshair
+ :param int linewidth: Width of the crosshair cursor
+ :param str linestyle: Style of line. See :meth:`.Plot.setGraphCursor`
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, color="black", linewidth=1, linestyle="-", parent=None):
+ self.color = color
+ """Color used to draw the crosshair (str)."""
+
+ self.linewidth = linewidth
+ """Width of the crosshair cursor (int)."""
+
+ self.linestyle = linestyle
+ """Style of line of the cursor (str)."""
+
+ super(CrosshairAction, self).__init__(
+ plot,
+ icon="crosshair",
+ text="Crosshair Cursor",
+ tooltip="Enable crosshair cursor when checked",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.setChecked(plot.getGraphCursor() is not None)
+ plot.sigSetGraphCursor.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setGraphCursor(
+ checked,
+ color=self.color,
+ linestyle=self.linestyle,
+ linewidth=self.linewidth,
+ )
+
+
+class PanWithArrowKeysAction(PlotAction):
+ """QAction toggling pan with arrow keys on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(PanWithArrowKeysAction, self).__init__(
+ plot,
+ icon="arrow-keys",
+ text="Pan with arrow keys",
+ tooltip="Enable pan with arrow keys when checked",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.setChecked(plot.isPanWithArrowKeys())
+ plot.sigSetPanWithArrowKeys.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setPanWithArrowKeys(checked)
+
+
+class ShowAxisAction(PlotAction):
+ """QAction controlling axis visibility on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ tooltip = "Show plot axis when checked, otherwise hide them"
+ PlotAction.__init__(
+ self,
+ plot,
+ icon="axis",
+ text="show axis",
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.setChecked(self.plot.isAxesDisplayed())
+ plot._sigAxesVisibilityChanged.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setAxesDisplayed(checked)
+
+
+class ClosePolygonInteractionAction(PlotAction):
+ """QAction controlling closure of a polygon in draw interaction mode
+ if the :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ tooltip = "Close the current polygon drawn"
+ PlotAction.__init__(
+ self,
+ plot,
+ icon="add-shape-polygon",
+ text="Close the polygon",
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ self.plot.sigInteractiveModeChanged.connect(self._modeChanged)
+ self._modeChanged(None)
+
+ def _modeChanged(self, source):
+ mode = self.plot.getInteractiveMode()
+ enabled = "shape" in mode and mode["shape"] == "polygon"
+ self.setEnabled(enabled)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.interaction()._validate()
+
+
+class OpenGLAction(PlotAction):
+ """QAction controlling rendering of a :class:`.PlotWidget`.
+
+ For now it can enable or not the OpenGL backend.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ # Uses two images for checked/unchecked states
+ self._states = {
+ "opengl": (
+ icons.getQIcon("backend-opengl"),
+ "OpenGL rendering (fast)\nClick to disable OpenGL",
+ ),
+ "matplotlib": (
+ icons.getQIcon("backend-opengl"),
+ "Matplotlib rendering (safe)\nClick to enable OpenGL",
+ ),
+ "unknown": (icons.getQIcon("backend-opengl"), "Custom rendering"),
+ }
+
+ name = self._getBackendName(plot)
+ icon, tooltip = self._states[name]
+ super(OpenGLAction, self).__init__(
+ plot,
+ icon=icon,
+ text="Enable/disable OpenGL rendering",
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ plot.sigBackendChanged.connect(self._backendUpdated)
+
+ def _backendUpdated(self):
+ name = self._getBackendName(self.plot)
+ icon, tooltip = self._states[name]
+ self.setIcon(icon)
+ self.setToolTip(tooltip)
+ self.setChecked(name == "opengl")
+
+ def _getBackendName(self, plot):
+ backend = plot.getBackend()
+ name = type(backend).__name__.lower()
+ if "opengl" in name:
+ return "opengl"
+ elif "matplotlib" in name:
+ return "matplotlib"
+ else:
+ return "unknown"
+
+ def _actionTriggered(self, checked=False):
+ plot = self.plot
+ name = self._getBackendName(self.plot)
+ if name != "opengl":
+ from silx.gui.utils import glutils
+
+ result = glutils.isOpenGLAvailable()
+ if not result:
+ qt.QMessageBox.critical(
+ plot, "OpenGL rendering is not available", result.error
+ )
+ return
+ plot.setBackend("opengl")
+ else:
+ plot.setBackend("matplotlib")
diff --git a/src/silx/gui/plot/actions/fit.py b/src/silx/gui/plot/actions/fit.py
new file mode 100644
index 0000000..ae8835a
--- /dev/null
+++ b/src/silx/gui/plot/actions/fit.py
@@ -0,0 +1,442 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.fit` module provides actions relative to fit.
+
+The following QAction are available:
+
+- :class:`.FitAction`
+
+.. autoclass:`.FitAction`
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "10/10/2018"
+
+import logging
+import sys
+import weakref
+import numpy
+
+from .PlotToolAction import PlotToolAction
+from .. import items
+from silx.gui import qt
+from silx.gui.plot.ItemsSelectionDialog import ItemsSelectionDialog
+
+_logger = logging.getLogger(__name__)
+
+
+def _getUniqueCurveOrHistogram(plot):
+ """Returns unique :class:`Curve` or :class:`Histogram` in a `PlotWidget`.
+
+ If there is an active curve, returns it, else return curve or histogram
+ only if alone in the plot.
+
+ :param PlotWidget plot:
+ :rtype: Union[None,~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram]
+ """
+ curve = plot.getActiveCurve()
+ if curve is not None:
+ return curve
+
+ visibleItems = [item for item in plot.getItems() if item.isVisible()]
+ histograms = [item for item in visibleItems if isinstance(item, items.Histogram)]
+ curves = [item for item in visibleItems if isinstance(item, items.Curve)]
+
+ if len(histograms) == 1 and len(curves) == 0:
+ return histograms[0]
+ elif len(curves) == 1 and len(histograms) == 0:
+ return curves[0]
+ else:
+ return None
+
+
+class _FitItemSelector(qt.QObject):
+ """
+ :class:`PlotWidget` observer that emits signal when fit selection changes.
+
+ Track active curve or unique curve or histogram.
+ """
+
+ sigCurrentItemChanged = qt.Signal(object)
+ """Signal emitted when the item to fit has changed"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__plotWidgetRef = None
+ self.__currentItem = None
+
+ def getCurrentItem(self):
+ """Return currently selected item
+
+ :rtype: Union[Item,None]
+ """
+ return self.__currentItem
+
+ def getPlotWidget(self):
+ """Return currently attached :class:`PlotWidget`
+
+ :rtype: Union[PlotWidget,None]
+ """
+ return None if self.__plotWidgetRef is None else self.__plotWidgetRef()
+
+ def setPlotWidget(self, plotWidget):
+ """Set the :class:`PlotWidget` for which to track changes
+
+ :param Union[PlotWidget,None] plotWidget:
+ The :class:`PlotWidget` to observe
+ """
+ # disconnect from previous plot
+ previousPlotWidget = self.getPlotWidget()
+ if previousPlotWidget is not None:
+ previousPlotWidget.sigItemAdded.disconnect(self.__plotWidgetUpdated)
+ previousPlotWidget.sigItemRemoved.disconnect(self.__plotWidgetUpdated)
+ previousPlotWidget.sigActiveCurveChanged.disconnect(
+ self.__plotWidgetUpdated
+ )
+
+ if plotWidget is None:
+ self.__plotWidgetRef = None
+ self.__setCurrentItem(None)
+ return
+ self.__plotWidgetRef = weakref.ref(plotWidget, self.__plotDeleted)
+
+ # connect to new plot
+ plotWidget.sigItemAdded.connect(self.__plotWidgetUpdated)
+ plotWidget.sigItemRemoved.connect(self.__plotWidgetUpdated)
+ plotWidget.sigActiveCurveChanged.connect(self.__plotWidgetUpdated)
+ self.__plotWidgetUpdated()
+
+ def __plotDeleted(self):
+ """Handle deletion of PlotWidget"""
+ self.__setCurrentItem(None)
+
+ def __plotWidgetUpdated(self, *args, **kwargs):
+ """Handle updates of PlotWidget content"""
+ plotWidget = self.getPlotWidget()
+ if plotWidget is None:
+ return
+ self.__setCurrentItem(_getUniqueCurveOrHistogram(plotWidget))
+
+ def __setCurrentItem(self, item):
+ """Handle change of current item"""
+ if sys.is_finalizing():
+ return
+
+ previousItem = self.getCurrentItem()
+ if item != previousItem:
+ if previousItem is not None:
+ previousItem.sigItemChanged.disconnect(self.__itemUpdated)
+
+ self.__currentItem = item
+
+ if self.__currentItem is not None:
+ self.__currentItem.sigItemChanged.connect(self.__itemUpdated)
+ self.sigCurrentItemChanged.emit(self.__currentItem)
+
+ def __itemUpdated(self, event):
+ """Handle change on current item"""
+ if event == items.ItemChangedType.DATA:
+ self.sigCurrentItemChanged.emit(self.__currentItem)
+
+
+class FitAction(PlotToolAction):
+ """QAction to open a :class:`FitWidget` and set its data to the
+ active curve if any, or to the first curve.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ self.__item = None
+ self.__activeCurveSynchroEnabled = False
+ self.__range = 0, 1
+ self.__rangeAutoUpdate = False
+ self.__x, self.__y = None, None # Data to fit
+ self.__curveParams = {} # Store curve parameters to use for fit result
+ self.__legend = None
+
+ super(FitAction, self).__init__(
+ plot,
+ icon="math-fit",
+ text="Fit curve",
+ tooltip="Open a fit dialog",
+ parent=parent,
+ )
+
+ self.__fitItemSelector = _FitItemSelector()
+ self.__fitItemSelector.sigCurrentItemChanged.connect(self._setFittedItem)
+
+ def _createToolWindow(self):
+ # import done here rather than at module level to avoid circular import
+ # FitWidget -> BackgroundWidget -> PlotWindow -> actions -> fit -> FitWidget
+ from ...fit.FitWidget import FitWidget
+
+ window = FitWidget(parent=self.plot)
+ window.setWindowFlags(qt.Qt.Dialog)
+ window.sigFitWidgetSignal.connect(self.handle_signal)
+ return window
+
+ def _connectPlot(self, window):
+ if self.isXRangeUpdatedOnZoom():
+ self.__setAutoXRangeEnabled(True)
+ else:
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+ self._setXRange(*plot.getXAxis().getLimits())
+
+ if self.isFittedItemUpdatedFromActiveCurve():
+ self.__setFittedItemAutoUpdateEnabled(True)
+ else:
+ # Wait for the next iteration, else the plot is not yet initialized
+ # No curve available
+ qt.QTimer.singleShot(10, self._initFit)
+
+ def _disconnectPlot(self, window):
+ if self.isXRangeUpdatedOnZoom():
+ self.__setAutoXRangeEnabled(False)
+
+ if self.isFittedItemUpdatedFromActiveCurve():
+ self.__setFittedItemAutoUpdateEnabled(False)
+
+ def _initFit(self):
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+
+ item = _getUniqueCurveOrHistogram(plot)
+ if item is None:
+ # ambiguous case, we need to ask which plot item to fit
+ isd = ItemsSelectionDialog(parent=plot, plot=plot)
+ isd.setWindowTitle("Select item to be fitted")
+ isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection)
+ isd.setAvailableKinds(["curve", "histogram"])
+ isd.selectAllKinds()
+
+ if not isd.exec(): # Cancel
+ self._getToolWindow().setVisible(False)
+ else:
+ selectedItems = isd.getSelectedItems()
+ item = selectedItems[0] if len(selectedItems) == 1 else None
+
+ self._setXRange(*plot.getXAxis().getLimits())
+ self._setFittedItem(item)
+
+ def __updateFitWidget(self):
+ """Update the data/range used by the FitWidget"""
+ fitWidget = self._getToolWindow()
+
+ item = self._getFittedItem()
+ xdata = self.getXData(copy=False)
+ ydata = self.getYData(copy=False)
+ if item is None or xdata is None or ydata is None:
+ fitWidget.setData(y=None)
+ fitWidget.setWindowTitle("No curve selected")
+
+ else:
+ xmin, xmax = self.getXRange()
+ fitWidget.setData(xdata, ydata, xmin=xmin, xmax=xmax)
+ fitWidget.setWindowTitle(
+ "Fitting " + item.getName() + " on x range %f-%f" % (xmin, xmax)
+ )
+
+ # X Range management
+
+ def getXRange(self):
+ """Returns the range on the X axis on which to perform the fit."""
+ return self.__range
+
+ def _setXRange(self, xmin, xmax):
+ """Set the range on which the fit is done.
+
+ :param float xmin:
+ :param float xmax:
+ """
+ range_ = float(xmin), float(xmax)
+ if self.__range != range_:
+ self.__range = range_
+ self.__updateFitWidget()
+
+ def __setAutoXRangeEnabled(self, enabled):
+ """Implement the change of update mode of the X range.
+
+ :param bool enabled:
+ """
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+
+ if enabled:
+ self._setXRange(*plot.getXAxis().getLimits())
+ plot.getXAxis().sigLimitsChanged.connect(self._setXRange)
+ else:
+ plot.getXAxis().sigLimitsChanged.disconnect(self._setXRange)
+
+ def setXRangeUpdatedOnZoom(self, enabled):
+ """Set whether or not to update the X range on zoom change.
+
+ :param bool enabled:
+ """
+ if enabled != self.__rangeAutoUpdate:
+ self.__rangeAutoUpdate = enabled
+ if self._getToolWindow().isVisible():
+ self.__setAutoXRangeEnabled(enabled)
+
+ def isXRangeUpdatedOnZoom(self):
+ """Returns the current mode of fitted data X range update.
+
+ :rtype: bool
+ """
+ return self.__rangeAutoUpdate
+
+ # Fitted item update
+
+ def getXData(self, copy=True):
+ """Returns the X data used for the fit or None if undefined.
+
+ :param bool copy:
+ True to get a copy of the data, False to get the internal data.
+ :rtype: Union[numpy.ndarray,None]
+ """
+ return None if self.__x is None else numpy.array(self.__x, copy=copy)
+
+ def getYData(self, copy=True):
+ """Returns the Y data used for the fit or None if undefined.
+
+ :param bool copy:
+ True to get a copy of the data, False to get the internal data.
+ :rtype: Union[numpy.ndarray,None]
+ """
+ return None if self.__y is None else numpy.array(self.__y, copy=copy)
+
+ def _getFittedItem(self):
+ """Returns the current item used for the fit
+
+ :rtype: Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None]
+ """
+ return self.__item
+
+ def _setFittedItem(self, item):
+ """Set the curve to use for fitting.
+
+ :param Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None] item:
+ """
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+
+ if plot is None or item is None:
+ self.__item = None
+ self.__curveParams = {}
+ self.__updateFitWidget()
+ return
+
+ axis = item.getYAxis() if isinstance(item, items.YAxisMixIn) else "left"
+ self.__curveParams = {
+ "yaxis": axis,
+ "xlabel": plot.getXAxis().getLabel(),
+ "ylabel": plot.getYAxis(axis).getLabel(),
+ }
+ self.__legend = item.getName()
+
+ if isinstance(item, items.Histogram):
+ bin_edges = item.getBinEdgesData(copy=False)
+ # take the middle coordinate between adjacent bin edges
+ self.__x = (bin_edges[1:] + bin_edges[:-1]) / 2
+ self.__y = item.getValueData(copy=False)
+ # else take the active curve, or else the unique curve
+ elif isinstance(item, items.Curve):
+ self.__x = item.getXData(copy=False)
+ self.__y = item.getYData(copy=False)
+
+ self.__item = item
+ self.__updateFitWidget()
+
+ def __setFittedItemAutoUpdateEnabled(self, enabled):
+ """Implement the change of fitted item update mode
+
+ :param bool enabled:
+ """
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+
+ self.__fitItemSelector.setPlotWidget(self.plot if enabled else None)
+
+ def setFittedItemUpdatedFromActiveCurve(self, enabled):
+ """Toggle fitted data synchronization with plot active curve.
+
+ :param bool enabled:
+ """
+ enabled = bool(enabled)
+ if enabled != self.__activeCurveSynchroEnabled:
+ self.__activeCurveSynchroEnabled = enabled
+ if self._getToolWindow().isVisible():
+ self.__setFittedItemAutoUpdateEnabled(enabled)
+
+ def isFittedItemUpdatedFromActiveCurve(self):
+ """Returns True if fitted data is synchronized with plot.
+
+ :rtype: bool
+ """
+ return self.__activeCurveSynchroEnabled
+
+ # Handle fit completed
+
+ def handle_signal(self, ddict):
+ xdata = self.getXData(copy=False)
+ if xdata is None:
+ _logger.error("No reference data to display fit result for")
+ return
+
+ xmin, xmax = self.getXRange()
+ x_fit = xdata[xmin <= xdata]
+ x_fit = x_fit[x_fit <= xmax]
+ fit_legend = "Fit <%s>" % self.__legend
+ fit_curve = self.plot.getCurve(fit_legend)
+
+ if ddict["event"] == "FitFinished":
+ fit_widget = self._getToolWindow()
+ if fit_widget is None:
+ return
+ y_fit = fit_widget.fitmanager.gendata()
+ if fit_curve is None:
+ self.plot.addCurve(
+ x_fit, y_fit, fit_legend, resetzoom=False, **self.__curveParams
+ )
+ else:
+ fit_curve.setData(x_fit, y_fit)
+ fit_curve.setVisible(True)
+ fit_curve.setYAxis(self.__curveParams.get("yaxis", "left"))
+
+ if ddict["event"] in ["FitStarted", "FitFailed"]:
+ if fit_curve is not None:
+ fit_curve.setVisible(False)
diff --git a/src/silx/gui/plot/actions/histogram.py b/src/silx/gui/plot/actions/histogram.py
new file mode 100644
index 0000000..39c669b
--- /dev/null
+++ b/src/silx/gui/plot/actions/histogram.py
@@ -0,0 +1,548 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.histogram` provides actions relative to histograms
+for :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`PixelIntensitiesHistoAction`
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__date__ = "07/11/2023"
+__license__ = "MIT"
+
+from typing import Optional, Tuple
+import numpy
+import logging
+import weakref
+
+from .PlotToolAction import PlotToolAction
+
+from silx.math.histogram import Histogramnd
+from silx.math.combo import min_max
+from silx.gui import qt
+from silx.gui.plot import items
+from silx.gui.widgets.ElidedLabel import ElidedLabel
+from silx.gui.widgets.RangeSlider import RangeSlider
+
+_logger = logging.getLogger(__name__)
+
+
+class _ElidedLabel(ElidedLabel):
+ """QLabel with a default size larger than what is displayed."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ def sizeHint(self):
+ hint = super().sizeHint()
+ nbchar = max(len(self.text()), 12)
+ width = self.fontMetrics().boundingRect("#" * nbchar).width()
+ return qt.QSize(max(hint.width(), width), hint.height())
+
+
+class _StatWidget(qt.QWidget):
+ """Widget displaying a name and a value
+
+ :param parent:
+ :param name:
+ """
+
+ def __init__(self, parent=None, name: str = ""):
+ super().__init__(parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ keyWidget = qt.QLabel(parent=self)
+ keyWidget.setText("<b>" + name.capitalize() + ":<b>")
+ layout.addWidget(keyWidget)
+ self.__valueWidget = _ElidedLabel(parent=self)
+ self.__valueWidget.setText("-")
+ self.__valueWidget.setTextInteractionFlags(
+ qt.Qt.TextSelectableByMouse | qt.Qt.TextSelectableByKeyboard
+ )
+ layout.addWidget(self.__valueWidget)
+
+ def setValue(self, value: Optional[float]):
+ """Set the displayed value
+
+ :param value:
+ """
+ self.__valueWidget.setText("-" if value is None else "{:.5g}".format(value))
+
+
+class _IntEdit(qt.QLineEdit):
+ """QLineEdit for integers with a default value and update on validation.
+
+ :param QWidget parent:
+ """
+
+ sigValueChanged = qt.Signal(int)
+ """Signal emitted when the value has changed (on editing finished)"""
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.__value = None
+ self.setAlignment(qt.Qt.AlignRight)
+ validator = qt.QIntValidator()
+ self.setValidator(validator)
+ validator.bottomChanged.connect(self.__updateSize)
+ validator.topChanged.connect(self.__updateSize)
+ self.__updateSize()
+
+ self.textEdited.connect(self.__textEdited)
+
+ def __updateSize(self, *args):
+ """Update widget's maximum size according to bounds"""
+ bottom, top = self.getRange()
+ nbchar = max(len(str(bottom)), len(str(top)))
+ font = self.font()
+ font.setStyle(qt.QFont.StyleItalic)
+ fontMetrics = qt.QFontMetrics(font)
+ self.setMaximumWidth(fontMetrics.boundingRect("0" * (nbchar + 1)).width())
+ self.setMaxLength(nbchar)
+
+ def __textEdited(self, _):
+ if self.font().style() != qt.QFont.StyleItalic:
+ font = self.font()
+ font.setStyle(qt.QFont.StyleItalic)
+ self.setFont(font)
+
+ # Use events rather than editingFinished to also trigger with empty text
+
+ def focusOutEvent(self, event):
+ self.__commitValue()
+ return super().focusOutEvent(event)
+
+ def keyPressEvent(self, event):
+ if event.key() in (qt.Qt.Key_Enter, qt.Qt.Key_Return):
+ self.__commitValue()
+ return super().keyPressEvent(event)
+
+ def __commitValue(self):
+ """Update the value returned by :meth:`getValue`"""
+ value = self.getCurrentValue()
+ if value is None:
+ value = self.getDefaultValue()
+ if value is None:
+ return # No value, keep previous one
+
+ if self.font().style() != qt.QFont.StyleNormal:
+ font = self.font()
+ font.setStyle(qt.QFont.StyleNormal)
+ self.setFont(font)
+
+ if value != self.__value:
+ self.__value = value
+ self.sigValueChanged.emit(value)
+
+ def getValue(self) -> Optional[int]:
+ """Return current value (None if never set)."""
+ return self.__value
+
+ def setRange(self, bottom: int, top: int):
+ """Set the range of valid values"""
+ self.validator().setRange(bottom, top)
+
+ def getRange(self) -> Tuple[int, int]:
+ """Returns the current range of valid values
+
+ :returns: (bottom, top)
+ """
+ return self.validator().bottom(), self.validator().top()
+
+ def __validate(self, value: int, extend_range: bool):
+ """Ensure value is in range
+
+ :param int value:
+ :param bool extend_range:
+ True to extend range if needed.
+ False to clip value if needed.
+ """
+ if extend_range:
+ bottom, top = self.getRange()
+ self.setRange(min(value, bottom), max(value, top))
+ return numpy.clip(value, *self.getRange())
+
+ def setDefaultValue(self, value: int, extend_range: bool = False):
+ """Set default value when QLineEdit is empty
+
+ :param int value:
+ :param bool extend_range:
+ True to extend range if needed.
+ False to clip value if needed
+ """
+ self.setPlaceholderText(str(self.__validate(value, extend_range)))
+ if self.getCurrentValue() is None:
+ self.__commitValue()
+
+ def getDefaultValue(self) -> Optional[int]:
+ """Return the default value or the bottom one if not set"""
+ try:
+ return int(self.placeholderText())
+ except ValueError:
+ return None
+
+ def setCurrentValue(self, value: int, extend_range: bool = False):
+ """Set the currently displayed value
+
+ :param int value:
+ :param bool extend_range:
+ True to extend range if needed.
+ False to clip value if needed
+ """
+ self.setText(str(self.__validate(value, extend_range)))
+ self.__commitValue()
+
+ def getCurrentValue(self) -> Optional[int]:
+ """Returns the displayed value or None if not correct"""
+ try:
+ return int(self.text())
+ except ValueError:
+ return None
+
+
+class HistogramWidget(qt.QWidget):
+ """Widget displaying a histogram and some statistic indicators"""
+
+ _SUPPORTED_ITEM_CLASS = items.ImageBase, items.Scatter
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setWindowTitle("Histogram")
+
+ self.__itemRef = None # weakref on the item to track
+
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ # Plot
+ # Lazy import to avoid circular dependencies
+ from silx.gui.plot.PlotWindow import Plot1D
+
+ self.__plot = Plot1D(self)
+ layout.addWidget(self.__plot)
+
+ self.__plot.setDataMargins(0.1, 0.1, 0.1, 0.1)
+ self.__plot.getXAxis().setLabel("Value")
+ self.__plot.getYAxis().setLabel("Count")
+ posInfo = self.__plot.getPositionInfoWidget()
+ posInfo.setSnappingMode(posInfo.SNAPPING_CURVE)
+
+ # Histogram controls
+ controlsWidget = qt.QWidget(self)
+ layout.addWidget(controlsWidget)
+ controlsLayout = qt.QHBoxLayout(controlsWidget)
+ controlsLayout.setContentsMargins(4, 4, 4, 4)
+
+ controlsLayout.addWidget(qt.QLabel("<b>Histogram:<b>"))
+ controlsLayout.addWidget(qt.QLabel("N. bins:"))
+ self.__nbinsLineEdit = _IntEdit(self)
+ self.__nbinsLineEdit.setRange(2, 9999)
+ self.__nbinsLineEdit.sigValueChanged.connect(self.__updateHistogramFromControls)
+ controlsLayout.addWidget(self.__nbinsLineEdit)
+ self.__rangeLabel = qt.QLabel("Range:")
+ controlsLayout.addWidget(self.__rangeLabel)
+ self.__rangeSlider = RangeSlider(parent=self)
+ self.__rangeSlider.sigValueChanged.connect(self.__updateHistogramFromControls)
+ self.__rangeSlider.sigValueChanged.connect(self.__rangeChanged)
+ controlsLayout.addWidget(self.__rangeSlider)
+ self.__weightCheckBox = qt.QCheckBox(self)
+ self.__weightCheckBox.setText("Use weights")
+ self.__weightCheckBox.clicked.connect(self.__weightChanged)
+ controlsLayout.addWidget(self.__weightCheckBox)
+ controlsLayout.addStretch(1)
+
+ # Stats display
+ statsWidget = qt.QWidget(self)
+ layout.addWidget(statsWidget)
+ statsLayout = qt.QHBoxLayout(statsWidget)
+ statsLayout.setContentsMargins(4, 4, 4, 4)
+
+ self.__statsWidgets = dict(
+ (name, _StatWidget(parent=statsWidget, name=name))
+ for name in ("min", "max", "mean", "std", "sum")
+ )
+
+ for widget in self.__statsWidgets.values():
+ statsLayout.addWidget(widget)
+ statsLayout.addStretch(1)
+
+ def getPlotWidget(self):
+ """Returns :class:`PlotWidget` use to display the histogram"""
+ return self.__plot
+
+ def resetZoom(self):
+ """Reset PlotWidget zoom"""
+ self.getPlotWidget().resetZoom()
+
+ def reset(self):
+ """Clear displayed information"""
+ self.getPlotWidget().clear()
+ self.setStatistics()
+
+ def getItem(self) -> Optional[items.Item]:
+ """Returns item used to display histogram and statistics."""
+ return None if self.__itemRef is None else self.__itemRef()
+
+ def setItem(self, item: Optional[items.Item]):
+ """Set item from which to display histogram and statistics.
+
+ :param item:
+ """
+ previous = self.getItem()
+ if previous is not None:
+ previous.sigItemChanged.disconnect(self.__itemChanged)
+
+ self.__itemRef = None if item is None else weakref.ref(item)
+ if item is not None:
+ if isinstance(item, self._SUPPORTED_ITEM_CLASS):
+ # Only listen signal for supported items
+ item.sigItemChanged.connect(self.__itemChanged)
+ self._updateFromItem()
+
+ def __itemChanged(self, event):
+ """Handle update of the item"""
+ if event in (items.ItemChangedType.DATA, items.ItemChangedType.MASK):
+ self._updateFromItem()
+
+ def __updateHistogramFromControls(self, *args):
+ """Handle udates coming from histogram control widgets"""
+
+ hist = self.getHistogram(copy=False)
+ if hist is not None:
+ count, edges = hist
+ if (
+ len(count) == self.__nbinsLineEdit.getValue()
+ and (edges[0], edges[-1]) == self.__rangeSlider.getValues()
+ ):
+ return # Nothing has changed
+
+ self._updateFromItem()
+
+ def __rangeChanged(self, first, second):
+ """Handle change of histogram range from the range slider"""
+ tooltip = "Histogram range:\n[%g, %g]" % (first, second)
+ self.__rangeSlider.setToolTip(tooltip)
+ self.__rangeLabel.setToolTip(tooltip)
+
+ def __weightChanged(self, value):
+ self._updateFromItem()
+
+ def _updateFromItem(self):
+ """Update histogram and stats from the item"""
+ item = self.getItem()
+
+ if item is None:
+ self.reset()
+ return
+
+ if not isinstance(item, self._SUPPORTED_ITEM_CLASS):
+ _logger.error("Unsupported item", item)
+ self.reset()
+ return
+
+ # Compute histogram and stats
+ array = item.getValueData(copy=False)
+
+ if array.size == 0:
+ self.reset()
+ return
+
+ xmin, xmax = min_max(array, min_positive=False, finite=True)
+ if xmin is None or xmax is None: # All not finite data
+ self.reset()
+ return
+ guessed_nbins = min(1024, int(numpy.sqrt(array.size)))
+
+ # bad hack: get 256 bins in the case we have a B&W
+ if numpy.issubdtype(array.dtype, numpy.integer):
+ if guessed_nbins > xmax - xmin:
+ guessed_nbins = xmax - xmin
+ guessed_nbins = max(2, guessed_nbins)
+
+ # Set default nbins
+ self.__nbinsLineEdit.setDefaultValue(guessed_nbins, extend_range=True)
+ # Set slider range: do not keep the range value, but the relative pos.
+ previousPositions = self.__rangeSlider.getPositions()
+ if xmin == xmax: # Enlarge range is none
+ if xmin == 0:
+ range_ = -0.01, 0.01
+ else:
+ range_ = sorted((xmin * 0.99, xmin * 1.01))
+ else:
+ range_ = xmin, xmax
+
+ self.__rangeSlider.setRange(*range_)
+ self.__rangeSlider.setPositions(*previousPositions)
+
+ data = array.ravel().astype(numpy.float32)
+ histogram = Histogramnd(
+ data,
+ n_bins=max(2, self.__nbinsLineEdit.getValue()),
+ histo_range=self.__rangeSlider.getValues(),
+ weights=data,
+ )
+ if len(histogram.edges) != 1:
+ _logger.error("Error while computing the histogram")
+ self.reset()
+ return
+
+ if self.__weightCheckBox.isChecked():
+ self.setHistogram(histogram.weighted_histo, histogram.edges[0])
+ self.__plot.getYAxis().setLabel("Count * Value")
+ else:
+ self.setHistogram(histogram.histo, histogram.edges[0])
+ self.__plot.getYAxis().setLabel("Count")
+ self.resetZoom()
+ self.setStatistics(
+ min_=xmin,
+ max_=xmax,
+ mean=numpy.nanmean(array),
+ std=numpy.nanstd(array),
+ sum_=numpy.nansum(array),
+ )
+
+ def setHistogram(self, histogram, edges):
+ """Set displayed histogram
+
+ :param histogram: Bin values (N)
+ :param edges: Bin edges (N+1)
+ """
+ # Only useful if setHistogram is called directly
+ # TODO
+ # nbins = len(histogram)
+ # if nbins != self.__nbinsLineEdit.getDefaultValue():
+ # self.__nbinsLineEdit.setValue(nbins, extend_range=True)
+ # self.__rangeSlider.setValues(edges[0], edges[-1])
+
+ self.getPlotWidget().addHistogram(
+ histogram=histogram,
+ edges=edges,
+ legend="histogram",
+ fill=True,
+ color="#66aad7",
+ resetzoom=False,
+ )
+
+ def getHistogram(self, copy: bool = True):
+ """Returns currently displayed histogram.
+
+ :param copy: True to get a copy,
+ False to get internal representation (Do not modify!)
+ :return: (histogram, edges) or None
+ """
+ for item in self.getPlotWidget().getItems():
+ if item.getName() == "histogram":
+ return (item.getValueData(copy=copy), item.getBinEdgesData(copy=copy))
+ else:
+ return None
+
+ def setStatistics(
+ self,
+ min_: Optional[float] = None,
+ max_: Optional[float] = None,
+ mean: Optional[float] = None,
+ std: Optional[float] = None,
+ sum_: Optional[float] = None,
+ ):
+ """Set displayed statistic indicators."""
+ self.__statsWidgets["min"].setValue(min_)
+ self.__statsWidgets["max"].setValue(max_)
+ self.__statsWidgets["mean"].setValue(mean)
+ self.__statsWidgets["std"].setValue(std)
+ self.__statsWidgets["sum"].setValue(sum_)
+
+
+class PixelIntensitiesHistoAction(PlotToolAction):
+ """QAction to plot the pixels intensities diagram
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ PlotToolAction.__init__(
+ self,
+ plot,
+ icon="pixel-intensities",
+ text="pixels intensity",
+ tooltip="Compute image intensity distribution",
+ parent=parent,
+ )
+
+ def _connectPlot(self, window):
+ plot = self.plot
+ if plot is not None:
+ selection = plot.selection()
+ selection.sigSelectedItemsChanged.connect(self._selectedItemsChanged)
+ self._updateSelectedItem()
+
+ PlotToolAction._connectPlot(self, window)
+
+ def _disconnectPlot(self, window):
+ plot = self.plot
+ if plot is not None:
+ selection = self.plot.selection()
+ selection.sigSelectedItemsChanged.disconnect(self._selectedItemsChanged)
+
+ PlotToolAction._disconnectPlot(self, window)
+ self.getHistogramWidget().setItem(None)
+
+ def _updateSelectedItem(self):
+ """Synchronises selected item with plot widget."""
+ plot = self.plot
+ if plot is not None:
+ selected = plot.selection().getSelectedItems()
+ # Give priority to image over scatter
+ for klass in (items.ImageBase, items.Scatter):
+ for item in selected:
+ if isinstance(item, klass):
+ # Found a matching item, use it
+ self.getHistogramWidget().setItem(item)
+ return
+ self.getHistogramWidget().setItem(None)
+
+ def _selectedItemsChanged(self):
+ if self._isWindowInUse():
+ self._updateSelectedItem()
+
+ def getHistogramWidget(self):
+ """Returns the widget displaying the histogram"""
+ return self._getToolWindow()
+
+ def _createToolWindow(self):
+ return HistogramWidget(self.plot, qt.Qt.Window)
+
+ def getHistogram(self) -> Optional[numpy.ndarray]:
+ """Return the last computed histogram
+
+ :return: the histogram displayed in the HistogramWidget
+ """
+ histogram = self.getHistogramWidget().getHistogram()
+ return None if histogram is None else histogram[0]
diff --git a/src/silx/gui/plot/actions/io.py b/src/silx/gui/plot/actions/io.py
new file mode 100644
index 0000000..1ff95f3
--- /dev/null
+++ b/src/silx/gui/plot/actions/io.py
@@ -0,0 +1,874 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.io` provides a set of QAction relative of inputs
+and outputs for a :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`CopyAction`
+- :class:`PrintAction`
+- :class:`SaveAction`
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "25/09/2020"
+
+from io import BytesIO
+import logging
+import sys
+import os.path
+import traceback
+import numpy
+from fabio.TiffIO import TiffIO
+from fabio.edfimage import EdfImage
+
+from silx.gui import qt, printer
+from silx.gui.dialog.GroupDialog import GroupDialog
+from silx.io.utils import save1D, savespec, NEXUS_HDF5_EXT
+from silx.io.nxdata import save_NXdata
+
+from . import PlotAction
+from ...utils.image import convertArrayToQImage
+
+
+_logger = logging.getLogger(__name__)
+
+_NEXUS_HDF5_EXT_STR = " ".join(["*" + ext for ext in NEXUS_HDF5_EXT])
+
+
+def selectOutputGroup(h5filename):
+ """Open a dialog to prompt the user to select a group in
+ which to output data.
+
+ :param str h5filename: name of an existing HDF5 file
+ :rtype: str
+ :return: Name of output group, or None if the dialog was cancelled
+ """
+ dialog = GroupDialog()
+ dialog.addFile(h5filename)
+ dialog.setWindowTitle("Select an output group")
+ if not dialog.exec():
+ return None
+ return dialog.getSelectedDataUrl().data_path()
+
+
+class SaveAction(PlotAction):
+ """QAction for saving Plot content.
+
+ It opens a Save as... dialog.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param parent: See :class:`QAction`.
+ """
+
+ SNAPSHOT_FILTER_SVG = "Plot Snapshot as SVG (*.svg)"
+ SNAPSHOT_FILTER_PNG = "Plot Snapshot as PNG (*.png)"
+
+ DEFAULT_ALL_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG)
+
+ # Dict of curve filters with CSV-like format
+ # Using ordered dict to guarantee filters order
+ # Note: '%.18e' is numpy.savetxt default format
+ CURVE_FILTERS_TXT = dict(
+ (
+ (
+ "Curve as Raw ASCII (*.txt)",
+ {"fmt": "%.18e", "delimiter": " ", "header": False},
+ ),
+ (
+ 'Curve as ";"-separated CSV (*.csv)',
+ {"fmt": "%.18e", "delimiter": ";", "header": True},
+ ),
+ (
+ 'Curve as ","-separated CSV (*.csv)',
+ {"fmt": "%.18e", "delimiter": ",", "header": True},
+ ),
+ (
+ "Curve as tab-separated CSV (*.csv)",
+ {"fmt": "%.18e", "delimiter": "\t", "header": True},
+ ),
+ (
+ "Curve as OMNIC CSV (*.csv)",
+ {"fmt": "%.7E", "delimiter": ",", "header": False},
+ ),
+ (
+ "Curve as SpecFile (*.dat)",
+ {"fmt": "%.10g", "delimiter": "", "header": False},
+ ),
+ )
+ )
+
+ CURVE_FILTER_NPY = "Curve as NumPy binary file (*.npy)"
+
+ CURVE_FILTER_NXDATA = "Curve as NXdata (%s)" % _NEXUS_HDF5_EXT_STR
+
+ DEFAULT_CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [
+ CURVE_FILTER_NPY,
+ CURVE_FILTER_NXDATA,
+ ]
+
+ DEFAULT_ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)",)
+
+ IMAGE_FILTER_EDF = "Image data as EDF (*.edf)"
+ IMAGE_FILTER_TIFF = "Image data as TIFF (*.tif)"
+ IMAGE_FILTER_NUMPY = "Image data as NumPy binary file (*.npy)"
+ IMAGE_FILTER_ASCII = "Image data as ASCII (*.dat)"
+ IMAGE_FILTER_CSV_COMMA = "Image data as ,-separated CSV (*.csv)"
+ IMAGE_FILTER_CSV_SEMICOLON = "Image data as ;-separated CSV (*.csv)"
+ IMAGE_FILTER_CSV_TAB = "Image data as tab-separated CSV (*.csv)"
+ IMAGE_FILTER_RGB_PNG = "Image as PNG (*.png)"
+ IMAGE_FILTER_NXDATA = "Image as NXdata (%s)" % _NEXUS_HDF5_EXT_STR
+
+ DEFAULT_IMAGE_FILTERS = (
+ IMAGE_FILTER_EDF,
+ IMAGE_FILTER_TIFF,
+ IMAGE_FILTER_NUMPY,
+ IMAGE_FILTER_ASCII,
+ IMAGE_FILTER_CSV_COMMA,
+ IMAGE_FILTER_CSV_SEMICOLON,
+ IMAGE_FILTER_CSV_TAB,
+ IMAGE_FILTER_RGB_PNG,
+ IMAGE_FILTER_NXDATA,
+ )
+
+ SCATTER_FILTER_NXDATA = "Scatter as NXdata (%s)" % _NEXUS_HDF5_EXT_STR
+ DEFAULT_SCATTER_FILTERS = (SCATTER_FILTER_NXDATA,)
+
+ # filters for which we don't want an "overwrite existing file" warning
+ DEFAULT_APPEND_FILTERS = (
+ CURVE_FILTER_NXDATA,
+ IMAGE_FILTER_NXDATA,
+ SCATTER_FILTER_NXDATA,
+ )
+
+ def __init__(self, plot, parent=None):
+ self._filters = {
+ "all": {},
+ "curve": {},
+ "curves": {},
+ "image": {},
+ "scatter": {},
+ }
+
+ self._appendFilters = list(self.DEFAULT_APPEND_FILTERS)
+
+ # Initialize filters
+ for nameFilter in self.DEFAULT_ALL_FILTERS:
+ self.setFileFilter(
+ dataKind="all", nameFilter=nameFilter, func=self._saveSnapshot
+ )
+
+ for nameFilter in self.DEFAULT_CURVE_FILTERS:
+ self.setFileFilter(
+ dataKind="curve", nameFilter=nameFilter, func=self._saveCurve
+ )
+
+ for nameFilter in self.DEFAULT_ALL_CURVES_FILTERS:
+ self.setFileFilter(
+ dataKind="curves", nameFilter=nameFilter, func=self._saveCurves
+ )
+
+ for nameFilter in self.DEFAULT_IMAGE_FILTERS:
+ self.setFileFilter(
+ dataKind="image", nameFilter=nameFilter, func=self._saveImage
+ )
+
+ for nameFilter in self.DEFAULT_SCATTER_FILTERS:
+ self.setFileFilter(
+ dataKind="scatter", nameFilter=nameFilter, func=self._saveScatter
+ )
+
+ super(SaveAction, self).__init__(
+ plot,
+ icon="document-save",
+ text="Save as...",
+ tooltip="Save curve/image/plot snapshot dialog",
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent,
+ )
+ self.setShortcut(qt.QKeySequence.Save)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ @staticmethod
+ def _errorMessage(informativeText="", parent=None):
+ """Display an error message."""
+ # TODO issue with QMessageBox size fixed and too small
+ msg = qt.QMessageBox(parent)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setInformativeText(informativeText + " " + str(sys.exc_info()[1]))
+ msg.setDetailedText(traceback.format_exc())
+ msg.exec()
+
+ def _saveSnapshot(self, plot, filename, nameFilter):
+ """Save a snapshot of the :class:`PlotWindow` widget.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter == self.SNAPSHOT_FILTER_PNG:
+ fileFormat = "png"
+ elif nameFilter == self.SNAPSHOT_FILTER_SVG:
+ fileFormat = "svg"
+ else: # Format not supported
+ _logger.error("Saving plot snapshot failed: format not supported")
+ return False
+
+ plot.saveGraph(filename, fileFormat=fileFormat)
+ return True
+
+ def _getAxesLabels(self, item):
+ # If curve has no associated label, get the default from the plot
+ xlabel = item.getXLabel() or self.plot.getXAxis().getLabel()
+ ylabel = item.getYLabel() or self.plot.getYAxis().getLabel()
+ return xlabel, ylabel
+
+ def _get1dData(self, item):
+ "provide xdata, [ydata], xlabel, [ylabel] and manages error bars"
+ xlabel, ylabel = self._getAxesLabels(item)
+ x_data = item.getXData(copy=False)
+ y_data = item.getYData(copy=False)
+ x_err = item.getXErrorData(copy=False)
+ y_err = item.getYErrorData(copy=False)
+ labels = [ylabel]
+ data = [y_data]
+
+ if x_err is not None:
+ if numpy.isscalar(x_err):
+ data.append(numpy.zeros_like(y_data) + x_err)
+ labels.append(xlabel + "_errors")
+ elif x_err.ndim == 1:
+ data.append(x_err)
+ labels.append(xlabel + "_errors")
+ elif x_err.ndim == 2:
+ data.append(x_err[0])
+ labels.append(xlabel + "_errors_below")
+ data.append(x_err[1])
+ labels.append(xlabel + "_errors_above")
+
+ if y_err is not None:
+ if numpy.isscalar(y_err):
+ data.append(numpy.zeros_like(y_data) + y_err)
+ labels.append(ylabel + "_errors")
+ elif y_err.ndim == 1:
+ data.append(y_err)
+ labels.append(ylabel + "_errors")
+ elif y_err.ndim == 2:
+ data.append(y_err[0])
+ labels.append(ylabel + "_errors_below")
+ data.append(y_err[1])
+ labels.append(ylabel + "_errors_above")
+ return x_data, data, xlabel, labels
+
+ @staticmethod
+ def _selectWriteableOutputGroup(filename, parent):
+ if (
+ os.path.exists(filename)
+ and os.path.isfile(filename)
+ and os.access(filename, os.W_OK)
+ ):
+ entryPath = selectOutputGroup(filename)
+ if entryPath is None:
+ _logger.info("Save operation cancelled")
+ return None
+ return entryPath
+ elif not os.path.exists(filename):
+ # create new entry in new file
+ return "/entry"
+ else:
+ SaveAction._errorMessage("Save failed (file access issue)\n", parent=parent)
+ return None
+
+ def _saveCurveAsNXdata(self, curve, filename):
+ entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
+ if entryPath is None:
+ return False
+
+ xlabel, ylabel = self._getAxesLabels(curve)
+
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=curve.getYData(copy=False),
+ axes=[curve.getXData(copy=False)],
+ signal_name="y",
+ axes_names=["x"],
+ signal_long_name=ylabel,
+ axes_long_names=[xlabel],
+ signal_errors=curve.getYErrorData(copy=False),
+ axes_errors=[curve.getXErrorData(copy=True)],
+ title=self.plot.getGraphTitle(),
+ )
+
+ def _saveCurve(self, plot, filename, nameFilter):
+ """Save a curve from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_CURVE_FILTERS:
+ return False
+
+ # Check if a curve is to be saved
+ curve = plot.getActiveCurve()
+ # before calling _saveCurve, if there is no selected curve, we
+ # make sure there is only one curve on the graph
+ if curve is None:
+ curves = plot.getAllCurves()
+ if not curves:
+ self._errorMessage("No curve to be saved", parent=self.plot)
+ return False
+ curve = curves[0]
+
+ if nameFilter in self.CURVE_FILTERS_TXT:
+ filter_ = self.CURVE_FILTERS_TXT[nameFilter]
+ fmt = filter_["fmt"]
+ csvdelim = filter_["delimiter"]
+ autoheader = filter_["header"]
+ else:
+ # .npy or nxdata
+ fmt, csvdelim, autoheader = ("", "", False)
+
+ if nameFilter == self.CURVE_FILTER_NXDATA:
+ return self._saveCurveAsNXdata(curve, filename)
+
+ xdata, data, xlabel, labels = self._get1dData(curve)
+
+ try:
+ save1D(
+ filename,
+ xdata,
+ data,
+ xlabel,
+ labels,
+ fmt=fmt,
+ csvdelim=csvdelim,
+ autoheader=autoheader,
+ )
+ except IOError:
+ self._errorMessage("Save failed\n", parent=self.plot)
+ return False
+
+ return True
+
+ def _saveCurves(self, plot, filename, nameFilter):
+ """Save all curves from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_ALL_CURVES_FILTERS:
+ return False
+
+ curves = plot.getAllCurves()
+ if not curves:
+ self._errorMessage("No curves to be saved", parent=self.plot)
+ return False
+
+ curve = curves[0]
+ scanno = 1
+ try:
+ xdata, data, xlabel, labels = self._get1dData(curve)
+
+ specfile = savespec(
+ filename,
+ xdata,
+ data,
+ xlabel,
+ labels,
+ fmt="%.7g",
+ scan_number=1,
+ mode="w",
+ write_file_header=True,
+ close_file=False,
+ )
+ except IOError:
+ self._errorMessage("Save failed\n", parent=self.plot)
+ return False
+
+ for curve in curves[1:]:
+ try:
+ scanno += 1
+ xdata, data, xlabel, labels = self._get1dData(curve)
+ specfile = savespec(
+ specfile,
+ xdata,
+ data,
+ xlabel,
+ labels,
+ fmt="%.7g",
+ scan_number=scanno,
+ write_file_header=False,
+ close_file=False,
+ )
+ except IOError:
+ self._errorMessage("Save failed\n", parent=self.plot)
+ return False
+ specfile.close()
+
+ return True
+
+ def _saveImage(self, plot, filename, nameFilter):
+ """Save an image from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_IMAGE_FILTERS:
+ return False
+
+ image = plot.getActiveImage()
+ if image is None:
+ qt.QMessageBox.warning(plot, "No Data", "No image to be saved")
+ return False
+
+ data = image.getData(copy=False)
+
+ # TODO Use silx.io for writing files
+ if nameFilter == self.IMAGE_FILTER_EDF:
+ EdfImage(data=data, header={}).write(filename)
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_TIFF:
+ tiffFile = TiffIO(filename, mode="w")
+ tiffFile.writeImage(data, software="silx")
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_NUMPY:
+ try:
+ numpy.save(filename, data)
+ except IOError:
+ self._errorMessage("Save failed\n", parent=self.plot)
+ return False
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_NXDATA:
+ entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
+ if entryPath is None:
+ return False
+ xorigin, yorigin = image.getOrigin()
+ xscale, yscale = image.getScale()
+ xaxis = xorigin + xscale * numpy.arange(data.shape[1])
+ yaxis = yorigin + yscale * numpy.arange(data.shape[0])
+ xlabel, ylabel = self._getAxesLabels(image)
+ interpretation = "image" if len(data.shape) == 2 else "rgba-image"
+
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=data,
+ axes=[yaxis, xaxis],
+ signal_name="image",
+ axes_names=["y", "x"],
+ axes_long_names=[ylabel, xlabel],
+ title=plot.getGraphTitle(),
+ interpretation=interpretation,
+ )
+
+ elif nameFilter in (
+ self.IMAGE_FILTER_ASCII,
+ self.IMAGE_FILTER_CSV_COMMA,
+ self.IMAGE_FILTER_CSV_SEMICOLON,
+ self.IMAGE_FILTER_CSV_TAB,
+ ):
+ csvdelim, filetype = {
+ self.IMAGE_FILTER_ASCII: (" ", "txt"),
+ self.IMAGE_FILTER_CSV_COMMA: (",", "csv"),
+ self.IMAGE_FILTER_CSV_SEMICOLON: (";", "csv"),
+ self.IMAGE_FILTER_CSV_TAB: ("\t", "csv"),
+ }[nameFilter]
+
+ height, width = data.shape
+ rows, cols = numpy.mgrid[0:height, 0:width]
+ try:
+ save1D(
+ filename,
+ rows.ravel(),
+ (cols.ravel(), data.ravel()),
+ filetype=filetype,
+ xlabel="row",
+ ylabels=["column", "value"],
+ csvdelim=csvdelim,
+ autoheader=True,
+ )
+
+ except IOError:
+ self._errorMessage("Save failed\n", parent=self.plot)
+ return False
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_RGB_PNG:
+ # Get displayed image
+ rgbaImage = image.getRgbaImageData(copy=False)
+ # Convert RGB QImage
+ qimage = convertArrayToQImage(rgbaImage[:, :, :3])
+
+ if qimage.save(filename, "PNG"):
+ return True
+ else:
+ _logger.error("Failed to save image as %s", filename)
+ qt.QMessageBox.critical(
+ self.parent(), "Save image as", "Failed to save image"
+ )
+
+ return False
+
+ def _saveScatter(self, plot, filename, nameFilter):
+ """Save an image from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_SCATTER_FILTERS:
+ return False
+
+ if nameFilter == self.SCATTER_FILTER_NXDATA:
+ entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
+ if entryPath is None:
+ return False
+ scatter = plot.getScatter()
+
+ x = scatter.getXData(copy=False)
+ y = scatter.getYData(copy=False)
+ z = scatter.getValueData(copy=False)
+
+ xerror = scatter.getXErrorData(copy=False)
+ if isinstance(xerror, float):
+ xerror = xerror * numpy.ones(x.shape, dtype=numpy.float32)
+
+ yerror = scatter.getYErrorData(copy=False)
+ if isinstance(yerror, float):
+ yerror = yerror * numpy.ones(x.shape, dtype=numpy.float32)
+
+ xlabel = plot.getGraphXLabel()
+ ylabel = plot.getGraphYLabel()
+
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=z,
+ axes=[x, y],
+ signal_name="values",
+ axes_names=["x", "y"],
+ axes_long_names=[xlabel, ylabel],
+ axes_errors=[xerror, yerror],
+ title=plot.getGraphTitle(),
+ )
+
+ def setFileFilter(self, dataKind, nameFilter, func, index=None, appendToFile=False):
+ """Set a name filter to add/replace a file format support
+
+ :param str dataKind:
+ The kind of data for which the provided filter is valid.
+ One of: 'all', 'curve', 'curves', 'image', 'scatter'
+ :param str nameFilter: The name filter in the QFileDialog.
+ See :meth:`QFileDialog.setNameFilters`.
+ :param callable func: The function to call to perform saving.
+ Expected signature is:
+ bool func(PlotWidget plot, str filename, str nameFilter)
+ :param bool appendToFile: True to append the data into the selected
+ file.
+ :param integer index: Index of the filter in the final list (or None)
+ """
+ assert dataKind in ("all", "curve", "curves", "image", "scatter")
+
+ if appendToFile:
+ self._appendFilters.append(nameFilter)
+
+ # first append or replace the new filter to prevent colissions
+ self._filters[dataKind][nameFilter] = func
+ if index is None:
+ # we are already done
+ return
+
+ # get the current ordered list of keys
+ keyList = list(self._filters[dataKind].keys())
+
+ # deal with negative indices
+ if index < 0:
+ index = len(keyList) + index
+ if index < 0:
+ index = 0
+
+ if index >= len(keyList):
+ # nothing to be done, already at the end
+ txt = "Requested index %d impossible, already at the end" % index
+ _logger.info(txt)
+ return
+
+ # get the new ordered list
+ oldIndex = keyList.index(nameFilter)
+ del keyList[oldIndex]
+ keyList.insert(index, nameFilter)
+
+ # build the new filters
+ newFilters = {}
+ for key in keyList:
+ newFilters[key] = self._filters[dataKind][key]
+
+ # and update the filters
+ self._filters[dataKind] = newFilters
+ return
+
+ def getFileFilters(self, dataKind):
+ """Returns the nameFilter and associated function for a kind of data.
+
+ :param str dataKind:
+ The kind of data for which the provided filter is valid.
+ On of: 'all', 'curve', 'curves', 'image', 'scatter'
+ :return: {nameFilter: function} associations.
+ :rtype: dict
+ """
+ assert dataKind in ("all", "curve", "curves", "image", "scatter")
+
+ return self._filters[dataKind].copy()
+
+ def _actionTriggered(self, checked=False):
+ """Handle save action."""
+ # Set-up filters
+ filters = {}
+
+ # Add image filters if there is an active image
+ if self.plot.getActiveImage() is not None:
+ filters.update(self._filters["image"].items())
+
+ # Add curve filters if there is a curve to save
+ if self.plot.getActiveCurve() is not None or len(self.plot.getAllCurves()) == 1:
+ filters.update(self._filters["curve"].items())
+ if len(self.plot.getAllCurves()) >= 1:
+ filters.update(self._filters["curves"].items())
+
+ # Add scatter filters if there is a scatter
+ # todo: CSV
+ if self.plot.getScatter() is not None:
+ filters.update(self._filters["scatter"].items())
+
+ filters.update(self._filters["all"].items())
+
+ # Create and run File dialog
+ dialog = qt.QFileDialog(self.plot)
+ dialog.setOption(qt.QFileDialog.DontUseNativeDialog)
+ dialog.setWindowTitle("Output File Selection")
+ dialog.setModal(1)
+ dialog.setNameFilters(list(filters.keys()))
+
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+
+ def onFilterSelection(filt_):
+ # disable overwrite confirmation for NXdata types,
+ # because we append the data to existing files
+ if filt_ in self._appendFilters:
+ dialog.setOption(qt.QFileDialog.DontConfirmOverwrite)
+ else:
+ dialog.setOption(qt.QFileDialog.DontConfirmOverwrite, False)
+
+ dialog.filterSelected.connect(onFilterSelection)
+
+ if not dialog.exec():
+ return False
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if "(" in nameFilter and ")" == nameFilter.strip()[-1]:
+ # Check for correct file extension
+ # Extract file extensions as .something
+ extensions = [
+ ext[ext.find(".") :]
+ for ext in nameFilter[nameFilter.find("(") + 1 : -1].split()
+ ]
+ for ext in extensions:
+ if (
+ len(filename) > len(ext)
+ and filename[-len(ext) :].lower() == ext.lower()
+ ):
+ break
+ else: # filename has no extension supported in nameFilter, add one
+ if len(extensions) >= 1:
+ filename += extensions[0]
+
+ # Handle save
+ func = filters.get(nameFilter, None)
+ if func is not None:
+ return func(self.plot, filename, nameFilter)
+ else:
+ _logger.error("Unsupported file filter: %s", nameFilter)
+ return False
+
+
+def _plotAsPNG(plot):
+ """Save a :class:`Plot` as PNG and return the payload.
+
+ :param plot: The :class:`Plot` to save
+ """
+ pngFile = BytesIO()
+ plot.saveGraph(pngFile, fileFormat="png")
+ pngFile.flush()
+ pngFile.seek(0)
+ data = pngFile.read()
+ pngFile.close()
+ return data
+
+
+class PrintAction(PlotAction):
+ """QAction for printing the plot.
+
+ It opens a Print dialog.
+
+ Current implementation print a bitmap of the plot area and not vector
+ graphics, so printing quality is not great.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param parent: See :class:`QAction`.
+ """
+
+ def __init__(self, plot, parent=None):
+ super(PrintAction, self).__init__(
+ plot,
+ icon="document-print",
+ text="Print...",
+ tooltip="Open print dialog",
+ triggered=self.printPlot,
+ checkable=False,
+ parent=parent,
+ )
+ self.setShortcut(qt.QKeySequence.Print)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def getPrinter(self):
+ """The QPrinter instance used by the PrintAction.
+
+ :rtype: QPrinter
+ """
+ return printer.getDefaultPrinter()
+
+ def printPlotAsWidget(self):
+ """Open the print dialog and print the plot.
+
+ Use :meth:`QWidget.render` to print the plot
+
+ :return: True if successful
+ """
+ dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
+ dialog.setWindowTitle("Print Plot")
+ if not dialog.exec():
+ return False
+
+ # Print a snapshot of the plot widget at the top of the page
+ widget = self.plot.centralWidget()
+
+ painter = qt.QPainter()
+ if not painter.begin(self.getPrinter()):
+ return False
+
+ pageRect = self.getPrinter().pageRect(qt.QPrinter.DevicePixel)
+ xScale = pageRect.width() / widget.width()
+ yScale = pageRect.height() / widget.height()
+ scale = min(xScale, yScale)
+
+ painter.translate(pageRect.width() / 2.0, 0.0)
+ painter.scale(scale, scale)
+ painter.translate(-widget.width() / 2.0, 0.0)
+ widget.render(painter)
+ painter.end()
+
+ return True
+
+ def printPlot(self):
+ """Open the print dialog and print the plot.
+
+ Use :meth:`Plot.saveGraph` to print the plot.
+
+ :return: True if successful
+ """
+ # Init printer and start printer dialog
+ dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
+ dialog.setWindowTitle("Print Plot")
+ if not dialog.exec():
+ return False
+
+ # Save Plot as PNG and make a pixmap from it with default dpi
+ pngData = _plotAsPNG(self.plot)
+
+ pixmap = qt.QPixmap()
+ pixmap.loadFromData(pngData, "png")
+
+ pageRect = self.getPrinter().pageRect(qt.QPrinter.DevicePixel)
+ xScale = pageRect.width() / pixmap.width()
+ yScale = pageRect.height() / pixmap.height()
+ scale = min(xScale, yScale)
+
+ # Draw pixmap with painter
+ painter = qt.QPainter()
+ if not painter.begin(self.getPrinter()):
+ return False
+
+ painter.drawPixmap(
+ 0, 0, pixmap.width() * scale, pixmap.height() * scale, pixmap
+ )
+ painter.end()
+
+ return True
+
+
+class CopyAction(PlotAction):
+ """QAction to copy :class:`.PlotWidget` content to clipboard.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(CopyAction, self).__init__(
+ plot,
+ icon="edit-copy",
+ text="Copy plot",
+ tooltip="Copy a snapshot of the plot into the clipboard",
+ triggered=self.copyPlot,
+ checkable=False,
+ parent=parent,
+ )
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def copyPlot(self):
+ """Copy plot content to the clipboard as a bitmap."""
+ # Save Plot as PNG and make a QImage from it with default dpi
+ pngData = _plotAsPNG(self.plot)
+ image = qt.QImage.fromData(pngData, "png")
+ qt.QApplication.clipboard().setImage(image)
diff --git a/src/silx/gui/plot/actions/medfilt.py b/src/silx/gui/plot/actions/medfilt.py
new file mode 100644
index 0000000..a335499
--- /dev/null
+++ b/src/silx/gui/plot/actions/medfilt.py
@@ -0,0 +1,140 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.medfilt` provides a set of QAction to apply filter
+on data contained in a :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`MedianFilterAction`
+- :class:`MedianFilter1DAction`
+- :class:`MedianFilter2DAction`
+
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+
+__date__ = "10/10/2018"
+
+from .PlotToolAction import PlotToolAction
+from silx.gui.widgets.MedianFilterDialog import MedianFilterDialog
+from silx.math.medianfilter import medfilt2d
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+class MedianFilterAction(PlotToolAction):
+ """QAction to plot the pixels intensities diagram
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ PlotToolAction.__init__(
+ self,
+ plot,
+ icon="median-filter",
+ text="median filter",
+ tooltip="Apply a median filter on the image",
+ parent=parent,
+ )
+ self._originalImage = None
+ self._legend = None
+ self._filteredImage = None
+
+ def _createToolWindow(self):
+ popup = MedianFilterDialog(parent=self.plot)
+ popup.sigFilterOptChanged.connect(self._updateFilter)
+ return popup
+
+ def _connectPlot(self, window):
+ PlotToolAction._connectPlot(self, window)
+ self.plot.sigActiveImageChanged.connect(self._updateActiveImage)
+ self._updateActiveImage()
+
+ def _disconnectPlot(self, window):
+ PlotToolAction._disconnectPlot(self, window)
+ self.plot.sigActiveImageChanged.disconnect(self._updateActiveImage)
+
+ def _updateActiveImage(self):
+ """Set _activeImageLegend and _originalImage from the active image"""
+ self._activeImageLegend = self.plot.getActiveImage(just_legend=True)
+ if self._activeImageLegend is None:
+ self._originalImage = None
+ self._legend = None
+ else:
+ self._originalImage = self.plot.getImage(self._activeImageLegend).getData(
+ copy=False
+ )
+ self._legend = self.plot.getImage(self._activeImageLegend).getName()
+
+ def _updateFilter(self, kernelWidth, conditional=False):
+ if self._originalImage is None:
+ return
+
+ self.plot.sigActiveImageChanged.disconnect(self._updateActiveImage)
+ filteredImage = self._computeFilteredImage(kernelWidth, conditional)
+ self.plot.addImage(data=filteredImage, legend=self._legend, replace=True)
+ self.plot.sigActiveImageChanged.connect(self._updateActiveImage)
+
+ def _computeFilteredImage(self, kernelWidth, conditional):
+ raise NotImplementedError("MedianFilterAction is a an abstract class")
+
+ def getFilteredImage(self):
+ """
+ :return: the image with the median filter apply on"""
+ return self._filteredImage
+
+
+class MedianFilter1DAction(MedianFilterAction):
+ """Define the MedianFilterAction for 1D
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ MedianFilterAction.__init__(self, plot, parent=parent)
+
+ def _computeFilteredImage(self, kernelWidth, conditional):
+ assert self.plot is not None
+ return medfilt2d(self._originalImage, (kernelWidth, 1), conditional)
+
+
+class MedianFilter2DAction(MedianFilterAction):
+ """Define the MedianFilterAction for 2D
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ MedianFilterAction.__init__(self, plot, parent=parent)
+
+ def _computeFilteredImage(self, kernelWidth, conditional):
+ assert self.plot is not None
+ return medfilt2d(self._originalImage, (kernelWidth, kernelWidth), conditional)
diff --git a/src/silx/gui/plot/actions/mode.py b/src/silx/gui/plot/actions/mode.py
new file mode 100644
index 0000000..511a8df
--- /dev/null
+++ b/src/silx/gui/plot/actions/mode.py
@@ -0,0 +1,139 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.mode` provides a set of QAction relative to mouse
+mode of a :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`ZoomModeAction`
+- :class:`PanModeAction`
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "16/08/2017"
+
+
+from silx.gui import qt
+
+from ..tools.menus import ZoomEnabledAxesMenu
+from . import PlotAction
+
+
+class ZoomModeAction(PlotAction):
+ """QAction controlling the zoom mode of a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomModeAction, self).__init__(
+ plot,
+ icon="zoom",
+ text="Zoom mode",
+ tooltip="Zoom-in on mouse selection",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+
+ self.__menu = ZoomEnabledAxesMenu(self.plot, self.plot)
+
+ # Listen to interaction configuration change
+ self.plot.interaction().sigChanged.connect(self._interactionChanged)
+ # Init the state
+ self._interactionChanged()
+
+ def isAxesMenuEnabled(self) -> bool:
+ """Returns whether the axes selection menu is enabled or not (default: False)"""
+ return self.menu() is self.__menu
+
+ def setAxesMenuEnabled(self, enabled: bool):
+ """Toggle the availability of the axes selection menu (default: False)"""
+ if enabled == self.isAxesMenuEnabled():
+ return
+
+ self.setMenu(self.__menu if enabled else None)
+
+ # Update associated QToolButton's popupMode if any, this is not done at least with Qt5
+ parent = self.parent()
+ if not isinstance(parent, qt.QToolBar):
+ return
+ widget = parent.widgetForAction(self)
+ if not isinstance(widget, qt.QToolButton):
+ return
+ widget.setPopupMode(
+ qt.QToolButton.MenuButtonPopup if enabled else qt.QToolButton.DelayedPopup
+ )
+ widget.update()
+
+ def _interactionChanged(self):
+ plot = self.plot
+ if plot is None:
+ return
+
+ self.setChecked(plot.getInteractiveMode()["mode"] == "zoom")
+
+ def _actionTriggered(self, checked=False):
+ plot = self.plot
+ if plot is None:
+ return
+
+ plot.setInteractiveMode("zoom", source=self)
+
+
+class PanModeAction(PlotAction):
+ """QAction controlling the pan mode of a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(PanModeAction, self).__init__(
+ plot,
+ icon="pan",
+ text="Pan mode",
+ tooltip="Pan the view",
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
+ # Listen to mode change
+ self.plot.sigInteractiveModeChanged.connect(self._modeChanged)
+ # Init the state
+ self._modeChanged(None)
+
+ def _modeChanged(self, source):
+ modeDict = self.plot.getInteractiveMode()
+ old = self.blockSignals(True)
+ self.setChecked(modeDict["mode"] == "pan")
+ self.blockSignals(old)
+
+ def _actionTriggered(self, checked=False):
+ plot = self.plot
+ if plot is not None:
+ plot.setInteractiveMode("pan", source=self)
diff --git a/src/silx/gui/plot/backends/BackendBase.py b/src/silx/gui/plot/backends/BackendBase.py
new file mode 100755
index 0000000..8d70286
--- /dev/null
+++ b/src/silx/gui/plot/backends/BackendBase.py
@@ -0,0 +1,606 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Base class for Plot backends.
+
+It documents the Plot backend API.
+
+This API is a simplified version of PyMca PlotBackend API.
+"""
+
+from __future__ import annotations
+
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+from collections.abc import Callable
+import weakref
+from silx.gui.colors import RGBAColorType
+
+from ... import qt
+
+
+# Names for setCursor
+CURSOR_DEFAULT = "default"
+CURSOR_POINTING = "pointing"
+CURSOR_SIZE_HOR = "size horizontal"
+CURSOR_SIZE_VER = "size vertical"
+CURSOR_SIZE_ALL = "size all"
+
+
+class BackendBase(object):
+ """Class defining the API a backend of the Plot should provide."""
+
+ def __init__(self, plot, parent=None):
+ """Init.
+
+ :param Plot plot: The Plot this backend is attached to
+ :param parent: The parent widget of the plot widget.
+ """
+ self.__xLimits = 1.0, 100.0
+ self.__yLimits = {"left": (1.0, 100.0), "right": (1.0, 100.0)}
+ self.__yAxisInverted = False
+ self.__keepDataAspectRatio = False
+ self.__xAxisTimeSeries = False
+ self._xAxisTimeZone = None
+ # Store a weakref to get access to the plot state.
+ self._setPlot(plot)
+
+ @property
+ def _plot(self):
+ """The plot this backend is attached to."""
+ if self._plotRef is None:
+ raise RuntimeError("This backend is not attached to a Plot")
+
+ plot = self._plotRef()
+ if plot is None:
+ raise RuntimeError("This backend is no more attached to a Plot")
+ return plot
+
+ def _setPlot(self, plot):
+ """Allow to set plot after init.
+
+ Use with caution, basically **immediately** after init.
+ """
+ self._plotRef = weakref.ref(plot)
+
+ # Add methods
+
+ def addCurve(
+ self,
+ x,
+ y,
+ color,
+ gapcolor,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ xerror,
+ yerror,
+ fill,
+ alpha,
+ symbolsize,
+ baseline,
+ ):
+ """Add a 1D curve given by x an y to the graph.
+
+ :param numpy.ndarray x: The data corresponding to the x axis
+ :param numpy.ndarray y: The data corresponding to the y axis
+ :param color: color(s) to be used
+ :type color: string ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param Union[str, None] gapcolor:
+ color used to fill dashed line gaps.
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - ' ' or '' no symbol
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :param float linewidth: The width of the curve in pixels
+ :param linestyle: Type of line::
+
+ - ' ' or '' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - (offset, (dash pattern))
+
+ :param str yaxis: The Y axis this curve belongs to in: 'left', 'right'
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: numpy.ndarray or None
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: numpy.ndarray or None
+ :param bool fill: True to fill the curve, False otherwise
+ :param float alpha: Curve opacity, as a float in [0., 1.]
+ :param float symbolsize: Size of the symbol (if any) drawn
+ at each (x, y) position.
+ :returns: The handle used by the backend to univocally access the curve
+ """
+ return object()
+
+ def addImage(self, data, origin, scale, colormap, alpha):
+ """Add an image to the plot.
+
+ :param numpy.ndarray data: (nrows, ncolumns) data or
+ (nrows, ncolumns, RGBA) ubyte array
+ :param origin: (origin X, origin Y) of the data.
+ Default: (0., 0.)
+ :type origin: 2-tuple of float
+ :param scale: (scale X, scale Y) of the data.
+ Default: (1., 1.)
+ :type scale: 2-tuple of float
+ :param ~silx.gui.colors.Colormap colormap: Colormap object to use.
+ Ignored if data is RGB(A).
+ :param float alpha: Opacity of the image, as a float in range [0, 1].
+ :returns: The handle used by the backend to univocally access the image
+ """
+ return object()
+
+ def addTriangles(self, x, y, triangles, color, alpha):
+ """Add a set of triangles.
+
+ :param numpy.ndarray x: The data corresponding to the x axis
+ :param numpy.ndarray y: The data corresponding to the y axis
+ :param numpy.ndarray triangles: The indices to make triangles
+ as a (Ntriangle, 3) array
+ :param numpy.ndarray color: color(s) as (npoints, 4) array
+ :param float alpha: Opacity as a float in [0., 1.]
+ :returns: The triangles' unique identifier used by the backend
+ """
+ return object()
+
+ def addShape(
+ self, x, y, shape, color, fill, overlay, linestyle, linewidth, gapcolor
+ ):
+ """Add an item (i.e. a shape) to the plot.
+
+ :param numpy.ndarray x: The X coords of the points of the shape
+ :param numpy.ndarray y: The Y coords of the points of the shape
+ :param str shape: Type of item to be drawn in
+ hline, polygon, rectangle, vline, polylines
+ :param str color: Color of the item
+ :param bool fill: True to fill the shape
+ :param bool overlay: True if item is an overlay, False otherwise
+ :param linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - (offset, (dash pattern))
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str gapcolor: Background color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
+ :returns: The handle used by the backend to univocally access the item
+ """
+ return object()
+
+ def addMarker(
+ self,
+ x: float | None,
+ y: float | None,
+ text: str | None,
+ color: str,
+ symbol: str | None,
+ linestyle: str | tuple[float, tuple[float, ...] | None],
+ linewidth: float,
+ constraint: Callable[[float, float], tuple[float, float]] | None,
+ yaxis: str,
+ font: qt.QFont,
+ bgcolor: RGBAColorType | None,
+ ) -> object:
+ """Add a point, vertical line or horizontal line marker to the plot.
+
+ :param x: Horizontal position of the marker in graph coordinates.
+ If None, the marker is a horizontal line.
+ :param y: Vertical position of the marker in graph coordinates.
+ If None, the marker is a vertical line.
+ :param text: Text associated to the marker (or None for no text)
+ :param color: Color to be used for instance 'blue', 'b', '#FF0000'
+ :param bgcolor: Text background color to be used for instance 'blue', 'b', '#FF0000'
+ :param symbol: Symbol representing the marker.
+ Only relevant for point markers where X and Y are not None.
+ Value in:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ :param linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - (offset, (dash pattern))
+ :param linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is moved.
+ It takes the coordinates of the current cursor position in the plot
+ as input and that returns the filtered coordinates.
+ :param yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :param font: QFont to use to render text
+ :return: Handle used by the backend to univocally access the marker
+ """
+ return object()
+
+ # Remove methods
+
+ def remove(self, item):
+ """Remove an existing item from the plot.
+
+ :param item: A backend specific item handle returned by a add* method
+ """
+ pass
+
+ # Interaction methods
+
+ def setGraphCursorShape(self, cursor):
+ """Set the cursor shape.
+
+ To override in interactive backends.
+
+ :param str cursor: Name of the cursor shape or None
+ """
+ pass
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ """Toggle the display of a crosshair cursor and set its attributes.
+
+ To override in interactive backends.
+
+ :param bool flag: Toggle the display of a crosshair cursor.
+ :param color: The color to use for the crosshair.
+ :type color: A string (either a predefined color name in colors.py
+ or "#RRGGBB")) or a 4 columns unsigned byte array.
+ :param int linewidth: The width of the lines of the crosshair.
+ :param linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - (offset, (dash pattern))
+
+ :type linestyle: None, one of the predefined styles or (offset, (dash pattern)).
+ """
+ pass
+
+ def getItemsFromBackToFront(self, condition=None):
+ """Returns the list of plot items order as rendered by the backend.
+
+ This is the order used for rendering.
+ By default, it takes into account overlays, z value and order of addition of items,
+ but backends can override it.
+
+ :param callable condition:
+ Callable taking an item as input and returning False for items to skip.
+ If None (default), no item is skipped.
+ :rtype: List[~silx.gui.plot.items.Item]
+ """
+ # Sort items: Overlays first, then others
+ # and in each category ordered by z and then by order of addition
+ # as content keeps this order.
+ content = self._plot.getItems()
+ if condition is not None:
+ content = [item for item in content if condition(item)]
+
+ return sorted(
+ content, key=lambda i: ((1 if i.isOverlay() else 0), i.getZValue())
+ )
+
+ def pickItem(self, x, y, item):
+ """Return picked indices if any, or None.
+
+ :param float x: The x pixel coord where to pick.
+ :param float y: The y pixel coord where to pick.
+ :param item: A backend item created with add* methods.
+ :return: None if item was not picked, else returns
+ picked indices information.
+ :rtype: Union[None,List]
+ """
+ return None
+
+ # Update curve
+
+ def setCurveColor(self, curve, color):
+ """Set the color of a curve.
+
+ :param curve: The curve handle
+ :param str color: The color to use.
+ """
+ pass
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ """Return the widget this backend is drawing to."""
+ return None
+
+ def postRedisplay(self):
+ """Trigger backend update and repaint."""
+ self.replot()
+
+ def replot(self):
+ """Redraw the plot."""
+ with self._plot._paintContext():
+ pass
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ """Save the graph to a file (or a StringIO)
+
+ At least "png", "svg" are supported.
+
+ :param fileName: Destination
+ :type fileName: String or StringIO or BytesIO
+ :param str fileFormat: String specifying the format
+ :param int dpi: The resolution to use or None.
+ """
+ pass
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ """Set the main title of the plot.
+
+ :param str title: Title associated to the plot
+ """
+ pass
+
+ def setGraphXLabel(self, label):
+ """Set the X axis label.
+
+ :param str label: label associated to the plot bottom X axis
+ """
+ pass
+
+ def setGraphYLabel(self, label, axis):
+ """Set the left Y axis label.
+
+ :param str label: label associated to the plot left Y axis
+ :param str axis: The axis for which to get the limits: left or right
+ """
+ pass
+
+ # Graph limits
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ """Set the limits of the X and Y axes at once.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param float y2min: minimum right axis value
+ :param float y2max: maximum right axis value
+ """
+ self.__xLimits = xmin, xmax
+ self.__yLimits["left"] = ymin, ymax
+ if y2min is not None and y2max is not None:
+ self.__yLimits["right"] = y2min, y2max
+
+ def getGraphXLimits(self):
+ """Get the graph X (bottom) limits.
+
+ :return: Minimum and maximum values of the X axis
+ """
+ return self.__xLimits
+
+ def setGraphXLimits(self, xmin, xmax):
+ """Set the limits of X axis.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ """
+ self.__xLimits = xmin, xmax
+
+ def getGraphYLimits(self, axis):
+ """Get the graph Y (left) limits.
+
+ :param str axis: The axis for which to get the limits: left or right
+ :return: Minimum and maximum values of the Y axis
+ """
+ return self.__yLimits[axis]
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ """Set the limits of the Y axis.
+
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param str axis: The axis for which to get the limits: left or right
+ """
+ self.__yLimits[axis] = ymin, ymax
+
+ # Graph axes
+
+ def getXAxisTimeZone(self):
+ """Returns tzinfo that is used if the X-Axis plots date-times.
+
+ None means the datetimes are interpreted as local time.
+
+ :rtype: datetime.tzinfo of None.
+ """
+ return self._xAxisTimeZone
+
+ def setXAxisTimeZone(self, tz):
+ """Sets tzinfo that is used if the X-Axis plots date-times.
+
+ Use None to let the datetimes be interpreted as local time.
+
+ :rtype: datetime.tzinfo of None.
+ """
+ self._xAxisTimeZone = tz
+
+ def isXAxisTimeSeries(self):
+ """Return True if the X-axis scale shows datetime objects.
+
+ :rtype: bool
+ """
+ return self.__xAxisTimeSeries
+
+ def setXAxisTimeSeries(self, isTimeSeries):
+ """Set whether the X-axis is a time series
+
+ :param bool flag: True to switch to time series, False for regular axis.
+ """
+ self.__xAxisTimeSeries = bool(isTimeSeries)
+
+ def setXAxisLogarithmic(self, flag):
+ """Set the X axis scale between linear and log.
+
+ :param bool flag: If True, the bottom axis will use a log scale
+ """
+ pass
+
+ def setYAxisLogarithmic(self, flag):
+ """Set the Y axis scale between linear and log.
+
+ :param bool flag: If True, the left axis will use a log scale
+ """
+ pass
+
+ def setYAxisInverted(self, flag):
+ """Invert the Y axis.
+
+ :param bool flag: If True, put the vertical axis origin on the top
+ """
+ self.__yAxisInverted = bool(flag)
+
+ def isYAxisInverted(self):
+ """Return True if left Y axis is inverted, False otherwise."""
+ return self.__yAxisInverted
+
+ def isYRightAxisVisible(self) -> bool:
+ """Return True if the Y axis on the right side of the plot is visible"""
+ return False
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self.__keepDataAspectRatio
+
+ def setKeepDataAspectRatio(self, flag):
+ """Set whether to keep data aspect ratio or not.
+
+ :param flag: True to respect data aspect ratio
+ :type flag: Boolean, default True
+ """
+ self.__keepDataAspectRatio = bool(flag)
+
+ def setGraphGrid(self, which):
+ """Set grid.
+
+ :param which: None to disable grid, 'major' for major grid,
+ 'both' for major and minor grid
+ """
+ pass
+
+ # Data <-> Pixel coordinates conversion
+
+ def dataToPixel(self, x, y, axis):
+ """Convert a position in data space to a position in pixels
+ in the widget.
+
+ :param x: The X coordinate in data space.
+ :type x: float or sequence of float
+ :param y: The Y coordinate in data space.
+ :type y: float or sequence of float
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :returns: The corresponding position in pixels or
+ None if the data position is not in the displayed area.
+ :rtype: A tuple of 2 floats: (xPixel, yPixel) or None.
+ """
+ raise NotImplementedError()
+
+ def pixelToData(self, x, y, axis):
+ """Convert a position in pixels in the widget to a position in
+ the data space.
+
+ :param float x: The X coordinate in pixels.
+ :param float y: The Y coordinate in pixels.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :returns: The corresponding position in data space or
+ None if the pixel position is not in the plot area.
+ :rtype: A tuple of 2 floats: (xData, yData) or None.
+ """
+ raise NotImplementedError()
+
+ def getPlotBoundsInPixels(self):
+ """Plot area bounds in widget coordinates in pixels.
+
+ :return: bounds as a 4-tuple of int: (left, top, width, height)
+ """
+ raise NotImplementedError()
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ """Set the size of plot margins as ratios.
+
+ Values are expected in [0., 1.]
+
+ :param float left:
+ :param float top:
+ :param float right:
+ :param float bottom:
+ """
+ pass
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ """Set foreground and grid colors used to display this widget.
+
+ :param List[float] foregroundColor: RGBA foreground color of the widget
+ :param List[float] gridColor: RGBA grid color of the data view
+ """
+ pass
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ """Set background colors used to display this widget.
+
+ :param List[float] backgroundColor: RGBA background color of the widget
+ :param Union[Tuple[float],None] dataBackgroundColor:
+ RGBA background color of the data view
+ """
+ pass
diff --git a/src/silx/gui/plot/backends/BackendMatplotlib.py b/src/silx/gui/plot/backends/BackendMatplotlib.py
new file mode 100755
index 0000000..facb63c
--- /dev/null
+++ b/src/silx/gui/plot/backends/BackendMatplotlib.py
@@ -0,0 +1,1726 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Matplotlib Plot backend."""
+
+from __future__ import annotations
+
+__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+
+import logging
+import datetime as dt
+from typing import Tuple, Union
+import numpy
+
+from packaging.version import Version
+
+
+_logger = logging.getLogger(__name__)
+
+
+from ... import qt
+
+# First of all init matplotlib and set its backend
+from ...utils.matplotlib import (
+ DefaultTickFormatter,
+ FigureCanvasQTAgg,
+ qFontToFontProperties,
+)
+import matplotlib
+from matplotlib.container import Container
+from matplotlib.figure import Figure
+from matplotlib.patches import Rectangle, Polygon
+from matplotlib.image import AxesImage
+from matplotlib.backend_bases import MouseEvent
+from matplotlib.lines import Line2D
+from matplotlib.text import Text
+from matplotlib.collections import PathCollection, LineCollection
+from matplotlib.ticker import Formatter, Locator
+from matplotlib.tri import Triangulation
+from matplotlib.collections import TriMesh
+from matplotlib import path as mpath
+
+from . import BackendBase
+from .. import items
+from .._utils import FLOAT32_MINPOS
+from .._utils.dtime_ticklayout import (
+ calcTicks,
+ formatDatetimes,
+ timestamp,
+)
+from ...qt import inspect as qt_inspect
+from .... import config
+from silx.gui.colors import RGBAColorType
+
+_PATCH_LINESTYLE = {
+ "-": "solid",
+ "--": "dashed",
+ "-.": "dashdot",
+ ":": "dotted",
+ "": "solid",
+ None: "solid",
+}
+"""Patches do not uses the same matplotlib syntax"""
+
+_MARKER_PATHS = {}
+"""Store cached extra marker paths"""
+
+_SPECIAL_MARKERS = {
+ "tickleft": 0,
+ "tickright": 1,
+ "tickup": 2,
+ "tickdown": 3,
+ "caretleft": 4,
+ "caretright": 5,
+ "caretup": 6,
+ "caretdown": 7,
+}
+
+
+def normalize_linestyle(linestyle):
+ """Normalize known old-style linestyle, else return the provided value."""
+ return _PATCH_LINESTYLE.get(linestyle, linestyle)
+
+
+def get_path_from_symbol(symbol):
+ """Get the path representation of a symbol, else None if
+ it is not provided.
+
+ :param str symbol: Symbol description used by silx
+ :rtype: Union[None,matplotlib.path.Path]
+ """
+ if symbol == "\u2665":
+ path = _MARKER_PATHS.get(symbol, None)
+ if path is not None:
+ return path
+ vertices = numpy.array(
+ [
+ [0, -99],
+ [31, -73],
+ [47, -55],
+ [55, -46],
+ [63, -37],
+ [94, -2],
+ [94, 33],
+ [94, 69],
+ [71, 89],
+ [47, 89],
+ [24, 89],
+ [8, 74],
+ [0, 58],
+ [-8, 74],
+ [-24, 89],
+ [-47, 89],
+ [-71, 89],
+ [-94, 69],
+ [-94, 33],
+ [-94, -2],
+ [-63, -37],
+ [-55, -46],
+ [-47, -55],
+ [-31, -73],
+ [0, -99],
+ [0, -99],
+ ]
+ )
+ codes = [mpath.Path.CURVE4] * len(vertices)
+ codes[0] = mpath.Path.MOVETO
+ codes[-1] = mpath.Path.CLOSEPOLY
+ path = mpath.Path(vertices, codes)
+ _MARKER_PATHS[symbol] = path
+ return path
+ return None
+
+
+class NiceDateLocator(Locator):
+ """
+ Matplotlib Locator that uses Nice Numbers algorithm (adapted to dates)
+ to find the tick locations. This results in the same number behaviour
+ as when using the silx Open GL backend.
+
+ Expects the data to be posix timestampes (i.e. seconds since 1970)
+ """
+
+ def __init__(self, numTicks=5, tz=None):
+ """
+ :param numTicks: target number of ticks
+ :param datetime.tzinfo tz: optional time zone. None is local time.
+ """
+ super(NiceDateLocator, self).__init__()
+ self.numTicks = numTicks
+
+ self._spacing = None
+ self._unit = None
+ self.tz = tz
+
+ @property
+ def spacing(self):
+ """The current spacing. Will be updated when new tick value are made"""
+ return self._spacing
+
+ @property
+ def unit(self):
+ """The current DtUnit. Will be updated when new tick value are made"""
+ return self._unit
+
+ def __call__(self):
+ """Return the locations of the ticks"""
+ vmin, vmax = self.axis.get_view_interval()
+ return self.tick_values(vmin, vmax)
+
+ def tick_values(self, vmin, vmax):
+ """Calculates tick values"""
+ if vmax < vmin:
+ vmin, vmax = vmax, vmin
+
+ # vmin and vmax should be timestamps (i.e. seconds since 1 Jan 1970)
+ try:
+ dtMin = dt.datetime.fromtimestamp(vmin, tz=self.tz)
+ dtMax = dt.datetime.fromtimestamp(vmax, tz=self.tz)
+ except ValueError:
+ _logger.warning("Data range cannot be displayed with time axis")
+ return []
+
+ dtTicks, self._spacing, self._unit = calcTicks(dtMin, dtMax, self.numTicks)
+
+ # Convert datetime back to time stamps.
+ ticks = [timestamp(dtTick) for dtTick in dtTicks]
+ return ticks
+
+
+class NiceAutoDateFormatter(Formatter):
+ """
+ Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the
+ best possible formats given the locators current spacing an date unit.
+ """
+
+ def __init__(self, locator, tz=None):
+ """
+ :param niceDateLocator: a NiceDateLocator object
+ :param datetime.tzinfo tz: optional time zone. None is local time.
+ """
+ super(NiceAutoDateFormatter, self).__init__()
+ self.locator = locator
+ self.tz = tz
+
+ def __call__(self, x, pos=None):
+ """Return the format for tick val *x* at position *pos*
+ Expects x to be a POSIX timestamp (seconds since 1 Jan 1970)
+ """
+ datetime = dt.datetime.fromtimestamp(x, tz=self.tz)
+ return formatDatetimes(
+ [datetime],
+ self.locator.spacing,
+ self.locator.unit,
+ )[datetime]
+
+ def format_ticks(self, values):
+ return tuple(
+ formatDatetimes(
+ [dt.datetime.fromtimestamp(value, tz=self.tz) for value in values],
+ self.locator.spacing,
+ self.locator.unit,
+ ).values()
+ )
+
+
+class _PickableContainer(Container):
+ """Artists container with a :meth:`contains` method"""
+
+ def __init__(self, *args, **kwargs):
+ Container.__init__(self, *args, **kwargs)
+ self.__zorder = None
+
+ @property
+ def axes(self):
+ """Mimin Artist.axes"""
+ for child in self.get_children():
+ if hasattr(child, "axes"):
+ return child.axes
+ return None
+
+ def draw(self, *args, **kwargs):
+ """artist-like draw to broadcast draw to children"""
+ for child in self.get_children():
+ child.draw(*args, **kwargs)
+
+ def get_zorder(self):
+ """Mimic Artist.get_zorder"""
+ return self.__zorder
+
+ def set_zorder(self, z):
+ """Mimic Artist.set_zorder to broadcast to children"""
+ if z != self.__zorder:
+ self.__zorder = z
+ for child in self.get_children():
+ child.set_zorder(z)
+
+ def contains(self, mouseevent):
+ """Mimic Artist.contains, and call it on all children.
+
+ :param mouseevent:
+ :return: Picking status and associated information as a dict
+ :rtype: (bool,dict)
+ """
+ # Goes through children from front to back and return first picked one.
+ for child in reversed(self.get_children()):
+ picked, info = child.contains(mouseevent)
+ if picked:
+ return picked, info
+ return False, {}
+
+
+class _TextWithOffset(Text):
+ """Text object which can be displayed at a specific position
+ of the plot, but with a pixel offset"""
+
+ def __init__(self, *args, **kwargs):
+ Text.__init__(self, *args, **kwargs)
+ self.pixel_offset = (0, 0)
+ self.__cache = None
+
+ def draw(self, renderer):
+ self.__cache = None
+ return Text.draw(self, renderer)
+
+ def __get_xy(self):
+ if self.__cache is not None:
+ return self.__cache
+
+ align = self.get_horizontalalignment()
+ if align == "left":
+ xoffset = self.pixel_offset[0]
+ elif align == "right":
+ xoffset = -self.pixel_offset[0]
+ else:
+ xoffset = 0
+
+ align = self.get_verticalalignment()
+ if align == "top":
+ yoffset = -self.pixel_offset[1]
+ elif align == "bottom":
+ yoffset = self.pixel_offset[1]
+ else:
+ yoffset = 0
+
+ trans = self.get_transform()
+ x = super(_TextWithOffset, self).convert_xunits(self._x)
+ y = super(_TextWithOffset, self).convert_xunits(self._y)
+ pos = x, y
+
+ try:
+ invtrans = trans.inverted()
+ except numpy.linalg.LinAlgError:
+ # Cannot inverse transform, fallback: pos without offset
+ self.__cache = None
+ return pos
+
+ proj = trans.transform_point(pos)
+ proj = proj + numpy.array((xoffset, yoffset))
+ pos = invtrans.transform_point(proj)
+ self.__cache = pos
+ return pos
+
+ def convert_xunits(self, x):
+ """Return the pixel position of the annotated point."""
+ return self.__get_xy()[0]
+
+ def convert_yunits(self, y):
+ """Return the pixel position of the annotated point."""
+ return self.__get_xy()[1]
+
+
+class _MarkerContainer(_PickableContainer):
+ """Marker artists container supporting draw/remove and text position update
+
+ :param artists:
+ Iterable with either one Line2D or a Line2D and a Text.
+ The use of an iterable if enforced by Container being
+ a subclass of tuple that defines a specific __new__.
+ :param x: X coordinate of the marker (None for horizontal lines)
+ :param y: Y coordinate of the marker (None for vertical lines)
+ """
+
+ def __init__(self, artists, symbol, x, y, yAxis):
+ self.line = artists[0]
+ self.text = artists[1] if len(artists) > 1 else None
+ self.symbol = symbol
+ self.x = x
+ self.y = y
+ self.yAxis = yAxis
+
+ _PickableContainer.__init__(self, artists)
+
+ def draw(self, *args, **kwargs):
+ """artist-like draw to broadcast draw to line and text"""
+ self.line.draw(*args, **kwargs)
+ if self.text is not None:
+ self.text.draw(*args, **kwargs)
+
+ def updateMarkerText(self, xmin, xmax, ymin, ymax, yinverted):
+ """Update marker text position and visibility according to plot limits
+
+ :param xmin: X axis lower limit
+ :param xmax: X axis upper limit
+ :param ymin: Y axis lower limit
+ :param ymax: Y axis upper limit
+ :param yinverted: True if the y axis is inverted
+ """
+ if self.text is not None:
+ visible = (self.x is None or xmin <= self.x <= xmax) and (
+ self.y is None or ymin <= self.y <= ymax
+ )
+ self.text.set_visible(visible)
+
+ if self.x is not None and self.y is not None:
+ if self.symbol is None:
+ valign = "baseline"
+ else:
+ if yinverted:
+ valign = "bottom"
+ else:
+ valign = "top"
+ self.text.set_verticalalignment(valign)
+
+ elif self.y is None: # vertical line
+ # Always display it on top
+ center = (ymax + ymin) * 0.5
+ pos = (ymax - ymin) * 0.5 * 0.99
+ if yinverted:
+ pos = -pos
+ self.text.set_y(center + pos)
+
+ elif self.x is None: # Horizontal line
+ delta = abs(xmax - xmin)
+ if xmin > xmax:
+ xmax = xmin
+ xmax -= 0.005 * delta
+ self.text.set_x(xmax)
+
+ def contains(self, mouseevent):
+ """Mimic Artist.contains, and call it on the line Artist.
+
+ :param mouseevent:
+ :return: Picking status and associated information as a dict
+ :rtype: (bool,dict)
+ """
+ return self.line.contains(mouseevent)
+
+
+class SecondEdgeColorPatchMixIn:
+ """Mix-in class to add a second color for patches with dashed lines"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._second_edgecolor = None
+
+ def set_second_edgecolor(self, color):
+ """Set the second color used to fill dashed edges"""
+ self._second_edgecolor = color
+
+ def get_second_edgecolor(self):
+ """Returns the second color used to fill dashed edges"""
+ return self._second_edgecolor
+
+ def draw(self, renderer):
+ linestyle = self.get_linestyle()
+ if linestyle == "solid" or self.get_second_edgecolor() is None:
+ super().draw(renderer)
+ return
+
+ edgecolor = self.get_edgecolor()
+ hatch = self.get_hatch()
+
+ self.set_linestyle("solid")
+ self.set_edgecolor(self.get_second_edgecolor())
+ self.set_hatch(None)
+ super().draw(renderer)
+
+ self.set_linestyle(linestyle)
+ self.set_edgecolor(edgecolor)
+ self.set_hatch(hatch)
+ super().draw(renderer)
+
+
+class Rectangle2EdgeColor(SecondEdgeColorPatchMixIn, Rectangle):
+ """Rectangle patch with a second edge color for dashed line"""
+
+
+class Polygon2EdgeColor(SecondEdgeColorPatchMixIn, Polygon):
+ """Polygon patch with a second edge color for dashed line"""
+
+
+class Image(AxesImage):
+ """An AxesImage with a fast path for uint8 RGBA images.
+
+ :param List[float] silx_origin: (ox, oy) Offset of the image.
+ :param List[float] silx_scale: (sx, sy) Scale of the image.
+ """
+
+ def __init__(self, *args, silx_origin=(0.0, 0.0), silx_scale=(1.0, 1.0), **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__silx_origin = silx_origin
+ self.__silx_scale = silx_scale
+
+ def contains(self, mouseevent):
+ """Overridden to fill 'ind' with row and column"""
+ inside, info = super().contains(mouseevent)
+ if inside:
+ x, y = mouseevent.xdata, mouseevent.ydata
+ ox, oy = self.__silx_origin
+ sx, sy = self.__silx_scale
+ height, width = self.get_size()
+ column = numpy.clip(int((x - ox) / sx), 0, width - 1)
+ row = numpy.clip(int((y - oy) / sy), 0, height - 1)
+ info["ind"] = (row,), (column,)
+ return inside, info
+
+ def set_data(self, A):
+ """Overridden to add a fast path for RGBA unit8 images"""
+ A = numpy.array(A, copy=False)
+ if A.ndim != 3 or A.shape[2] != 4 or A.dtype != numpy.uint8:
+ super(Image, self).set_data(A)
+ else:
+ # Call AxesImage.set_data with small data to set attributes
+ super(Image, self).set_data(numpy.zeros((2, 2, 4), dtype=A.dtype))
+ self._A = A # Override stored data
+
+
+class BackendMatplotlib(BackendBase.BackendBase):
+ """Base class for Matplotlib backend without a FigureCanvas.
+
+ For interactive on screen plot, see :class:`BackendMatplotlibQt`.
+
+ See :class:`BackendBase.BackendBase` for public API documentation.
+ """
+
+ def __init__(self, plot, parent=None):
+ super(BackendMatplotlib, self).__init__(plot, parent)
+
+ # matplotlib is handling keep aspect ratio at draw time
+ # When keep aspect ratio is on, and one changes the limits and
+ # ask them *before* next draw has been performed he will get the
+ # limits without applying keep aspect ratio.
+ # This attribute is used to ensure consistent values returned
+ # when getting the limits at the expense of a replot
+ self._dirtyLimits = True
+ self._axesDisplayed = True
+ self._matplotlibVersion = Version(matplotlib.__version__)
+
+ self.fig = Figure(
+ tight_layout=config._MPL_TIGHT_LAYOUT,
+ )
+ self.fig.set_facecolor("w")
+
+ if config._MPL_TIGHT_LAYOUT:
+ self.ax = self.fig.add_subplot(label="left")
+ else:
+ self.ax = self.fig.add_axes([0.15, 0.15, 0.75, 0.75], label="left")
+ self.ax2 = self.ax.twinx()
+ self.ax2.set_label("right")
+ # Make sure background of Axes is displayed
+ self.ax2.patch.set_visible(False)
+ self.ax.patch.set_visible(True)
+
+ # Set axis zorder=0.5 so grid is displayed at 0.5
+ self.ax.set_axisbelow(True)
+
+ # Configure axes tick label formatter
+ for axis in (self.ax.yaxis, self.ax.xaxis, self.ax2.yaxis, self.ax2.xaxis):
+ axis.set_major_formatter(DefaultTickFormatter())
+
+ self.ax2.set_autoscaley_on(True)
+
+ # this works but the figure color is left
+ if self._matplotlibVersion < Version("2"):
+ self.ax.set_axis_bgcolor("none")
+ else:
+ self.ax.set_facecolor("none")
+ self.fig.sca(self.ax)
+
+ self._background = None
+
+ self._colormaps = {}
+
+ self._graphCursor = tuple()
+
+ self._enableAxis("right", False)
+ self._isXAxisTimeSeries = False
+
+ def getItemsFromBackToFront(self, condition=None):
+ """Order as BackendBase + take into account matplotlib Axes structure"""
+
+ def axesOrder(item):
+ if item.isOverlay():
+ return 2
+ elif isinstance(item, items.YAxisMixIn) and item.getYAxis() == "right":
+ return 1
+ else:
+ return 0
+
+ return sorted(
+ BackendBase.BackendBase.getItemsFromBackToFront(self, condition=condition),
+ key=axesOrder,
+ )
+
+ def _overlayItems(self):
+ """Generator of backend renderer for overlay items"""
+ for item in self._plot.getItems():
+ if (
+ item.isOverlay()
+ and item.isVisible()
+ and item._backendRenderer is not None
+ ):
+ yield item._backendRenderer
+
+ def _hasOverlays(self):
+ """Returns whether there is an overlay layer or not.
+
+ The overlay layers contains overlay items and the crosshair.
+
+ :rtype: bool
+ """
+ if self._graphCursor:
+ return True # There is the crosshair
+
+ for item in self._overlayItems():
+ return True # There is at least one overlay item
+ return False
+
+ # Add methods
+
+ def _getMarkerFromSymbol(self, symbol):
+ """Returns a marker that can be displayed by matplotlib.
+
+ :param str symbol: A symbol description used by silx
+ :rtype: Union[str,int,matplotlib.path.Path]
+ """
+ path = get_path_from_symbol(symbol)
+ if path is not None:
+ return path
+ num = _SPECIAL_MARKERS.get(symbol, None)
+ if num is not None:
+ return num
+ # This symbol must be supported by matplotlib
+ return symbol
+
+ def addCurve(
+ self,
+ x,
+ y,
+ color,
+ gapcolor,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ xerror,
+ yerror,
+ fill,
+ alpha,
+ symbolsize,
+ baseline,
+ ):
+ for parameter in (
+ x,
+ y,
+ color,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ fill,
+ alpha,
+ symbolsize,
+ ):
+ assert parameter is not None
+ assert yaxis in ("left", "right")
+
+ if len(color) == 4 and type(color[3]) in [type(1), numpy.uint8, numpy.int8]:
+ color = numpy.array(color, dtype=numpy.float64) / 255.0
+
+ if yaxis == "right":
+ axes = self.ax2
+ self._enableAxis("right", True)
+ else:
+ axes = self.ax
+
+ pickradius = 3
+
+ artists = [] # All the artists composing the curve
+
+ # First add errorbars if any so they are behind the curve
+ if xerror is not None or yerror is not None:
+ if hasattr(color, "dtype") and len(color) == len(x):
+ errorbarColor = "k"
+ else:
+ errorbarColor = color
+
+ # Nx1 error array deprecated in matplotlib >=3.1 (removed in 3.3)
+ if (
+ isinstance(xerror, numpy.ndarray)
+ and xerror.ndim == 2
+ and xerror.shape[1] == 1
+ ):
+ xerror = numpy.ravel(xerror)
+ if (
+ isinstance(yerror, numpy.ndarray)
+ and yerror.ndim == 2
+ and yerror.shape[1] == 1
+ ):
+ yerror = numpy.ravel(yerror)
+
+ errorbars = axes.errorbar(
+ x, y, xerr=xerror, yerr=yerror, linestyle=" ", color=errorbarColor
+ )
+ artists += list(errorbars.get_children())
+
+ if hasattr(color, "dtype") and len(color) == len(x):
+ # scatter plot
+ if color.dtype not in [numpy.float32, numpy.float64]:
+ actualColor = color / 255.0
+ else:
+ actualColor = color
+
+ if linestyle not in ["", " ", None]:
+ # scatter plot with an actual line ...
+ # we need to assign a color ...
+ curveList = axes.plot(
+ x,
+ y,
+ linestyle=linestyle,
+ color=actualColor[0],
+ linewidth=linewidth,
+ picker=True,
+ pickradius=pickradius,
+ marker=None,
+ )
+ artists += list(curveList)
+
+ marker = self._getMarkerFromSymbol(symbol)
+ scatter = axes.scatter(
+ x,
+ y,
+ color=actualColor,
+ marker=marker,
+ picker=True,
+ pickradius=pickradius,
+ s=symbolsize**2,
+ )
+ artists.append(scatter)
+
+ if fill:
+ if baseline is None:
+ _baseline = FLOAT32_MINPOS
+ else:
+ _baseline = baseline
+ artists.append(
+ axes.fill_between(
+ x, _baseline, y, facecolor=actualColor[0], linestyle=""
+ )
+ )
+
+ else: # Curve
+ curveList = axes.plot(
+ x,
+ y,
+ linestyle=linestyle,
+ color=color,
+ linewidth=linewidth,
+ marker=symbol,
+ picker=True,
+ pickradius=pickradius,
+ markersize=symbolsize,
+ )
+
+ if gapcolor is not None and self._matplotlibVersion >= Version("3.6.0"):
+ for line2d in curveList:
+ line2d.set_gapcolor(gapcolor)
+ artists += list(curveList)
+
+ if fill:
+ if baseline is None:
+ _baseline = FLOAT32_MINPOS
+ else:
+ _baseline = baseline
+ artists.append(axes.fill_between(x, _baseline, y, facecolor=color))
+
+ for artist in artists:
+ if alpha < 1:
+ artist.set_alpha(alpha)
+
+ return _PickableContainer(artists)
+
+ def addImage(self, data, origin, scale, colormap, alpha):
+ # Non-uniform image
+ # http://wiki.scipy.org/Cookbook/Histograms
+ # Non-linear axes
+ # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
+ for parameter in (data, origin, scale):
+ assert parameter is not None
+
+ origin = float(origin[0]), float(origin[1])
+ scale = float(scale[0]), float(scale[1])
+ height, width = data.shape[0:2]
+
+ # All image are shown as RGBA image
+ image = Image(
+ self.ax,
+ interpolation="nearest",
+ picker=True,
+ origin="lower",
+ silx_origin=origin,
+ silx_scale=scale,
+ )
+
+ if alpha < 1:
+ image.set_alpha(alpha)
+
+ # Set image extent
+ xmin = origin[0]
+ xmax = xmin + scale[0] * width
+ if scale[0] < 0.0:
+ xmin, xmax = xmax, xmin
+
+ ymin = origin[1]
+ ymax = ymin + scale[1] * height
+ if scale[1] < 0.0:
+ ymin, ymax = ymax, ymin
+
+ image.set_extent((xmin, xmax, ymin, ymax))
+
+ # Set image data
+ if scale[0] < 0.0 or scale[1] < 0.0:
+ # For negative scale, step by -1
+ xstep = 1 if scale[0] >= 0.0 else -1
+ ystep = 1 if scale[1] >= 0.0 else -1
+ data = data[::ystep, ::xstep]
+
+ if data.ndim == 2: # Data image, convert to RGBA image
+ data = colormap.applyToData(data)
+ elif data.dtype == numpy.uint16:
+ # Normalize uint16 data to have a similar behavior as opengl backend
+ data = data.astype(numpy.float32)
+ data /= 65535
+
+ image.set_data(data)
+ self.ax.add_artist(image)
+ return image
+
+ def addTriangles(self, x, y, triangles, color, alpha):
+ for parameter in (x, y, triangles, color, alpha):
+ assert parameter is not None
+
+ color = numpy.array(color, copy=False)
+ assert color.ndim == 2 and len(color) == len(x)
+
+ if color.dtype not in [numpy.float32, numpy.float64]:
+ color = color.astype(numpy.float32) / 255.0
+
+ collection = TriMesh(
+ Triangulation(x, y, triangles), alpha=alpha, pickradius=0
+ ) # 0 enables picking on filled triangle
+ collection.set_color(color)
+ self.ax.add_collection(collection)
+
+ return collection
+
+ def addShape(
+ self, x, y, shape, color, fill, overlay, linestyle, linewidth, gapcolor
+ ):
+ if gapcolor is not None and shape not in (
+ "rectangle",
+ "polygon",
+ "polylines",
+ ):
+ _logger.warning(
+ "gapcolor not implemented for %s with matplotlib backend", shape
+ )
+ xView = numpy.array(x, copy=False)
+ yView = numpy.array(y, copy=False)
+
+ linestyle = normalize_linestyle(linestyle)
+
+ if shape == "line":
+ item = self.ax.plot(
+ x, y, color=color, linestyle=linestyle, linewidth=linewidth, marker=None
+ )[0]
+
+ elif shape == "hline":
+ if hasattr(y, "__len__"):
+ y = y[-1]
+ item = self.ax.axhline(
+ y, color=color, linestyle=linestyle, linewidth=linewidth
+ )
+
+ elif shape == "vline":
+ if hasattr(x, "__len__"):
+ x = x[-1]
+ item = self.ax.axvline(
+ x, color=color, linestyle=linestyle, linewidth=linewidth
+ )
+
+ elif shape == "rectangle":
+ xMin = numpy.nanmin(xView)
+ xMax = numpy.nanmax(xView)
+ yMin = numpy.nanmin(yView)
+ yMax = numpy.nanmax(yView)
+ w = xMax - xMin
+ h = yMax - yMin
+ item = Rectangle2EdgeColor(
+ xy=(xMin, yMin),
+ width=w,
+ height=h,
+ fill=False,
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth,
+ )
+ item.set_second_edgecolor(gapcolor)
+
+ if fill:
+ item.set_hatch(".")
+
+ self.ax.add_patch(item)
+
+ elif shape in ("polygon", "polylines"):
+ points = numpy.array((xView, yView)).T
+ if shape == "polygon":
+ closed = True
+ else: # shape == 'polylines'
+ closed = numpy.all(numpy.equal(points[0], points[-1]))
+ item = Polygon2EdgeColor(
+ points,
+ closed=closed,
+ fill=False,
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth,
+ )
+ item.set_second_edgecolor(gapcolor)
+
+ if fill and shape == "polygon":
+ item.set_hatch("/")
+
+ self.ax.add_patch(item)
+
+ else:
+ raise NotImplementedError("Unsupported item shape %s" % shape)
+
+ if overlay:
+ item.set_animated(True)
+
+ return item
+
+ def addMarker(
+ self,
+ x,
+ y,
+ text,
+ color,
+ symbol,
+ linestyle,
+ linewidth,
+ constraint,
+ yaxis,
+ font,
+ bgcolor: RGBAColorType | None,
+ ):
+ textArtist = None
+ fontProperties = None if font is None else qFontToFontProperties(font)
+
+ xmin, xmax = self.getGraphXLimits()
+ ymin, ymax = self.getGraphYLimits(axis=yaxis)
+
+ if yaxis == "left":
+ ax = self.ax
+ elif yaxis == "right":
+ ax = self.ax2
+ else:
+ assert False
+
+ if bgcolor is None:
+ bgcolor = "none"
+
+ marker = self._getMarkerFromSymbol(symbol)
+ if x is not None and y is not None:
+ line = ax.plot(
+ x, y, linestyle=" ", color=color, marker=marker, markersize=10.0
+ )[-1]
+
+ if text is not None:
+ textArtist = _TextWithOffset(
+ x,
+ y,
+ text,
+ color=color,
+ backgroundcolor=bgcolor,
+ horizontalalignment="left",
+ fontproperties=fontProperties,
+ )
+ if symbol is not None:
+ textArtist.pixel_offset = 10, 3
+ elif x is not None:
+ line = ax.axvline(x, color=color, linewidth=linewidth, linestyle=linestyle)
+ if text is not None:
+ # Y position will be updated in updateMarkerText call
+ textArtist = _TextWithOffset(
+ x,
+ 1.0,
+ text,
+ color=color,
+ backgroundcolor=bgcolor,
+ horizontalalignment="left",
+ verticalalignment="top",
+ fontproperties=fontProperties,
+ )
+ textArtist.pixel_offset = 5, 3
+ elif y is not None:
+ line = ax.axhline(y, color=color, linewidth=linewidth, linestyle=linestyle)
+
+ if text is not None:
+ # X position will be updated in updateMarkerText call
+ textArtist = _TextWithOffset(
+ 1.0,
+ y,
+ text,
+ color=color,
+ backgroundcolor=bgcolor,
+ horizontalalignment="right",
+ verticalalignment="top",
+ fontproperties=fontProperties,
+ )
+ textArtist.pixel_offset = 5, 3
+ else:
+ raise RuntimeError("A marker must at least have one coordinate")
+
+ line.set_picker(True)
+ line.set_pickradius(5)
+
+ # All markers are overlays
+ line.set_animated(True)
+ if textArtist is not None:
+ ax.add_artist(textArtist)
+ textArtist.set_animated(True)
+
+ artists = [line] if textArtist is None else [line, textArtist]
+ container = _MarkerContainer(artists, symbol, x, y, yaxis)
+ container.updateMarkerText(xmin, xmax, ymin, ymax, self.isYAxisInverted())
+
+ return container
+
+ def _updateMarkers(self):
+ xmin, xmax = self.ax.get_xbound()
+ ymin1, ymax1 = self.ax.get_ybound()
+ ymin2, ymax2 = self.ax2.get_ybound()
+ yinverted = self.isYAxisInverted()
+ for item in self._overlayItems():
+ if isinstance(item, _MarkerContainer):
+ if item.yAxis == "left":
+ item.updateMarkerText(xmin, xmax, ymin1, ymax1, yinverted)
+ else:
+ item.updateMarkerText(xmin, xmax, ymin2, ymax2, yinverted)
+
+ # Remove methods
+
+ def remove(self, item):
+ try:
+ item.remove()
+ except ValueError:
+ pass # Already removed e.g., in set[X|Y]AxisLogarithmic
+
+ # Interaction methods
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ if flag:
+ lineh = self.ax.axhline(
+ self.ax.get_ybound()[0],
+ visible=False,
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle,
+ )
+ lineh.set_animated(True)
+
+ linev = self.ax.axvline(
+ self.ax.get_xbound()[0],
+ visible=False,
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle,
+ )
+ linev.set_animated(True)
+
+ self._graphCursor = lineh, linev
+ else:
+ if self._graphCursor:
+ lineh, linev = self._graphCursor
+ lineh.remove()
+ linev.remove()
+ self._graphCursor = tuple()
+
+ # Active curve
+
+ def setCurveColor(self, curve, color):
+ # Store Line2D and PathCollection
+ for artist in curve.get_children():
+ if isinstance(artist, (Line2D, LineCollection)):
+ artist.set_color(color)
+ elif isinstance(artist, PathCollection):
+ artist.set_facecolors(color)
+ artist.set_edgecolors(color)
+ else:
+ _logger.warning("setActiveCurve ignoring artist %s", str(artist))
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ return self.fig.canvas
+
+ def _enableAxis(self, axis, flag=True):
+ """Show/hide Y axis
+
+ :param str axis: Axis name: 'left' or 'right'
+ :param bool flag: Default, True
+ """
+ assert axis in ("right", "left")
+ axes = self.ax2 if axis == "right" else self.ax
+ axes.get_yaxis().set_visible(flag)
+
+ def replot(self):
+ """Do not perform rendering.
+
+ Override in subclass to actually draw something.
+ """
+ with self._plot._paintContext():
+ self._replot()
+
+ def _replot(self):
+ """Call from subclass :meth:`replot` to handle updates"""
+ # TODO images, markers? scatter plot? move in remove?
+ # Right Y axis only support curve for now
+ # Hide right Y axis if no line is present
+ self._dirtyLimits = False
+ if not self.ax2.lines:
+ self._enableAxis("right", False)
+
+ def _drawOverlays(self):
+ """Draw overlays if any."""
+
+ def condition(item):
+ return (
+ item.isVisible()
+ and item._backendRenderer is not None
+ and item.isOverlay()
+ )
+
+ for item in self.getItemsFromBackToFront(condition=condition):
+ if isinstance(item, items.YAxisMixIn) and item.getYAxis() == "right":
+ axes = self.ax2
+ else:
+ axes = self.ax
+ axes.draw_artist(item._backendRenderer)
+
+ for item in self._graphCursor:
+ self.ax.draw_artist(item)
+
+ def updateZOrder(self):
+ """Reorder all items with z order from 0 to 1"""
+ items = self.getItemsFromBackToFront(
+ lambda item: item.isVisible() and item._backendRenderer is not None
+ )
+ count = len(items)
+ for index, item in enumerate(items):
+ if item.getZValue() < 0.5:
+ # Make sure matplotlib z order is below the grid (with z=0.5)
+ zorder = 0.5 * index / count
+ else: # Make sure matplotlib z order is above the grid (> 0.5)
+ zorder = 1.0 + index / count
+ if zorder != item._backendRenderer.get_zorder():
+ item._backendRenderer.set_zorder(zorder)
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ self.updateZOrder()
+
+ # fileName can be also a StringIO or file instance
+ if dpi is not None:
+ self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
+ else:
+ self.fig.savefig(fileName, format=fileFormat)
+ self._plot._setDirtyPlot()
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ self.ax.set_title(title)
+
+ def setGraphXLabel(self, label):
+ self.ax.set_xlabel(label)
+
+ def setGraphYLabel(self, label, axis):
+ axes = self.ax if axis == "left" else self.ax2
+ axes.set_ylabel(label)
+
+ # Graph limits
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ # Let matplotlib taking care of keep aspect ratio if any
+ self._dirtyLimits = True
+ self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
+
+ if y2min is not None and y2max is not None:
+ if not self.isYAxisInverted():
+ self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
+ else:
+ self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))
+
+ if not self.isYAxisInverted():
+ self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
+ else:
+ self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))
+
+ self._updateMarkers()
+
+ def getGraphXLimits(self):
+ if self._dirtyLimits and self.isKeepDataAspectRatio():
+ self.ax.apply_aspect()
+ self.ax2.apply_aspect()
+ self._dirtyLimits = False
+ return self.ax.get_xbound()
+
+ def setGraphXLimits(self, xmin, xmax):
+ self._dirtyLimits = True
+ self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
+ self._updateMarkers()
+
+ def getGraphYLimits(self, axis):
+ assert axis in ("left", "right")
+ ax = self.ax2 if axis == "right" else self.ax
+
+ if not ax.get_visible():
+ return None
+
+ if self._dirtyLimits and self.isKeepDataAspectRatio():
+ self.ax.apply_aspect()
+ self.ax2.apply_aspect()
+ self._dirtyLimits = False
+
+ return ax.get_ybound()
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ ax = self.ax2 if axis == "right" else self.ax
+ if ymax < ymin:
+ ymin, ymax = ymax, ymin
+ self._dirtyLimits = True
+
+ if self.isKeepDataAspectRatio():
+ # matplotlib keeps limits of shared axis when keeping aspect ratio
+ # So x limits are kept when changing y limits....
+ # Change x limits first by taking into account aspect ratio
+ # and then change y limits.. so matplotlib does not need
+ # to make change (to y) to keep aspect ratio
+ xmin, xmax = ax.get_xbound()
+ curYMin, curYMax = ax.get_ybound()
+
+ newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
+ xcenter = 0.5 * (xmin + xmax)
+ ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)
+
+ if not self.isYAxisInverted():
+ ax.set_ylim(ymin, ymax)
+ else:
+ ax.set_ylim(ymax, ymin)
+
+ self._updateMarkers()
+
+ # Graph axes
+
+ def __initXAxisFormatterAndLocator(self):
+ if self.ax.xaxis.get_scale() != "linear":
+ return # Do not override formatter and locator
+
+ if not self.isXAxisTimeSeries():
+ self.ax.xaxis.set_major_formatter(DefaultTickFormatter())
+ return
+
+ # We can't use a matplotlib.dates.DateFormatter because it expects
+ # the data to be in datetimes. Silx works internally with
+ # timestamps (floats).
+ locator = NiceDateLocator(tz=self.getXAxisTimeZone())
+ self.ax.xaxis.set_major_locator(locator)
+ self.ax.xaxis.set_major_formatter(
+ NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone())
+ )
+
+ def setXAxisTimeZone(self, tz):
+ super(BackendMatplotlib, self).setXAxisTimeZone(tz)
+
+ # Make new formatter and locator with the time zone.
+ self.setXAxisTimeSeries(self.isXAxisTimeSeries())
+
+ def isXAxisTimeSeries(self):
+ return self._isXAxisTimeSeries
+
+ def setXAxisTimeSeries(self, isTimeSeries):
+ self._isXAxisTimeSeries = isTimeSeries
+ self.__initXAxisFormatterAndLocator()
+
+ def setXAxisLogarithmic(self, flag):
+ # Workaround for matplotlib 2.1.0 when one tries to set an axis
+ # to log scale with both limits <= 0
+ # In this case a draw with positive limits is needed first
+ if flag and self._matplotlibVersion >= Version("2.1.0"):
+ xlim = self.ax.get_xlim()
+ if xlim[0] <= 0 and xlim[1] <= 0:
+ self.ax.set_xlim(1, 10)
+ self.draw()
+
+ xscale = "log" if flag else "linear"
+ self.ax2.set_xscale(xscale)
+ self.ax.set_xscale(xscale)
+ self.__initXAxisFormatterAndLocator()
+
+ def setYAxisLogarithmic(self, flag):
+ # Workaround for matplotlib 2.0 issue with negative bounds
+ # before switching to log scale
+ if flag and self._matplotlibVersion >= Version("2.0.0"):
+ redraw = False
+ for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)):
+ ylim = axis.get_ylim()
+ if ylim[0] <= 0 or ylim[1] <= 0:
+ dataRange = self._plot.getDataRange()[dataRangeIndex]
+ if dataRange is None:
+ dataRange = 1, 100 # Fallback
+ axis.set_ylim(*dataRange)
+ redraw = True
+ if redraw:
+ self.draw()
+
+ if flag:
+ self.ax2.set_yscale("log")
+ self.ax.set_yscale("log")
+ return
+
+ self.ax2.set_yscale("linear")
+ self.ax2.yaxis.set_major_formatter(DefaultTickFormatter())
+ self.ax.set_yscale("linear")
+ self.ax.yaxis.set_major_formatter(DefaultTickFormatter())
+
+ def setYAxisInverted(self, flag):
+ if self.ax.yaxis_inverted() != bool(flag):
+ self.ax.invert_yaxis()
+ self._updateMarkers()
+
+ def isYAxisInverted(self):
+ return self.ax.yaxis_inverted()
+
+ def isYRightAxisVisible(self):
+ return self.ax2.yaxis.get_visible()
+
+ def isKeepDataAspectRatio(self):
+ return self.ax.get_aspect() in (1.0, "equal")
+
+ def setKeepDataAspectRatio(self, flag):
+ self.ax.set_aspect(1.0 if flag else "auto")
+ self.ax2.set_aspect(1.0 if flag else "auto")
+
+ def setGraphGrid(self, which):
+ self.ax.grid(False, which="both") # Disable all grid first
+ if which is not None:
+ self.ax.grid(True, which=which)
+
+ # Data <-> Pixel coordinates conversion
+
+ def _getDevicePixelRatio(self) -> float:
+ """Compatibility wrapper for devicePixelRatioF"""
+ return 1.0
+
+ def _mplToQtPosition(
+ self, x: Union[float, numpy.ndarray], y: Union[float, numpy.ndarray]
+ ) -> Tuple[Union[float, numpy.ndarray], Union[float, numpy.ndarray]]:
+ """Convert matplotlib "display" space coord to Qt widget logical pixel"""
+ ratio = self._getDevicePixelRatio()
+ # Convert from matplotlib origin (bottom) to Qt origin (top)
+ # and apply device pixel ratio
+ return x / ratio, (self.fig.get_window_extent().height - y) / ratio
+
+ def _qtToMplPosition(self, x: float, y: float) -> Tuple[float, float]:
+ """Convert Qt widget logical pixel to matplotlib "display" space coord"""
+ ratio = self._getDevicePixelRatio()
+ # Apply device pixel ration and
+ # convert from Qt origin (top) to matplotlib origin (bottom)
+ return x * ratio, self.fig.get_window_extent().height - (y * ratio)
+
+ def dataToPixel(self, x, y, axis):
+ ax = self.ax2 if axis == "right" else self.ax
+ points = numpy.transpose((x, y))
+ displayPos = ax.transData.transform(points).transpose()
+ return self._mplToQtPosition(*displayPos)
+
+ def pixelToData(self, x, y, axis):
+ ax = self.ax2 if axis == "right" else self.ax
+ displayPos = self._qtToMplPosition(x, y)
+ return tuple(ax.transData.inverted().transform_point(displayPos))
+
+ def getPlotBoundsInPixels(self):
+ bbox = self.ax.get_window_extent()
+ # Warning this is not returning int...
+ ratio = self._getDevicePixelRatio()
+ return tuple(
+ int(value / ratio)
+ for value in (
+ bbox.xmin,
+ self.fig.get_window_extent().height - bbox.ymax,
+ bbox.width,
+ bbox.height,
+ )
+ )
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ width, height = 1.0 - left - right, 1.0 - top - bottom
+ position = left, bottom, width, height
+
+ istight = config._MPL_TIGHT_LAYOUT and (left, top, right, bottom) != (
+ 0,
+ 0,
+ 0,
+ 0,
+ )
+ if self._matplotlibVersion >= Version("3.6"):
+ self.fig.set_layout_engine("tight" if istight else None)
+ else:
+ self.fig.set_tight_layout(True if istight else None)
+
+ # Toggle display of axes and viewbox rect
+ isFrameOn = position != (0.0, 0.0, 1.0, 1.0)
+ self.ax.set_frame_on(isFrameOn)
+ self.ax2.set_frame_on(isFrameOn)
+
+ self.ax.set_position(position)
+ self.ax2.set_position(position)
+
+ self._synchronizeBackgroundColors()
+ self._synchronizeForegroundColors()
+ self._plot._setDirtyPlot()
+
+ def _synchronizeBackgroundColors(self):
+ backgroundColor = self._plot.getBackgroundColor().getRgbF()
+
+ dataBackgroundColor = self._plot.getDataBackgroundColor()
+ if dataBackgroundColor.isValid():
+ dataBackgroundColor = dataBackgroundColor.getRgbF()
+ else:
+ dataBackgroundColor = backgroundColor
+
+ if self.ax.get_frame_on():
+ self.fig.patch.set_facecolor(backgroundColor)
+ if self._matplotlibVersion < Version("2"):
+ self.ax.set_axis_bgcolor(dataBackgroundColor)
+ else:
+ self.ax.set_facecolor(dataBackgroundColor)
+ else:
+ self.fig.patch.set_facecolor(dataBackgroundColor)
+
+ def _synchronizeForegroundColors(self):
+ foregroundColor = self._plot.getForegroundColor().getRgbF()
+
+ gridColor = self._plot.getGridColor()
+ if gridColor.isValid():
+ gridColor = gridColor.getRgbF()
+ else:
+ gridColor = foregroundColor
+
+ for axes in (self.ax, self.ax2):
+ if axes.get_frame_on():
+ axes.spines["bottom"].set_color(foregroundColor)
+ axes.spines["top"].set_color(foregroundColor)
+ axes.spines["right"].set_color(foregroundColor)
+ axes.spines["left"].set_color(foregroundColor)
+ axes.tick_params(axis="x", colors=foregroundColor)
+ axes.tick_params(axis="y", colors=foregroundColor)
+ axes.yaxis.label.set_color(foregroundColor)
+ axes.xaxis.label.set_color(foregroundColor)
+ axes.title.set_color(foregroundColor)
+
+ for line in axes.get_xgridlines():
+ line.set_color(gridColor)
+
+ for line in axes.get_ygridlines():
+ line.set_color(gridColor)
+ # axes.grid().set_markeredgecolor(gridColor)
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._synchronizeBackgroundColors()
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._synchronizeForegroundColors()
+
+
+class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
+ """QWidget matplotlib backend using a QtAgg canvas.
+
+ It adds fast overlay drawing and mouse event management.
+ """
+
+ _sigPostRedisplay = qt.Signal()
+ """Signal handling automatic asynchronous replot"""
+
+ def __init__(self, plot, parent=None):
+ BackendMatplotlib.__init__(self, plot, parent)
+ FigureCanvasQTAgg.__init__(self, self.fig)
+ self.setParent(parent)
+
+ self._limitsBeforeResize = None
+
+ FigureCanvasQTAgg.setSizePolicy(
+ self, qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding
+ )
+ FigureCanvasQTAgg.updateGeometry(self)
+
+ # Make postRedisplay asynchronous using Qt signal
+ self._sigPostRedisplay.connect(self.__deferredReplot, qt.Qt.QueuedConnection)
+
+ self._picked = None
+
+ self.mpl_connect("button_press_event", self._onMousePress)
+ self.mpl_connect("button_release_event", self._onMouseRelease)
+ self.mpl_connect("motion_notify_event", self._onMouseMove)
+ self.mpl_connect("scroll_event", self._onMouseWheel)
+
+ def postRedisplay(self):
+ self._sigPostRedisplay.emit()
+
+ def __deferredReplot(self):
+ # Since this is deferred, makes sure it is still needed
+ plot = self._plotRef()
+ if plot is not None and plot._getDirtyPlot() and plot.getBackend() is self:
+ self.replot()
+
+ def _getDevicePixelRatio(self) -> float:
+ """Compatibility wrapper for devicePixelRatioF"""
+ if hasattr(self, "devicePixelRatioF"):
+ ratio = self.devicePixelRatioF()
+ else: # Qt < 5.6 compatibility
+ ratio = float(self.devicePixelRatio())
+ # Safety net: avoid returning 0
+ return ratio if ratio != 0.0 else 1.0
+
+ # Mouse event forwarding
+
+ _MPL_TO_PLOT_BUTTONS = {1: "left", 2: "middle", 3: "right"}
+
+ def _onMousePress(self, event):
+ button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
+ if button is not None:
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMousePress(int(x), int(y), button)
+
+ def _onMouseMove(self, event):
+ x, y = self._mplToQtPosition(event.x, event.y)
+ if self._graphCursor:
+ position = self._plot.pixelToData(x, y, axis="left", check=True)
+ lineh, linev = self._graphCursor
+ if position is not None:
+ linev.set_visible(True)
+ linev.set_xdata((position[0], position[0]))
+ lineh.set_visible(True)
+ lineh.set_ydata((position[1], position[1]))
+ self._plot._setDirtyPlot(overlayOnly=True)
+ elif lineh.get_visible():
+ lineh.set_visible(False)
+ linev.set_visible(False)
+ self._plot._setDirtyPlot(overlayOnly=True)
+ # onMouseMove must trigger replot if dirty flag is raised
+
+ self._plot.onMouseMove(int(x), int(y))
+
+ def _onMouseRelease(self, event):
+ button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
+ if button is not None:
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMouseRelease(int(x), int(y), button)
+
+ def _onMouseWheel(self, event):
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMouseWheel(int(x), int(y), event.step)
+
+ def leaveEvent(self, event):
+ """QWidget event handler"""
+ try:
+ plot = self._plot
+ except RuntimeError:
+ pass
+ else:
+ plot.onMouseLeaveWidget()
+
+ # picking
+
+ def pickItem(self, x, y, item):
+ xDisplay, yDisplay = self._qtToMplPosition(x, y)
+ mouseEvent = MouseEvent(
+ "button_press_event", self, int(xDisplay), int(yDisplay)
+ )
+ # Override axes and data position with the axes
+ mouseEvent.inaxes = item.axes
+ mouseEvent.xdata, mouseEvent.ydata = self.pixelToData(
+ x, y, axis="left" if item.axes is self.ax else "right"
+ )
+ picked, info = item.contains(mouseEvent)
+
+ if not picked:
+ return None
+
+ elif isinstance(item, TriMesh):
+ # Convert selected triangle to data point indices
+ triangulation = item._triangulation
+ indices = triangulation.get_masked_triangles()[info["ind"][0]]
+
+ # Sort picked triangle points by distance to mouse
+ # from furthest to closest to put closest point last
+ # This is to be somewhat consistent with last scatter point
+ # being the top one.
+ xdata, ydata = self.pixelToData(x, y, axis="left")
+ dists = (triangulation.x[indices] - xdata) ** 2 + (
+ triangulation.y[indices] - ydata
+ ) ** 2
+ return indices[numpy.flip(numpy.argsort(dists), axis=0)]
+
+ else: # Returns indices if any
+ return info.get("ind", ())
+
+ # replot control
+
+ def resizeEvent(self, event):
+ # Store current limits
+ self._limitsBeforeResize = (
+ self.ax.get_xbound(),
+ self.ax.get_ybound(),
+ self.ax2.get_ybound(),
+ )
+
+ FigureCanvasQTAgg.resizeEvent(self, event)
+ if self.isKeepDataAspectRatio() or self._hasOverlays():
+ # This is needed with matplotlib 1.5.x and 2.0.x
+ self._plot._setDirtyPlot()
+
+ def draw(self):
+ """Overload draw
+
+ It performs a full redraw (including overlays) of the plot.
+ It also resets background and emit limits changed signal.
+
+ This is directly called by matplotlib for widget resize.
+ """
+ if self.size().isEmpty():
+ return # Skip rendering of 0-sized canvas
+
+ self.updateZOrder()
+
+ if not qt_inspect.isValid(self):
+ _logger.info("draw requested but widget no longer exists")
+ return
+
+ # Starting with mpl 2.1.0, toggling autoscale raises a ValueError
+ # in some situations. See #1081, #1136, #1163,
+ if self._matplotlibVersion >= Version("2.0.0"):
+ try:
+ FigureCanvasQTAgg.draw(self)
+ except ValueError as err:
+ _logger.debug(
+ "ValueError caught while calling FigureCanvasQTAgg.draw: " "'%s'",
+ err,
+ )
+ else:
+ FigureCanvasQTAgg.draw(self)
+
+ if self._hasOverlays():
+ # Save background
+ self._background = self.copy_from_bbox(self.fig.bbox)
+ else:
+ self._background = None # Reset background
+
+ # Check if limits changed due to a resize of the widget
+ if self._limitsBeforeResize is not None:
+ xLimits, yLimits, yRightLimits = self._limitsBeforeResize
+ self._limitsBeforeResize = None
+
+ if xLimits != self.ax.get_xbound() or yLimits != self.ax.get_ybound():
+ self._updateMarkers()
+
+ if xLimits != self.ax.get_xbound():
+ self._plot.getXAxis()._emitLimitsChanged()
+ if yLimits != self.ax.get_ybound():
+ self._plot.getYAxis(axis="left")._emitLimitsChanged()
+ if yRightLimits != self.ax2.get_ybound():
+ self._plot.getYAxis(axis="right")._emitLimitsChanged()
+
+ self._drawOverlays()
+
+ def replot(self):
+ if not qt_inspect.isValid(self):
+ _logger.info("replot requested but widget no longer exists")
+ return
+
+ with self._plot._paintContext():
+ BackendMatplotlib._replot(self)
+
+ dirtyFlag = self._plot._getDirtyPlot()
+
+ if dirtyFlag == "overlay":
+ # Only redraw overlays using fast rendering path
+ if self._background is None:
+ self._background = self.copy_from_bbox(self.fig.bbox)
+ self.restore_region(self._background)
+ self._drawOverlays()
+ self.blit(self.fig.bbox)
+
+ elif dirtyFlag: # Need full redraw
+ self.draw()
+
+ # Workaround issue of rendering overlays with some matplotlib versions
+ if Version("1.5") <= self._matplotlibVersion < Version(
+ "2.1"
+ ) and not hasattr(self, "_firstReplot"):
+ self._firstReplot = False
+ if self._hasOverlays():
+ qt.QTimer.singleShot(0, self.draw) # Request async draw
+
+ # cursor
+
+ _QT_CURSORS = {
+ BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
+ BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
+ BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
+ BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
+ }
+
+ def setGraphCursorShape(self, cursor):
+ if cursor is None:
+ FigureCanvasQTAgg.unsetCursor(self)
+ else:
+ cursor = self._QT_CURSORS[cursor]
+ FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor))
diff --git a/src/silx/gui/plot/backends/BackendOpenGL.py b/src/silx/gui/plot/backends/BackendOpenGL.py
new file mode 100755
index 0000000..370f14b
--- /dev/null
+++ b/src/silx/gui/plot/backends/BackendOpenGL.py
@@ -0,0 +1,1660 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""OpenGL Plot backend."""
+
+from __future__ import annotations
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+import logging
+import weakref
+
+import numpy
+
+from .. import items
+from .._utils import FLOAT32_MINPOS
+from . import BackendBase
+from ... import colors
+from ... import qt
+
+from ..._glutils import gl
+from ... import _glutils as glu
+from . import glutils
+from .glutils.PlotImageFile import saveImageToFile
+from silx.gui.colors import RGBAColorType
+
+_logger = logging.getLogger(__name__)
+
+
+# TODO idea: BackendQtMixIn class to share code between mpl and gl
+# TODO check if OpenGL is available
+# TODO make an off-screen mesa backend
+
+# Content #####################################################################
+
+
+class _ShapeItem(dict):
+ def __init__(
+ self,
+ x,
+ y,
+ shape,
+ color,
+ fill,
+ overlay,
+ linewidth,
+ dashoffset,
+ dashpattern,
+ gapcolor,
+ ):
+ super(_ShapeItem, self).__init__()
+
+ if shape not in ("polygon", "rectangle", "line", "vline", "hline", "polylines"):
+ raise NotImplementedError("Unsupported shape {0}".format(shape))
+
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ if shape == "rectangle":
+ xMin, xMax = x
+ x = numpy.array((xMin, xMin, xMax, xMax))
+ yMin, yMax = y
+ y = numpy.array((yMin, yMax, yMax, yMin))
+
+ # Ignore fill for polylines to mimic matplotlib
+ fill = fill if shape != "polylines" else False
+
+ self.update(
+ {
+ "shape": shape,
+ "color": colors.rgba(color),
+ "fill": "hatch" if fill else None,
+ "x": x,
+ "y": y,
+ "linewidth": linewidth,
+ "dashoffset": dashoffset,
+ "dashpattern": dashpattern,
+ "gapcolor": gapcolor,
+ }
+ )
+
+
+class _MarkerItem(dict):
+ def __init__(
+ self,
+ x,
+ y,
+ text,
+ color,
+ symbol,
+ linewidth,
+ dashoffset,
+ dashpattern,
+ constraint,
+ yaxis,
+ font,
+ bgcolor,
+ ):
+ super(_MarkerItem, self).__init__()
+
+ if symbol is None:
+ symbol = "+"
+
+ # Apply constraint to provided position
+ isConstraint = constraint is not None and x is not None and y is not None
+ if isConstraint:
+ x, y = constraint(x, y)
+
+ self.update(
+ {
+ "x": x,
+ "y": y,
+ "text": text,
+ "color": colors.rgba(color),
+ "constraint": constraint if isConstraint else None,
+ "symbol": symbol,
+ "linewidth": linewidth,
+ "dashoffset": dashoffset,
+ "dashpattern": dashpattern,
+ "yaxis": yaxis,
+ "font": font,
+ "bgcolor": bgcolor,
+ }
+ )
+
+
+# shaders #####################################################################
+
+_baseVertShd = """
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform bvec2 isLog;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec2 posTransformed = position;
+ if (isLog.x) {
+ posTransformed.x = oneOverLog10 * log(position.x);
+ }
+ if (isLog.y) {
+ posTransformed.y = oneOverLog10 * log(position.y);
+ }
+ gl_Position = matrix * vec4(posTransformed, 0.0, 1.0);
+ }
+ """
+
+_baseFragShd = """
+ uniform vec4 color;
+ uniform int hatchStep;
+ uniform float tickLen;
+
+ void main(void) {
+ if (tickLen != 0.) {
+ if (mod((gl_FragCoord.x + gl_FragCoord.y) / tickLen, 2.) < 1.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ } else if (hatchStep == 0 ||
+ mod(gl_FragCoord.x - gl_FragCoord.y, float(hatchStep)) == 0.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ }
+ """
+
+_texVertShd = """
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ coords = texCoords;
+ }
+ """
+
+_texFragShd = """
+ uniform sampler2D tex;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, coords);
+ gl_FragColor.a = 1.0;
+ }
+ """
+
+# BackendOpenGL ###############################################################
+
+
+class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
+ """OpenGL-based Plot backend.
+
+ WARNINGS:
+ Unless stated otherwise, this API is NOT thread-safe and MUST be
+ called from the main thread.
+ When numpy arrays are passed as arguments to the API (through
+ :func:`addCurve` and :func:`addImage`), they are copied only if
+ required.
+ So, the caller should not modify these arrays afterwards.
+ """
+
+ _TEXT_MARKER_PADDING = 4
+
+ def __init__(self, plot, parent=None, f=qt.Qt.Widget):
+ glu.OpenGLWidget.__init__(
+ self,
+ parent,
+ alphaBufferSize=8,
+ depthBufferSize=0,
+ stencilBufferSize=0,
+ version=(2, 1),
+ f=f,
+ )
+ BackendBase.BackendBase.__init__(self, plot, parent)
+
+ self._defaultFont: qt.QFont = None
+ self.__isOpenGLValid = False
+
+ self._backgroundColor = 1.0, 1.0, 1.0, 1.0
+ self._dataBackgroundColor = 1.0, 1.0, 1.0, 1.0
+
+ self.matScreenProj = glutils.mat4Identity()
+
+ self._progBase = glu.Program(_baseVertShd, _baseFragShd, attrib0="position")
+ self._progTex = glu.Program(_texVertShd, _texFragShd, attrib0="position")
+ self._plotFBOs = weakref.WeakKeyDictionary()
+
+ self._keepDataAspectRatio = False
+
+ self._crosshairCursor = None
+ self._mousePosInPixels = None
+
+ self._glGarbageCollector = []
+
+ self._plotFrame = glutils.GLPlotFrame2D(
+ foregroundColor=(0.0, 0.0, 0.0, 1.0),
+ gridColor=(0.7, 0.7, 0.7, 1.0),
+ marginRatios=(0.15, 0.1, 0.1, 0.15),
+ font=self.getDefaultFont(),
+ )
+ self._plotFrame.size = ( # Init size with size int
+ int(self.getDevicePixelRatio() * 640),
+ int(self.getDevicePixelRatio() * 480),
+ )
+
+ self.setAutoFillBackground(False)
+ self.setMouseTracking(True)
+
+ # QWidget
+
+ _MOUSE_BTNS = {
+ qt.Qt.LeftButton: "left",
+ qt.Qt.RightButton: "right",
+ qt.Qt.MiddleButton: "middle",
+ }
+
+ def sizeHint(self):
+ return qt.QSize(8 * 80, 6 * 80) # Mimic MatplotlibBackend
+
+ def mousePressEvent(self, event):
+ if event.button() not in self._MOUSE_BTNS:
+ return super(BackendOpenGL, self).mousePressEvent(event)
+ x, y = qt.getMouseEventPosition(event)
+ self._plot.onMousePress(x, y, self._MOUSE_BTNS[event.button()])
+ event.accept()
+
+ def mouseMoveEvent(self, event):
+ qtPos = qt.getMouseEventPosition(event)
+
+ previousMousePosInPixels = self._mousePosInPixels
+ if qtPos == self._mouseInPlotArea(*qtPos):
+ devicePixelRatio = self.getDevicePixelRatio()
+ devicePos = qtPos[0] * devicePixelRatio, qtPos[1] * devicePixelRatio
+ self._mousePosInPixels = devicePos # Mouse in plot area
+ else:
+ self._mousePosInPixels = None # Mouse outside plot area
+
+ if (
+ self._crosshairCursor is not None
+ and previousMousePosInPixels != self._mousePosInPixels
+ ):
+ # Avoid replot when cursor remains outside plot area
+ self._plot._setDirtyPlot(overlayOnly=True)
+
+ self._plot.onMouseMove(*qtPos)
+ event.accept()
+
+ def mouseReleaseEvent(self, event):
+ if event.button() not in self._MOUSE_BTNS:
+ return super(BackendOpenGL, self).mouseReleaseEvent(event)
+ x, y = qt.getMouseEventPosition(event)
+ self._plot.onMouseRelease(x, y, self._MOUSE_BTNS[event.button()])
+ event.accept()
+
+ def wheelEvent(self, event):
+ delta = event.angleDelta().y()
+ angleInDegrees = delta / 8.0
+ x, y = qt.getMouseEventPosition(event)
+ self._plot.onMouseWheel(x, y, angleInDegrees)
+ event.accept()
+
+ def leaveEvent(self, _):
+ self._plot.onMouseLeaveWidget()
+
+ # OpenGLWidget API
+
+ def initializeGL(self):
+ self.__isOpenGLValid = gl.testGL()
+ if not self.__isOpenGLValid:
+ return
+
+ gl.glClearStencil(0)
+
+ gl.glEnable(gl.GL_BLEND)
+ # gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
+ gl.glBlendFuncSeparate(
+ gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA, gl.GL_ONE, gl.GL_ONE
+ )
+
+ # For lines
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+
+ # For points
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ def _paintDirectGL(self):
+ self._renderPlotAreaGL()
+ self._plotFrame.render()
+ self._renderOverlayGL()
+
+ def _paintFBOGL(self):
+ context = glu.Context.getCurrent()
+ plotFBOTex = self._plotFBOs.get(context)
+ if self._plot._getDirtyPlot() or self._plotFrame.isDirty or plotFBOTex is None:
+ self._plotVertices = (
+ # Vertex coordinates
+ numpy.array(
+ ((-1.0, -1.0), (1.0, -1.0), (-1.0, 1.0), (1.0, 1.0)),
+ dtype=numpy.float32,
+ ),
+ # Texture coordinates
+ numpy.array(
+ ((0.0, 0.0), (1.0, 0.0), (0.0, 1.0), (1.0, 1.0)),
+ dtype=numpy.float32,
+ ),
+ )
+ if (
+ plotFBOTex is None
+ or plotFBOTex.shape[1] != self._plotFrame.size[0]
+ or plotFBOTex.shape[0] != self._plotFrame.size[1]
+ ):
+ if plotFBOTex is not None:
+ plotFBOTex.discard()
+ plotFBOTex = glu.FramebufferTexture(
+ gl.GL_RGBA,
+ shape=(self._plotFrame.size[1], self._plotFrame.size[0]),
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE),
+ )
+ self._plotFBOs[context] = plotFBOTex
+
+ with plotFBOTex:
+ gl.glClearColor(*self._backgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
+ self._renderPlotAreaGL()
+ self._plotFrame.render()
+
+ # Render plot in screen coords
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ self._progTex.use()
+ texUnit = 0
+
+ gl.glUniform1i(self._progTex.uniforms["tex"], texUnit)
+ gl.glUniformMatrix4fv(
+ self._progTex.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ glutils.mat4Identity().astype(numpy.float32),
+ )
+
+ gl.glEnableVertexAttribArray(self._progTex.attributes["position"])
+ gl.glVertexAttribPointer(
+ self._progTex.attributes["position"],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._plotVertices[0],
+ )
+
+ gl.glEnableVertexAttribArray(self._progTex.attributes["texCoords"])
+ gl.glVertexAttribPointer(
+ self._progTex.attributes["texCoords"],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._plotVertices[1],
+ )
+
+ with plotFBOTex.texture:
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices[0]))
+
+ self._renderOverlayGL()
+
+ def paintGL(self):
+ if not self.__isOpenGLValid:
+ return
+
+ plot = self._plotRef()
+ if plot is None:
+ return
+
+ with plot._paintContext():
+ with glu.Context.current(self.context()):
+ # Release OpenGL resources
+ for item in self._glGarbageCollector:
+ item.discard()
+ self._glGarbageCollector = []
+
+ gl.glClearColor(*self._backgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
+
+ # Check if window is large enough
+ if self._plotFrame.plotSize <= (2, 2):
+ return
+
+ # Sync plot frame with window
+ self._plotFrame.devicePixelRatio = self.getDevicePixelRatio()
+ self._plotFrame.dotsPerInch = self.getDotsPerInch()
+ # self._paintDirectGL()
+ self._paintFBOGL()
+
+ def _renderItems(self, overlay=False):
+ """Render items according to :class:`PlotWidget` order
+
+ Note: Scissor test should already be set.
+
+ :param bool overlay:
+ False (the default) to render item that are not overlays.
+ True to render items that are overlays.
+ """
+ # Values that are often used
+ plotWidth, plotHeight = self._plotFrame.plotSize
+ isXLog = self._plotFrame.xAxis.isLog
+ isYLog = self._plotFrame.yAxis.isLog
+ isYInverted = self._plotFrame.isYAxisInverted
+
+ # Used by marker rendering
+ labels = []
+ pixelOffset = 3
+
+ context = glutils.RenderContext(
+ isXLog=isXLog,
+ isYLog=isYLog,
+ dpi=self.getDotsPerInch(),
+ plotFrame=self._plotFrame,
+ )
+
+ for plotItem in self.getItemsFromBackToFront(
+ condition=lambda i: i.isVisible() and i.isOverlay() == overlay
+ ):
+ if plotItem._backendRenderer is None:
+ continue
+
+ item = plotItem._backendRenderer
+
+ if isinstance(item, glutils.GLPlotItem): # Render data items
+ gl.glViewport(
+ self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth,
+ plotHeight,
+ )
+ # Set matrix
+ if item.yaxis == "right":
+ context.matrix = self._plotFrame.transformedDataY2ProjMat
+ else:
+ context.matrix = self._plotFrame.transformedDataProjMat
+ item.render(context)
+
+ elif isinstance(item, _ShapeItem): # Render shape items
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ if (isXLog and numpy.min(item["x"]) < FLOAT32_MINPOS) or (
+ isYLog and numpy.min(item["y"]) < FLOAT32_MINPOS
+ ):
+ # Ignore items <= 0. on log axes
+ continue
+
+ if item["shape"] == "hline":
+ width = self._plotFrame.size[0]
+ _, yPixel = self._plotFrame.dataToPixel(
+ 0.5 * sum(self._plotFrame.dataRanges[0]), item["y"], axis="left"
+ )
+ subShapes = [
+ numpy.array(
+ ((0.0, yPixel), (width, yPixel)), dtype=numpy.float32
+ )
+ ]
+
+ elif item["shape"] == "vline":
+ xPixel, _ = self._plotFrame.dataToPixel(
+ item["x"], 0.5 * sum(self._plotFrame.dataRanges[1]), axis="left"
+ )
+ height = self._plotFrame.size[1]
+ subShapes = [
+ numpy.array(
+ ((xPixel, 0), (xPixel, height)), dtype=numpy.float32
+ )
+ ]
+
+ else:
+ # Split sub-shapes at not finite values
+ splits = numpy.nonzero(
+ numpy.logical_not(
+ numpy.logical_and(
+ numpy.isfinite(item["x"]), numpy.isfinite(item["y"])
+ )
+ )
+ )[0]
+ splits = numpy.concatenate(([-1], splits, [len(item["x"])]))
+ subShapes = []
+ for begin, end in zip(splits[:-1] + 1, splits[1:]):
+ if end > begin:
+ subShapes.append(
+ numpy.array(
+ [
+ self._plotFrame.dataToPixel(x, y, axis="left")
+ for (x, y) in zip(
+ item["x"][begin:end], item["y"][begin:end]
+ )
+ ]
+ )
+ )
+
+ for points in subShapes: # Draw each sub-shape
+ # Draw the fill
+ if item["fill"] is not None and item["shape"] not in (
+ "hline",
+ "vline",
+ ):
+ self._progBase.use()
+ gl.glUniformMatrix4fv(
+ self._progBase.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ self.matScreenProj.astype(numpy.float32),
+ )
+ gl.glUniform2i(self._progBase.uniforms["isLog"], False, False)
+ gl.glUniform1f(self._progBase.uniforms["tickLen"], 0.0)
+
+ shape2D = glutils.FilledShape2D(
+ points, style=item["fill"], color=item["color"]
+ )
+ shape2D.render(
+ posAttrib=self._progBase.attributes["position"],
+ colorUnif=self._progBase.uniforms["color"],
+ hatchStepUnif=self._progBase.uniforms["hatchStep"],
+ )
+
+ # Draw the stroke
+ if item["dashpattern"] is not None:
+ if item["shape"] != "polylines":
+ # close the polyline
+ points = numpy.append(
+ points, numpy.atleast_2d(points[0]), axis=0
+ )
+
+ lines = glutils.GLLines2D(
+ points[:, 0],
+ points[:, 1],
+ color=item["color"],
+ gapColor=item["gapcolor"],
+ width=item["linewidth"],
+ dashOffset=item["dashoffset"],
+ dashPattern=item["dashpattern"],
+ )
+ context.matrix = self.matScreenProj
+ lines.render(context)
+
+ elif isinstance(item, _MarkerItem):
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ xCoord, yCoord, yAxis = item["x"], item["y"], item["yaxis"]
+
+ if (isXLog and xCoord is not None and xCoord <= 0) or (
+ isYLog and yCoord is not None and yCoord <= 0
+ ):
+ # Do not render markers with negative coords on log axis
+ continue
+
+ color = item["color"]
+ bgColor = item["bgcolor"]
+ if xCoord is None or yCoord is None:
+ if xCoord is None: # Horizontal line in data space
+ pixelPos = self._plotFrame.dataToPixel(
+ 0.5 * sum(self._plotFrame.dataRanges[0]), yCoord, axis=yAxis
+ )
+
+ if item["text"] is not None:
+ x = (
+ self._plotFrame.size[0]
+ - self._plotFrame.margins.right
+ - pixelOffset
+ )
+ y = pixelPos[1] - pixelOffset
+ label = glutils.Text2D(
+ item["text"],
+ item["font"],
+ x,
+ y,
+ color=color,
+ bgColor=bgColor,
+ align=glutils.RIGHT,
+ valign=glutils.BOTTOM,
+ devicePixelRatio=self.getDevicePixelRatio(),
+ padding=self._TEXT_MARKER_PADDING,
+ )
+ labels.append(label)
+
+ width = self._plotFrame.size[0]
+ lines = glutils.GLLines2D(
+ (0, width),
+ (pixelPos[1], pixelPos[1]),
+ color=color,
+ width=item["linewidth"],
+ dashOffset=item["dashoffset"],
+ dashPattern=item["dashpattern"],
+ )
+ context.matrix = self.matScreenProj
+ lines.render(context)
+
+ else: # yCoord is None: vertical line in data space
+ yRange = self._plotFrame.dataRanges[1 if yAxis == "left" else 2]
+ pixelPos = self._plotFrame.dataToPixel(
+ xCoord, 0.5 * sum(yRange), axis=yAxis
+ )
+
+ if item["text"] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = self._plotFrame.margins.top + pixelOffset
+ label = glutils.Text2D(
+ item["text"],
+ item["font"],
+ x,
+ y,
+ color=color,
+ bgColor=bgColor,
+ align=glutils.LEFT,
+ valign=glutils.TOP,
+ devicePixelRatio=self.getDevicePixelRatio(),
+ padding=self._TEXT_MARKER_PADDING,
+ )
+ labels.append(label)
+
+ height = self._plotFrame.size[1]
+ lines = glutils.GLLines2D(
+ (pixelPos[0], pixelPos[0]),
+ (0, height),
+ color=color,
+ width=item["linewidth"],
+ dashOffset=item["dashoffset"],
+ dashPattern=item["dashpattern"],
+ )
+ context.matrix = self.matScreenProj
+ lines.render(context)
+
+ else:
+ xmin, xmax = self._plot.getXAxis().getLimits()
+ ymin, ymax = self._plot.getYAxis(axis=yAxis).getLimits()
+ if not xmin < xCoord < xmax or not ymin < yCoord < ymax:
+ # Do not render markers outside visible plot area
+ continue
+ pixelPos = self._plotFrame.dataToPixel(xCoord, yCoord, axis=yAxis)
+
+ if isYInverted:
+ valign = glutils.BOTTOM
+ vPixelOffset = -pixelOffset
+ else:
+ valign = glutils.TOP
+ vPixelOffset = pixelOffset
+
+ if item["text"] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = pixelPos[1] + vPixelOffset
+ label = glutils.Text2D(
+ item["text"],
+ item["font"],
+ x,
+ y,
+ color=color,
+ bgColor=bgColor,
+ align=glutils.LEFT,
+ valign=valign,
+ devicePixelRatio=self.getDevicePixelRatio(),
+ padding=self._TEXT_MARKER_PADDING,
+ )
+ labels.append(label)
+
+ # For now simple implementation: using a curve for each marker
+ # Should pack all markers to a single set of points
+ marker = glutils.Points2D(
+ (pixelPos[0],),
+ (pixelPos[1],),
+ marker=item["symbol"],
+ color=color,
+ size=11,
+ )
+ context.matrix = self.matScreenProj
+ marker.render(context)
+
+ else:
+ _logger.error("Unsupported item: %s", str(item))
+ continue
+
+ # Render marker labels
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+ for label in labels:
+ label.render(self.matScreenProj, self._plotFrame.dotsPerInch)
+
+ def _renderOverlayGL(self):
+ """Render overlay layer: overlay items and crosshair."""
+ plotWidth, plotHeight = self._plotFrame.plotSize
+
+ # Scissor to plot area
+ gl.glScissor(
+ self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth,
+ plotHeight,
+ )
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+
+ self._renderItems(overlay=True)
+
+ # Render crosshair cursor
+ if self._crosshairCursor is not None and self._mousePosInPixels is not None:
+ self._progBase.use()
+ gl.glUniform2i(self._progBase.uniforms["isLog"], False, False)
+ gl.glUniform1f(self._progBase.uniforms["tickLen"], 0.0)
+ posAttrib = self._progBase.attributes["position"]
+ matrixUnif = self._progBase.uniforms["matrix"]
+ colorUnif = self._progBase.uniforms["color"]
+ hatchStepUnif = self._progBase.uniforms["hatchStep"]
+
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ gl.glUniformMatrix4fv(
+ matrixUnif, 1, gl.GL_TRUE, self.matScreenProj.astype(numpy.float32)
+ )
+
+ color, lineWidth = self._crosshairCursor
+ gl.glUniform4f(colorUnif, *color)
+ gl.glUniform1i(hatchStepUnif, 0)
+
+ xPixel, yPixel = self._mousePosInPixels
+ xPixel, yPixel = xPixel + 0.5, yPixel + 0.5
+ vertices = numpy.array(
+ (
+ (0.0, yPixel),
+ (self._plotFrame.size[0], yPixel),
+ (xPixel, 0.0),
+ (xPixel, self._plotFrame.size[1]),
+ ),
+ dtype=numpy.float32,
+ )
+
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, vertices
+ )
+ gl.glLineWidth(lineWidth)
+ gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+
+ def _renderPlotAreaGL(self):
+ """Render base layer of plot area.
+
+ It renders the background, grid and items except overlays
+ """
+ plotWidth, plotHeight = self._plotFrame.plotSize
+
+ gl.glScissor(
+ self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth,
+ plotHeight,
+ )
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+
+ if self._dataBackgroundColor != self._backgroundColor:
+ gl.glClearColor(*self._dataBackgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+ self._plotFrame.renderGrid()
+
+ # Matrix
+ trBounds = self._plotFrame.transformedDataRanges
+ if trBounds.x[0] != trBounds.x[1] and trBounds.y[0] != trBounds.y[1]:
+ # Do rendering of items
+ self._renderItems(overlay=False)
+
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+
+ def resizeGL(self, width, height):
+ if width == 0 or height == 0: # Do not resize
+ return
+
+ self._plotFrame.size = (
+ int(self.getDevicePixelRatio() * width),
+ int(self.getDevicePixelRatio() * height),
+ )
+
+ self.matScreenProj = glutils.mat4Ortho(
+ 0, self._plotFrame.size[0], self._plotFrame.size[1], 0, 1, -1
+ )
+
+ # Store current ranges
+ previousXRange = self.getGraphXLimits()
+ previousYRange = self.getGraphYLimits(axis="left")
+ previousYRightRange = self.getGraphYLimits(axis="right")
+
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self._plotFrame.dataRanges
+ self.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+ # If plot range has changed, then emit signal
+ if previousXRange != self.getGraphXLimits():
+ self._plot.getXAxis()._emitLimitsChanged()
+ if previousYRange != self.getGraphYLimits(axis="left"):
+ self._plot.getYAxis(axis="left")._emitLimitsChanged()
+ if previousYRightRange != self.getGraphYLimits(axis="right"):
+ self._plot.getYAxis(axis="right")._emitLimitsChanged()
+
+ # Add methods
+
+ @staticmethod
+ def _castArrayTo(v):
+ """Returns best floating type to cast the array to.
+
+ :param numpy.ndarray v: Array to cast
+ :rtype: numpy.dtype
+ :raise ValueError: If dtype is not supported
+ """
+ if numpy.issubdtype(v.dtype, numpy.floating):
+ return numpy.float32 if v.itemsize <= 4 else numpy.float64
+ elif numpy.issubdtype(v.dtype, numpy.integer):
+ return numpy.float32 if v.itemsize <= 2 else numpy.float64
+ else:
+ raise ValueError("Unsupported data type")
+
+ _DASH_PATTERNS = {
+ "": (0.0, None),
+ " ": (0.0, None),
+ "-": (0.0, ()),
+ "--": (0.0, (3.7, 1.6, 3.7, 1.6)),
+ "-.": (0.0, (6.4, 1.6, 1, 1.6)),
+ ":": (0.0, (1, 1.65, 1, 1.65)),
+ None: (0.0, None),
+ }
+ """Convert from linestyle to (offset, (dash pattern))
+
+ Note: dash pattern internal convention differs from matplotlib:
+ - None: no line at all
+ - (): "solid" line
+ """
+
+ def _lineStyleToDashOffsetPattern(
+ self, style
+ ) -> tuple[float, tuple[float, float, float, float] | tuple[()] | None]:
+ """Convert a linestyle to its corresponding offset and dash pattern"""
+ if style is None or isinstance(style, str):
+ return self._DASH_PATTERNS[style]
+
+ # (offset, (dash pattern)) case
+ offset, pattern = style
+ if pattern is None:
+ # Convert from matplotlib to internal representation of solid
+ pattern = ()
+ if len(pattern) == 2:
+ pattern = pattern * 2
+ return float(offset), tuple(float(v) for v in pattern)
+
+ def addCurve(
+ self,
+ x,
+ y,
+ color,
+ gapcolor,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ xerror,
+ yerror,
+ fill,
+ alpha,
+ symbolsize,
+ baseline,
+ ):
+ for parameter in (
+ x,
+ y,
+ color,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ fill,
+ symbolsize,
+ ):
+ assert parameter is not None
+ assert yaxis in ("left", "right")
+
+ # Convert input data
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ # Check if float32 is enough
+ if (
+ self._castArrayTo(x) is numpy.float32
+ and self._castArrayTo(y) is numpy.float32
+ ):
+ dtype = numpy.float32
+ else:
+ dtype = numpy.float64
+
+ x = numpy.array(x, dtype=dtype, copy=False, order="C")
+ y = numpy.array(y, dtype=dtype, copy=False, order="C")
+
+ # Convert errors to float32
+ if xerror is not None:
+ xerror = numpy.array(xerror, dtype=numpy.float32, copy=False, order="C")
+ if yerror is not None:
+ yerror = numpy.array(yerror, dtype=numpy.float32, copy=False, order="C")
+
+ # Handle axes log scale: convert data
+
+ if self._plotFrame.xAxis.isLog:
+ logX = numpy.log10(x)
+
+ if xerror is not None:
+ # Transform xerror so that
+ # log10(x) +/- xerror' = log10(x +/- xerror)
+ if hasattr(xerror, "shape") and len(xerror.shape) == 2:
+ xErrorMinus, xErrorPlus = xerror[0], xerror[1]
+ else:
+ xErrorMinus, xErrorPlus = xerror, xerror
+ with numpy.errstate(divide="ignore", invalid="ignore"):
+ # Ignore divide by zero, invalid value encountered in log10
+ xErrorMinus = logX - numpy.log10(x - xErrorMinus)
+ xErrorPlus = numpy.log10(x + xErrorPlus) - logX
+ xerror = numpy.array((xErrorMinus, xErrorPlus), dtype=numpy.float32)
+
+ x = logX
+
+ isYLog = (yaxis == "left" and self._plotFrame.yAxis.isLog) or (
+ yaxis == "right" and self._plotFrame.y2Axis.isLog
+ )
+
+ if isYLog:
+ logY = numpy.log10(y)
+
+ if yerror is not None:
+ # Transform yerror so that
+ # log10(y) +/- yerror' = log10(y +/- yerror)
+ if hasattr(yerror, "shape") and len(yerror.shape) == 2:
+ yErrorMinus, yErrorPlus = yerror[0], yerror[1]
+ else:
+ yErrorMinus, yErrorPlus = yerror, yerror
+ with numpy.errstate(divide="ignore", invalid="ignore"):
+ # Ignore divide by zero, invalid value encountered in log10
+ yErrorMinus = logY - numpy.log10(y - yErrorMinus)
+ yErrorPlus = numpy.log10(y + yErrorPlus) - logY
+ yerror = numpy.array((yErrorMinus, yErrorPlus), dtype=numpy.float32)
+
+ y = logY
+
+ # TODO check if need more filtering of error (e.g., clip to positive)
+
+ # TODO check and improve this
+ if len(color) == 4 and type(color[3]) in [type(1), numpy.uint8, numpy.int8]:
+ color = numpy.array(color, dtype=numpy.float32) / 255.0
+
+ if isinstance(color, numpy.ndarray) and color.ndim == 2:
+ colorArray = color
+ color = None
+ else:
+ colorArray = None
+ color = colors.rgba(color)
+
+ if alpha < 1.0: # Apply image transparency
+ if colorArray is not None and colorArray.shape[1] == 4:
+ # multiply alpha channel
+ colorArray[:, 3] = colorArray[:, 3] * alpha
+ if color is not None:
+ color = color[0], color[1], color[2], color[3] * alpha
+
+ fillColor = None
+ if fill is True:
+ fillColor = color
+
+ dashoffset, dashpattern = self._lineStyleToDashOffsetPattern(linestyle)
+ curve = glutils.GLPlotCurve2D(
+ x,
+ y,
+ colorArray,
+ xError=xerror,
+ yError=yerror,
+ lineColor=color,
+ lineGapColor=gapcolor,
+ lineWidth=linewidth,
+ lineDashOffset=dashoffset,
+ lineDashPattern=dashpattern,
+ marker=symbol,
+ markerColor=color,
+ markerSize=symbolsize,
+ fillColor=fillColor,
+ baseline=baseline,
+ isYLog=isYLog,
+ )
+ curve.yaxis = "left" if yaxis is None else yaxis
+
+ if yaxis == "right":
+ self._plotFrame.isY2Axis = True
+
+ return curve
+
+ def addImage(self, data, origin, scale, colormap, alpha):
+ for parameter in (data, origin, scale):
+ assert parameter is not None
+
+ if data.ndim == 2:
+ # Ensure array is contiguous and eventually convert its type
+ dtypes = [
+ dtype
+ for dtype in (numpy.float32, numpy.float16, numpy.uint8, numpy.uint16)
+ if glu.isSupportedGLType(dtype)
+ ]
+ if data.dtype in dtypes:
+ data = numpy.array(data, copy=False, order="C")
+ else:
+ _logger.info("addImage: Convert %s data to float32", str(data.dtype))
+ data = numpy.array(data, dtype=numpy.float32, order="C")
+
+ normalization = colormap.getNormalization()
+ if normalization in glutils.GLPlotColormap.SUPPORTED_NORMALIZATIONS:
+ # Fast path applying colormap on the GPU
+ cmapRange = colormap.getColormapRange(data=data)
+ colormapLut = colormap.getNColors(nbColors=256)
+ gamma = colormap.getGammaNormalizationParameter()
+ nanColor = colors.rgba(colormap.getNaNColor())
+
+ image = glutils.GLPlotColormap(
+ data,
+ origin,
+ scale,
+ colormapLut,
+ normalization,
+ gamma,
+ cmapRange,
+ alpha,
+ nanColor,
+ )
+
+ else: # Fallback applying colormap on CPU
+ rgba = colormap.applyToData(data)
+ image = glutils.GLPlotRGBAImage(rgba, origin, scale, alpha)
+
+ elif len(data.shape) == 3:
+ # For RGB, RGBA data
+ assert data.shape[2] in (3, 4)
+
+ if numpy.issubdtype(data.dtype, numpy.floating):
+ data = numpy.array(data, dtype=numpy.float32, copy=False)
+ elif data.dtype in [numpy.uint8, numpy.uint16]:
+ pass
+ elif numpy.issubdtype(data.dtype, numpy.integer):
+ data = numpy.array(data, dtype=numpy.uint8, copy=False)
+ else:
+ raise ValueError("Unsupported data type")
+
+ image = glutils.GLPlotRGBAImage(data, origin, scale, alpha)
+
+ else:
+ raise RuntimeError("Unsupported data shape {0}".format(data.shape))
+
+ # TODO is this needed?
+ if self._plotFrame.xAxis.isLog and image.xMin <= 0.0:
+ raise RuntimeError("Cannot add image with X <= 0 with X axis log scale")
+ if self._plotFrame.yAxis.isLog and image.yMin <= 0.0:
+ raise RuntimeError("Cannot add image with Y <= 0 with Y axis log scale")
+
+ return image
+
+ def addTriangles(self, x, y, triangles, color, alpha):
+ # Handle axes log scale: convert data
+ if self._plotFrame.xAxis.isLog:
+ x = numpy.log10(x)
+ if self._plotFrame.yAxis.isLog:
+ y = numpy.log10(y)
+
+ triangles = glutils.GLPlotTriangles(x, y, color, triangles, alpha)
+
+ return triangles
+
+ def addShape(
+ self, x, y, shape, color, fill, overlay, linestyle, linewidth, gapcolor
+ ):
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ # TODO is this needed?
+ if self._plotFrame.xAxis.isLog and x.min() <= 0.0:
+ raise RuntimeError("Cannot add item with X <= 0 with X axis log scale")
+ if self._plotFrame.yAxis.isLog and y.min() <= 0.0:
+ raise RuntimeError("Cannot add item with Y <= 0 with Y axis log scale")
+
+ dashoffset, dashpattern = self._lineStyleToDashOffsetPattern(linestyle)
+ return _ShapeItem(
+ x,
+ y,
+ shape,
+ color,
+ fill,
+ overlay,
+ linewidth,
+ dashoffset,
+ dashpattern,
+ gapcolor,
+ )
+
+ def getDefaultFont(self):
+ """Returns the default font, used by raw markers and axes labels"""
+ if self._defaultFont is None:
+ from matplotlib.font_manager import findfont, FontProperties
+
+ font_filename = findfont(FontProperties(family=["sans-serif"]))
+ _logger.debug("Load font from mpl: %s", font_filename)
+ id = qt.QFontDatabase.addApplicationFont(font_filename)
+ family = qt.QFontDatabase.applicationFontFamilies(id)[0]
+ font = qt.QFont(family, 10, qt.QFont.Normal, False)
+ font.setStyleStrategy(qt.QFont.PreferAntialias)
+ self._defaultFont = font
+ return self._defaultFont
+
+ def addMarker(
+ self,
+ x,
+ y,
+ text,
+ color,
+ symbol,
+ linestyle,
+ linewidth,
+ constraint,
+ yaxis,
+ font,
+ bgcolor: RGBAColorType | None,
+ ):
+ if font is None:
+ font = self.getDefaultFont()
+
+ dashoffset, dashpattern = self._lineStyleToDashOffsetPattern(linestyle)
+ return _MarkerItem(
+ x,
+ y,
+ text,
+ color,
+ symbol,
+ linewidth,
+ dashoffset,
+ dashpattern,
+ constraint,
+ yaxis,
+ font,
+ bgcolor,
+ )
+
+ # Remove methods
+
+ def remove(self, item):
+ if isinstance(item, glutils.GLPlotItem):
+ if item.yaxis == "right":
+ # Check if some curves remains on the right Y axis
+ y2AxisItems = (
+ item
+ for item in self._plot.getItems()
+ if isinstance(item, items.YAxisMixIn) and item.getYAxis() == "right"
+ )
+ self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None
+
+ if item.isInitialized():
+ self._glGarbageCollector.append(item)
+
+ elif isinstance(item, (_MarkerItem, _ShapeItem)):
+ pass # No-op
+
+ else:
+ _logger.error("Unsupported item: %s", str(item))
+
+ # Interaction methods
+
+ _QT_CURSORS = {
+ BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
+ BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
+ BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
+ BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
+ }
+
+ def setGraphCursorShape(self, cursor):
+ if cursor is None:
+ super(BackendOpenGL, self).unsetCursor()
+ else:
+ cursor = self._QT_CURSORS[cursor]
+ super(BackendOpenGL, self).setCursor(qt.QCursor(cursor))
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ if linestyle != "-":
+ _logger.warning("BackendOpenGL.setGraphCursor linestyle parameter ignored")
+
+ if flag:
+ color = colors.rgba(color)
+ crosshairCursor = color, linewidth
+ else:
+ crosshairCursor = None
+
+ if crosshairCursor != self._crosshairCursor:
+ self._crosshairCursor = crosshairCursor
+
+ _PICK_OFFSET = 3 # Offset in pixel used for picking
+
+ def _mouseInPlotArea(self, x, y):
+ """Returns closest visible position in the plot.
+
+ This is performed in Qt widget pixel, not device pixel.
+
+ :param float x: X coordinate in Qt widget pixel
+ :param float y: Y coordinate in Qt widget pixel
+ :return: (x, y) closest point in the plot.
+ :rtype: List[float]
+ """
+ left, top, width, height = self.getPlotBoundsInPixels()
+ return (
+ numpy.clip(x, left, left + width - 1), # TODO -1?
+ numpy.clip(y, top, top + height - 1),
+ )
+
+ def __pickCurves(self, item, x, y):
+ """Perform picking on a curve item.
+
+ :param GLPlotCurve2D item:
+ :param float x: X position of the mouse in widget coordinates
+ :param float y: Y position of the mouse in widget coordinates
+ :return: List of indices of picked points or None if not picked
+ :rtype: Union[List[int],None]
+ """
+ offset = self._PICK_OFFSET
+ if item.marker is not None:
+ # Convert markerSize from points to qt pixels
+ qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
+ size = item.markerSize / 72.0 * qtDpi
+ offset = max(size / 2.0, offset)
+ if item.lineDashPattern is not None:
+ # Convert line width from points to qt pixels
+ qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
+ lineWidth = item.lineWidth / 72.0 * qtDpi
+ offset = max(lineWidth / 2.0, offset)
+
+ inAreaPos = self._mouseInPlotArea(x - offset, y - offset)
+ dataPos = self._plot.pixelToData(
+ inAreaPos[0], inAreaPos[1], axis=item.yaxis, check=True
+ )
+ if dataPos is None:
+ return None
+ xPick0, yPick0 = dataPos
+
+ inAreaPos = self._mouseInPlotArea(x + offset, y + offset)
+ dataPos = self._plot.pixelToData(
+ inAreaPos[0], inAreaPos[1], axis=item.yaxis, check=True
+ )
+ if dataPos is None:
+ return None
+ xPick1, yPick1 = dataPos
+
+ if xPick0 < xPick1:
+ xPickMin, xPickMax = xPick0, xPick1
+ else:
+ xPickMin, xPickMax = xPick1, xPick0
+
+ if yPick0 < yPick1:
+ yPickMin, yPickMax = yPick0, yPick1
+ else:
+ yPickMin, yPickMax = yPick1, yPick0
+
+ # Apply log scale if axis is log
+ if self._plotFrame.xAxis.isLog:
+ xPickMin = numpy.log10(xPickMin)
+ xPickMax = numpy.log10(xPickMax)
+
+ if (item.yaxis == "left" and self._plotFrame.yAxis.isLog) or (
+ item.yaxis == "right" and self._plotFrame.y2Axis.isLog
+ ):
+ yPickMin = numpy.log10(yPickMin)
+ yPickMax = numpy.log10(yPickMax)
+
+ return item.pick(xPickMin, yPickMin, xPickMax, yPickMax)
+
+ def pickItem(self, x, y, item):
+ # Picking is performed in Qt widget pixels not device pixels
+ dataPos = self._plot.pixelToData(x, y, axis="left", check=True)
+ if dataPos is None:
+ return None # Outside plot area
+
+ if item is None:
+ _logger.error("No item provided for picking")
+ return None
+
+ # Pick markers
+ if isinstance(item, _MarkerItem):
+ yaxis = item["yaxis"]
+ pixelPos = self._plot.dataToPixel(
+ item["x"], item["y"], axis=yaxis, check=False
+ )
+ if pixelPos is None:
+ return None # negative coord on a log axis
+
+ if item["x"] is None: # Horizontal line
+ pt1 = self._plot.pixelToData(
+ x, y - self._PICK_OFFSET, axis=yaxis, check=False
+ )
+ pt2 = self._plot.pixelToData(
+ x, y + self._PICK_OFFSET, axis=yaxis, check=False
+ )
+ isPicked = min(pt1[1], pt2[1]) <= item["y"] <= max(pt1[1], pt2[1])
+
+ elif item["y"] is None: # Vertical line
+ pt1 = self._plot.pixelToData(
+ x - self._PICK_OFFSET, y, axis=yaxis, check=False
+ )
+ pt2 = self._plot.pixelToData(
+ x + self._PICK_OFFSET, y, axis=yaxis, check=False
+ )
+ isPicked = min(pt1[0], pt2[0]) <= item["x"] <= max(pt1[0], pt2[0])
+
+ else:
+ isPicked = (
+ numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET
+ and numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET
+ )
+
+ return (0,) if isPicked else None
+
+ # Pick image, curve, triangles
+ elif isinstance(item, glutils.GLPlotItem):
+ if isinstance(item, glutils.GLPlotCurve2D):
+ return self.__pickCurves(item, x, y)
+ else:
+ return item.pick(*dataPos) # Might be None
+
+ # Update curve
+
+ def setCurveColor(self, curve, color):
+ pass # TODO
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ return self
+
+ def postRedisplay(self):
+ self.update()
+
+ def replot(self):
+ self.update() # async redraw
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ if dpi is not None:
+ _logger.warning("saveGraph ignores dpi parameter")
+
+ if fileFormat not in ["png", "ppm", "svg", "tif", "tiff"]:
+ raise NotImplementedError("Unsupported format: %s" % fileFormat)
+
+ if not self.isValid():
+ _logger.error("OpenGL 2.1 not available, cannot save OpenGL image")
+ width, height = self._plotFrame.size
+ data = numpy.zeros((height, width, 3), dtype=numpy.uint8)
+ else:
+ self.makeCurrent()
+
+ data = numpy.empty(
+ (self._plotFrame.size[1], self._plotFrame.size[0], 3),
+ dtype=numpy.uint8,
+ order="C",
+ )
+
+ context = self.context()
+ framebufferTexture = self._plotFBOs.get(context)
+ if framebufferTexture is None:
+ # Fallback, supports direct rendering mode: _paintDirectGL
+ # might have issues as it can read on-screen framebuffer
+ fboName = self.defaultFramebufferObject()
+ width, height = self._plotFrame.size
+ else:
+ fboName = framebufferTexture.name
+ height, width = framebufferTexture.shape
+
+ previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fboName)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, data)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer)
+
+ # glReadPixels gives bottom to top,
+ # while images are stored as top to bottom
+ data = numpy.flipud(data)
+
+ # fileName is either a file-like object or a str
+ saveImageToFile(data, fileName, fileFormat)
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ self._plotFrame.title = title
+
+ def setGraphXLabel(self, label):
+ self._plotFrame.xAxis.title = label
+
+ def setGraphYLabel(self, label, axis):
+ if axis == "left":
+ self._plotFrame.yAxis.title = label
+ else: # right axis
+ self._plotFrame.y2Axis.title = label
+
+ # Graph limits
+
+ def _setDataRanges(self, xlim=None, ylim=None, y2lim=None):
+ """Set the visible range of data in the plot frame.
+
+ This clips the ranges to possible values (takes care of float32
+ range + positive range for log).
+ This also takes care of non-orthogonal axes.
+
+ This should be moved to PlotFrame.
+ """
+ # Update axes range with a clipped range if too wide
+ self._plotFrame.setDataRanges(xlim, ylim, y2lim)
+
+ def _ensureAspectRatio(self, keepDim=None):
+ """Update plot bounds in order to keep aspect ratio.
+
+ Warning: keepDim on right Y axis is not implemented !
+
+ :param str keepDim: The dimension to maintain: 'x', 'y' or None.
+ If None (the default), the dimension with the largest range.
+ """
+ plotWidth, plotHeight = self._plotFrame.plotSize
+ if plotWidth <= 2 or plotHeight <= 2:
+ return
+
+ if keepDim is None:
+ ranges = self._plot.getDataRange()
+ if (
+ ranges.y is not None
+ and ranges.x is not None
+ and (ranges.y[1] - ranges.y[0]) != 0.0
+ ):
+ dataRatio = (ranges.x[1] - ranges.x[0]) / float(
+ ranges.y[1] - ranges.y[0]
+ )
+ plotRatio = plotWidth / float(plotHeight) # Test != 0 before
+
+ keepDim = "x" if dataRatio > plotRatio else "y"
+ else: # Limit case
+ keepDim = "x"
+
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self._plotFrame.dataRanges
+ if keepDim == "y":
+ dataW = (yMax - yMin) * plotWidth / float(plotHeight)
+ xCenter = 0.5 * (xMin + xMax)
+ xMin = xCenter - 0.5 * dataW
+ xMax = xCenter + 0.5 * dataW
+ elif keepDim == "x":
+ dataH = (xMax - xMin) * plotHeight / float(plotWidth)
+ yCenter = 0.5 * (yMin + yMax)
+ yMin = yCenter - 0.5 * dataH
+ yMax = yCenter + 0.5 * dataH
+ y2Center = 0.5 * (y2Min + y2Max)
+ y2Min = y2Center - 0.5 * dataH
+ y2Max = y2Center + 0.5 * dataH
+ else:
+ raise RuntimeError("Unsupported dimension to keep: %s" % keepDim)
+
+ # Update plot frame bounds
+ self._setDataRanges(xlim=(xMin, xMax), ylim=(yMin, yMax), y2lim=(y2Min, y2Max))
+
+ def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None, keepDim=None):
+ # Update axes range with a clipped range if too wide
+ self._setDataRanges(xlim=xRange, ylim=yRange, y2lim=y2Range)
+
+ # Keep data aspect ratio
+ if self.isKeepDataAspectRatio():
+ self._ensureAspectRatio(keepDim)
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ assert xmin < xmax
+ assert ymin < ymax
+
+ if y2min is None or y2max is None:
+ y2Range = None
+ else:
+ assert y2min < y2max
+ y2Range = y2min, y2max
+ self._setPlotBounds((xmin, xmax), (ymin, ymax), y2Range)
+
+ def getGraphXLimits(self):
+ return self._plotFrame.dataRanges.x
+
+ def setGraphXLimits(self, xmin, xmax):
+ assert xmin < xmax
+ self._setPlotBounds(xRange=(xmin, xmax), keepDim="x")
+
+ def getGraphYLimits(self, axis):
+ assert axis in ("left", "right")
+ if axis == "left":
+ return self._plotFrame.dataRanges.y
+ else:
+ return self._plotFrame.dataRanges.y2
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ assert ymin < ymax
+ assert axis in ("left", "right")
+
+ if axis == "left":
+ self._setPlotBounds(yRange=(ymin, ymax), keepDim="y")
+ else:
+ self._setPlotBounds(y2Range=(ymin, ymax), keepDim="y")
+
+ # Graph axes
+
+ def getXAxisTimeZone(self):
+ return self._plotFrame.xAxis.timeZone
+
+ def setXAxisTimeZone(self, tz):
+ self._plotFrame.xAxis.timeZone = tz
+
+ def isXAxisTimeSeries(self):
+ return self._plotFrame.xAxis.isTimeSeries
+
+ def setXAxisTimeSeries(self, isTimeSeries):
+ self._plotFrame.xAxis.isTimeSeries = isTimeSeries
+
+ def setXAxisLogarithmic(self, flag):
+ if flag != self._plotFrame.xAxis.isLog:
+ if flag and self._keepDataAspectRatio:
+ _logger.warning("KeepDataAspectRatio is ignored with log axes")
+
+ self._plotFrame.xAxis.isLog = flag
+
+ def setYAxisLogarithmic(self, flag):
+ if flag != self._plotFrame.yAxis.isLog or flag != self._plotFrame.y2Axis.isLog:
+ if flag and self._keepDataAspectRatio:
+ _logger.warning("KeepDataAspectRatio is ignored with log axes")
+
+ self._plotFrame.yAxis.isLog = flag
+ self._plotFrame.y2Axis.isLog = flag
+
+ def setYAxisInverted(self, flag):
+ if flag != self._plotFrame.isYAxisInverted:
+ self._plotFrame.isYAxisInverted = flag
+
+ def isYAxisInverted(self):
+ return self._plotFrame.isYAxisInverted
+
+ def isYRightAxisVisible(self):
+ return self._plotFrame.isY2Axis
+
+ def isKeepDataAspectRatio(self):
+ if self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog:
+ return False
+ else:
+ return self._keepDataAspectRatio
+
+ def setKeepDataAspectRatio(self, flag):
+ if flag and (self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog):
+ _logger.warning("KeepDataAspectRatio is ignored with log axes")
+
+ self._keepDataAspectRatio = flag
+
+ def setGraphGrid(self, which):
+ assert which in (None, "major", "both")
+ self._plotFrame.grid = which is not None # TODO True grid support
+
+ # Data <-> Pixel coordinates conversion
+
+ def dataToPixel(self, x, y, axis):
+ result = self._plotFrame.dataToPixel(x, y, axis)
+ if result is None:
+ return None
+ else:
+ devicePixelRatio = self.getDevicePixelRatio()
+ return tuple(value / devicePixelRatio for value in result)
+
+ def pixelToData(self, x, y, axis):
+ devicePixelRatio = self.getDevicePixelRatio()
+ return self._plotFrame.pixelToData(
+ x * devicePixelRatio, y * devicePixelRatio, axis
+ )
+
+ def getPlotBoundsInPixels(self):
+ devicePixelRatio = self.getDevicePixelRatio()
+ return tuple(
+ int(value / devicePixelRatio)
+ for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize
+ )
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ self._plotFrame.marginRatios = left, top, right, bottom
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._plotFrame.foregroundColor = foregroundColor
+ self._plotFrame.gridColor = gridColor
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._backgroundColor = backgroundColor
+ self._dataBackgroundColor = dataBackgroundColor
diff --git a/src/silx/gui/plot/backends/__init__.py b/src/silx/gui/plot/backends/__init__.py
new file mode 100644
index 0000000..d75a943
--- /dev/null
+++ b/src/silx/gui/plot/backends/__init__.py
@@ -0,0 +1,28 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package implements the backend of the Plot."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotCurve.py b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py
new file mode 100644
index 0000000..26442d7
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py
@@ -0,0 +1,1494 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides classes to render 2D lines and scatter plots
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import math
+import logging
+
+import numpy
+
+from silx.math.combo import min_max
+
+from ...._glutils import gl
+from ...._glutils import Program, vertexBuffer, VertexBufferAttrib
+from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate
+from .GLPlotImage import GLPlotItem
+
+
+_logger = logging.getLogger(__name__)
+
+
+_MPL_NONES = None, "None", "", " "
+"""Possible values for None"""
+
+
+def _notNaNSlices(array, length=1):
+ """Returns slices of none NaN values in the array.
+
+ :param numpy.ndarray array: 1D array from which to get slices
+ :param int length: Slices shorter than length gets discarded
+ :return: Array of (start, end) slice indices
+ :rtype: numpy.ndarray
+ """
+ isnan = numpy.isnan(numpy.array(array, copy=False).reshape(-1))
+ notnan = numpy.logical_not(isnan)
+ start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1
+ if notnan[0]:
+ start = numpy.append(0, start)
+ end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1
+ if notnan[-1]:
+ end = numpy.append(end, len(array))
+ slices = numpy.transpose((start, end))
+ if length > 1:
+ # discard slices with less than length values
+ slices = slices[numpy.diff(slices, axis=1).ravel() >= length]
+ return slices
+
+
+# fill ########################################################################
+
+
+class _Fill2D(object):
+ """Object rendering curve filling as polygons
+
+ :param numpy.ndarray xData: X coordinates of points
+ :param numpy.ndarray yData: Y coordinates of points
+ :param float baseline: Y value of the 'bottom' of the fill.
+ 0 for linear Y scale, -38 for log Y scale
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ _PROGRAM = Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0);
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ uniform vec4 color;
+
+ void main(void) {
+ gl_FragColor = color;
+ }
+ """,
+ attrib0="xPos",
+ )
+
+ def __init__(
+ self,
+ xData=None,
+ yData=None,
+ baseline=0,
+ color=(0.0, 0.0, 0.0, 1.0),
+ offset=(0.0, 0.0),
+ ):
+ self.xData = xData
+ self.yData = yData
+ self._xFillVboData = None
+ self._yFillVboData = None
+ self.color = color
+ self.offset = offset
+
+ # Offset baseline
+ self.baseline = baseline - self.offset[1]
+
+ def prepare(self):
+ """Rendering preparation: build indices and bounding box vertices"""
+ if (
+ self._xFillVboData is None
+ and self.xData is not None
+ and self.yData is not None
+ ):
+ # Get slices of not NaN values longer than 1 element
+ isnan = numpy.logical_or(numpy.isnan(self.xData), numpy.isnan(self.yData))
+ notnan = numpy.logical_not(isnan)
+ start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1
+ if notnan[0]:
+ start = numpy.append(0, start)
+ end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1
+ if notnan[-1]:
+ end = numpy.append(end, len(isnan))
+ slices = numpy.transpose((start, end))
+ # discard slices with less than length values
+ slices = slices[numpy.diff(slices, axis=1).reshape(-1) >= 2]
+
+ # Number of points: slice + 2 * leading and trailing points
+ # Twice leading and trailing points to produce degenerated triangles
+ nbPoints = numpy.sum(numpy.diff(slices, axis=1)) * 2 + 4 * len(slices)
+ points = numpy.empty((nbPoints, 2), dtype=numpy.float32)
+
+ offset = 0
+ # invert baseline for filling
+ new_y_data = numpy.append(self.yData, self.baseline)
+ for start, end in slices:
+ # Duplicate first point for connecting degenerated triangle
+ points[offset : offset + 2] = self.xData[start], new_y_data[start]
+
+ # 2nd point of the polygon is last point
+ points[offset + 2] = self.xData[start], self.baseline[start]
+
+ indices = numpy.append(
+ numpy.arange(start, end),
+ numpy.arange(
+ len(self.xData) + end - 1, len(self.xData) + start - 1, -1
+ ),
+ )
+ indices = indices[buildFillMaskIndices(len(indices))]
+
+ points[offset + 3 : offset + 3 + len(indices), 0] = self.xData[
+ indices % len(self.xData)
+ ]
+ points[offset + 3 : offset + 3 + len(indices), 1] = new_y_data[indices]
+
+ # Duplicate last point for connecting degenerated triangle
+ points[offset + 3 + len(indices)] = points[
+ offset + 3 + len(indices) - 1
+ ]
+
+ offset += len(indices) + 4
+
+ self._xFillVboData, self._yFillVboData = vertexBuffer(points.T)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ self.prepare()
+
+ if self._xFillVboData is None:
+ return # Nothing to display
+
+ self._PROGRAM.use()
+
+ gl.glUniformMatrix4fv(
+ self._PROGRAM.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ numpy.dot(context.matrix, mat4Translate(*self.offset)).astype(
+ numpy.float32
+ ),
+ )
+
+ gl.glUniform4f(self._PROGRAM.uniforms["color"], *self.color)
+
+ xPosAttrib = self._PROGRAM.attributes["xPos"]
+ yPosAttrib = self._PROGRAM.attributes["yPos"]
+
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ self._xFillVboData.setVertexAttrib(xPosAttrib)
+
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ self._yFillVboData.setVertexAttrib(yPosAttrib)
+
+ # Prepare fill mask
+ gl.glEnable(gl.GL_STENCIL_TEST)
+ gl.glStencilMask(1)
+ gl.glStencilFunc(gl.GL_ALWAYS, 1, 1)
+ gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT)
+ gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_FALSE)
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, self._xFillVboData.size)
+
+ gl.glStencilFunc(gl.GL_EQUAL, 1, 1)
+ # Reset stencil while drawing
+ gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO)
+ gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_TRUE)
+
+ # Draw directly in NDC
+ gl.glUniformMatrix4fv(
+ self._PROGRAM.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ mat4Identity().astype(numpy.float32),
+ )
+
+ # NDC vertices
+ gl.glVertexAttribPointer(
+ xPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ numpy.array((-1.0, -1.0, 1.0, 1.0), dtype=numpy.float32),
+ )
+ gl.glVertexAttribPointer(
+ yPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ numpy.array((-1.0, 1.0, -1.0, 1.0), dtype=numpy.float32),
+ )
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4)
+
+ gl.glDisable(gl.GL_STENCIL_TEST)
+
+ def discard(self):
+ """Release VBOs"""
+ if self.isInitialized():
+ self._xFillVboData.vbo.discard()
+
+ self._xFillVboData = None
+ self._yFillVboData = None
+
+ def isInitialized(self):
+ return self._xFillVboData is not None
+
+
+# line ########################################################################
+
+
+class GLLines2D(object):
+ """Object rendering curve as a polyline
+
+ :param xVboData: X coordinates VBO
+ :param yVboData: Y coordinates VBO
+ :param colorVboData: VBO of colors
+ :param distVboData: VBO of distance along the polyline
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param float width: Line width
+ :param List[float] dashPattern:
+ "unscaled" dash pattern as 4 lengths in points (dash1, gap1, dash2, gap2).
+ This pattern is scaled with the line width.
+ Set to () to draw solid lines (default), and to None to disable rendering.
+ :param float dashOffset: The offset in points the patterns starts at.
+ The offset is scaled with the line width.
+ :param drawMode: OpenGL drawing mode
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ _SOLID_PROGRAM = Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0., 1.) ;
+ vColor = color;
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_FragColor = vColor;
+ }
+ """,
+ attrib0="xPos",
+ )
+
+ # Limitation: Dash using an estimate of distance in screen coord
+ # to avoid computing distance when viewport is resized
+ # results in inequal dashes when viewport aspect ratio is far from 1
+ _DASH_PROGRAM = Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ uniform float distanceScale;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+ attribute float distance;
+
+ varying float vDist;
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0., 1.);
+ vDist = distance * distanceScale;
+ vColor = color;
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ /* Dashes: [0, x], [y, z]
+ Dash period: w */
+ uniform vec4 dash;
+ uniform float dashOffset;
+ uniform vec4 gapColor;
+
+ varying float vDist;
+ varying vec4 vColor;
+
+ void main(void) {
+ float dist = mod(vDist + dashOffset, dash.w);
+ if ((dist > dash.x && dist < dash.y) || dist > dash.z) {
+ if (gapColor.a == 0.) {
+ discard; // Discard full transparent bg color
+ } else {
+ gl_FragColor = gapColor;
+ }
+ } else {
+ gl_FragColor = vColor;
+ }
+ }
+ """,
+ attrib0="xPos",
+ )
+
+ def __init__(
+ self,
+ xVboData=None,
+ yVboData=None,
+ colorVboData=None,
+ distVboData=None,
+ color=(0.0, 0.0, 0.0, 1.0),
+ gapColor=None,
+ width=1,
+ dashOffset=0.0,
+ dashPattern=(),
+ drawMode=None,
+ offset=(0.0, 0.0),
+ ):
+ if xVboData is not None and not isinstance(xVboData, VertexBufferAttrib):
+ xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32)
+ self.xVboData = xVboData
+
+ if yVboData is not None and not isinstance(yVboData, VertexBufferAttrib):
+ yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32)
+ self.yVboData = yVboData
+
+ # Compute distances if not given while providing numpy array coordinates
+ if (
+ isinstance(self.xVboData, numpy.ndarray)
+ and isinstance(self.yVboData, numpy.ndarray)
+ and distVboData is None
+ ):
+ distVboData = distancesFromArrays(self.xVboData, self.yVboData)
+
+ if distVboData is not None and not isinstance(distVboData, VertexBufferAttrib):
+ distVboData = numpy.array(distVboData, copy=False, dtype=numpy.float32)
+ self.distVboData = distVboData
+
+ if colorVboData is not None:
+ assert isinstance(colorVboData, VertexBufferAttrib)
+ self.colorVboData = colorVboData
+ self.useColorVboData = colorVboData is not None
+
+ self.color = color
+ self.gapColor = gapColor
+ self.width = width
+ self.dashPattern = dashPattern
+ self.dashOffset = dashOffset
+ self.offset = offset
+
+ self._drawMode = drawMode if drawMode is not None else gl.GL_LINE_STRIP
+
+ @classmethod
+ def init(cls):
+ """OpenGL context initialization"""
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ if self.dashPattern is None: # Nothing to display
+ return
+
+ if self.dashPattern == (): # No dash: solid line
+ program = self._SOLID_PROGRAM
+ program.use()
+
+ else: # Dashed line defined by 4 control points
+ program = self._DASH_PROGRAM
+ program.use()
+
+ # Scale pattern by width, convert from lengths in points to offsets in pixels
+ scale = self.width / 72.0 * context.dpi
+ dashOffsets = tuple(
+ offset * scale for offset in numpy.cumsum(self.dashPattern)
+ )
+ gl.glUniform4f(program.uniforms["dash"], *dashOffsets)
+ gl.glUniform1f(program.uniforms["dashOffset"], self.dashOffset * scale)
+
+ if self.gapColor is None:
+ # Use fully transparent color which gets discarded in shader
+ gapColor = (0.0, 0.0, 0.0, 0.0)
+ else:
+ gapColor = self.gapColor
+ gl.glUniform4f(program.uniforms["gapColor"], *gapColor)
+
+ viewWidth = gl.glGetFloatv(gl.GL_VIEWPORT)[2]
+ xNDCPerData = (
+ numpy.dot(context.matrix, [1.0, 0.0, 0.0, 1.0])[0]
+ - numpy.dot(context.matrix, [0.0, 0.0, 0.0, 1.0])[0]
+ )
+ xPixelPerData = 0.5 * viewWidth * xNDCPerData
+ gl.glUniform1f(program.uniforms["distanceScale"], xPixelPerData)
+
+ distAttrib = program.attributes["distance"]
+ gl.glEnableVertexAttribArray(distAttrib)
+ if isinstance(self.distVboData, VertexBufferAttrib):
+ self.distVboData.setVertexAttrib(distAttrib)
+ else:
+ gl.glVertexAttribPointer(
+ distAttrib, 1, gl.GL_FLOAT, False, 0, self.distVboData
+ )
+
+ gl.glEnable(gl.GL_LINE_SMOOTH)
+
+ matrix = numpy.dot(context.matrix, mat4Translate(*self.offset)).astype(
+ numpy.float32
+ )
+ gl.glUniformMatrix4fv(program.uniforms["matrix"], 1, gl.GL_TRUE, matrix)
+
+ colorAttrib = program.attributes["color"]
+ if self.useColorVboData and self.colorVboData is not None:
+ gl.glEnableVertexAttribArray(colorAttrib)
+ self.colorVboData.setVertexAttrib(colorAttrib)
+ else:
+ gl.glDisableVertexAttribArray(colorAttrib)
+ gl.glVertexAttrib4f(colorAttrib, *self.color)
+
+ xPosAttrib = program.attributes["xPos"]
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ if isinstance(self.xVboData, VertexBufferAttrib):
+ self.xVboData.setVertexAttrib(xPosAttrib)
+ else:
+ gl.glVertexAttribPointer(
+ xPosAttrib, 1, gl.GL_FLOAT, False, 0, self.xVboData
+ )
+
+ yPosAttrib = program.attributes["yPos"]
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ if isinstance(self.yVboData, VertexBufferAttrib):
+ self.yVboData.setVertexAttrib(yPosAttrib)
+ else:
+ gl.glVertexAttribPointer(
+ yPosAttrib, 1, gl.GL_FLOAT, False, 0, self.yVboData
+ )
+
+ gl.glLineWidth(self.width / 72.0 * context.dpi)
+ gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
+
+ gl.glDisable(gl.GL_LINE_SMOOTH)
+
+
+def distancesFromArrays(xData, yData, ratio: float = 1.0):
+ """Returns distances between each points
+
+ :param numpy.ndarray xData: X coordinate of points
+ :param numpy.ndarray yData: Y coordinate of points
+ :param ratio: Y/X pixel per data resolution ratio
+ :rtype: numpy.ndarray
+ """
+ # Split array into sub-shapes at not finite points
+ splits = numpy.nonzero(
+ numpy.logical_not(
+ numpy.logical_and(numpy.isfinite(xData), numpy.isfinite(yData))
+ )
+ )[0]
+ splits = numpy.concatenate(([-1], splits, [len(xData) - 1]))
+
+ # Compute distance independently for each sub-shapes,
+ # putting not finite points as last points of sub-shapes
+ distances = []
+ for begin, end in zip(splits[:-1] + 1, splits[1:] + 1):
+ if begin == end: # Empty shape
+ continue
+ elif end - begin == 1: # Single element
+ distances.append(numpy.array([0], dtype=numpy.float32))
+ else:
+ deltas = numpy.dstack(
+ (
+ numpy.ediff1d(xData[begin:end], to_begin=numpy.float32(0.0)),
+ numpy.ediff1d(
+ yData[begin:end] * ratio, to_begin=numpy.float32(0.0)
+ ),
+ )
+ )[0]
+ distances.append(numpy.cumsum(numpy.sqrt(numpy.sum(deltas**2, axis=1))))
+ return numpy.concatenate(distances)
+
+
+# points ######################################################################
+
+DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK = (
+ "d",
+ "o",
+ "s",
+ "+",
+ "x",
+ ".",
+ ",",
+ "*",
+)
+
+H_LINE, V_LINE, HEART = "_", "|", "\u2665"
+
+TICK_LEFT = "tickleft"
+TICK_RIGHT = "tickright"
+TICK_UP = "tickup"
+TICK_DOWN = "tickdown"
+CARET_LEFT = "caretleft"
+CARET_RIGHT = "caretright"
+CARET_UP = "caretup"
+CARET_DOWN = "caretdown"
+
+
+class Points2D(object):
+ """Object rendering curve markers
+
+ :param xVboData: X coordinates VBO
+ :param yVboData: Y coordinates VBO
+ :param colorVboData: VBO of colors
+ :param str marker: Kind of symbol to use, see :attr:`MARKERS`.
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param float size: Marker size
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ MARKERS = (
+ DIAMOND,
+ CIRCLE,
+ SQUARE,
+ PLUS,
+ X_MARKER,
+ POINT,
+ PIXEL,
+ ASTERISK,
+ H_LINE,
+ V_LINE,
+ HEART,
+ TICK_LEFT,
+ TICK_RIGHT,
+ TICK_UP,
+ TICK_DOWN,
+ CARET_LEFT,
+ CARET_RIGHT,
+ CARET_UP,
+ CARET_DOWN,
+ )
+ """List of supported markers"""
+
+ _VERTEX_SHADER = """
+ #version 120
+
+ uniform mat4 matrix;
+ uniform int transform;
+ uniform float size;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0., 1.);
+ vColor = color;
+ gl_PointSize = size;
+ }
+ """
+
+ _FRAGMENT_SHADER_SYMBOLS = {
+ DIAMOND: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 centerCoord = abs(coord - vec2(0.5, 0.5));
+ float f = centerCoord.x + centerCoord.y;
+ return clamp(size * (0.5 - f), 0.0, 1.0);
+ }
+ """,
+ CIRCLE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float radius = 0.5;
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (radius - r), 0.0, 1.0);
+ }
+ """,
+ SQUARE: """
+ float alphaSymbol(vec2 coord, float size) {
+ return 1.0;
+ }
+ """,
+ PLUS: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 d = abs(size * (coord - vec2(0.5, 0.5)));
+ return local_smoothstep(1.5, 0.5, min(d.x, d.y));
+ }
+ """,
+ X_MARKER: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ return local_smoothstep(1.5, 0.5, min(d_x.x, d_x.y));
+ }
+ """,
+ ASTERISK: """
+ float alphaSymbol(vec2 coord, float size) {
+ /* Combining +, x and circle */
+ vec2 d_plus = abs(size * (coord - vec2(0.5, 0.5)));
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ if (min(d_plus.x, d_plus.y) < 0.5) {
+ return 1.0;
+ } else if (min(d_x.x, d_x.y) <= 0.5) {
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (0.5 - r), 0.0, 1.0);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ H_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float d = abs(size * (coord.y - 0.5));
+ return local_smoothstep(1.5, 0.5, d);
+ }
+ """,
+ V_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float d = abs(size * (coord.x - 0.5));
+ return local_smoothstep(1.5, 0.5, d);
+ }
+ """,
+ HEART: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = (coord - 0.5) * 2.;
+ coord *= 0.75;
+ coord.y += 0.25;
+ float a = atan(coord.x,-coord.y)/3.141593;
+ float r = length(coord);
+ float h = abs(a);
+ float d = (13.0*h - 22.0*h*h + 10.0*h*h*h)/(6.0-5.0*h);
+ float res = clamp(r-d, 0., 1.);
+ // antialiasing
+ res = local_smoothstep(0.1, 0.001, res);
+ return res;
+ }
+ """,
+ TICK_LEFT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dy = abs(coord.y);
+ if (coord.x > 0.5) {
+ return 0.0;
+ }
+ return local_smoothstep(1.5, 0.5, dy);
+ }
+ """,
+ TICK_RIGHT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dy = abs(coord.y);
+ if (coord.x < -0.5) {
+ return 0.0;
+ }
+ return local_smoothstep(1.5, 0.5, dy);
+ }
+ """,
+ TICK_UP: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dx = abs(coord.x);
+ if (coord.y > 0.5) {
+ return 0.0;
+ }
+ return local_smoothstep(1.5, 0.5, dx);
+ }
+ """,
+ TICK_DOWN: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dx = abs(coord.x);
+ if (coord.y < -0.5) {
+ return 0.0;
+ }
+ return local_smoothstep(1.5, 0.5, dx);
+ }
+ """,
+ CARET_LEFT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.x) - abs(coord.y);
+ if (d >= -0.1 && coord.x > 0.5) {
+ return local_smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_RIGHT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.x) - abs(coord.y);
+ if (d >= -0.1 && coord.x < 0.5) {
+ return local_smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_UP: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.y) - abs(coord.x);
+ if (d >= -0.1 && coord.y > 0.5) {
+ return local_smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_DOWN: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.y) - abs(coord.x);
+ if (d >= -0.1 && coord.y < 0.5) {
+ return local_smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ }
+
+ _FRAGMENT_SHADER_TEMPLATE = """
+ #version 120
+
+ uniform float size;
+
+ varying vec4 vColor;
+
+ /* smoothstep function implementation to support GLSL 1.20 */
+ float local_smoothstep(float edge0, float edge1, float x) {
+ float t;
+ t = clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0);
+ return t * t * (3.0 - 2.0 * t);
+ }
+
+ %s
+
+ void main(void) {
+ float alpha = alphaSymbol(gl_PointCoord, size);
+ if (alpha <= 0.0) {
+ discard;
+ } else {
+ gl_FragColor = vec4(vColor.rgb, alpha * clamp(vColor.a, 0.0, 1.0));
+ }
+ }
+ """
+
+ _PROGRAMS = {}
+
+ def __init__(
+ self,
+ xVboData=None,
+ yVboData=None,
+ colorVboData=None,
+ marker=SQUARE,
+ color=(0.0, 0.0, 0.0, 1.0),
+ size=7,
+ offset=(0.0, 0.0),
+ ):
+ self.color = color
+ self._marker = None
+ self.marker = marker
+ self.size = size
+ self.offset = offset
+
+ if xVboData is not None and not isinstance(xVboData, VertexBufferAttrib):
+ xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32)
+ self.xVboData = xVboData
+
+ if yVboData is not None and not isinstance(yVboData, VertexBufferAttrib):
+ yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32)
+ self.yVboData = yVboData
+
+ if colorVboData is not None:
+ assert isinstance(colorVboData, VertexBufferAttrib)
+ self.colorVboData = colorVboData
+ self.useColorVboData = colorVboData is not None
+
+ @property
+ def marker(self):
+ """Symbol used to display markers (str)"""
+ return self._marker
+
+ @marker.setter
+ def marker(self, marker):
+ if marker in _MPL_NONES:
+ self._marker = None
+ else:
+ assert marker in self.MARKERS
+ self._marker = marker
+
+ @classmethod
+ def _getProgram(cls, marker):
+ """On-demand shader program creation."""
+ if marker == PIXEL:
+ marker = SQUARE
+ elif marker == POINT:
+ marker = CIRCLE
+
+ if marker not in cls._PROGRAMS:
+ cls._PROGRAMS[marker] = Program(
+ vertexShader=cls._VERTEX_SHADER,
+ fragmentShader=(
+ cls._FRAGMENT_SHADER_TEMPLATE % cls._FRAGMENT_SHADER_SYMBOLS[marker]
+ ),
+ attrib0="xPos",
+ )
+
+ return cls._PROGRAMS[marker]
+
+ @classmethod
+ def init(cls):
+ """OpenGL context initialization"""
+ version = gl.getVersion()
+ majorVersion = version[0]
+ assert majorVersion >= 2
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ if majorVersion >= 3: # OpenGL 3
+ gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ if self.marker is None:
+ return
+
+ program = self._getProgram(self.marker)
+ program.use()
+
+ matrix = numpy.dot(context.matrix, mat4Translate(*self.offset)).astype(
+ numpy.float32
+ )
+ gl.glUniformMatrix4fv(program.uniforms["matrix"], 1, gl.GL_TRUE, matrix)
+
+ if self.marker == PIXEL:
+ size = 1
+ elif self.marker == POINT:
+ size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point
+ else:
+ size = self.size
+ size = size / 72.0 * context.dpi
+
+ if self.marker in (
+ PLUS,
+ H_LINE,
+ V_LINE,
+ TICK_LEFT,
+ TICK_RIGHT,
+ TICK_UP,
+ TICK_DOWN,
+ ):
+ # Convert to nearest odd number
+ size = size // 2 * 2 + 1.0
+
+ gl.glUniform1f(program.uniforms["size"], size)
+ # gl.glPointSize(self.size)
+
+ cAttrib = program.attributes["color"]
+ if self.useColorVboData and self.colorVboData is not None:
+ gl.glEnableVertexAttribArray(cAttrib)
+ self.colorVboData.setVertexAttrib(cAttrib)
+ else:
+ gl.glDisableVertexAttribArray(cAttrib)
+ gl.glVertexAttrib4f(cAttrib, *self.color)
+
+ xPosAttrib = program.attributes["xPos"]
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ if isinstance(self.xVboData, VertexBufferAttrib):
+ self.xVboData.setVertexAttrib(xPosAttrib)
+ else:
+ gl.glVertexAttribPointer(
+ xPosAttrib, 1, gl.GL_FLOAT, False, 0, self.xVboData
+ )
+
+ yPosAttrib = program.attributes["yPos"]
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ if isinstance(self.yVboData, VertexBufferAttrib):
+ self.yVboData.setVertexAttrib(yPosAttrib)
+ else:
+ gl.glVertexAttribPointer(
+ yPosAttrib, 1, gl.GL_FLOAT, False, 0, self.yVboData
+ )
+
+ gl.glDrawArrays(gl.GL_POINTS, 0, self.xVboData.size)
+
+
+# error bars ##################################################################
+
+
+class _ErrorBars(object):
+ """Display errors bars.
+
+ This is using its own VBO as opposed to fill/points/lines.
+ There is no picking on error bars.
+
+ It uses 2 vertices per error bars and uses :class:`GLLines2D` to
+ render error bars and :class:`_Points2D` to render the ends.
+
+ :param numpy.ndarray xData: X coordinates of the data.
+ :param numpy.ndarray yData: Y coordinates of the data.
+ :param xError: The absolute error on the X axis.
+ :type xError: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for negative errors,
+ row 1 for positive errors.
+ :param yError: The absolute error on the Y axis.
+ :type yError: A float, or a numpy.ndarray of float32. See xError.
+ :param float xMin: The min X value already computed by GLPlotCurve2D.
+ :param float yMin: The min Y value already computed by GLPlotCurve2D.
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ def __init__(
+ self,
+ xData,
+ yData,
+ xError,
+ yError,
+ xMin,
+ yMin,
+ color=(0.0, 0.0, 0.0, 1.0),
+ offset=(0.0, 0.0),
+ ):
+ self._attribs = None
+ self._xMin, self._yMin = xMin, yMin
+ self.offset = offset
+
+ if xError is not None or yError is not None:
+ self._xData = numpy.array(xData, order="C", dtype=numpy.float32, copy=False)
+ self._yData = numpy.array(yData, order="C", dtype=numpy.float32, copy=False)
+
+ # This also works if xError, yError is a float/int
+ self._xError = numpy.array(
+ xError, order="C", dtype=numpy.float32, copy=False
+ )
+ self._yError = numpy.array(
+ yError, order="C", dtype=numpy.float32, copy=False
+ )
+ else:
+ self._xData, self._yData = None, None
+ self._xError, self._yError = None, None
+
+ self._lines = GLLines2D(
+ None, None, color=color, drawMode=gl.GL_LINES, offset=offset
+ )
+ self._xErrPoints = Points2D(
+ None, None, color=color, marker=V_LINE, offset=offset
+ )
+ self._yErrPoints = Points2D(
+ None, None, color=color, marker=H_LINE, offset=offset
+ )
+
+ def _buildVertices(self):
+ """Generates error bars vertices"""
+ nbLinesPerDataPts = (0 if self._xError is None else 2) + (
+ 0 if self._yError is None else 2
+ )
+
+ nbDataPts = len(self._xData)
+
+ # interleave coord+error, coord-error.
+ # xError vertices first if any, then yError vertices if any.
+ xCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2, dtype=numpy.float32)
+ yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2, dtype=numpy.float32)
+
+ if self._xError is not None: # errors on the X axis
+ if len(self._xError.shape) == 2:
+ xErrorMinus, xErrorPlus = self._xError[0], self._xError[1]
+ else:
+ # numpy arrays of len 1 or len(xData)
+ xErrorMinus, xErrorPlus = self._xError, self._xError
+
+ # Interleave vertices for xError
+ endXError = 4 * nbDataPts
+ with numpy.errstate(invalid="ignore"):
+ xCoords[0 : endXError - 3 : 4] = self._xData + xErrorPlus
+ xCoords[1 : endXError - 2 : 4] = self._xData
+ xCoords[2 : endXError - 1 : 4] = self._xData
+ with numpy.errstate(invalid="ignore"):
+ xCoords[3:endXError:4] = self._xData - xErrorMinus
+
+ yCoords[0 : endXError - 3 : 4] = self._yData
+ yCoords[1 : endXError - 2 : 4] = self._yData
+ yCoords[2 : endXError - 1 : 4] = self._yData
+ yCoords[3:endXError:4] = self._yData
+
+ else:
+ endXError = 0
+
+ if self._yError is not None: # errors on the Y axis
+ if len(self._yError.shape) == 2:
+ yErrorMinus, yErrorPlus = self._yError[0], self._yError[1]
+ else:
+ # numpy arrays of len 1 or len(yData)
+ yErrorMinus, yErrorPlus = self._yError, self._yError
+
+ # Interleave vertices for yError
+ xCoords[endXError::4] = self._xData
+ xCoords[endXError + 1 :: 4] = self._xData
+ xCoords[endXError + 2 :: 4] = self._xData
+ xCoords[endXError + 3 :: 4] = self._xData
+
+ with numpy.errstate(invalid="ignore"):
+ yCoords[endXError::4] = self._yData + yErrorPlus
+ yCoords[endXError + 1 :: 4] = self._yData
+ yCoords[endXError + 2 :: 4] = self._yData
+ with numpy.errstate(invalid="ignore"):
+ yCoords[endXError + 3 :: 4] = self._yData - yErrorMinus
+
+ return xCoords, yCoords
+
+ def prepare(self):
+ """Rendering preparation: build indices and bounding box vertices"""
+ if self._xData is None:
+ return
+
+ if self._attribs is None:
+ xCoords, yCoords = self._buildVertices()
+
+ xAttrib, yAttrib = vertexBuffer((xCoords, yCoords))
+ self._attribs = xAttrib, yAttrib
+
+ self._lines.xVboData = xAttrib
+ self._lines.yVboData = yAttrib
+
+ # Set xError points using the same VBO as lines
+ self._xErrPoints.xVboData = xAttrib.copy()
+ self._xErrPoints.xVboData.size //= 2
+ self._xErrPoints.yVboData = yAttrib.copy()
+ self._xErrPoints.yVboData.size //= 2
+
+ # Set yError points using the same VBO as lines
+ self._yErrPoints.xVboData = xAttrib.copy()
+ self._yErrPoints.xVboData.size //= 2
+ self._yErrPoints.xVboData.offset += xAttrib.itemsize * xAttrib.size // 2
+ self._yErrPoints.yVboData = yAttrib.copy()
+ self._yErrPoints.yVboData.size //= 2
+ self._yErrPoints.yVboData.offset += yAttrib.itemsize * yAttrib.size // 2
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ self.prepare()
+
+ if self._attribs is not None:
+ self._lines.render(context)
+ self._xErrPoints.render(context)
+ self._yErrPoints.render(context)
+
+ def discard(self):
+ """Release VBOs"""
+ if self.isInitialized():
+ self._lines.xVboData, self._lines.yVboData = None, None
+ self._xErrPoints.xVboData, self._xErrPoints.yVboData = None, None
+ self._yErrPoints.xVboData, self._yErrPoints.yVboData = None, None
+ self._attribs[0].vbo.discard()
+ self._attribs = None
+
+ def isInitialized(self):
+ return self._attribs is not None
+
+
+# curves ######################################################################
+
+
+def _proxyProperty(*componentsAttributes):
+ """Create a property to access an attribute of attribute(s).
+ Useful for composition.
+ Supports multiple components this way:
+ getter returns the first found, setter sets all
+ """
+
+ def getter(self):
+ for compName, attrName in componentsAttributes:
+ try:
+ component = getattr(self, compName)
+ except AttributeError:
+ pass
+ else:
+ return getattr(component, attrName)
+
+ def setter(self, value):
+ for compName, attrName in componentsAttributes:
+ component = getattr(self, compName)
+ setattr(component, attrName, value)
+
+ return property(getter, setter)
+
+
+class GLPlotCurve2D(GLPlotItem):
+ def __init__(
+ self,
+ xData,
+ yData,
+ colorData=None,
+ xError=None,
+ yError=None,
+ lineColor=(0.0, 0.0, 0.0, 1.0),
+ lineGapColor=None,
+ lineWidth=1,
+ lineDashOffset=0.0,
+ lineDashPattern=(),
+ marker=SQUARE,
+ markerColor=(0.0, 0.0, 0.0, 1.0),
+ markerSize=7,
+ fillColor=None,
+ baseline=None,
+ isYLog=False,
+ ):
+ super().__init__()
+ self._ratio = None
+ self.colorData = colorData
+
+ # Compute x bounds
+ if xError is None:
+ self.xMin, self.xMax = min_max(xData, min_positive=False)
+ else:
+ # Takes the error into account
+ if hasattr(xError, "shape") and len(xError.shape) == 2:
+ xErrorMinus, xErrorPlus = xError[0], xError[1]
+ else:
+ xErrorMinus, xErrorPlus = xError, xError
+ self.xMin = numpy.nanmin(xData - xErrorMinus)
+ self.xMax = numpy.nanmax(xData + xErrorPlus)
+
+ # Compute y bounds
+ if yError is None:
+ self.yMin, self.yMax = min_max(yData, min_positive=False)
+ else:
+ # Takes the error into account
+ if hasattr(yError, "shape") and len(yError.shape) == 2:
+ yErrorMinus, yErrorPlus = yError[0], yError[1]
+ else:
+ yErrorMinus, yErrorPlus = yError, yError
+ self.yMin = numpy.nanmin(yData - yErrorMinus)
+ self.yMax = numpy.nanmax(yData + yErrorPlus)
+
+ # Handle data offset
+ if xData.itemsize > 4 or yData.itemsize > 4: # Use normalization
+ # offset data, do not offset error as it is relative
+ self.offset = self.xMin, self.yMin
+ with numpy.errstate(invalid="ignore"):
+ self.xData = (xData - self.offset[0]).astype(numpy.float32)
+ self.yData = (yData - self.offset[1]).astype(numpy.float32)
+
+ else: # float32
+ self.offset = 0.0, 0.0
+ self.xData = xData
+ self.yData = yData
+ if fillColor is not None:
+
+ def deduce_baseline(baseline):
+ if baseline is None:
+ _baseline = 0
+ else:
+ _baseline = baseline
+ if not isinstance(_baseline, numpy.ndarray):
+ _baseline = numpy.repeat(_baseline, len(self.xData))
+ if isYLog is True:
+ with numpy.errstate(divide="ignore", invalid="ignore"):
+ log_val = numpy.log10(_baseline)
+ _baseline = numpy.where(_baseline > 0.0, log_val, -38)
+ return _baseline
+
+ _baseline = deduce_baseline(baseline)
+
+ # Use different baseline depending of Y log scale
+ self.fill = _Fill2D(
+ self.xData,
+ self.yData,
+ baseline=_baseline,
+ color=fillColor,
+ offset=self.offset,
+ )
+ else:
+ self.fill = None
+
+ self._errorBars = _ErrorBars(
+ self.xData,
+ self.yData,
+ xError,
+ yError,
+ self.xMin,
+ self.yMin,
+ offset=self.offset,
+ )
+
+ self.lines = GLLines2D()
+ self.lines.color = lineColor
+ self.lines.gapColor = lineGapColor
+ self.lines.width = lineWidth
+ self.lines.dashOffset = lineDashOffset
+ self.lines.dashPattern = lineDashPattern
+ self.lines.offset = self.offset
+
+ self.points = Points2D()
+ self.points.marker = marker
+ self.points.color = markerColor
+ self.points.size = markerSize
+ self.points.offset = self.offset
+
+ xVboData = _proxyProperty(("lines", "xVboData"), ("points", "xVboData"))
+
+ yVboData = _proxyProperty(("lines", "yVboData"), ("points", "yVboData"))
+
+ colorVboData = _proxyProperty(("lines", "colorVboData"), ("points", "colorVboData"))
+
+ useColorVboData = _proxyProperty(
+ ("lines", "useColorVboData"), ("points", "useColorVboData")
+ )
+
+ distVboData = _proxyProperty(("lines", "distVboData"))
+
+ lineColor = _proxyProperty(("lines", "color"))
+
+ lineGapColor = _proxyProperty(("lines", "gapColor"))
+
+ lineWidth = _proxyProperty(("lines", "width"))
+
+ lineDashOffset = _proxyProperty(("lines", "dashOffset"))
+
+ lineDashPattern = _proxyProperty(("lines", "dashPattern"))
+
+ marker = _proxyProperty(("points", "marker"))
+
+ markerColor = _proxyProperty(("points", "color"))
+
+ markerSize = _proxyProperty(("points", "size"))
+
+ @classmethod
+ def init(cls):
+ """OpenGL context initialization"""
+ GLLines2D.init()
+ Points2D.init()
+
+ def prepare(self):
+ """Rendering preparation: build indices and bounding box vertices"""
+ if self.xVboData is None:
+ xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None
+ if self.lineDashPattern:
+ dists = distancesFromArrays(self.xData, self.yData, self._ratio)
+ if self.colorData is None:
+ xAttrib, yAttrib, dAttrib = vertexBuffer(
+ (self.xData, self.yData, dists)
+ )
+ else:
+ xAttrib, yAttrib, cAttrib, dAttrib = vertexBuffer(
+ (self.xData, self.yData, self.colorData, dists)
+ )
+ elif self.colorData is None:
+ xAttrib, yAttrib = vertexBuffer((self.xData, self.yData))
+ else:
+ xAttrib, yAttrib, cAttrib = vertexBuffer(
+ (self.xData, self.yData, self.colorData)
+ )
+
+ self.xVboData = xAttrib
+ self.yVboData = yAttrib
+ self.distVboData = dAttrib
+
+ if cAttrib is not None and self.colorData.dtype.kind == "u":
+ cAttrib.normalization = True # Normalize uint to [0, 1]
+ self.colorVboData = cAttrib
+ self.useColorVboData = cAttrib is not None
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ if self.lineDashPattern:
+ visibleRanges = context.plotFrame.transformedDataRanges
+ xLimits = visibleRanges.x
+ yLimits = visibleRanges.y if self.yaxis == "left" else visibleRanges.y2
+ width, height = context.plotFrame.plotSize
+ ratio = (height * (xLimits[1] - xLimits[0])) / (
+ width * (yLimits[1] - yLimits[0])
+ )
+ if (
+ self._ratio is None or abs(1.0 - ratio / self._ratio) > 0.05
+ ): # Tolerate 5% difference
+ # Rebuild curve buffers to update distances
+ self._ratio = ratio
+ self.discard()
+
+ self.prepare()
+ if self.fill is not None:
+ self.fill.render(context)
+ self._errorBars.render(context)
+ self.lines.render(context)
+ self.points.render(context)
+
+ def discard(self):
+ """Release VBOs"""
+ if self.xVboData is not None:
+ self.xVboData.vbo.discard()
+
+ self.xVboData = None
+ self.yVboData = None
+ self.colorVboData = None
+ self.distVboData = None
+
+ self._errorBars.discard()
+ if self.fill is not None:
+ self.fill.discard()
+
+ def isInitialized(self):
+ return (
+ self.xVboData is not None
+ or self._errorBars.isInitialized()
+ or (self.fill is not None and self.fill.isInitialized())
+ )
+
+ def pick(self, xPickMin, yPickMin, xPickMax, yPickMax):
+ """Perform picking on the curve according to its rendering.
+
+ The picking area is [xPickMin, xPickMax], [yPickMin, yPickMax].
+
+ In case a segment between 2 points with indices i, i+1 is picked,
+ only its lower index end point (i.e., i) is added to the result.
+ In case an end point with index i is picked it is added to the result,
+ and the segment [i-1, i] is not tested for picking.
+
+ :return: The indices of the picked data
+ :rtype: Union[List[int],None]
+ """
+ if (
+ (self.marker is None and self.lineDashPattern is None)
+ or self.xMin > xPickMax
+ or xPickMin > self.xMax
+ or self.yMin > yPickMax
+ or yPickMin > self.yMax
+ ):
+ return None
+
+ # offset picking bounds
+ xPickMin = xPickMin - self.offset[0]
+ xPickMax = xPickMax - self.offset[0]
+ yPickMin = yPickMin - self.offset[1]
+ yPickMax = yPickMax - self.offset[1]
+
+ if self.lineDashPattern is not None:
+ # Using Cohen-Sutherland algorithm for line clipping
+ with numpy.errstate(invalid="ignore"): # Ignore NaN comparison warnings
+ codes = (
+ ((self.yData > yPickMax) << 3)
+ | ((self.yData < yPickMin) << 2)
+ | ((self.xData > xPickMax) << 1)
+ | (self.xData < xPickMin)
+ )
+
+ notNaN = numpy.logical_not(
+ numpy.logical_or(numpy.isnan(self.xData), numpy.isnan(self.yData))
+ )
+
+ # Add all points that are inside the picking area
+ indices = numpy.nonzero(numpy.logical_and(codes == 0, notNaN))[0].tolist()
+
+ # Segment that might cross the area with no end point inside it
+ segToTestIdx = numpy.nonzero(
+ (codes[:-1] != 0) & (codes[1:] != 0) & ((codes[:-1] & codes[1:]) == 0)
+ )[0]
+
+ TOP, BOTTOM, RIGHT, LEFT = (1 << 3), (1 << 2), (1 << 1), (1 << 0)
+
+ for index in segToTestIdx:
+ if index not in indices:
+ x0, y0 = self.xData[index], self.yData[index]
+ x1, y1 = self.xData[index + 1], self.yData[index + 1]
+ code1 = codes[index + 1]
+
+ # check for crossing with horizontal bounds
+ # y0 == y1 is a never event:
+ # => pt0 and pt1 in same vertical area are not in segToTest
+ if code1 & TOP:
+ x = x0 + (x1 - x0) * (yPickMax - y0) / (y1 - y0)
+ elif code1 & BOTTOM:
+ x = x0 + (x1 - x0) * (yPickMin - y0) / (y1 - y0)
+ else:
+ x = None # No horizontal bounds intersection test
+
+ if x is not None and xPickMin <= x <= xPickMax:
+ # Intersection
+ indices.append(index)
+
+ else:
+ # check for crossing with vertical bounds
+ # x0 == x1 is a never event (see remark for y)
+ if code1 & RIGHT:
+ y = y0 + (y1 - y0) * (xPickMax - x0) / (x1 - x0)
+ elif code1 & LEFT:
+ y = y0 + (y1 - y0) * (xPickMin - x0) / (x1 - x0)
+ else:
+ y = None # No vertical bounds intersection test
+
+ if y is not None and yPickMin <= y <= yPickMax:
+ # Intersection
+ indices.append(index)
+
+ indices.sort()
+
+ else:
+ with numpy.errstate(invalid="ignore"): # Ignore NaN comparison warnings
+ indices = numpy.nonzero(
+ (self.xData >= xPickMin)
+ & (self.xData <= xPickMax)
+ & (self.yData >= yPickMin)
+ & (self.yData <= yPickMax)
+ )[0].tolist()
+
+ return tuple(indices) if len(indices) > 0 else None
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotFrame.py b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py
new file mode 100644
index 0000000..42cfa50
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py
@@ -0,0 +1,1399 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This modules provides the rendering of plot titles, axes and grid.
+"""
+
+from __future__ import annotations
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+# TODO
+# keep aspect ratio managed here?
+# smarter dirty flag handling?
+
+import datetime as dt
+import math
+import weakref
+import logging
+import numbers
+from typing import Optional, Union
+from collections import namedtuple
+
+import numpy
+
+from .... import qt
+from ...._glutils import gl, Program
+from ....utils.matplotlib import DefaultTickFormatter
+from ..._utils import checkAxisLimits, FLOAT32_MINPOS
+from .GLSupport import mat4Ortho
+from .GLText import Text2D, CENTER, BOTTOM, TOP, LEFT, RIGHT, ROTATE_270
+from ..._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10
+from ..._utils.dtime_ticklayout import (
+ DtUnit,
+ bestUnit,
+ calcTicksAdaptive,
+ formatDatetimes,
+)
+from ..._utils.dtime_ticklayout import timestamp
+
+_logger = logging.getLogger(__name__)
+
+
+# PlotAxis ####################################################################
+
+
+class PlotAxis(object):
+ """Represents a 1D axis of the plot.
+ This class is intended to be used with :class:`GLPlotFrame`.
+ """
+
+ def __init__(
+ self,
+ plotFrame,
+ tickLength=(0.0, 0.0),
+ foregroundColor=(0.0, 0.0, 0.0, 1.0),
+ labelAlign=CENTER,
+ labelVAlign=CENTER,
+ titleAlign=CENTER,
+ titleVAlign=CENTER,
+ orderOffsetAlign=CENTER,
+ orderOffsetVAlign=CENTER,
+ titleRotate=0,
+ titleOffset=(0.0, 0.0),
+ font: qt.QFont | None = None,
+ ):
+ self._tickFormatter = DefaultTickFormatter()
+ self._ticks = None
+ self._orderAndOffsetText = ""
+
+ self._plotFrameRef = weakref.ref(plotFrame)
+
+ self._isDateTime = False
+ self._timeZone = None
+ self._isLog = False
+ self._dataRange = 1.0, 100.0
+ self._displayCoords = (0.0, 0.0), (1.0, 0.0)
+ self._title = ""
+
+ self._tickLength = tickLength
+ self._foregroundColor = foregroundColor
+ self._labelAlign = labelAlign
+ self._labelVAlign = labelVAlign
+ self._orderOffetAnchor = (1.0, 0.0)
+ self._orderOffsetAlign = orderOffsetAlign
+ self._orderOffsetVAlign = orderOffsetVAlign
+ self._titleAlign = titleAlign
+ self._titleVAlign = titleVAlign
+ self._titleRotate = titleRotate
+ self._titleOffset = titleOffset
+ self._font = font
+
+ @property
+ def dataRange(self):
+ """The range of the data represented on the axis as a tuple
+ of 2 floats: (min, max)."""
+ return self._dataRange
+
+ @property
+ def font(self) -> qt.QFont:
+ if self._font is None:
+ return qt.QApplication.instance().font()
+ return self._font
+
+ @dataRange.setter
+ def dataRange(self, dataRange):
+ assert len(dataRange) == 2
+ assert dataRange[0] <= dataRange[1]
+ dataRange = float(dataRange[0]), float(dataRange[1])
+
+ if dataRange != self._dataRange:
+ self._dataRange = dataRange
+ self._dirtyTicks()
+
+ @property
+ def isLog(self):
+ """Whether the axis is using a log10 scale or not as a bool."""
+ return self._isLog
+
+ @isLog.setter
+ def isLog(self, isLog):
+ isLog = bool(isLog)
+ if isLog != self._isLog:
+ self._isLog = isLog
+ self._dirtyTicks()
+
+ @property
+ def timeZone(self):
+ """Returnss datetime.tzinfo that is used if this axis plots date times."""
+ return self._timeZone
+
+ @timeZone.setter
+ def timeZone(self, tz):
+ """Sets dateetime.tzinfo that is used if this axis plots date times."""
+ self._timeZone = tz
+ self._dirtyTicks()
+
+ @property
+ def isTimeSeries(self):
+ """Whether the axis is showing floats as datetime objects"""
+ return self._isDateTime
+
+ @isTimeSeries.setter
+ def isTimeSeries(self, isTimeSeries):
+ isTimeSeries = bool(isTimeSeries)
+ if isTimeSeries != self._isDateTime:
+ self._isDateTime = isTimeSeries
+ self._dirtyTicks()
+
+ @property
+ def displayCoords(self):
+ """The coordinates of the start and end points of the axis
+ in display space (i.e., in pixels) as a tuple of 2 tuples of
+ 2 floats: ((x0, y0), (x1, y1)).
+ """
+ return self._displayCoords
+
+ @displayCoords.setter
+ def displayCoords(self, displayCoords):
+ assert len(displayCoords) == 2
+ assert len(displayCoords[0]) == 2
+ assert len(displayCoords[1]) == 2
+ displayCoords = tuple(displayCoords[0]), tuple(displayCoords[1])
+ if displayCoords != self._displayCoords:
+ self._displayCoords = displayCoords
+ self._dirtyTicks()
+
+ @property
+ def devicePixelRatio(self):
+ """Returns the ratio between qt pixels and device pixels."""
+ plotFrame = self._plotFrameRef()
+ return plotFrame.devicePixelRatio if plotFrame is not None else 1.0
+
+ @property
+ def dotsPerInch(self):
+ """Returns the screen DPI"""
+ plotFrame = self._plotFrameRef()
+ return plotFrame.dotsPerInch if plotFrame is not None else 92
+
+ @property
+ def title(self):
+ """The text label associated with this axis as a str in latin-1."""
+ return self._title
+
+ @title.setter
+ def title(self, title):
+ if title != self._title:
+ self._title = title
+ self._dirtyPlotFrame()
+
+ @property
+ def orderOffetAnchor(self) -> tuple[float, float]:
+ """Anchor position for the tick order&offset text"""
+ return self._orderOffetAnchor
+
+ @orderOffetAnchor.setter
+ def orderOffetAnchor(self, position: tuple[float, float]):
+ if position != self._orderOffetAnchor:
+ self._orderOffetAnchor = position
+ self._dirtyTicks()
+
+ @property
+ def titleOffset(self):
+ """Title offset in pixels (x: int, y: int)"""
+ return self._titleOffset
+
+ @titleOffset.setter
+ def titleOffset(self, offset):
+ if offset != self._titleOffset:
+ self._titleOffset = offset
+ self._dirtyTicks()
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, "foregroundColor must have length 4, got {}".format(
+ len(self._foregroundColor)
+ )
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._dirtyTicks()
+
+ @property
+ def ticks(self):
+ """Ticks as tuples: ((x, y) in display, dataPos, textLabel)."""
+ if self._ticks is None:
+ self._ticks = tuple(self._ticksGenerator())
+ return self._ticks
+
+ def getVerticesAndLabels(self):
+ """Create the list of vertices for axis and associated text labels.
+
+ :returns: A tuple: List of 2D line vertices, List of Text2D labels.
+ """
+ vertices = list(self.displayCoords) # Add start and end points
+ labels = []
+
+ xTickLength, yTickLength = self._tickLength
+ xTickLength *= self.devicePixelRatio
+ yTickLength *= self.devicePixelRatio
+ for (xPixel, yPixel), dataPos, text in self.ticks:
+ if text is None:
+ tickScale = 0.5
+ else:
+ tickScale = 1.0
+
+ label = Text2D(
+ text=text,
+ font=self.font,
+ color=self._foregroundColor,
+ x=xPixel - xTickLength,
+ y=yPixel - yTickLength,
+ align=self._labelAlign,
+ valign=self._labelVAlign,
+ devicePixelRatio=self.devicePixelRatio,
+ )
+ labels.append(label)
+
+ vertices.append((xPixel, yPixel))
+ vertices.append(
+ (xPixel + tickScale * xTickLength, yPixel + tickScale * yTickLength)
+ )
+
+ (x0, y0), (x1, y1) = self.displayCoords
+ xAxisCenter = 0.5 * (x0 + x1)
+ yAxisCenter = 0.5 * (y0 + y1)
+
+ xOffset, yOffset = self.titleOffset
+
+ # Adaptative title positioning:
+ # tickNorm = math.sqrt(xTickLength ** 2 + yTickLength ** 2)
+ # xOffset = -tickLabelsSize[0] * xTickLength / tickNorm
+ # xOffset -= 3 * xTickLength
+ # yOffset = -tickLabelsSize[1] * yTickLength / tickNorm
+ # yOffset -= 3 * yTickLength
+
+ axisTitle = Text2D(
+ text=self.title,
+ font=self.font,
+ color=self._foregroundColor,
+ x=xAxisCenter + xOffset,
+ y=yAxisCenter + yOffset,
+ align=self._titleAlign,
+ valign=self._titleVAlign,
+ rotate=self._titleRotate,
+ devicePixelRatio=self.devicePixelRatio,
+ )
+ labels.append(axisTitle)
+
+ if self._orderAndOffsetText:
+ xOrderOffset, yOrderOffet = self.orderOffetAnchor
+ labels.append(
+ Text2D(
+ text=self._orderAndOffsetText,
+ font=self.font,
+ color=self._foregroundColor,
+ x=xOrderOffset,
+ y=yOrderOffet,
+ align=self._orderOffsetAlign,
+ valign=self._orderOffsetVAlign,
+ devicePixelRatio=self.devicePixelRatio,
+ )
+ )
+ return vertices, labels
+
+ def _dirtyPlotFrame(self):
+ """Dirty parent GLPlotFrame"""
+ plotFrame = self._plotFrameRef()
+ if plotFrame is not None:
+ plotFrame._dirty()
+
+ def _dirtyTicks(self):
+ """Mark ticks as dirty and notify listener (i.e., background)."""
+ self._ticks = None
+ self._dirtyPlotFrame()
+
+ @staticmethod
+ def _frange(start, stop, step):
+ """range for float (including stop)."""
+ while start <= stop:
+ yield start
+ start += step
+
+ def _ticksGenerator(self):
+ """Generator of ticks as tuples:
+ ((x, y) in display, dataPos, textLabel).
+ """
+ self._orderAndOffsetText = ""
+
+ dataMin, dataMax = self.dataRange
+ if self.isLog and dataMin <= 0.0:
+ _logger.warning("Getting ticks while isLog=True and dataRange[0]<=0.")
+ dataMin = 1.0
+ if dataMax < dataMin:
+ dataMax = 1.0
+
+ if dataMin != dataMax: # data range is not null
+ (x0, y0), (x1, y1) = self.displayCoords
+
+ if self.isLog:
+ if self.isTimeSeries:
+ _logger.warning("Time series not implemented for log-scale")
+
+ logMin, logMax = math.log10(dataMin), math.log10(dataMax)
+ tickMin, tickMax, step, _ = niceNumbersForLog10(logMin, logMax)
+
+ xScale = (x1 - x0) / (logMax - logMin)
+ yScale = (y1 - y0) / (logMax - logMin)
+
+ for logPos in self._frange(tickMin, tickMax, step):
+ if logMin <= logPos <= logMax:
+ dataPos = 10**logPos
+ xPixel = x0 + (logPos - logMin) * xScale
+ yPixel = y0 + (logPos - logMin) * yScale
+ text = "1e%+03d" % logPos
+ yield ((xPixel, yPixel), dataPos, text)
+
+ if step == 1:
+ ticks = list(self._frange(tickMin, tickMax, step))[:-1]
+ for logPos in ticks:
+ dataOrigPos = 10**logPos
+ for index in range(2, 10):
+ dataPos = dataOrigPos * index
+ if dataMin <= dataPos <= dataMax:
+ logSubPos = math.log10(dataPos)
+ xPixel = x0 + (logSubPos - logMin) * xScale
+ yPixel = y0 + (logSubPos - logMin) * yScale
+ yield ((xPixel, yPixel), dataPos, None)
+
+ else:
+ xScale = (x1 - x0) / (dataMax - dataMin)
+ yScale = (y1 - y0) / (dataMax - dataMin)
+
+ nbPixels = (
+ math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio
+ )
+
+ # Density of 1.3 label per 92 pixels
+ # i.e., 1.3 label per inch on a 92 dpi screen
+ tickDensity = 1.3 * self.devicePixelRatio / self.dotsPerInch
+
+ if not self.isTimeSeries:
+ tickMin, tickMax, step, _ = niceNumbersAdaptative(
+ dataMin, dataMax, nbPixels, tickDensity
+ )
+
+ visibleTickPositions = [
+ pos
+ for pos in self._frange(tickMin, tickMax, step)
+ if dataMin <= pos <= dataMax
+ ]
+ self._tickFormatter.axis.set_view_interval(dataMin, dataMax)
+ self._tickFormatter.axis.set_data_interval(dataMin, dataMax)
+ texts = self._tickFormatter.format_ticks(visibleTickPositions)
+ self._orderAndOffsetText = self._tickFormatter.get_offset()
+
+ for dataPos, text in zip(visibleTickPositions, texts):
+ xPixel = x0 + (dataPos - dataMin) * xScale
+ yPixel = y0 + (dataPos - dataMin) * yScale
+ yield ((xPixel, yPixel), dataPos, text)
+
+ else:
+ # Time series
+ try:
+ dtMin = dt.datetime.fromtimestamp(dataMin, tz=self.timeZone)
+ dtMax = dt.datetime.fromtimestamp(dataMax, tz=self.timeZone)
+ except ValueError:
+ _logger.warning("Data range cannot be displayed with time axis")
+ return # Range is out of bound of the datetime
+
+ if bestUnit(
+ (dtMax - dtMin).total_seconds() == DtUnit.MICRO_SECONDS
+ ):
+ # Special case for micro seconds: Reduce tick density
+ tickDensity = 1.0 * self.devicePixelRatio / self.dotsPerInch
+
+ tickDateTimes, spacing, unit = calcTicksAdaptive(
+ dtMin, dtMax, nbPixels, tickDensity
+ )
+ visibleDatetimes = tuple(
+ dt for dt in tickDateTimes if dtMin <= dt <= dtMax
+ )
+ ticks = formatDatetimes(visibleDatetimes, spacing, unit)
+
+ for tickDateTime, text in ticks.items():
+ dataPos = timestamp(tickDateTime)
+ xPixel = x0 + (dataPos - dataMin) * xScale
+ yPixel = y0 + (dataPos - dataMin) * yScale
+ yield ((xPixel, yPixel), dataPos, text)
+
+
+# GLPlotFrame #################################################################
+
+
+class GLPlotFrame(object):
+ """Base class for rendering a 2D frame surrounded by axes."""
+
+ _TICK_LENGTH_IN_PIXELS = 5
+ _LINE_WIDTH = 1
+
+ _SHADERS = {
+ "vertex": """
+ attribute vec2 position;
+ uniform mat4 matrix;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ }
+ """,
+ "fragment": """
+ uniform vec4 color;
+ uniform float tickFactor; /* = 1./tickLength or 0. for solid line */
+
+ void main(void) {
+ if (mod(tickFactor * (gl_FragCoord.x + gl_FragCoord.y), 2.) < 1.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ }
+ """,
+ }
+
+ _Margins = namedtuple("Margins", ("left", "right", "top", "bottom"))
+
+ # Margins used when plot frame is not displayed
+ _NoDisplayMargins = _Margins(0, 0, 0, 0)
+
+ def __init__(self, marginRatios, foregroundColor, gridColor, font: qt.QFont):
+ """
+ :param List[float] marginRatios:
+ The ratios of margins around plot area for axis and labels.
+ (left, top, right, bottom) as float in [0., 1.]
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
+ :param font: Font used by the axes label
+ """
+ self._renderResources = None
+
+ self.__marginRatios = marginRatios
+ self.__marginsCache = None
+
+ self._foregroundColor = foregroundColor
+ self._gridColor = gridColor
+
+ self.axes = [] # List of PlotAxis to be updated by subclasses
+
+ self._grid = False
+ self._size = 0.0, 0.0
+ self._title = ""
+ self._font: qt.QFont = font
+
+ self._devicePixelRatio = 1.0
+ self._dpi = 92
+
+ @property
+ def isDirty(self):
+ """True if it need to refresh graphic rendering, False otherwise."""
+ return self._renderResources is None
+
+ GRID_NONE = 0
+ GRID_MAIN_TICKS = 1
+ GRID_SUB_TICKS = 2
+ GRID_ALL_TICKS = GRID_MAIN_TICKS + GRID_SUB_TICKS
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, "foregroundColor must have length 4, got {}".format(
+ len(self._foregroundColor)
+ )
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ for axis in self.axes:
+ axis.foregroundColor = color
+ self._dirty()
+
+ @property
+ def gridColor(self):
+ """Color used for frame and labels"""
+ return self._gridColor
+
+ @gridColor.setter
+ def gridColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, "gridColor must have length 4, got {}".format(
+ len(self._gridColor)
+ )
+ if self._gridColor != color:
+ self._gridColor = color
+ self._dirty()
+
+ @property
+ def marginRatios(self):
+ """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1]."""
+ return self.__marginRatios
+
+ @marginRatios.setter
+ def marginRatios(self, ratios):
+ ratios = tuple(float(v) for v in ratios)
+ assert len(ratios) == 4
+ for value in ratios:
+ assert 0.0 <= value <= 1.0
+ assert ratios[0] + ratios[2] < 1.0
+ assert ratios[1] + ratios[3] < 1.0
+
+ if self.__marginRatios != ratios:
+ self.__marginRatios = ratios
+ self.__marginsCache = None # Clear cached margins
+ self._dirty()
+
+ @property
+ def margins(self):
+ """Margins in pixels around the plot."""
+ if self.__marginsCache is None:
+ width, height = self.size
+ left, top, right, bottom = self.marginRatios
+ self.__marginsCache = self._Margins(
+ left=int(left * width),
+ right=int(right * width),
+ top=int(top * height),
+ bottom=int(bottom * height),
+ )
+ return self.__marginsCache
+
+ @property
+ def devicePixelRatio(self):
+ return self._devicePixelRatio
+
+ @devicePixelRatio.setter
+ def devicePixelRatio(self, ratio):
+ if ratio != self._devicePixelRatio:
+ self._devicePixelRatio = ratio
+ self._dirty()
+
+ @property
+ def dotsPerInch(self):
+ return self._dpi
+
+ @dotsPerInch.setter
+ def dotsPerInch(self, dpi):
+ if dpi != self._dpi:
+ self._dpi = dpi
+ self._dirty()
+
+ @property
+ def grid(self):
+ """Grid display mode:
+ - 0: No grid.
+ - 1: Grid on main ticks.
+ - 2: Grid on sub-ticks for log scale axes.
+ - 3: Grid on main and sub ticks."""
+ return self._grid
+
+ @grid.setter
+ def grid(self, grid):
+ assert grid in (
+ self.GRID_NONE,
+ self.GRID_MAIN_TICKS,
+ self.GRID_SUB_TICKS,
+ self.GRID_ALL_TICKS,
+ )
+ if grid != self._grid:
+ self._grid = grid
+ self._dirty()
+
+ @property
+ def size(self):
+ """Size in device pixels of the plot area including margins."""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 2
+ size = tuple(size)
+ if size != self._size:
+ self._size = size
+ self.__marginsCache = None # Clear cached margins
+ self._dirty()
+
+ @property
+ def plotOrigin(self):
+ """Plot area origin (left, top) in widget coordinates in pixels."""
+ return self.margins.left, self.margins.top
+
+ @property
+ def plotSize(self):
+ """Plot area size (width, height) in pixels."""
+ w, h = self.size
+ w -= self.margins.left + self.margins.right
+ h -= self.margins.top + self.margins.bottom
+ return w, h
+
+ @property
+ def title(self):
+ """Main title as a str in latin-1."""
+ return self._title
+
+ @title.setter
+ def title(self, title):
+ if title != self._title:
+ self._title = title
+ self._dirty()
+
+ # In-place update
+ # if self._renderResources is not None:
+ # self._renderResources[-1][-1].text = title
+
+ def _dirty(self):
+ # When Text2D require discard we need to handle it
+ self._renderResources = None
+
+ def _buildGridVertices(self):
+ if self._grid == self.GRID_NONE:
+ return []
+
+ elif self._grid == self.GRID_MAIN_TICKS:
+
+ def test(text):
+ return text is not None
+
+ elif self._grid == self.GRID_SUB_TICKS:
+
+ def test(text):
+ return text is None
+
+ elif self._grid == self.GRID_ALL_TICKS:
+
+ def test(_):
+ return True
+
+ else:
+ logging.warning("Wrong grid mode: %d" % self._grid)
+ return []
+
+ return self._buildGridVerticesWithTest(test)
+
+ def _buildGridVerticesWithTest(self, test):
+ """Override in subclass to generate grid vertices"""
+ return []
+
+ def _buildVerticesAndLabels(self):
+ # To fill with copy of axes lists
+ vertices = []
+ labels = []
+
+ for axis in self.axes:
+ axisVertices, axisLabels = axis.getVerticesAndLabels()
+ vertices += axisVertices
+ labels += axisLabels
+
+ vertices = numpy.array(vertices, dtype=numpy.float32)
+
+ # Add main title
+ xTitle = (self.size[0] + self.margins.left - self.margins.right) // 2
+ yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS
+ labels.append(
+ Text2D(
+ text=self.title,
+ font=self._font,
+ color=self._foregroundColor,
+ x=xTitle,
+ y=yTitle,
+ align=CENTER,
+ valign=BOTTOM,
+ devicePixelRatio=self.devicePixelRatio,
+ )
+ )
+
+ # grid
+ gridVertices = numpy.array(self._buildGridVertices(), dtype=numpy.float32)
+
+ self._renderResources = (vertices, gridVertices, labels)
+
+ _program = Program(_SHADERS["vertex"], _SHADERS["fragment"], attrib0="position")
+
+ def render(self):
+ if self.margins == self._NoDisplayMargins:
+ return
+
+ if self._renderResources is None:
+ self._buildVerticesAndLabels()
+ vertices, gridVertices, labels = self._renderResources
+
+ width, height = self.size
+ matProj = mat4Ortho(0, width, height, 0, 1, -1)
+
+ gl.glViewport(0, 0, width, height)
+
+ prog = self._program
+ prog.use()
+
+ gl.glLineWidth(self._LINE_WIDTH)
+
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, matProj.astype(numpy.float32)
+ )
+ gl.glUniform4f(prog.uniforms["color"], *self._foregroundColor)
+ gl.glUniform1f(prog.uniforms["tickFactor"], 0.0)
+
+ gl.glEnableVertexAttribArray(prog.attributes["position"])
+ gl.glVertexAttribPointer(
+ prog.attributes["position"], 2, gl.GL_FLOAT, gl.GL_FALSE, 0, vertices
+ )
+
+ gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+
+ for label in labels:
+ label.render(matProj, self.dotsPerInch)
+
+ def renderGrid(self):
+ if self._grid == self.GRID_NONE:
+ return
+
+ if self._renderResources is None:
+ self._buildVerticesAndLabels()
+ vertices, gridVertices, labels = self._renderResources
+
+ width, height = self.size
+ matProj = mat4Ortho(0, width, height, 0, 1, -1)
+
+ gl.glViewport(0, 0, width, height)
+
+ prog = self._program
+ prog.use()
+
+ gl.glLineWidth(self._LINE_WIDTH)
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, matProj.astype(numpy.float32)
+ )
+ gl.glUniform4f(prog.uniforms["color"], *self._gridColor)
+ gl.glUniform1f(prog.uniforms["tickFactor"], 0.0) # 1/2.) # 1/tickLen
+
+ gl.glEnableVertexAttribArray(prog.attributes["position"])
+ gl.glVertexAttribPointer(
+ prog.attributes["position"], 2, gl.GL_FLOAT, gl.GL_FALSE, 0, gridVertices
+ )
+
+ gl.glDrawArrays(gl.GL_LINES, 0, len(gridVertices))
+
+
+# GLPlotFrame2D ###############################################################
+
+
+class GLPlotFrame2D(GLPlotFrame):
+ def __init__(self, marginRatios, foregroundColor, gridColor, font: qt.QFont):
+ """
+ :param List[float] marginRatios:
+ The ratios of margins around plot area for axis and labels.
+ (left, top, right, bottom) as float in [0., 1.]
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
+ :param font: Font used by the axes label
+ """
+ super(GLPlotFrame2D, self).__init__(
+ marginRatios, foregroundColor, gridColor, font
+ )
+ self._font = font
+
+ self.axes.append(
+ PlotAxis(
+ self,
+ tickLength=(0.0, -5.0),
+ foregroundColor=self._foregroundColor,
+ labelAlign=CENTER,
+ labelVAlign=TOP,
+ orderOffsetAlign=RIGHT,
+ orderOffsetVAlign=TOP,
+ titleAlign=CENTER,
+ titleVAlign=TOP,
+ titleRotate=0,
+ font=self._font,
+ )
+ )
+
+ self._x2AxisCoords = ()
+
+ self.axes.append(
+ PlotAxis(
+ self,
+ tickLength=(5.0, 0.0),
+ foregroundColor=self._foregroundColor,
+ labelAlign=RIGHT,
+ labelVAlign=CENTER,
+ orderOffsetAlign=LEFT,
+ orderOffsetVAlign=BOTTOM,
+ titleAlign=CENTER,
+ titleVAlign=BOTTOM,
+ titleRotate=ROTATE_270,
+ font=self._font,
+ )
+ )
+
+ self._y2Axis = PlotAxis(
+ self,
+ tickLength=(-5.0, 0.0),
+ foregroundColor=self._foregroundColor,
+ labelAlign=LEFT,
+ labelVAlign=CENTER,
+ orderOffsetAlign=RIGHT,
+ orderOffsetVAlign=BOTTOM,
+ titleAlign=CENTER,
+ titleVAlign=TOP,
+ titleRotate=ROTATE_270,
+ font=self._font,
+ )
+
+ self._isYAxisInverted = False
+
+ self._dataRanges = {"x": (1.0, 100.0), "y": (1.0, 100.0), "y2": (1.0, 100.0)}
+
+ self._baseVectors = (1.0, 0.0), (0.0, 1.0)
+
+ self._transformedDataRanges = None
+ self._transformedDataProjMat = None
+ self._transformedDataY2ProjMat = None
+
+ def _dirty(self):
+ super(GLPlotFrame2D, self)._dirty()
+ self._transformedDataRanges = None
+ self._transformedDataProjMat = None
+ self._transformedDataY2ProjMat = None
+
+ @property
+ def isDirty(self):
+ """True if it need to refresh graphic rendering, False otherwise."""
+ return (
+ super(GLPlotFrame2D, self).isDirty
+ or self._transformedDataRanges is None
+ or self._transformedDataProjMat is None
+ or self._transformedDataY2ProjMat is None
+ )
+
+ @property
+ def xAxis(self):
+ return self.axes[0]
+
+ @property
+ def yAxis(self):
+ return self.axes[1]
+
+ @property
+ def y2Axis(self):
+ return self._y2Axis
+
+ @property
+ def isY2Axis(self):
+ """Whether to display the left Y axis or not."""
+ return len(self.axes) == 3
+
+ @isY2Axis.setter
+ def isY2Axis(self, isY2Axis):
+ if isY2Axis != self.isY2Axis:
+ if isY2Axis:
+ self.axes.append(self._y2Axis)
+ else:
+ self.axes = self.axes[:2]
+
+ self._dirty()
+
+ @property
+ def isYAxisInverted(self):
+ """Whether Y axes are inverted or not as a bool."""
+ return self._isYAxisInverted
+
+ @isYAxisInverted.setter
+ def isYAxisInverted(self, value):
+ value = bool(value)
+ if value != self._isYAxisInverted:
+ self._isYAxisInverted = value
+ self._dirty()
+
+ DEFAULT_BASE_VECTORS = (1.0, 0.0), (0.0, 1.0)
+ """Values of baseVectors for orthogonal axes."""
+
+ @property
+ def baseVectors(self):
+ """Coordinates of the X and Y axes in the orthogonal plot coords.
+
+ Raises ValueError if corresponding matrix is singular.
+
+ 2 tuples of 2 floats: (xx, xy), (yx, yy)
+ """
+ return self._baseVectors
+
+ @baseVectors.setter
+ def baseVectors(self, baseVectors):
+ self._dirty()
+
+ (xx, xy), (yx, yy) = baseVectors
+ vectors = (float(xx), float(xy)), (float(yx), float(yy))
+
+ det = vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1]
+ if det == 0.0:
+ raise ValueError("Singular matrix for base vectors: " + str(vectors))
+
+ if vectors != self._baseVectors:
+ self._baseVectors = vectors
+ self._dirty()
+
+ def _updateTitleOffset(self):
+ """Update axes title offset according to margins"""
+ margins = self.margins
+ self.xAxis.titleOffset = 0, margins.bottom // 2
+ self.yAxis.titleOffset = -3 * margins.left // 4, 0
+ self.y2Axis.titleOffset = 3 * margins.right // 4, 0
+
+ # Override size and marginRatios setters to update titleOffsets
+ @GLPlotFrame.size.setter
+ def size(self, size):
+ GLPlotFrame.size.fset(self, size)
+ self._updateTitleOffset()
+
+ @GLPlotFrame.marginRatios.setter
+ def marginRatios(self, ratios):
+ GLPlotFrame.marginRatios.fset(self, ratios)
+ self._updateTitleOffset()
+
+ @property
+ def dataRanges(self):
+ """Ranges of data visible in the plot on x, y and y2 axes.
+
+ This is different to the axes range when axes are not orthogonal.
+
+ Type: ((xMin, xMax), (yMin, yMax), (y2Min, y2Max))
+ """
+ return self._DataRanges(
+ self._dataRanges["x"], self._dataRanges["y"], self._dataRanges["y2"]
+ )
+
+ def setDataRanges(self, x=None, y=None, y2=None):
+ """Set data range over each axes.
+
+ The provided ranges are clipped to possible values
+ (i.e., 32 float range + positive range for log scale).
+
+ :param x: (min, max) data range over X axis
+ :param y: (min, max) data range over Y axis
+ :param y2: (min, max) data range over Y2 axis
+ """
+ if x is not None:
+ self._dataRanges["x"] = checkAxisLimits(
+ x[0], x[1], self.xAxis.isLog, name="x"
+ )
+
+ if y is not None:
+ self._dataRanges["y"] = checkAxisLimits(
+ y[0], y[1], self.yAxis.isLog, name="y"
+ )
+
+ if y2 is not None:
+ self._dataRanges["y2"] = checkAxisLimits(
+ y2[0], y2[1], self.y2Axis.isLog, name="y2"
+ )
+
+ self.xAxis.dataRange = self._dataRanges["x"]
+ self.yAxis.dataRange = self._dataRanges["y"]
+ self.y2Axis.dataRange = self._dataRanges["y2"]
+
+ _DataRanges = namedtuple("dataRanges", ("x", "y", "y2"))
+
+ @property
+ def transformedDataRanges(self):
+ """Bounds of the displayed area in transformed data coordinates
+ (i.e., log scale applied if any as well as skew)
+
+ 3-tuple of 2-tuple (min, max) for each axis: x, y, y2.
+ """
+ if self._transformedDataRanges is None:
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self.dataRanges
+
+ if self.xAxis.isLog:
+ try:
+ xMin = math.log10(xMin)
+ except ValueError:
+ _logger.info("xMin: warning log10(%f)", xMin)
+ xMin = 0.0
+ try:
+ xMax = math.log10(xMax)
+ except ValueError:
+ _logger.info("xMax: warning log10(%f)", xMax)
+ xMax = 0.0
+
+ if self.yAxis.isLog:
+ try:
+ yMin = math.log10(yMin)
+ except ValueError:
+ _logger.info("yMin: warning log10(%f)", yMin)
+ yMin = 0.0
+ try:
+ yMax = math.log10(yMax)
+ except ValueError:
+ _logger.info("yMax: warning log10(%f)", yMax)
+ yMax = 0.0
+
+ try:
+ y2Min = math.log10(y2Min)
+ except ValueError:
+ _logger.info("yMin: warning log10(%f)", y2Min)
+ y2Min = 0.0
+ try:
+ y2Max = math.log10(y2Max)
+ except ValueError:
+ _logger.info("yMax: warning log10(%f)", y2Max)
+ y2Max = 0.0
+
+ self._transformedDataRanges = self._DataRanges(
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max)
+ )
+
+ return self._transformedDataRanges
+
+ @property
+ def transformedDataProjMat(self):
+ """Orthographic projection matrix for rendering transformed data
+
+ :type: numpy.matrix
+ """
+ if self._transformedDataProjMat is None:
+ xMin, xMax = self.transformedDataRanges.x
+ yMin, yMax = self.transformedDataRanges.y
+
+ if self.isYAxisInverted:
+ mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1)
+ else:
+ mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1)
+ self._transformedDataProjMat = mat
+
+ return self._transformedDataProjMat
+
+ @property
+ def transformedDataY2ProjMat(self):
+ """Orthographic projection matrix for rendering transformed data
+ for the 2nd Y axis
+
+ :type: numpy.matrix
+ """
+ if self._transformedDataY2ProjMat is None:
+ xMin, xMax = self.transformedDataRanges.x
+ y2Min, y2Max = self.transformedDataRanges.y2
+
+ if self.isYAxisInverted:
+ mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1)
+ else:
+ mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1)
+ self._transformedDataY2ProjMat = mat
+
+ return self._transformedDataY2ProjMat
+
+ @staticmethod
+ def __applyLog(
+ data: Union[float, numpy.ndarray], isLog: bool
+ ) -> Optional[Union[float, numpy.ndarray]]:
+ """Apply log to data filtering out"""
+ if not isLog:
+ return data
+
+ if isinstance(data, numbers.Real):
+ return None if data < FLOAT32_MINPOS else math.log10(data)
+
+ isBelowMin = data < FLOAT32_MINPOS
+ if numpy.any(isBelowMin):
+ data = numpy.array(data, copy=True, dtype=numpy.float64)
+ data[isBelowMin] = numpy.nan
+
+ with numpy.errstate(divide="ignore"):
+ return numpy.log10(data)
+
+ def dataToPixel(self, x, y, axis="left"):
+ """Convert data coordinate to widget pixel coordinate."""
+ assert axis in ("left", "right")
+
+ trBounds = self.transformedDataRanges
+
+ xDataTr = self.__applyLog(x, self.xAxis.isLog)
+ if xDataTr is None:
+ return None
+
+ yDataTr = self.__applyLog(y, self.yAxis.isLog)
+ if yDataTr is None:
+ return None
+
+ # Non-orthogonal axes
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ skew_mat = numpy.array(((xx, yx), (xy, yy)))
+
+ coords = numpy.dot(skew_mat, numpy.array((xDataTr, yDataTr)))
+ xDataTr, yDataTr = coords
+
+ plotWidth, plotHeight = self.plotSize
+
+ xPixel = self.margins.left + plotWidth * (xDataTr - trBounds.x[0]) / (
+ trBounds.x[1] - trBounds.x[0]
+ )
+
+ usedAxis = trBounds.y if axis == "left" else trBounds.y2
+ yOffset = plotHeight * (yDataTr - usedAxis[0]) / (usedAxis[1] - usedAxis[0])
+
+ if self.isYAxisInverted:
+ yPixel = self.margins.top + yOffset
+ else:
+ yPixel = self.size[1] - self.margins.bottom - yOffset
+
+ return (
+ int(xPixel)
+ if isinstance(xPixel, numbers.Real)
+ else xPixel.astype(numpy.int64),
+ int(yPixel)
+ if isinstance(yPixel, numbers.Real)
+ else yPixel.astype(numpy.int64),
+ )
+
+ def pixelToData(self, x, y, axis="left"):
+ """Convert pixel position to data coordinates.
+
+ :param float x: X coord
+ :param float y: Y coord
+ :param str axis: Y axis to use in ('left', 'right')
+ :return: (x, y) position in data coords
+ """
+ assert axis in ("left", "right")
+
+ plotWidth, plotHeight = self.plotSize
+
+ trBounds = self.transformedDataRanges
+
+ xData = (x - self.margins.left + 0.5) / float(plotWidth)
+ xData = trBounds.x[0] + xData * (trBounds.x[1] - trBounds.x[0])
+
+ usedAxis = trBounds.y if axis == "left" else trBounds.y2
+ if self.isYAxisInverted:
+ yData = (y - self.margins.top + 0.5) / float(plotHeight)
+ yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
+ else:
+ yData = self.size[1] - self.margins.bottom - y - 0.5
+ yData /= float(plotHeight)
+ yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
+
+ # non-orthogonal axis
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ skew_mat = numpy.array(((xx, yx), (xy, yy)))
+ skew_mat = numpy.linalg.inv(skew_mat)
+
+ coords = numpy.dot(skew_mat, numpy.array((xData, yData)))
+ xData, yData = coords
+
+ if self.xAxis.isLog:
+ xData = pow(10, xData)
+ if self.yAxis.isLog:
+ yData = pow(10, yData)
+
+ return xData, yData
+
+ def _buildGridVerticesWithTest(self, test):
+ vertices = []
+
+ if self.baseVectors == self.DEFAULT_BASE_VECTORS:
+ for axis in self.axes:
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ vertices.append((xPixel, yPixel))
+ if axis == self.xAxis:
+ vertices.append((xPixel, self.margins.top))
+ elif axis == self.yAxis:
+ vertices.append((self.size[0] - self.margins.right, yPixel))
+ else: # axis == self.y2Axis
+ vertices.append((self.margins.left, yPixel))
+
+ else:
+ # Get plot corners in data coords
+ plotLeft, plotTop = self.plotOrigin
+ plotWidth, plotHeight = self.plotSize
+
+ corners = [
+ (plotLeft, plotTop),
+ (plotLeft, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop),
+ ]
+
+ for axis in self.axes:
+ if axis == self.xAxis:
+ cornersInData = numpy.array(
+ [self.pixelToData(x, y) for (x, y) in corners]
+ )
+ borders = (
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[1], cornersInData[0]), # left
+ (cornersInData[3], cornersInData[2]),
+ ) # right
+
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ for (x0, y0), (x1, y1) in borders:
+ if min(x0, x1) <= data < max(x0, x1):
+ yIntersect = (data - x0) * (y1 - y0) / (
+ x1 - x0
+ ) + y0
+
+ pixelPos = self.dataToPixel(data, yIntersect)
+ if pixelPos is not None:
+ vertices.append((xPixel, yPixel))
+ vertices.append(pixelPos)
+ break # Stop at first intersection
+
+ else: # y or y2 axes
+ if axis == self.yAxis:
+ axis_name = "left"
+ cornersInData = numpy.array(
+ [self.pixelToData(x, y) for (x, y) in corners]
+ )
+ borders = (
+ (cornersInData[3], cornersInData[2]), # right
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[2], cornersInData[1]),
+ ) # bottom
+
+ else: # axis == self.y2Axis
+ axis_name = "right"
+ corners = numpy.array(
+ [self.pixelToData(x, y, axis="right") for (x, y) in corners]
+ )
+ borders = (
+ (cornersInData[1], cornersInData[0]), # left
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[2], cornersInData[1]),
+ ) # bottom
+
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ for (x0, y0), (x1, y1) in borders:
+ if min(y0, y1) <= data < max(y0, y1):
+ xIntersect = (data - y0) * (x1 - x0) / (
+ y1 - y0
+ ) + x0
+
+ pixelPos = self.dataToPixel(
+ xIntersect, data, axis=axis_name
+ )
+ if pixelPos is not None:
+ vertices.append((xPixel, yPixel))
+ vertices.append(pixelPos)
+ break # Stop at first intersection
+
+ return vertices
+
+ def _buildVerticesAndLabels(self):
+ width, height = self.size
+
+ xCoords = (self.margins.left - 0.5, width - self.margins.right + 0.5)
+ yCoords = (height - self.margins.bottom + 0.5, self.margins.top - 0.5)
+
+ self.axes[0].displayCoords = (
+ (xCoords[0], yCoords[0]),
+ (xCoords[1], yCoords[0]),
+ )
+
+ self._x2AxisCoords = ((xCoords[0], yCoords[1]), (xCoords[1], yCoords[1]))
+
+ # Set order&offset anchor **before** handling Y axis inversion
+ fontPixelSize = self._font.pixelSize()
+ if fontPixelSize == -1:
+ fontPixelSize = self._font.pointSizeF() / 72.0 * self.dotsPerInch
+
+ self.axes[0].orderOffetAnchor = (
+ xCoords[1],
+ yCoords[0] + fontPixelSize * 1.2,
+ )
+ self.axes[1].orderOffetAnchor = (
+ xCoords[0],
+ yCoords[1] - 4 * self.devicePixelRatio,
+ )
+ self._y2Axis.orderOffetAnchor = (
+ xCoords[1],
+ yCoords[1] - 4 * self.devicePixelRatio,
+ )
+
+ if self.isYAxisInverted:
+ # Y axes are inverted, axes coordinates are inverted
+ yCoords = yCoords[1], yCoords[0]
+
+ self.axes[1].displayCoords = (
+ (xCoords[0], yCoords[0]),
+ (xCoords[0], yCoords[1]),
+ )
+
+ self._y2Axis.displayCoords = (
+ (xCoords[1], yCoords[0]),
+ (xCoords[1], yCoords[1]),
+ )
+
+ super(GLPlotFrame2D, self)._buildVerticesAndLabels()
+
+ vertices, gridVertices, labels = self._renderResources
+
+ # Adds vertices for borders without axis
+ extraVertices = []
+ extraVertices += self._x2AxisCoords
+ if not self.isY2Axis:
+ extraVertices += self._y2Axis.displayCoords
+
+ extraVertices = numpy.array(extraVertices, copy=False, dtype=numpy.float32)
+ vertices = numpy.append(vertices, extraVertices, axis=0)
+
+ self._renderResources = (vertices, gridVertices, labels)
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, "foregroundColor must have length 4, got {}".format(
+ len(self._foregroundColor)
+ )
+ if self._foregroundColor != color:
+ self._y2Axis.foregroundColor = color
+ GLPlotFrame.foregroundColor.fset(self, color) # call parent property
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotImage.py b/src/silx/gui/plot/backends/glutils/GLPlotImage.py
new file mode 100644
index 0000000..0973c47
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotImage.py
@@ -0,0 +1,789 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides a class to render 2D array as a colormap or RGB(A) image
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import math
+import numpy
+
+from ...._glutils import gl, Program, Texture
+from ..._utils import FLOAT32_MINPOS
+from .GLSupport import mat4Translate, mat4Scale
+from .GLTexture import Image
+from .GLPlotItem import GLPlotItem
+
+
+class _GLPlotData2D(GLPlotItem):
+ def __init__(self, data, origin, scale):
+ super().__init__()
+ self.data = data
+ assert len(origin) == 2
+ self.origin = tuple(origin)
+ assert len(scale) == 2
+ self.scale = tuple(scale)
+
+ def pick(self, x, y):
+ if self.xMin <= x <= self.xMax and self.yMin <= y <= self.yMax:
+ ox, oy = self.origin
+ sx, sy = self.scale
+ col = int((x - ox) / sx)
+ row = int((y - oy) / sy)
+ return (row,), (col,)
+ else:
+ return None
+
+ @property
+ def xMin(self):
+ ox, sx = self.origin[0], self.scale[0]
+ return ox if sx >= 0.0 else ox + sx * self.data.shape[1]
+
+ @property
+ def yMin(self):
+ oy, sy = self.origin[1], self.scale[1]
+ return oy if sy >= 0.0 else oy + sy * self.data.shape[0]
+
+ @property
+ def xMax(self):
+ ox, sx = self.origin[0], self.scale[0]
+ return ox + sx * self.data.shape[1] if sx >= 0.0 else ox
+
+ @property
+ def yMax(self):
+ oy, sy = self.origin[1], self.scale[1]
+ return oy + sy * self.data.shape[0] if sy >= 0.0 else oy
+
+
+class GLPlotColormap(_GLPlotData2D):
+ _SHADERS = {
+ "linear": {
+ "vertex": """
+ #version 120
+
+ uniform mat4 matrix;
+ attribute vec2 texCoords;
+ attribute vec2 position;
+
+ varying vec2 coords;
+
+ void main(void) {
+ coords = texCoords;
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ }
+ """,
+ "fragTransform": """
+ vec2 textureCoords(void) {
+ return coords;
+ }
+ """,
+ },
+ "log": {
+ "vertex": """
+ #version 120
+
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform mat4 matOffset;
+ uniform bvec2 isLog;
+
+ varying vec2 coords;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec4 dataPos = matOffset * vec4(position, 0.0, 1.0);
+ if (isLog.x) {
+ dataPos.x = oneOverLog10 * log(dataPos.x);
+ }
+ if (isLog.y) {
+ dataPos.y = oneOverLog10 * log(dataPos.y);
+ }
+ coords = dataPos.xy;
+ gl_Position = matrix * dataPos;
+ }
+ """,
+ "fragTransform": """
+ uniform bvec2 isLog;
+ uniform vec2 bounds_oneOverRange;
+ uniform vec2 bounds_originOverRange;
+
+ vec2 textureCoords(void) {
+ vec2 pos = coords;
+ if (isLog.x) {
+ pos.x = pow(10., coords.x);
+ }
+ if (isLog.y) {
+ pos.y = pow(10., coords.y);
+ }
+ return pos * bounds_oneOverRange - bounds_originOverRange;
+ // TODO texture coords in range different from [0, 1]
+ }
+ """,
+ },
+ "fragment": """
+ #version 120
+
+ /* isnan declaration for compatibility with GLSL 1.20 */
+ bool isnan(float value) {
+ return (value != value);
+ }
+
+ uniform sampler2D data;
+ uniform float data_scale;
+ uniform sampler2D cmap_texture;
+ uniform int cmap_normalization;
+ uniform float cmap_parameter;
+ uniform float cmap_min;
+ uniform float cmap_oneOverRange;
+ uniform float alpha;
+ uniform vec4 nancolor;
+
+ varying vec2 coords;
+
+ %s
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ float raw_data = texture2D(data, textureCoords()).r * data_scale;
+ float value = 0.;
+ if (cmap_normalization == 1) { /*Logarithm mapping*/
+ if (raw_data > 0.) {
+ value = clamp(cmap_oneOverRange *
+ (oneOverLog10 * log(raw_data) - cmap_min),
+ 0., 1.);
+ } else {
+ value = 0.;
+ }
+ } else if (cmap_normalization == 2) { /*Square root mapping*/
+ if (raw_data >= 0.) {
+ value = clamp(cmap_oneOverRange * (sqrt(raw_data) - cmap_min),
+ 0., 1.);
+ } else {
+ value = 0.;
+ }
+ } else if (cmap_normalization == 3) { /*Gamma correction mapping*/
+ value = pow(
+ clamp(cmap_oneOverRange * (raw_data - cmap_min), 0., 1.),
+ cmap_parameter);
+ } else if (cmap_normalization == 4) { /* arcsinh mapping */
+ /* asinh = log(x + sqrt(x*x + 1) for compatibility with GLSL 1.20 */
+ value = clamp(cmap_oneOverRange * (log(raw_data + sqrt(raw_data*raw_data + 1.0)) - cmap_min), 0., 1.);
+ } else { /*Linear mapping and fallback*/
+ value = clamp(cmap_oneOverRange * (raw_data - cmap_min), 0., 1.);
+ }
+
+ if (isnan(raw_data)) {
+ gl_FragColor = nancolor;
+ } else {
+ gl_FragColor = texture2D(cmap_texture, vec2(value, 0.5));
+ }
+ gl_FragColor.a *= alpha;
+ }
+ """,
+ }
+
+ _DATA_TEX_UNIT = 0
+ _CMAP_TEX_UNIT = 1
+
+ _INTERNAL_FORMATS = {
+ numpy.dtype(numpy.float32): gl.GL_R32F,
+ numpy.dtype(numpy.float16): gl.GL_R16F,
+ # Use normalized integer for unsigned int formats
+ numpy.dtype(numpy.uint16): gl.GL_R16,
+ numpy.dtype(numpy.uint8): gl.GL_R8,
+ }
+
+ _linearProgram = Program(
+ _SHADERS["linear"]["vertex"],
+ _SHADERS["fragment"] % _SHADERS["linear"]["fragTransform"],
+ attrib0="position",
+ )
+
+ _logProgram = Program(
+ _SHADERS["log"]["vertex"],
+ _SHADERS["fragment"] % _SHADERS["log"]["fragTransform"],
+ attrib0="position",
+ )
+
+ SUPPORTED_NORMALIZATIONS = "linear", "log", "sqrt", "gamma", "arcsinh"
+
+ def __init__(
+ self,
+ data,
+ origin,
+ scale,
+ colormap,
+ normalization="linear",
+ gamma=0.0,
+ cmapRange=None,
+ alpha=1.0,
+ nancolor=(1.0, 1.0, 1.0, 0.0),
+ ):
+ """Create a 2D colormap
+
+ :param data: The 2D scalar data array to display
+ :type data: numpy.ndarray with 2 dimensions (dtype=numpy.float32)
+ :param origin: (x, y) coordinates of the origin of the data array
+ :type origin: 2-tuple of floats.
+ :param scale: (sx, sy) scale factors of the data array.
+ This is the size of a data pixel in plot data space.
+ :type scale: 2-tuple of floats.
+ :param str colormap: Name of the colormap to use
+ TODO: Accept a 1D scalar array as the colormap
+ :param str normalization: The colormap normalization.
+ One of: 'linear', 'log', 'sqrt', 'gamma'
+ ;param float gamma: The gamma parameter (for 'gamma' normalization)
+ :param cmapRange: The range of colormap or None for autoscale colormap
+ For logarithmic colormap, the range is in the untransformed data
+ TODO: check consistency with matplotlib
+ :type cmapRange: (float, float) or None
+ :param float alpha: Opacity from 0 (transparent) to 1 (opaque)
+ :param nancolor: RGBA color for Not-A-Number values
+ :type nancolor: 4-tuple of float in [0., 1.]
+ """
+ assert data.dtype in self._INTERNAL_FORMATS
+ assert normalization in self.SUPPORTED_NORMALIZATIONS
+
+ super(GLPlotColormap, self).__init__(data, origin, scale)
+ self.colormap = numpy.array(colormap, copy=False)
+ self.normalization = normalization
+ self.gamma = gamma
+ self._cmapRange = (1.0, 10.0) # Colormap range
+ self.cmapRange = cmapRange # Update _cmapRange
+ self._alpha = numpy.clip(alpha, 0.0, 1.0)
+ self._nancolor = numpy.clip(nancolor, 0.0, 1.0)
+
+ self._cmap_texture = None
+ self._texture = None
+ self._textureIsDirty = False
+
+ def discard(self):
+ if self._cmap_texture is not None:
+ self._cmap_texture.discard()
+ self._cmap_texture = None
+
+ if self._texture is not None:
+ self._texture.discard()
+ self._texture = None
+ self._textureIsDirty = False
+
+ def isInitialized(self):
+ return self._cmap_texture is not None or self._texture is not None
+
+ @property
+ def cmapRange(self):
+ if self.normalization == "log":
+ assert self._cmapRange[0] > 0.0 and self._cmapRange[1] > 0.0
+ elif self.normalization == "sqrt":
+ assert self._cmapRange[0] >= 0.0 and self._cmapRange[1] >= 0.0
+ return self._cmapRange
+
+ @cmapRange.setter
+ def cmapRange(self, cmapRange):
+ assert len(cmapRange) == 2
+ assert cmapRange[0] <= cmapRange[1]
+ self._cmapRange = float(cmapRange[0]), float(cmapRange[1])
+
+ @property
+ def alpha(self):
+ return self._alpha
+
+ def updateData(self, data):
+ assert data.dtype in self._INTERNAL_FORMATS
+ oldData = self.data
+ self.data = data
+
+ if self._texture is not None:
+ if self.data.shape != oldData.shape or self.data.dtype != oldData.dtype:
+ self.discard()
+ else:
+ self._textureIsDirty = True
+
+ def prepare(self):
+ if self._cmap_texture is None:
+ # TODO share cmap texture accross Images
+ # put all cmaps in one texture
+ colormap = numpy.empty(
+ (16, 256, self.colormap.shape[1]), dtype=self.colormap.dtype
+ )
+ colormap[:] = self.colormap
+ format_ = gl.GL_RGBA if colormap.shape[-1] == 4 else gl.GL_RGB
+ self._cmap_texture = Texture(
+ internalFormat=format_,
+ data=colormap,
+ format_=format_,
+ texUnit=self._CMAP_TEX_UNIT,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE),
+ )
+ self._cmap_texture.prepare()
+
+ if self._texture is None:
+ internalFormat = self._INTERNAL_FORMATS[self.data.dtype]
+
+ self._texture = Image(
+ internalFormat,
+ self.data,
+ format_=gl.GL_RED,
+ texUnit=self._DATA_TEX_UNIT,
+ )
+ elif self._textureIsDirty:
+ self._textureIsDirty = True
+ self._texture.updateAll(format_=gl.GL_RED, data=self.data)
+
+ def _setCMap(self, prog):
+ dataMin, dataMax = self.cmapRange # If log, it is stricly positive
+ param = 0.0
+
+ if self.data.dtype in (numpy.uint16, numpy.uint8):
+ # Using unsigned int as normalized integer in OpenGL
+ # So revert normalization in the shader
+ dataScale = float(numpy.iinfo(self.data.dtype).max)
+ else:
+ dataScale = 1.0
+
+ if self.normalization == "log":
+ dataMin = math.log10(dataMin)
+ dataMax = math.log10(dataMax)
+ normID = 1
+ elif self.normalization == "sqrt":
+ dataMin = math.sqrt(dataMin)
+ dataMax = math.sqrt(dataMax)
+ normID = 2
+ elif self.normalization == "gamma":
+ # Keep dataMin, dataMax as is
+ param = self.gamma
+ normID = 3
+ elif self.normalization == "arcsinh":
+ dataMin = numpy.arcsinh(dataMin)
+ dataMax = numpy.arcsinh(dataMax)
+ normID = 4
+ else: # Linear and fallback
+ normID = 0
+
+ gl.glUniform1f(prog.uniforms["data_scale"], dataScale)
+ gl.glUniform1i(prog.uniforms["cmap_texture"], self._cmap_texture.texUnit)
+ gl.glUniform1i(prog.uniforms["cmap_normalization"], normID)
+ gl.glUniform1f(prog.uniforms["cmap_parameter"], param)
+ gl.glUniform1f(prog.uniforms["cmap_min"], dataMin)
+ if dataMax > dataMin:
+ oneOverRange = 1.0 / (dataMax - dataMin)
+ else:
+ oneOverRange = 0.0 # Fall-back
+ gl.glUniform1f(prog.uniforms["cmap_oneOverRange"], oneOverRange)
+
+ gl.glUniform4f(prog.uniforms["nancolor"], *self._nancolor)
+
+ self._cmap_texture.bind()
+
+ def _renderLinear(self, context):
+ """Perform rendering when both axes have linear scales
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+
+ prog = self._linearProgram
+ prog.use()
+
+ gl.glUniform1i(prog.uniforms["data"], self._DATA_TEX_UNIT)
+
+ mat = numpy.dot(
+ numpy.dot(context.matrix, mat4Translate(*self.origin)),
+ mat4Scale(*self.scale),
+ )
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
+
+ gl.glUniform1f(prog.uniforms["alpha"], self.alpha)
+
+ self._setCMap(prog)
+
+ self._texture.render(
+ prog.attributes["position"],
+ prog.attributes["texCoords"],
+ self._DATA_TEX_UNIT,
+ )
+
+ def _renderLog10(self, context):
+ """Perform rendering when one axis has log scale
+
+ :param RenderContext context: Rendering information
+ """
+ xMin, yMin = self.xMin, self.yMin
+ if (context.isXLog and xMin < FLOAT32_MINPOS) or (
+ context.isYLog and yMin < FLOAT32_MINPOS
+ ):
+ # Do not render images that are partly or totally <= 0
+ return
+
+ self.prepare()
+
+ prog = self._logProgram
+ prog.use()
+
+ ox, oy = self.origin
+
+ gl.glUniform1i(prog.uniforms["data"], self._DATA_TEX_UNIT)
+
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, context.matrix.astype(numpy.float32)
+ )
+ mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale))
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matOffset"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
+
+ gl.glUniform2i(prog.uniforms["isLog"], context.isXLog, context.isYLog)
+
+ ex = ox + self.scale[0] * self.data.shape[1]
+ ey = oy + self.scale[1] * self.data.shape[0]
+
+ xOneOverRange = 1.0 / (ex - ox)
+ yOneOverRange = 1.0 / (ey - oy)
+ gl.glUniform2f(
+ prog.uniforms["bounds_originOverRange"],
+ ox * xOneOverRange,
+ oy * yOneOverRange,
+ )
+ gl.glUniform2f(
+ prog.uniforms["bounds_oneOverRange"], xOneOverRange, yOneOverRange
+ )
+
+ gl.glUniform1f(prog.uniforms["alpha"], self.alpha)
+
+ self._setCMap(prog)
+
+ try:
+ tiles = self._texture.tiles
+ except AttributeError:
+ raise RuntimeError("No texture, discard has already been called")
+ if len(tiles) > 1:
+ raise NotImplementedError(
+ "Image over multiple textures not supported with log scale"
+ )
+
+ texture, vertices, info = tiles[0]
+
+ texture.bind(self._DATA_TEX_UNIT)
+
+ posAttrib = prog.attributes["position"]
+ stride = vertices.shape[-1] * vertices.itemsize
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, stride, vertices
+ )
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ if any((context.isXLog, context.isYLog)):
+ self._renderLog10(context)
+ else:
+ self._renderLinear(context)
+
+ # Unbind colormap texture
+ gl.glActiveTexture(gl.GL_TEXTURE0 + self._cmap_texture.texUnit)
+ gl.glBindTexture(self._cmap_texture.target, 0)
+
+
+# image #######################################################################
+
+
+class GLPlotRGBAImage(_GLPlotData2D):
+ _SHADERS = {
+ "linear": {
+ "vertex": """
+ #version 120
+
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ coords = texCoords;
+ }
+ """,
+ "fragment": """
+ #version 120
+
+ uniform sampler2D tex;
+ uniform float alpha;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, coords);
+ gl_FragColor.a *= alpha;
+ }
+ """,
+ },
+ "log": {
+ "vertex": """
+ #version 120
+
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform mat4 matOffset;
+ uniform bvec2 isLog;
+
+ varying vec2 coords;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec4 dataPos = matOffset * vec4(position, 0.0, 1.0);
+ if (isLog.x) {
+ dataPos.x = oneOverLog10 * log(dataPos.x);
+ }
+ if (isLog.y) {
+ dataPos.y = oneOverLog10 * log(dataPos.y);
+ }
+ coords = dataPos.xy;
+ gl_Position = matrix * dataPos;
+ }
+ """,
+ "fragment": """
+ #version 120
+
+ uniform sampler2D tex;
+ uniform bvec2 isLog;
+ uniform vec2 bounds_oneOverRange;
+ uniform vec2 bounds_originOverRange;
+ uniform float alpha;
+
+ varying vec2 coords;
+
+ vec2 textureCoords(void) {
+ vec2 pos = coords;
+ if (isLog.x) {
+ pos.x = pow(10., coords.x);
+ }
+ if (isLog.y) {
+ pos.y = pow(10., coords.y);
+ }
+ return pos * bounds_oneOverRange - bounds_originOverRange;
+ // TODO texture coords in range different from [0, 1]
+ }
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, textureCoords());
+ gl_FragColor.a *= alpha;
+ }
+ """,
+ },
+ }
+
+ _DATA_TEX_UNIT = 0
+
+ _SUPPORTED_DTYPES = (
+ numpy.dtype(numpy.float32),
+ numpy.dtype(numpy.uint8),
+ numpy.dtype(numpy.uint16),
+ )
+
+ _linearProgram = Program(
+ _SHADERS["linear"]["vertex"], _SHADERS["linear"]["fragment"], attrib0="position"
+ )
+
+ _logProgram = Program(
+ _SHADERS["log"]["vertex"], _SHADERS["log"]["fragment"], attrib0="position"
+ )
+
+ def __init__(self, data, origin, scale, alpha):
+ """Create a 2D RGB(A) image from data
+
+ :param data: The 2D image data array to display
+ :type data: numpy.ndarray with 3 dimensions
+ (dtype=numpy.uint8 or numpy.float32)
+ :param origin: (x, y) coordinates of the origin of the data array
+ :type origin: 2-tuple of floats.
+ :param scale: (sx, sy) scale factors of the data array.
+ This is the size of a data pixel in plot data space.
+ :type scale: 2-tuple of floats.
+ :param float alpha: Opacity from 0 (transparent) to 1 (opaque)
+ """
+ assert data.dtype in self._SUPPORTED_DTYPES
+ super(GLPlotRGBAImage, self).__init__(data, origin, scale)
+ self._texture = None
+ self._textureIsDirty = False
+ self._alpha = numpy.clip(alpha, 0.0, 1.0)
+
+ @property
+ def alpha(self):
+ return self._alpha
+
+ def discard(self):
+ if self.isInitialized():
+ self._texture.discard()
+ self._texture = None
+ self._textureIsDirty = False
+
+ def isInitialized(self):
+ return self._texture is not None
+
+ def updateData(self, data):
+ assert data.dtype in self._SUPPORTED_DTYPES
+ oldData = self.data
+ self.data = data
+
+ if self._texture is not None:
+ if self.data.shape != oldData.shape:
+ self.discard()
+ else:
+ self._textureIsDirty = True
+
+ def prepare(self):
+ if self._texture is None:
+ formatName = "GL_RGBA" if self.data.shape[2] == 4 else "GL_RGB"
+ format_ = getattr(gl, formatName)
+
+ if self.data.dtype == numpy.uint16:
+ formatName += "16" # Use sized internal format for uint16
+ internalFormat = getattr(gl, formatName)
+
+ self._texture = Image(
+ internalFormat, self.data, format_=format_, texUnit=self._DATA_TEX_UNIT
+ )
+ elif self._textureIsDirty:
+ self._textureIsDirty = False
+
+ # We should check that internal format is the same
+ format_ = gl.GL_RGBA if self.data.shape[2] == 4 else gl.GL_RGB
+ self._texture.updateAll(format_=format_, data=self.data)
+
+ def _renderLinear(self, context):
+ """Perform rendering with both axes having linear scales
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+
+ prog = self._linearProgram
+ prog.use()
+
+ gl.glUniform1i(prog.uniforms["tex"], self._DATA_TEX_UNIT)
+
+ mat = numpy.dot(
+ numpy.dot(context.matrix, mat4Translate(*self.origin)),
+ mat4Scale(*self.scale),
+ )
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
+
+ gl.glUniform1f(prog.uniforms["alpha"], self.alpha)
+
+ self._texture.render(
+ prog.attributes["position"],
+ prog.attributes["texCoords"],
+ self._DATA_TEX_UNIT,
+ )
+
+ def _renderLog(self, context):
+ """Perform rendering with axes having log scale
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+
+ prog = self._logProgram
+ prog.use()
+
+ ox, oy = self.origin
+
+ gl.glUniform1i(prog.uniforms["tex"], self._DATA_TEX_UNIT)
+
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, context.matrix.astype(numpy.float32)
+ )
+ mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale))
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matOffset"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
+
+ gl.glUniform2i(prog.uniforms["isLog"], context.isXLog, context.isYLog)
+
+ gl.glUniform1f(prog.uniforms["alpha"], self.alpha)
+
+ ex = ox + self.scale[0] * self.data.shape[1]
+ ey = oy + self.scale[1] * self.data.shape[0]
+
+ xOneOverRange = 1.0 / (ex - ox)
+ yOneOverRange = 1.0 / (ey - oy)
+ gl.glUniform2f(
+ prog.uniforms["bounds_originOverRange"],
+ ox * xOneOverRange,
+ oy * yOneOverRange,
+ )
+ gl.glUniform2f(
+ prog.uniforms["bounds_oneOverRange"], xOneOverRange, yOneOverRange
+ )
+
+ try:
+ tiles = self._texture.tiles
+ except AttributeError:
+ raise RuntimeError("No texture, discard has already been called")
+ if len(tiles) > 1:
+ raise NotImplementedError(
+ "Image over multiple textures not supported with log scale"
+ )
+
+ texture, vertices, info = tiles[0]
+
+ texture.bind(self._DATA_TEX_UNIT)
+
+ posAttrib = prog.attributes["position"]
+ stride = vertices.shape[-1] * vertices.itemsize
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, stride, vertices
+ )
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ if any((context.isXLog, context.isYLog)):
+ self._renderLog(context)
+ else:
+ self._renderLinear(context)
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotItem.py b/src/silx/gui/plot/backends/glutils/GLPlotItem.py
new file mode 100644
index 0000000..0287ad5
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotItem.py
@@ -0,0 +1,105 @@
+# /*##########################################################################
+#
+# Copyright (c) 2020-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides a base class for PlotWidget OpenGL backend primitives
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/07/2020"
+
+
+class RenderContext:
+ """Context with which to perform OpenGL rendering.
+
+ :param numpy.ndarray matrix: 4x4 transform matrix to use for rendering
+ :param bool isXLog: Whether X axis is log scale or not
+ :param bool isYLog: Whether Y axis is log scale or not
+ :param float dpi: Number of device pixels per inch
+ """
+
+ def __init__(
+ self, matrix=None, isXLog=False, isYLog=False, dpi=96.0, plotFrame=None
+ ):
+ self.matrix = matrix
+ """Current transformation matrix"""
+
+ self.__isXLog = isXLog
+ self.__isYLog = isYLog
+ self.__dpi = dpi
+ self.__plotFrame = plotFrame
+
+ @property
+ def isXLog(self):
+ """True if X axis is using log scale"""
+ return self.__isXLog
+
+ @property
+ def isYLog(self):
+ """True if Y axis is using log scale"""
+ return self.__isYLog
+
+ @property
+ def dpi(self):
+ """Number of device pixels per inch"""
+ return self.__dpi
+
+ @property
+ def plotFrame(self):
+ """Current PlotFrame"""
+ return self.__plotFrame
+
+
+class GLPlotItem:
+ """Base class for primitives used in the PlotWidget OpenGL backend"""
+
+ def __init__(self):
+ self.yaxis = "left"
+ "YAxis this item is attached to (either 'left' or 'right')"
+
+ def pick(self, x, y):
+ """Perform picking at given position.
+
+ :param float x: X coordinate in plot data frame of reference
+ :param float y: Y coordinate in plot data frame of reference
+ :returns:
+ Result of picking as a list of indices or None if nothing picked
+ :rtype: Union[List[int],None]
+ """
+ return None
+
+ def render(self, context):
+ """Performs OpenGL rendering of the item.
+
+ :param RenderContext context: Rendering context information
+ """
+ pass
+
+ def discard(self):
+ """Discards OpenGL resources this item has created."""
+ pass
+
+ def isInitialized(self) -> bool:
+ """Returns True if resources where initialized and requires `discard`."""
+ return True
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
new file mode 100644
index 0000000..e8a8e4a
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
@@ -0,0 +1,203 @@
+# /*##########################################################################
+#
+# Copyright (c) 2019-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides a class to render a set of 2D triangles
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import ctypes
+
+import numpy
+
+from .....math.combo import min_max
+from .... import _glutils as glutils
+from ...._glutils import gl
+from .GLPlotItem import GLPlotItem
+
+
+class GLPlotTriangles(GLPlotItem):
+ """Handle rendering of a set of colored triangles"""
+
+ _PROGRAM = glutils.Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0);
+ vColor = color;
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ uniform float alpha;
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_FragColor = vColor;
+ gl_FragColor.a *= alpha;
+ }
+ """,
+ attrib0="xPos",
+ )
+
+ def __init__(self, x, y, color, triangles, alpha=1.0):
+ """
+
+ :param numpy.ndarray x: X coordinates of triangle corners
+ :param numpy.ndarray y: Y coordinates of triangle corners
+ :param numpy.ndarray color: color for each point
+ :param numpy.ndarray triangles: (N, 3) array of indices of triangles
+ :param float alpha: Opacity in [0, 1]
+ """
+ super().__init__()
+ # Check and convert input data
+ x = numpy.ravel(numpy.array(x, dtype=numpy.float32))
+ y = numpy.ravel(numpy.array(y, dtype=numpy.float32))
+ color = numpy.array(color, copy=False)
+ # Cast to uint32
+ triangles = numpy.array(triangles, copy=False, dtype=numpy.uint32)
+
+ assert x.size == y.size
+ assert x.size == len(color)
+ assert color.ndim == 2 and color.shape[1] in (3, 4)
+ if numpy.issubdtype(color.dtype, numpy.floating):
+ color = numpy.array(color, dtype=numpy.float32, copy=False)
+ elif numpy.issubdtype(color.dtype, numpy.integer):
+ color = numpy.array(color, dtype=numpy.uint8, copy=False)
+ else:
+ raise ValueError("Unsupported color type")
+ assert triangles.ndim == 2 and triangles.shape[1] == 3
+
+ self.__x_y_color = x, y, color
+ self.xMin, self.xMax = min_max(x, finite=True)
+ self.yMin, self.yMax = min_max(y, finite=True)
+ self.__triangles = triangles
+ self.__alpha = numpy.clip(float(alpha), 0.0, 1.0)
+ self.__vbos = None
+ self.__indicesVbo = None
+ self.__picking_triangles = None
+
+ def pick(self, x, y):
+ """Perform picking
+
+ :param float x: X coordinates in plot data frame
+ :param float y: Y coordinates in plot data frame
+ :return: List of picked data point indices
+ :rtype: Union[List[int],None]
+ """
+ if x < self.xMin or x > self.xMax or y < self.yMin or y > self.yMax:
+ return None
+
+ xPts, yPts = self.__x_y_color[:2]
+ if self.__picking_triangles is None:
+ self.__picking_triangles = numpy.zeros(
+ self.__triangles.shape + (3,), dtype=numpy.float32
+ )
+ self.__picking_triangles[:, :, 0] = xPts[self.__triangles]
+ self.__picking_triangles[:, :, 1] = yPts[self.__triangles]
+
+ segment = numpy.array(((x, y, -1), (x, y, 1)), dtype=numpy.float32)
+ # Picked triangle indices
+ indices = glutils.segmentTrianglesIntersection(
+ segment, self.__picking_triangles
+ )[0]
+ # Point indices
+ indices = numpy.unique(numpy.ravel(self.__triangles[indices]))
+
+ # Sorted from furthest to closest point
+ dists = (xPts[indices] - x) ** 2 + (yPts[indices] - y) ** 2
+ indices = indices[numpy.flip(numpy.argsort(dists), axis=0)]
+
+ return tuple(indices) if len(indices) > 0 else None
+
+ def discard(self):
+ """Release resources on the GPU"""
+ if self.isInitialized():
+ self.__vbos[0].vbo.discard()
+ self.__vbos = None
+ self.__indicesVbo.discard()
+ self.__indicesVbo = None
+
+ def isInitialized(self):
+ return self.__vbos is not None
+
+ def prepare(self):
+ """Allocate resources on the GPU"""
+ if self.__vbos is None:
+ self.__vbos = glutils.vertexBuffer(self.__x_y_color)
+ # Normalization is need for color
+ self.__vbos[-1].normalization = True
+
+ if self.__indicesVbo is None:
+ self.__indicesVbo = glutils.VertexBuffer(
+ numpy.ravel(self.__triangles),
+ usage=gl.GL_STATIC_DRAW,
+ target=gl.GL_ELEMENT_ARRAY_BUFFER,
+ )
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+
+ if self.__vbos is None or self.__indicesVbo is None:
+ return # Nothing to display
+
+ self._PROGRAM.use()
+
+ gl.glUniformMatrix4fv(
+ self._PROGRAM.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ context.matrix.astype(numpy.float32),
+ )
+
+ gl.glUniform1f(self._PROGRAM.uniforms["alpha"], self.__alpha)
+
+ for index, name in enumerate(("xPos", "yPos", "color")):
+ attr = self._PROGRAM.attributes[name]
+ gl.glEnableVertexAttribArray(attr)
+ self.__vbos[index].setVertexAttrib(attr)
+
+ with self.__indicesVbo:
+ gl.glDrawElements(
+ gl.GL_TRIANGLES,
+ self.__triangles.size,
+ glutils.numpyToGLType(self.__triangles.dtype),
+ ctypes.c_void_p(0),
+ )
diff --git a/src/silx/gui/plot/backends/glutils/GLSupport.py b/src/silx/gui/plot/backends/glutils/GLSupport.py
new file mode 100644
index 0000000..c9afda0
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLSupport.py
@@ -0,0 +1,174 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides convenient classes and functions for OpenGL rendering.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import numpy
+
+from ...._glutils import gl
+
+
+def buildFillMaskIndices(nIndices, dtype=None):
+ """Returns triangle strip indices for rendering a filled polygon mask
+
+ :param int nIndices: Number of points
+ :param Union[numpy.dtype,None] dtype:
+ If specified the dtype of the returned indices array
+ :return: 1D array of indices constructing a triangle strip
+ :rtype: numpy.ndarray
+ """
+ if dtype is None:
+ if nIndices <= numpy.iinfo(numpy.uint16).max + 1:
+ dtype = numpy.uint16
+ else:
+ dtype = numpy.uint32
+
+ lastIndex = nIndices - 1
+ splitIndex = lastIndex // 2 + 1
+ indices = numpy.empty(nIndices, dtype=dtype)
+ indices[::2] = numpy.arange(0, splitIndex, step=1, dtype=dtype)
+ indices[1::2] = numpy.arange(lastIndex, splitIndex - 1, step=-1, dtype=dtype)
+ return indices
+
+
+class FilledShape2D(object):
+ _NO_HATCH = 0
+ _HATCH_STEP = 20
+
+ def __init__(self, points, style="solid", color=(0.0, 0.0, 0.0, 1.0)):
+ self.vertices = numpy.array(points, dtype=numpy.float32, copy=False)
+ self._indices = buildFillMaskIndices(len(self.vertices))
+
+ tVertex = numpy.transpose(self.vertices)
+ xMin, xMax = min(tVertex[0]), max(tVertex[0])
+ yMin, yMax = min(tVertex[1]), max(tVertex[1])
+ self.bboxVertices = numpy.array(
+ ((xMin, yMin), (xMin, yMax), (xMax, yMin), (xMax, yMax)),
+ dtype=numpy.float32,
+ )
+ self._xMin, self._xMax = xMin, xMax
+ self._yMin, self._yMax = yMin, yMax
+
+ self.style = style
+ self.color = color
+
+ def render(self, posAttrib, colorUnif, hatchStepUnif):
+ assert self.style in ("hatch", "solid")
+ gl.glUniform4f(colorUnif, *self.color)
+ step = self._HATCH_STEP if self.style == "hatch" else self._NO_HATCH
+ gl.glUniform1i(hatchStepUnif, step)
+
+ # Prepare fill mask
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, self.vertices
+ )
+
+ gl.glEnable(gl.GL_STENCIL_TEST)
+ gl.glStencilMask(1)
+ gl.glStencilFunc(gl.GL_ALWAYS, 1, 1)
+ gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT)
+ gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_FALSE)
+
+ gl.glDrawElements(
+ gl.GL_TRIANGLE_STRIP,
+ len(self._indices),
+ gl.GL_UNSIGNED_SHORT,
+ self._indices,
+ )
+
+ gl.glStencilFunc(gl.GL_EQUAL, 1, 1)
+ # Reset stencil while drawing
+ gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO)
+ gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_TRUE)
+
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, self.bboxVertices
+ )
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self.bboxVertices))
+
+ gl.glDisable(gl.GL_STENCIL_TEST)
+
+
+# matrix ######################################################################
+
+
+def mat4Ortho(left, right, bottom, top, near, far):
+ """Orthographic projection matrix (row-major)"""
+ return numpy.array(
+ (
+ (2.0 / (right - left), 0.0, 0.0, -(right + left) / float(right - left)),
+ (0.0, 2.0 / (top - bottom), 0.0, -(top + bottom) / float(top - bottom)),
+ (0.0, 0.0, -2.0 / (far - near), -(far + near) / float(far - near)),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float64,
+ )
+
+
+def mat4Translate(x=0.0, y=0.0, z=0.0):
+ """Translation matrix (row-major)"""
+ return numpy.array(
+ (
+ (1.0, 0.0, 0.0, x),
+ (0.0, 1.0, 0.0, y),
+ (0.0, 0.0, 1.0, z),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float64,
+ )
+
+
+def mat4Scale(sx=1.0, sy=1.0, sz=1.0):
+ """Scale matrix (row-major)"""
+ return numpy.array(
+ (
+ (sx, 0.0, 0.0, 0.0),
+ (0.0, sy, 0.0, 0.0),
+ (0.0, 0.0, sz, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float64,
+ )
+
+
+def mat4Identity():
+ """Identity matrix"""
+ return numpy.array(
+ (
+ (1.0, 0.0, 0.0, 0.0),
+ (0.0, 1.0, 0.0, 0.0),
+ (0.0, 0.0, 1.0, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float64,
+ )
diff --git a/src/silx/gui/plot/backends/glutils/GLText.py b/src/silx/gui/plot/backends/glutils/GLText.py
new file mode 100644
index 0000000..15d7a70
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLText.py
@@ -0,0 +1,297 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides minimalistic text support for OpenGL.
+It provides Latin-1 (ISO8859-1) characters for one monospace font at one size.
+"""
+
+from __future__ import annotations
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+from collections import OrderedDict
+import weakref
+
+import numpy
+
+from .... import qt
+from ...._glutils import font, gl, Context, Program, Texture
+from .GLSupport import mat4Translate
+from silx.gui.colors import RGBAColorType
+
+
+class _Cache:
+ """LRU (Least Recent Used) cache.
+
+ :param int maxsize: Maximum number of (key, value) pairs in the cache
+ :param callable callback:
+ Called when a (key, value) pair is removed from the cache.
+ It must take 2 arguments: key and value.
+ """
+
+ def __init__(self, maxsize=128, callback=None):
+ self._maxsize = int(maxsize)
+ self._callback = callback
+ self._cache = OrderedDict() # Needed for popitem(last=False)
+
+ def __contains__(self, item):
+ return item in self._cache
+
+ def __getitem__(self, key):
+ if key in self._cache:
+ # Remove/add key from ordered dict to store last access info
+ value = self._cache.pop(key)
+ self._cache[key] = value
+ return value
+ else:
+ raise KeyError
+
+ def __setitem__(self, key, value):
+ """Add a key, value pair to the cache.
+
+ :param key: The key to set
+ :param value: The corresponding value
+ """
+ if key not in self._cache and len(self._cache) >= self._maxsize:
+ removedKey, removedValue = self._cache.popitem(last=False)
+ if self._callback is not None:
+ self._callback(removedKey, removedValue)
+ self._cache[key] = value
+
+
+# Text2D ######################################################################
+
+LEFT, CENTER, RIGHT = "left", "center", "right"
+TOP, BASELINE, BOTTOM = "top", "baseline", "bottom"
+ROTATE_90, ROTATE_180, ROTATE_270 = 90, 180, 270
+
+
+class Text2D:
+ _SHADERS = {
+ "vertex": """
+ #version 120
+
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 vCoords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ vCoords = texCoords;
+ }
+ """,
+ "fragment": """
+ #version 120
+
+ uniform sampler2D texText;
+ uniform vec4 color;
+ uniform vec4 bgColor;
+
+ varying vec2 vCoords;
+
+ void main(void) {
+ if (vCoords.x < 0.0 || vCoords.x > 1.0 || vCoords.y < 0.0 || vCoords.y > 1.0) {
+ gl_FragColor = bgColor;
+ } else {
+ gl_FragColor = mix(bgColor, color, texture2D(texText, vCoords).r);
+ }
+ }
+ """,
+ }
+
+ _program = Program(_SHADERS["vertex"], _SHADERS["fragment"], attrib0="position")
+
+ # Discard texture objects when removed from the cache
+ _textures = weakref.WeakKeyDictionary()
+ """Cache already created textures"""
+
+ def __init__(
+ self,
+ text: str,
+ font: qt.QFont,
+ x: float = 0.0,
+ y: float = 0.0,
+ color: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
+ bgColor: RGBAColorType | None = None,
+ align: str = LEFT,
+ valign: str = BASELINE,
+ rotate: float = 0.0,
+ devicePixelRatio: float = 1.0,
+ padding: int = 0,
+ ):
+ self.devicePixelRatio = devicePixelRatio
+ self.font = font
+ self._vertices = None
+ self._text = text
+ self._padding = padding
+ self.x = x
+ self.y = y
+ self.color = color
+ self.bgColor = bgColor
+
+ if align not in (LEFT, CENTER, RIGHT):
+ raise ValueError("Horizontal alignment not supported: {0}".format(align))
+ self._align = align
+
+ if valign not in (TOP, CENTER, BASELINE, BOTTOM):
+ raise ValueError("Vertical alignment not supported: {0}".format(valign))
+ self._valign = valign
+
+ self._rotate = numpy.radians(rotate)
+
+ def _getTexture(self, dotsPerInch: float) -> tuple[Texture, int]:
+ # Retrieve/initialize texture cache for current context
+ key = self.text, self.font.key(), dotsPerInch
+
+ context = Context.getCurrent()
+ if context not in self._textures:
+ self._textures[context] = _Cache(
+ callback=lambda key, value: value[0].discard()
+ )
+ textures = self._textures[context]
+
+ if key not in textures:
+ image, offset = font.rasterText(self.text, self.font, dotsPerInch)
+
+ texture = Texture(
+ gl.GL_RED,
+ data=image,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE),
+ )
+ texture.prepare()
+ textures[key] = texture, offset
+
+ return textures[key]
+
+ @property
+ def text(self) -> str:
+ return self._text
+
+ @property
+ def padding(self) -> int:
+ return self._padding
+
+ def getVertices(self, offset: int, shape: tuple[int, int]) -> numpy.ndarray:
+ height, width = shape
+
+ if self._align == LEFT:
+ xOrig = 0
+ elif self._align == RIGHT:
+ xOrig = -width
+ else: # CENTER
+ xOrig = -width // 2
+
+ if self._valign == BASELINE:
+ yOrig = -offset
+ elif self._valign == TOP:
+ yOrig = 0
+ elif self._valign == BOTTOM:
+ yOrig = -height
+ else: # CENTER
+ yOrig = -height // 2
+
+ vertices = numpy.array(
+ (
+ (xOrig, yOrig),
+ (xOrig + width, yOrig),
+ (xOrig, yOrig + height),
+ (xOrig + width, yOrig + height),
+ ),
+ dtype=numpy.float32,
+ )
+
+ cos, sin = numpy.cos(self._rotate), numpy.sin(self._rotate)
+ vertices = numpy.ascontiguousarray(
+ numpy.transpose(
+ numpy.array(
+ (
+ cos * vertices[:, 0] - sin * vertices[:, 1],
+ sin * vertices[:, 0] + cos * vertices[:, 1],
+ ),
+ dtype=numpy.float32,
+ )
+ )
+ )
+
+ return vertices
+
+ def render(self, matrix: numpy.ndarray, dotsPerInch: float):
+ if not self.text.strip():
+ return
+
+ prog = self._program
+ prog.use()
+
+ texUnit = 0
+ texture, offset = self._getTexture(dotsPerInch)
+
+ gl.glUniform1i(prog.uniforms["texText"], texUnit)
+
+ mat = numpy.dot(matrix, mat4Translate(int(self.x), int(self.y)))
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
+
+ gl.glUniform4f(prog.uniforms["color"], *self.color)
+ if self.bgColor is not None:
+ bgColor = self.bgColor
+ else:
+ bgColor = self.color[0], self.color[1], self.color[2], 0.0
+ gl.glUniform4f(prog.uniforms["bgColor"], *bgColor)
+
+ paddingOffset = max(0, int(self.padding * self.devicePixelRatio))
+ height, width = texture.shape
+ vertices = self.getVertices(
+ offset, (height + 2 * paddingOffset, width + 2 * paddingOffset)
+ )
+
+ posAttrib = prog.attributes["position"]
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, vertices)
+
+ xoffset = paddingOffset / width
+ yoffset = paddingOffset / height
+ texCoords = numpy.array(
+ (
+ (-xoffset, -yoffset),
+ (1.0 + xoffset, -yoffset),
+ (-xoffset, 1.0 + yoffset),
+ (1.0 + xoffset, 1.0 + yoffset),
+ ),
+ dtype=numpy.float32,
+ ).ravel()
+
+ texAttrib = prog.attributes["texCoords"]
+ gl.glEnableVertexAttribArray(texAttrib)
+ gl.glVertexAttribPointer(texAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, texCoords)
+
+ with texture:
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4)
diff --git a/src/silx/gui/plot/backends/glutils/GLTexture.py b/src/silx/gui/plot/backends/glutils/GLTexture.py
new file mode 100644
index 0000000..cbbe7ac
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLTexture.py
@@ -0,0 +1,269 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides classes wrapping OpenGL texture."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+from ctypes import c_void_p
+import logging
+
+import numpy
+
+from ...._glutils import gl, Texture, numpyToGLType
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _checkTexture2D(internalFormat, shape, format_=None, type_=gl.GL_FLOAT, border=0):
+ """Check if texture size with provided parameters is supported
+
+ :rtype: bool
+ """
+ height, width = shape
+ gl.glTexImage2D(
+ gl.GL_PROXY_TEXTURE_2D,
+ 0,
+ internalFormat,
+ width,
+ height,
+ border,
+ format_ or internalFormat,
+ type_,
+ c_void_p(0),
+ )
+ width = gl.glGetTexLevelParameteriv(gl.GL_PROXY_TEXTURE_2D, 0, gl.GL_TEXTURE_WIDTH)
+ return bool(width)
+
+
+MIN_TEXTURE_SIZE = 64
+
+
+def _getMaxSquareTexture2DSize(
+ internalFormat=gl.GL_RGBA, format_=None, type_=gl.GL_FLOAT, border=0
+):
+ """Returns a supported size for a corresponding square texture
+
+ :returns: GL_MAX_TEXTURE_SIZE or a smaller supported size (not optimal)
+ :rtype: int
+ """
+ # Is this useful?
+ maxTexSize = gl.glGetIntegerv(gl.GL_MAX_TEXTURE_SIZE)
+ while maxTexSize > MIN_TEXTURE_SIZE and not _checkTexture2D(
+ internalFormat, (maxTexSize, maxTexSize), format_, type_, border
+ ):
+ maxTexSize //= 2
+ return max(MIN_TEXTURE_SIZE, maxTexSize)
+
+
+class Image(object):
+ """Image of any size eventually using multiple textures or larger texture"""
+
+ _WRAP = (gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE)
+ _MIN_FILTER = gl.GL_NEAREST
+ _MAG_FILTER = gl.GL_NEAREST
+
+ def __init__(self, internalFormat, data, format_=None, texUnit=0):
+ self.internalFormat = internalFormat
+ self.height, self.width = data.shape[0:2]
+ type_ = numpyToGLType(data.dtype)
+
+ if _checkTexture2D(internalFormat, data.shape[0:2], format_, type_):
+ texture = Texture(
+ internalFormat,
+ data,
+ format_,
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP,
+ )
+ texture.prepare()
+ vertices = numpy.array(
+ (
+ (0.0, 0.0, 0.0, 0.0),
+ (self.width, 0.0, 1.0, 0.0),
+ (0.0, self.height, 0.0, 1.0),
+ (self.width, self.height, 1.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
+ self.tiles = (
+ (
+ texture,
+ vertices,
+ {
+ "xOrigData": 0,
+ "yOrigData": 0,
+ "wData": self.width,
+ "hData": self.height,
+ },
+ ),
+ )
+
+ else:
+ # Handle dimension too large: make tiles
+ maxTexSize = _getMaxSquareTexture2DSize(internalFormat, format_, type_)
+
+ nCols = (self.width + maxTexSize - 1) // maxTexSize
+ colWidths = [self.width // nCols] * nCols
+ colWidths[-1] += self.width % nCols
+
+ nRows = (self.height + maxTexSize - 1) // maxTexSize
+ rowHeights = [self.height // nRows] * nRows
+ rowHeights[-1] += self.height % nRows
+
+ tiles = []
+ yOrig = 0
+ for hData in rowHeights:
+ xOrig = 0
+ for wData in colWidths:
+ if (
+ hData < MIN_TEXTURE_SIZE or wData < MIN_TEXTURE_SIZE
+ ) and not _checkTexture2D(
+ internalFormat, (hData, wData), format_, type_
+ ):
+ # Ensure texture size is at least MIN_TEXTURE_SIZE
+ tH = max(hData, MIN_TEXTURE_SIZE)
+ tW = max(wData, MIN_TEXTURE_SIZE)
+
+ uMax, vMax = float(wData) / tW, float(hData) / tH
+
+ # TODO issue with type_ and alignment
+ texture = Texture(
+ internalFormat,
+ data=None,
+ format_=format_,
+ shape=(tH, tW),
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP,
+ )
+ # TODO handle unpack
+ texture.update(
+ format_, data[yOrig : yOrig + hData, xOrig : xOrig + wData]
+ )
+ # texture.update(format_, type_, data,
+ # width=wData, height=hData,
+ # unpackRowLength=width,
+ # unpackSkipPixels=xOrig,
+ # unpackSkipRows=yOrig)
+ else:
+ uMax, vMax = 1, 1
+ # TODO issue with type_ and unpacking tiles
+ # TODO idea to handle unpack: use array strides
+ # As it is now, it will make a copy
+ texture = Texture(
+ internalFormat,
+ data[yOrig : yOrig + hData, xOrig : xOrig + wData],
+ format_,
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP,
+ )
+ # TODO
+ # unpackRowLength=width,
+ # unpackSkipPixels=xOrig,
+ # unpackSkipRows=yOrig)
+ vertices = numpy.array(
+ (
+ (xOrig, yOrig, 0.0, 0.0),
+ (xOrig + wData, yOrig, uMax, 0.0),
+ (xOrig, yOrig + hData, 0.0, vMax),
+ (xOrig + wData, yOrig + hData, uMax, vMax),
+ ),
+ dtype=numpy.float32,
+ )
+ texture.prepare()
+ tiles.append(
+ (
+ texture,
+ vertices,
+ {
+ "xOrigData": xOrig,
+ "yOrigData": yOrig,
+ "wData": wData,
+ "hData": hData,
+ },
+ )
+ )
+ xOrig += wData
+ yOrig += hData
+ self.tiles = tuple(tiles)
+
+ def discard(self):
+ for texture, vertices, _ in self.tiles:
+ texture.discard()
+ del self.tiles
+
+ def updateAll(self, format_, data, texUnit=0):
+ if not hasattr(self, "tiles"):
+ raise RuntimeError("No texture, discard has already been called")
+
+ assert data.shape[:2] == (self.height, self.width)
+ if len(self.tiles) == 1:
+ self.tiles[0][0].update(format_, data, texUnit=texUnit)
+ else:
+ for texture, _, info in self.tiles:
+ yOrig, xOrig = info["yOrigData"], info["xOrigData"]
+ height, width = info["hData"], info["wData"]
+ texture.update(
+ format_,
+ data[yOrig : yOrig + height, xOrig : xOrig + width],
+ texUnit=texUnit,
+ )
+ texture.prepare()
+ # TODO check
+ # width=info['wData'], height=info['hData'],
+ # texUnit=texUnit, unpackAlign=unpackAlign,
+ # unpackRowLength=self.width,
+ # unpackSkipPixels=info['xOrigData'],
+ # unpackSkipRows=info['yOrigData'])
+
+ def render(self, posAttrib, texAttrib, texUnit=0):
+ try:
+ tiles = self.tiles
+ except AttributeError:
+ raise RuntimeError("No texture, discard has already been called")
+
+ for texture, vertices, _ in tiles:
+ texture.bind(texUnit)
+
+ stride = vertices.shape[-1] * vertices.itemsize
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, stride, vertices
+ )
+
+ texCoordsPtr = c_void_p(vertices.ctypes.data + 2 * vertices.itemsize)
+ gl.glEnableVertexAttribArray(texAttrib)
+ gl.glVertexAttribPointer(
+ texAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, stride, texCoordsPtr
+ )
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
diff --git a/src/silx/gui/plot/backends/glutils/PlotImageFile.py b/src/silx/gui/plot/backends/glutils/PlotImageFile.py
new file mode 100644
index 0000000..1622122
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/PlotImageFile.py
@@ -0,0 +1,159 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Function to save an image to a file."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import base64
+import struct
+import zlib
+
+from fabio.TiffIO import TiffIO
+
+
+# Image writer ################################################################
+
+
+def convertRGBDataToPNG(data):
+ """Convert a RGB bitmap to PNG.
+
+ It only supports RGB bitmap with one byte per channel stored as a 3D array.
+ See `Definitive Guide <http://www.libpng.org/pub/png/book/>`_ and
+ `Specification <http://www.libpng.org/pub/png/spec/1.2/>`_ for details.
+
+ :param data: A 3D array (h, w, rgb) storing an RGB image
+ :type data: numpy.ndarray of unsigned bytes
+ :returns: The PNG encoded data
+ :rtype: bytes
+ """
+ height, width = data.shape[0], data.shape[1]
+ depth = 8 # 8 bit per channel
+ colorType = 2 # 'truecolor' = RGB
+ interlace = 0 # No
+
+ IHDRdata = struct.pack(
+ ">ccccIIBBBBB",
+ b"I",
+ b"H",
+ b"D",
+ b"R",
+ width,
+ height,
+ depth,
+ colorType,
+ 0,
+ 0,
+ interlace,
+ )
+
+ # Add filter 'None' before each scanline
+ preparedData = b"\x00" + b"\x00".join(line.tobytes() for line in data)
+ compressedData = zlib.compress(preparedData, 8)
+
+ IDATdata = struct.pack("cccc", b"I", b"D", b"A", b"T")
+ IDATdata += compressedData
+
+ return b"".join(
+ [
+ b"\x89PNG\r\n\x1a\n", # PNG signature
+ # IHDR chunk: Image Header
+ struct.pack(">I", 13), # length
+ IHDRdata,
+ struct.pack(">I", zlib.crc32(IHDRdata) & 0xFFFFFFFF), # CRC
+ # IDAT chunk: Payload
+ struct.pack(">I", len(compressedData)),
+ IDATdata,
+ struct.pack(">I", zlib.crc32(IDATdata) & 0xFFFFFFFF), # CRC
+ b"\x00\x00\x00\x00IEND\xaeB`\x82", # IEND chunk: footer
+ ]
+ )
+
+
+def saveImageToFile(data, fileNameOrObj, fileFormat):
+ """Save a RGB image to a file.
+
+ :param data: A 3D array (h, w, 3) storing an RGB image.
+ :type data: numpy.ndarray with of unsigned bytes.
+ :param fileNameOrObj: Filename or object to use to write the image.
+ :type fileNameOrObj: A str or a 'file-like' object with a 'write' method.
+ :param str fileFormat: The type of the file in: 'png', 'ppm', 'svg', 'tiff'.
+ """
+ assert len(data.shape) == 3
+ assert data.shape[2] == 3
+ assert fileFormat in ("png", "ppm", "svg", "tif", "tiff")
+
+ if not hasattr(fileNameOrObj, "write"):
+ if fileFormat in ("png", "ppm", "tiff"):
+ # Open in binary mode
+ fileObj = open(fileNameOrObj, "wb")
+ else:
+ fileObj = open(fileNameOrObj, "w", newline="")
+ else: # Use as a file-like object
+ fileObj = fileNameOrObj
+
+ if fileFormat == "svg":
+ height, width = data.shape[:2]
+ base64Data = base64.b64encode(convertRGBDataToPNG(data))
+
+ fileObj.write('<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n')
+ fileObj.write('<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"\n')
+ fileObj.write(' "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">\n')
+ fileObj.write('<svg xmlns:xlink="http://www.w3.org/1999/xlink"\n')
+ fileObj.write(' xmlns="http://www.w3.org/2000/svg"\n')
+ fileObj.write(' version="1.1"\n')
+ fileObj.write(' width="%d"\n' % width)
+ fileObj.write(' height="%d">\n' % height)
+ fileObj.write(' <image xlink:href="data:image/png;base64,')
+ fileObj.write(base64Data.decode("ascii"))
+ fileObj.write('"\n')
+ fileObj.write(' x="0"\n')
+ fileObj.write(' y="0"\n')
+ fileObj.write(' width="%d"\n' % width)
+ fileObj.write(' height="%d"\n' % height)
+ fileObj.write(' id="image" />\n')
+ fileObj.write("</svg>")
+
+ elif fileFormat == "ppm":
+ height, width = data.shape[:2]
+
+ fileObj.write(b"P6\n")
+ fileObj.write(b"%d %d\n" % (width, height))
+ fileObj.write(b"255\n")
+ fileObj.write(data.tobytes())
+
+ elif fileFormat == "png":
+ fileObj.write(convertRGBDataToPNG(data))
+
+ elif fileFormat in ("tif", "tiff"):
+ if fileObj == fileNameOrObj:
+ raise NotImplementedError("Save TIFF to a file-like object not implemented")
+
+ tif = TiffIO(fileNameOrObj, mode="wb+")
+ tif.writeImage(data, info={"Title": "OpenGL Plot Snapshot"})
+
+ if fileObj != fileNameOrObj:
+ fileObj.close()
diff --git a/src/silx/gui/plot/backends/glutils/__init__.py b/src/silx/gui/plot/backends/glutils/__init__.py
new file mode 100644
index 0000000..bc15b78
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/__init__.py
@@ -0,0 +1,45 @@
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides convenient classes for the OpenGL rendering backend.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import logging
+
+
+_logger = logging.getLogger(__name__)
+
+
+from .GLPlotCurve import * # noqa
+from .GLPlotFrame import * # noqa
+from .GLPlotImage import * # noqa
+from .GLPlotItem import GLPlotItem, RenderContext # noqa
+from .GLPlotTriangles import GLPlotTriangles # noqa
+from .GLSupport import * # noqa
+from .GLText import * # noqa
+from .GLTexture import * # noqa
diff --git a/src/silx/gui/plot/items/__init__.py b/src/silx/gui/plot/items/__init__.py
new file mode 100644
index 0000000..bbb4220
--- /dev/null
+++ b/src/silx/gui/plot/items/__init__.py
@@ -0,0 +1,80 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides classes that describes :class:`.PlotWidget` content.
+
+Instances of those classes are returned by :class:`.PlotWidget` methods that give
+access to its content such as :meth:`.PlotWidget.getCurve`, :meth:`.PlotWidget.getImage`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/06/2017"
+
+from .core import (
+ Item,
+ DataItem, # noqa
+ LabelsMixIn,
+ DraggableMixIn,
+ ColormapMixIn,
+ LineGapColorMixIn, # noqa
+ SymbolMixIn,
+ ColorMixIn,
+ YAxisMixIn,
+ FillMixIn, # noqa
+ AlphaMixIn,
+ LineMixIn,
+ ScatterVisualizationMixIn, # noqa
+ ComplexMixIn,
+ ItemChangedType,
+ PointsBase,
+) # noqa
+from .complex import ImageComplexData # noqa
+from .curve import Curve, CurveStyle # noqa
+from .histogram import Histogram # noqa
+from .image import (
+ ImageBase,
+ ImageData,
+ ImageDataBase,
+ ImageRgba,
+ ImageStack,
+ MaskImageData,
+) # noqa
+from .image_aggregated import ImageDataAggregated # noqa
+from .shape import Line, Shape, BoundingRect, XAxisExtent, YAxisExtent # noqa
+from .scatter import Scatter # noqa
+from .marker import MarkerBase, Marker, XMarker, YMarker # noqa
+from .axis import Axis, XAxis, YAxis, YRightAxis
+
+DATA_ITEMS = (
+ ImageComplexData,
+ Curve,
+ Histogram,
+ ImageBase,
+ Scatter,
+ BoundingRect,
+ XAxisExtent,
+ YAxisExtent,
+)
+"""Classes of items representing data and to consider to compute data bounds.
+"""
diff --git a/src/silx/gui/plot/items/_arc_roi.py b/src/silx/gui/plot/items/_arc_roi.py
new file mode 100644
index 0000000..658573a
--- /dev/null
+++ b/src/silx/gui/plot/items/_arc_roi.py
@@ -0,0 +1,992 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides Arc ROI item for the :class:`~silx.gui.plot.PlotWidget`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+import logging
+import numpy
+import enum
+from typing import Tuple
+
+from ... import utils
+from .. import items
+from ...colors import rgba
+from ....utils.proxy import docstring
+from ._roi_base import HandleBasedROI
+from ._roi_base import InteractionModeMixIn
+from ._roi_base import RoiInteractionMode
+
+
+logger = logging.getLogger(__name__)
+
+
+class _ArcGeometry:
+ """
+ Non-mutable object to store the geometry of the arc ROI.
+
+ The aim is is to switch between consistent state without dealing with
+ intermediate values.
+ """
+
+ def __init__(
+ self,
+ center,
+ startPoint,
+ endPoint,
+ radius,
+ weight,
+ startAngle,
+ endAngle,
+ closed=False,
+ ):
+ """Constructor for a consistent arc geometry.
+
+ There is also specific class method to create different kind of arc
+ geometry.
+ """
+ self.center = center
+ self.startPoint = startPoint
+ self.endPoint = endPoint
+ self.radius = radius
+ self.weight = weight
+ self.startAngle = startAngle
+ self.endAngle = endAngle
+ self._closed = closed
+
+ @classmethod
+ def createEmpty(cls):
+ """Create an arc geometry from an empty shape"""
+ zero = numpy.array([0, 0])
+ return cls(zero, zero.copy(), zero.copy(), 0, 0, 0, 0)
+
+ @classmethod
+ def createRect(cls, startPoint, endPoint, weight):
+ """Create an arc geometry from a definition of a rectangle"""
+ return cls(None, startPoint, endPoint, None, weight, None, None, False)
+
+ @classmethod
+ def createCircle(
+ cls, center, startPoint, endPoint, radius, weight, startAngle, endAngle
+ ):
+ """Create an arc geometry from a definition of a circle"""
+ return cls(
+ center, startPoint, endPoint, radius, weight, startAngle, endAngle, True
+ )
+
+ def withWeight(self, weight):
+ """Return a new geometry based on this object, with a specific weight"""
+ return _ArcGeometry(
+ self.center,
+ self.startPoint,
+ self.endPoint,
+ self.radius,
+ weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed,
+ )
+
+ def withRadius(self, radius):
+ """Return a new geometry based on this object, with a specific radius.
+
+ The weight and the center is conserved.
+ """
+ startPoint = (
+ self.center + (self.startPoint - self.center) / self.radius * radius
+ )
+ endPoint = self.center + (self.endPoint - self.center) / self.radius * radius
+ return _ArcGeometry(
+ self.center,
+ startPoint,
+ endPoint,
+ radius,
+ self.weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed,
+ )
+
+ def withStartAngle(self, startAngle):
+ """Return a new geometry based on this object, with a specific start angle"""
+ vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)])
+ startPoint = self.center + vector * self.radius
+
+ # Never add more than 180 to maintain coherency
+ deltaAngle = startAngle - self.startAngle
+ if deltaAngle > numpy.pi:
+ deltaAngle -= numpy.pi * 2
+ elif deltaAngle < -numpy.pi:
+ deltaAngle += numpy.pi * 2
+
+ startAngle = self.startAngle + deltaAngle
+ return _ArcGeometry(
+ self.center,
+ startPoint,
+ self.endPoint,
+ self.radius,
+ self.weight,
+ startAngle,
+ self.endAngle,
+ self._closed,
+ )
+
+ def withEndAngle(self, endAngle):
+ """Return a new geometry based on this object, with a specific end angle"""
+ vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)])
+ endPoint = self.center + vector * self.radius
+
+ # Never add more than 180 to maintain coherency
+ deltaAngle = endAngle - self.endAngle
+ if deltaAngle > numpy.pi:
+ deltaAngle -= numpy.pi * 2
+ elif deltaAngle < -numpy.pi:
+ deltaAngle += numpy.pi * 2
+
+ endAngle = self.endAngle + deltaAngle
+ return _ArcGeometry(
+ self.center,
+ self.startPoint,
+ endPoint,
+ self.radius,
+ self.weight,
+ self.startAngle,
+ endAngle,
+ self._closed,
+ )
+
+ def translated(self, dx, dy):
+ """Return the translated geometry by dx, dy"""
+ delta = numpy.array([dx, dy])
+ center = None if self.center is None else self.center + delta
+ startPoint = None if self.startPoint is None else self.startPoint + delta
+ endPoint = None if self.endPoint is None else self.endPoint + delta
+ return _ArcGeometry(
+ center,
+ startPoint,
+ endPoint,
+ self.radius,
+ self.weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed,
+ )
+
+ def getKind(self):
+ """Returns the kind of shape defined"""
+ if self.center is None:
+ return "rect"
+ elif numpy.isnan(self.startAngle):
+ return "point"
+ elif self.isClosed():
+ if self.weight <= 0 or self.weight * 0.5 >= self.radius:
+ return "circle"
+ else:
+ return "donut"
+ else:
+ if self.weight * 0.5 < self.radius:
+ return "arc"
+ else:
+ return "camembert"
+
+ def isClosed(self):
+ """Returns True if the geometry is a circle like"""
+ if self._closed is not None:
+ return self._closed
+ delta = numpy.abs(self.endAngle - self.startAngle)
+ self._closed = numpy.isclose(delta, numpy.pi * 2)
+ return self._closed
+
+ def __str__(self):
+ return str(
+ (
+ self.center,
+ self.startPoint,
+ self.endPoint,
+ self.radius,
+ self.weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed,
+ )
+ )
+
+
+class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
+ """A ROI identifying an arc of a circle with a width.
+
+ This ROI provides
+ - 3 handle to control the curvature
+ - 1 handle to control the weight
+ - 1 anchor to translate the shape.
+ """
+
+ ICON = "add-shape-arc"
+ NAME = "arc ROI"
+ SHORT_NAME = "arc"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ ThreePointMode = RoiInteractionMode(
+ "3 points", "Provides 3 points to define the main radius circle"
+ )
+ PolarMode = RoiInteractionMode(
+ "Polar", "Provides anchors to edit the ROI in polar coords"
+ )
+ # FIXME: MoveMode was designed cause there is too much anchors
+ # FIXME: It would be good replace it by a dnd on the shape
+ MoveMode = RoiInteractionMode(
+ "Translation", "Provides anchors to only move the ROI"
+ )
+
+ class Role(enum.Enum):
+ """Identify a set of roles which can be used for now to reach some positions"""
+
+ START = 0
+ """Location of the anchor at the start of the arc"""
+ STOP = 1
+ """Location of the anchor at the stop of the arc"""
+ MIDDLE = 2
+ """Location of the anchor at the middle of the arc"""
+ CENTER = 3
+ """Location of the center of the circle"""
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ InteractionModeMixIn.__init__(self)
+
+ self._geometry = _ArcGeometry.createEmpty()
+ self._handleLabel = self.addLabelHandle()
+
+ self._handleStart = self.addHandle()
+ self._handleMid = self.addHandle()
+ self._handleEnd = self.addHandle()
+ self._handleWeight = self.addHandle()
+ self._handleWeight._setConstraint(self._arcCurvatureMarkerConstraint)
+ self._handleMove = self.addTranslateHandle()
+
+ shape = items.Shape("polygon")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setColor(rgba(self.getColor()))
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ self.__shape = shape
+ self.addItem(shape)
+
+ self._initInteractionMode(self.ThreePointMode)
+ self._interactiveModeUpdated(self.ThreePointMode)
+
+ def availableInteractionModes(self):
+ """Returns the list of available interaction modes
+
+ :rtype: List[RoiInteractionMode]
+ """
+ return [self.ThreePointMode, self.PolarMode, self.MoveMode]
+
+ def _interactiveModeUpdated(self, modeId):
+ """Set the interaction mode.
+
+ :param RoiInteractionMode modeId:
+ """
+ if modeId is self.ThreePointMode:
+ self._handleStart.setVisible(True)
+ self._handleEnd.setVisible(True)
+ self._handleWeight.setVisible(True)
+ self._handleStart.setSymbol("s")
+ self._handleMid.setSymbol("s")
+ self._handleEnd.setSymbol("s")
+ self._handleWeight.setSymbol("d")
+ self._handleMove.setSymbol("+")
+ elif modeId is self.PolarMode:
+ self._handleStart.setVisible(True)
+ self._handleEnd.setVisible(True)
+ self._handleWeight.setVisible(True)
+ self._handleStart.setSymbol("o")
+ self._handleMid.setSymbol("o")
+ self._handleEnd.setSymbol("o")
+ self._handleWeight.setSymbol("d")
+ self._handleMove.setSymbol("+")
+ elif modeId is self.MoveMode:
+ self._handleStart.setVisible(False)
+ self._handleEnd.setVisible(False)
+ self._handleWeight.setVisible(False)
+ self._handleMid.setSymbol("+")
+ self._handleMove.setSymbol("+")
+ else:
+ assert False
+ if self._geometry.isClosed():
+ if modeId != self.MoveMode:
+ self._handleStart.setSymbol("x")
+ self._handleEnd.setSymbol("x")
+ self._updateHandles()
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ self._updateItemProperty(event, self, self.__shape)
+ super(ArcROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(ArcROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ """Initialize the ROI using the points from the first interaction.
+
+ This interaction is constrained by the plot API and only supports few
+ shapes.
+ """
+ # The first shape is a line
+ point0 = points[0]
+ point1 = points[1]
+
+ # Compute a non collinear point for the curvature
+ center = (point1 + point0) * 0.5
+ normal = point1 - center
+ normal = numpy.array((normal[1], -normal[0]))
+ defaultCurvature = numpy.pi / 5.0
+ weightCoef = 0.20
+ mid = center - normal * defaultCurvature
+ distance = numpy.linalg.norm(point0 - point1)
+ weight = distance * weightCoef
+
+ geometry = self._createGeometryFromControlPoints(point0, mid, point1, weight)
+ self._geometry = geometry
+ self._updateHandles()
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def _updateMidHandle(self):
+ """Keep the same geometry, but update the location of the control
+ points.
+
+ So calling this function do not trigger sigRegionChanged.
+ """
+ geometry = self._geometry
+
+ if geometry.isClosed():
+ start = numpy.array(self._handleStart.getPosition())
+ midPos = geometry.center + geometry.center - start
+ else:
+ if geometry.center is None:
+ midPos = geometry.startPoint * 0.5 + geometry.endPoint * 0.5
+ else:
+ midAngle = geometry.startAngle * 0.5 + geometry.endAngle * 0.5
+ vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
+ midPos = geometry.center + geometry.radius * vector
+
+ with utils.blockSignals(self._handleMid):
+ self._handleMid.setPosition(*midPos)
+
+ def _updateWeightHandle(self):
+ geometry = self._geometry
+ if geometry.center is None:
+ # rectangle
+ center = (geometry.startPoint + geometry.endPoint) * 0.5
+ normal = geometry.endPoint - geometry.startPoint
+ normal = numpy.array((normal[1], -normal[0]))
+ distance = numpy.linalg.norm(normal)
+ if distance != 0:
+ normal = normal / distance
+ weightPos = center + normal * geometry.weight * 0.5
+ else:
+ if geometry.isClosed():
+ midAngle = geometry.startAngle + numpy.pi * 0.5
+ elif geometry.center is not None:
+ midAngle = (geometry.startAngle + geometry.endAngle) * 0.5
+ vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
+ weightPos = (
+ geometry.center + (geometry.radius + geometry.weight * 0.5) * vector
+ )
+
+ with utils.blockSignals(self._handleWeight):
+ self._handleWeight.setPosition(*weightPos)
+
+ def _getWeightFromHandle(self, weightPos):
+ geometry = self._geometry
+ if geometry.center is None:
+ # rectangle
+ center = (geometry.startPoint + geometry.endPoint) * 0.5
+ return numpy.linalg.norm(center - weightPos) * 2
+ else:
+ distance = numpy.linalg.norm(geometry.center - weightPos)
+ return abs(distance - geometry.radius) * 2
+
+ def _updateHandles(self):
+ geometry = self._geometry
+ with utils.blockSignals(self._handleStart):
+ self._handleStart.setPosition(*geometry.startPoint)
+ with utils.blockSignals(self._handleEnd):
+ self._handleEnd.setPosition(*geometry.endPoint)
+
+ self._updateMidHandle()
+ self._updateWeightHandle()
+ self._updateShape()
+
+ def _updateCurvature(
+ self, start, mid, end, updateCurveHandles, checkClosed=False, updateStart=False
+ ):
+ """Update the curvature using 3 control points in the curve
+
+ :param bool updateCurveHandles: If False curve handles are already at
+ the right location
+ """
+ if checkClosed:
+ closed = self._isCloseInPixel(start, end)
+ else:
+ closed = self._geometry.isClosed()
+ if closed:
+ if updateStart:
+ start = end
+ else:
+ end = start
+
+ if updateCurveHandles:
+ with utils.blockSignals(self._handleStart):
+ self._handleStart.setPosition(*start)
+ with utils.blockSignals(self._handleMid):
+ self._handleMid.setPosition(*mid)
+ with utils.blockSignals(self._handleEnd):
+ self._handleEnd.setPosition(*end)
+
+ weight = self._geometry.weight
+ geometry = self._createGeometryFromControlPoints(
+ start, mid, end, weight, closed=closed
+ )
+ self._geometry = geometry
+
+ self._updateWeightHandle()
+ self._updateShape()
+
+ def _updateCloseInAngle(self, geometry, updateStart):
+ azim = numpy.abs(geometry.endAngle - geometry.startAngle)
+ if numpy.pi < azim < 3 * numpy.pi:
+ closed = self._isCloseInPixel(geometry.startPoint, geometry.endPoint)
+ geometry._closed = closed
+ if closed:
+ sign = 1 if geometry.startAngle < geometry.endAngle else -1
+ if updateStart:
+ geometry.startPoint = geometry.endPoint
+ geometry.startAngle = geometry.endAngle - sign * 2 * numpy.pi
+ else:
+ geometry.endPoint = geometry.startPoint
+ geometry.endAngle = geometry.startAngle + sign * 2 * numpy.pi
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ modeId = self.getInteractionMode()
+ if handle is self._handleStart:
+ if modeId is self.ThreePointMode:
+ mid = numpy.array(self._handleMid.getPosition())
+ end = numpy.array(self._handleEnd.getPosition())
+ self._updateCurvature(
+ current,
+ mid,
+ end,
+ checkClosed=True,
+ updateStart=True,
+ updateCurveHandles=False,
+ )
+ elif modeId is self.PolarMode:
+ v = current - self._geometry.center
+ startAngle = numpy.angle(complex(v[0], v[1]))
+ geometry = self._geometry.withStartAngle(startAngle)
+ self._updateCloseInAngle(geometry, updateStart=True)
+ self._geometry = geometry
+ self._updateHandles()
+ elif handle is self._handleMid:
+ if modeId is self.ThreePointMode:
+ if self._geometry.isClosed():
+ radius = numpy.linalg.norm(self._geometry.center - current)
+ self._geometry = self._geometry.withRadius(radius)
+ self._updateHandles()
+ else:
+ start = numpy.array(self._handleStart.getPosition())
+ end = numpy.array(self._handleEnd.getPosition())
+ self._updateCurvature(start, current, end, updateCurveHandles=False)
+ elif modeId is self.PolarMode:
+ radius = numpy.linalg.norm(self._geometry.center - current)
+ self._geometry = self._geometry.withRadius(radius)
+ self._updateHandles()
+ elif modeId is self.MoveMode:
+ delta = current - previous
+ self.translate(*delta)
+ elif handle is self._handleEnd:
+ if modeId is self.ThreePointMode:
+ start = numpy.array(self._handleStart.getPosition())
+ mid = numpy.array(self._handleMid.getPosition())
+ self._updateCurvature(
+ start,
+ mid,
+ current,
+ checkClosed=True,
+ updateStart=False,
+ updateCurveHandles=False,
+ )
+ elif modeId is self.PolarMode:
+ v = current - self._geometry.center
+ endAngle = numpy.angle(complex(v[0], v[1]))
+ geometry = self._geometry.withEndAngle(endAngle)
+ self._updateCloseInAngle(geometry, updateStart=False)
+ self._geometry = geometry
+ self._updateHandles()
+ elif handle is self._handleWeight:
+ weight = self._getWeightFromHandle(current)
+ self._geometry = self._geometry.withWeight(weight)
+ self._updateShape()
+ elif handle is self._handleMove:
+ delta = current - previous
+ self.translate(*delta)
+
+ def _isCloseInPixel(self, point1, point2):
+ manager = self.parent()
+ if manager is None:
+ return False
+ plot = manager.parent()
+ if plot is None:
+ return False
+ point1 = plot.dataToPixel(*point1)
+ if point1 is None:
+ return False
+ point2 = plot.dataToPixel(*point2)
+ if point2 is None:
+ return False
+ return abs(point1[0] - point2[0]) + abs(point1[1] - point2[1]) < 15
+
+ def _normalizeGeometry(self):
+ """Keep the same phisical geometry, but with normalized parameters."""
+ geometry = self._geometry
+ if geometry.weight * 0.5 >= geometry.radius:
+ radius = (geometry.weight * 0.5 + geometry.radius) * 0.5
+ geometry = geometry.withRadius(radius)
+ geometry = geometry.withWeight(radius * 2)
+ self._geometry = geometry
+ return True
+ return False
+
+ def handleDragFinished(self, handle, origin, current):
+ modeId = self.getInteractionMode()
+ if handle in [self._handleStart, self._handleMid, self._handleEnd]:
+ if modeId is self.ThreePointMode:
+ self._normalizeGeometry()
+ self._updateHandles()
+
+ if self._geometry.isClosed():
+ if modeId is self.MoveMode:
+ self._handleStart.setSymbol("")
+ self._handleEnd.setSymbol("")
+ else:
+ self._handleStart.setSymbol("x")
+ self._handleEnd.setSymbol("x")
+ else:
+ if modeId is self.ThreePointMode:
+ self._handleStart.setSymbol("s")
+ self._handleEnd.setSymbol("s")
+ elif modeId is self.PolarMode:
+ self._handleStart.setSymbol("o")
+ self._handleEnd.setSymbol("o")
+ if modeId is self.MoveMode:
+ self._handleStart.setSymbol("")
+ self._handleEnd.setSymbol("")
+
+ def _createGeometryFromControlPoints(self, start, mid, end, weight, closed=None):
+ """Returns the geometry of the object"""
+ if closed or (closed is None and numpy.allclose(start, end)):
+ # Special arc: It's a closed circle
+ center = (start + mid) * 0.5
+ radius = numpy.linalg.norm(start - center)
+ v = start - center
+ startAngle = numpy.angle(complex(v[0], v[1]))
+ endAngle = startAngle + numpy.pi * 2.0
+ return _ArcGeometry.createCircle(
+ center, start, end, radius, weight, startAngle, endAngle
+ )
+
+ elif numpy.linalg.norm(numpy.cross(mid - start, end - start)) < 1e-5:
+ # Degenerated arc, it's a rectangle
+ return _ArcGeometry.createRect(start, end, weight)
+ else:
+ center, radius = self._circleEquation(start, mid, end)
+ v = start - center
+ startAngle = numpy.angle(complex(v[0], v[1]))
+ v = mid - center
+ midAngle = numpy.angle(complex(v[0], v[1]))
+ v = end - center
+ endAngle = numpy.angle(complex(v[0], v[1]))
+
+ # Is it clockwise or anticlockwise
+ relativeMid = (endAngle - midAngle + 2 * numpy.pi) % (2 * numpy.pi)
+ relativeEnd = (endAngle - startAngle + 2 * numpy.pi) % (2 * numpy.pi)
+ if relativeMid < relativeEnd:
+ if endAngle < startAngle:
+ endAngle += 2 * numpy.pi
+ else:
+ if endAngle > startAngle:
+ endAngle -= 2 * numpy.pi
+
+ return _ArcGeometry(
+ center, start, end, radius, weight, startAngle, endAngle
+ )
+
+ def _createShapeFromGeometry(self, geometry):
+ kind = geometry.getKind()
+ if kind == "rect":
+ # It is not an arc
+ # but we can display it as an intermediate shape
+ normal = geometry.endPoint - geometry.startPoint
+ normal = numpy.array((normal[1], -normal[0]))
+ distance = numpy.linalg.norm(normal)
+ if distance != 0:
+ normal /= distance
+ points = numpy.array(
+ [
+ geometry.startPoint + normal * geometry.weight * 0.5,
+ geometry.endPoint + normal * geometry.weight * 0.5,
+ geometry.endPoint - normal * geometry.weight * 0.5,
+ geometry.startPoint - normal * geometry.weight * 0.5,
+ ]
+ )
+ elif kind == "point":
+ # It is not an arc
+ # but we can display it as an intermediate shape
+ # NOTE: At least 2 points are expected
+ points = numpy.array([geometry.startPoint, geometry.startPoint])
+ elif kind == "circle":
+ outerRadius = geometry.radius + geometry.weight * 0.5
+ angles = numpy.linspace(0, 2 * numpy.pi, num=50)
+ # It's a circle
+ points = []
+ numpy.append(angles, angles[-1])
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.append(geometry.center + direction * outerRadius)
+ points = numpy.array(points)
+ elif kind == "donut":
+ innerRadius = geometry.radius - geometry.weight * 0.5
+ outerRadius = geometry.radius + geometry.weight * 0.5
+ angles = numpy.linspace(0, 2 * numpy.pi, num=50)
+ # It's a donut
+ points = []
+ # NOTE: NaN value allow to create 2 separated circle shapes
+ # using a single plot item. It's a kind of cheat
+ points.append(numpy.array([float("nan"), float("nan")]))
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.insert(0, geometry.center + direction * innerRadius)
+ points.append(geometry.center + direction * outerRadius)
+ points.append(numpy.array([float("nan"), float("nan")]))
+ points = numpy.array(points)
+ else:
+ innerRadius = geometry.radius - geometry.weight * 0.5
+ outerRadius = geometry.radius + geometry.weight * 0.5
+
+ sign = numpy.sign(geometry.endAngle - geometry.startAngle)
+ delta = min(0.1, abs(geometry.startAngle - geometry.endAngle) / 100) * sign
+
+ if geometry.startAngle == geometry.endAngle:
+ # Degenerated, it's a line (single radius)
+ angle = geometry.startAngle
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points = []
+ points.append(geometry.center + direction * innerRadius)
+ points.append(geometry.center + direction * outerRadius)
+ return numpy.array(points)
+
+ angles = numpy.arange(geometry.startAngle, geometry.endAngle, delta)
+ if angles[-1] != geometry.endAngle:
+ angles = numpy.append(angles, geometry.endAngle)
+
+ if kind == "camembert":
+ # It's a part of camembert
+ points = []
+ points.append(geometry.center)
+ points.append(geometry.startPoint)
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.append(geometry.center + direction * outerRadius)
+ points.append(geometry.endPoint)
+ points.append(geometry.center)
+ elif kind == "arc":
+ # It's a part of donut
+ points = []
+ points.append(geometry.startPoint)
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.insert(0, geometry.center + direction * innerRadius)
+ points.append(geometry.center + direction * outerRadius)
+ points.insert(0, geometry.endPoint)
+ points.append(geometry.endPoint)
+ else:
+ assert False
+
+ points = numpy.array(points)
+
+ return points
+
+ def _updateShape(self):
+ geometry = self._geometry
+ points = self._createShapeFromGeometry(geometry)
+ self.__shape.setPoints(points)
+
+ index = numpy.nanargmin(points[:, 1])
+ pos = points[index]
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(pos[0], pos[1])
+
+ if geometry.center is None:
+ movePos = geometry.startPoint * 0.34 + geometry.endPoint * 0.66
+ else:
+ movePos = geometry.center
+
+ with utils.blockSignals(self._handleMove):
+ self._handleMove.setPosition(*movePos)
+
+ self.sigRegionChanged.emit()
+
+ def getGeometry(self):
+ """Returns a tuple containing the geometry of this ROI
+
+ It is a symmetric function of :meth:`setGeometry`.
+
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+
+ :rtype: Tuple[numpy.ndarray,float,float,float,float]
+ :raise ValueError: In case the ROI can't be represented as section of
+ a circle
+ """
+ geometry = self._geometry
+ if geometry.center is None:
+ raise ValueError("This ROI can't be represented as a section of circle")
+ return (
+ geometry.center,
+ self.getInnerRadius(),
+ self.getOuterRadius(),
+ geometry.startAngle,
+ geometry.endAngle,
+ )
+
+ def getPosition(self, role: Role = Role.CENTER) -> Tuple[float, float]:
+ """Returns a position by it's role.
+
+ By default returns the center of the circle of the arc ROI.
+ """
+ if role == self.Role.START:
+ return self._handleStart.getPosition()
+ if role == self.Role.STOP:
+ return self._handleEnd.getPosition()
+ if role == self.Role.MIDDLE:
+ return self._handleMid.getPosition()
+ if role == self.Role.CENTER:
+ p = self.getCenter()
+ return p[0], p[1]
+ raise ValueError(f"{role} is not supported")
+
+ def isClosed(self):
+ """Returns true if the arc is a closed shape, like a circle or a donut.
+
+ :rtype: bool
+ """
+ return self._geometry.isClosed()
+
+ def getCenter(self):
+ """Returns the center of the circle used to draw arcs of this ROI.
+
+ This center is usually outside the the shape itself.
+
+ :rtype: numpy.ndarray
+ """
+ return self._geometry.center
+
+ def getStartAngle(self):
+ """Returns the angle of the start of the section of this ROI (in radian).
+
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+
+ :rtype: float
+ """
+ return self._geometry.startAngle
+
+ def getEndAngle(self):
+ """Returns the angle of the end of the section of this ROI (in radian).
+
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+
+ :rtype: float
+ """
+ return self._geometry.endAngle
+
+ def getInnerRadius(self):
+ """Returns the radius of the smaller arc used to draw this ROI.
+
+ :rtype: float
+ """
+ geometry = self._geometry
+ radius = geometry.radius - geometry.weight * 0.5
+ if radius < 0:
+ radius = 0
+ return radius
+
+ def getOuterRadius(self):
+ """Returns the radius of the bigger arc used to draw this ROI.
+
+ :rtype: float
+ """
+ geometry = self._geometry
+ radius = geometry.radius + geometry.weight * 0.5
+ return radius
+
+ def setGeometry(self, center, innerRadius, outerRadius, startAngle, endAngle):
+ """
+ Set the geometry of this arc.
+
+ :param numpy.ndarray center: Center of the circle.
+ :param float innerRadius: Radius of the smaller arc of the section.
+ :param float outerRadius: Weight of the bigger arc of the section.
+ It have to be bigger than `innerRadius`
+ :param float startAngle: Location of the start of the section (in radian)
+ :param float endAngle: Location of the end of the section (in radian).
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+ """
+ if innerRadius > outerRadius:
+ logger.error("inner radius larger than outer radius")
+ innerRadius, outerRadius = outerRadius, innerRadius
+ center = numpy.array(center)
+ radius = (innerRadius + outerRadius) * 0.5
+ weight = outerRadius - innerRadius
+
+ vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)])
+ startPoint = center + vector * radius
+ vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)])
+ endPoint = center + vector * radius
+
+ geometry = _ArcGeometry(
+ center,
+ startPoint,
+ endPoint,
+ radius,
+ weight,
+ startAngle,
+ endAngle,
+ closed=None,
+ )
+ self._geometry = geometry
+ self._updateHandles()
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ # first check distance, fastest
+ center = self.getCenter()
+ distance = numpy.sqrt(
+ (position[1] - center[1]) ** 2 + ((position[0] - center[0])) ** 2
+ )
+ is_in_distance = self.getInnerRadius() <= distance <= self.getOuterRadius()
+ if not is_in_distance:
+ return False
+ rel_pos = position[1] - center[1], position[0] - center[0]
+ angle = numpy.arctan2(*rel_pos)
+ # angle is inside [-pi, pi]
+
+ # Normalize the start angle between [-pi, pi]
+ # with a positive angle range
+ start_angle = self.getStartAngle()
+ end_angle = self.getEndAngle()
+ azim_range = end_angle - start_angle
+ if azim_range < 0:
+ start_angle = end_angle
+ azim_range = -azim_range
+ start_angle = numpy.mod(start_angle + numpy.pi, 2 * numpy.pi) - numpy.pi
+
+ if angle < start_angle:
+ angle += 2 * numpy.pi
+ return start_angle <= angle <= start_angle + azim_range
+
+ def translate(self, x, y):
+ self._geometry = self._geometry.translated(x, y)
+ self._updateHandles()
+
+ def _arcCurvatureMarkerConstraint(self, x, y):
+ """Curvature marker remains on perpendicular bisector"""
+ geometry = self._geometry
+ if geometry.center is None:
+ center = (geometry.startPoint + geometry.endPoint) * 0.5
+ vector = geometry.startPoint - geometry.endPoint
+ vector = numpy.array((vector[1], -vector[0]))
+ vdist = numpy.linalg.norm(vector)
+ if vdist != 0:
+ normal = numpy.array((vector[1], -vector[0])) / vdist
+ else:
+ normal = numpy.array((0, 0))
+ else:
+ if geometry.isClosed():
+ midAngle = geometry.startAngle + numpy.pi * 0.5
+ else:
+ midAngle = (geometry.startAngle + geometry.endAngle) * 0.5
+ normal = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
+ center = geometry.center
+ dist = numpy.dot(normal, (numpy.array((x, y)) - center))
+ dist = numpy.clip(dist, geometry.radius, geometry.radius * 2)
+ x, y = center + dist * normal
+ return x, y
+
+ @staticmethod
+ def _circleEquation(pt1, pt2, pt3):
+ """Circle equation from 3 (x, y) points
+
+ :return: Position of the center of the circle and the radius
+ :rtype: Tuple[Tuple[float,float],float]
+ """
+ x, y, z = complex(*pt1), complex(*pt2), complex(*pt3)
+ w = z - x
+ w /= y - x
+ c = (x - y) * (w - abs(w) ** 2) / 2j / w.imag - x
+ return numpy.array((-c.real, -c.imag)), abs(c + x)
+
+ def __str__(self):
+ try:
+ center, innerRadius, outerRadius, startAngle, endAngle = self.getGeometry()
+ params = (
+ center[0],
+ center[1],
+ innerRadius,
+ outerRadius,
+ startAngle,
+ endAngle,
+ )
+ params = "center: %f %f; radius: %f %f; angles: %f %f" % params
+ except ValueError:
+ params = "invalid"
+ return "%s(%s)" % (self.__class__.__name__, params)
diff --git a/src/silx/gui/plot/items/_band_roi.py b/src/silx/gui/plot/items/_band_roi.py
new file mode 100644
index 0000000..0d2ad4e
--- /dev/null
+++ b/src/silx/gui/plot/items/_band_roi.py
@@ -0,0 +1,376 @@
+# /*##########################################################################
+#
+# Copyright (c) 2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Rectangular ROI that can be rotated"""
+
+import functools
+import logging
+from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple
+import numpy
+
+from ... import qt, utils
+from .. import items
+from ...colors import rgba
+from silx.image.shapes import Polygon
+from ....utils.proxy import docstring
+from ._roi_base import _RegionOfInterestBase
+from ._roi_base import HandleBasedROI
+from ._roi_base import InteractionModeMixIn
+from ._roi_base import RoiInteractionMode
+
+
+logger = logging.getLogger(__name__)
+
+
+class Point(NamedTuple):
+ x: float
+ y: float
+
+
+class BandGeometry(NamedTuple):
+ begin: Point
+ end: Point
+ width: float
+
+ @staticmethod
+ def create(
+ begin: Sequence[float] = (0.0, 0.0),
+ end: Sequence[float] = (0.0, 0.0),
+ width: Optional[float] = None,
+ ):
+ begin = Point(float(begin[0]), float(begin[1]))
+ end = Point(float(end[0]), float(end[1]))
+ if width is None:
+ width = 0.1 * numpy.linalg.norm(numpy.array(end) - begin)
+ return BandGeometry(begin, end, max(0.0, float(width)))
+
+ @property
+ @functools.lru_cache()
+ def normal(self) -> Point:
+ vector = numpy.array(self.end) - self.begin
+ length = numpy.linalg.norm(vector)
+ if length == 0:
+ return Point(0.0, 0.0)
+ return Point(-vector[1] / length, vector[0] / length)
+
+ @property
+ @functools.lru_cache()
+ def center(self) -> Point:
+ return Point(*(0.5 * (numpy.array(self.begin) + self.end)))
+
+ @property
+ @functools.lru_cache()
+ def corners(self) -> Tuple[Point, Point, Point, Point]:
+ """Returns a 4-uple of (x,y) position in float"""
+ offset = 0.5 * self.width * numpy.array(self.normal)
+ return tuple(
+ map(
+ lambda p: Point(*p),
+ (
+ self.begin - offset,
+ self.begin + offset,
+ self.end + offset,
+ self.end - offset,
+ ),
+ )
+ )
+
+ @property
+ @functools.lru_cache()
+ def slope(self) -> float:
+ """Slope of the line (begin, end), infinity for a vertical line"""
+ if self.begin.x == self.end.x:
+ return float("inf")
+ return (self.end.y - self.begin.y) / (self.end.x - self.begin.x)
+
+ @property
+ @functools.lru_cache()
+ def intercept(self) -> float:
+ """Intercept of the line (begin, end) or value of x for vertical line"""
+ if self.begin.x == self.end.x:
+ return self.begin.x
+ return self.begin.y - self.slope * self.begin.x
+
+ @property
+ @functools.lru_cache()
+ def edgesIntercept(self) -> Tuple[float, float]:
+ """Intercepts of lines describing band edges"""
+ offset = 0.5 * self.width * numpy.array(self.normal)
+ if self.begin.x == self.end.x:
+ return self.begin.x - offset[0], self.begin.x + offset[0]
+ return (
+ self.begin.y - offset[1] - self.slope * (self.begin.x - offset[0]),
+ self.begin.y + offset[1] - self.slope * (self.begin.x + offset[0]),
+ )
+
+ def contains(self, position: Sequence[float]) -> bool:
+ return Polygon(self.corners).is_inside(*position)
+
+
+class BandROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
+ """A ROI identifying a line in a 2D plot.
+
+ This ROI provides 1 anchor for each boundary of the line, plus an center
+ in the center to translate the full ROI.
+ """
+
+ ICON = "add-shape-rotated-rectangle"
+ NAME = "band ROI"
+ SHORT_NAME = "band"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ BoundedMode = RoiInteractionMode("Bounded", "Band is bounded on both sides")
+ """Interaction mode for a rectangular band ROI"""
+
+ UnboundedMode = RoiInteractionMode("Unbounded", "Band is unbounded on both sides")
+ """Interaction mode for unlimited band ROI """
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self.__availableInteractionModes = set((self.BoundedMode, self.UnboundedMode))
+ InteractionModeMixIn.__init__(self)
+
+ self.__handleBegin = self.addHandle()
+ self.__handleEnd = self.addHandle()
+ self.__handleCenter = self.addTranslateHandle()
+ self.__handleLabel = self.addLabelHandle()
+ self.__handleWidthUp = self.addHandle()
+ self.__handleWidthUp._setConstraint(self.__handleWidthUpConstraint)
+ self.__handleWidthUp.setSymbol("d")
+ self.__handleWidthDown = self.addHandle()
+ self.__handleWidthDown._setConstraint(self.__handleWidthDownConstraint)
+ self.__handleWidthDown.setSymbol("d")
+
+ self.__geometry = BandGeometry.create()
+
+ self.__lineUp = items.Line()
+ self.__lineUp.setVisible(False)
+ self.__lineMiddle = items.Line()
+ self.__lineMiddle.setLineWidth(1)
+ self.__lineMiddle.setVisible(False)
+ self.__lineDown = items.Line()
+ self.__lineDown.setVisible(False)
+
+ self.__shape = items.Shape("polygon")
+ self.__shape.setPoints(self.__geometry.corners)
+ self.__shape.setFill(False)
+
+ for item in (self.__lineUp, self.__lineMiddle, self.__lineDown, self.__shape):
+ item.setColor(rgba(self.getColor()))
+ item.setOverlay(True)
+ item.setLineStyle(self.getLineStyle())
+ if item != self.__lineMiddle:
+ item.setLineWidth(self.getLineWidth())
+ self.addItem(item)
+
+ self._initInteractionMode(self.BoundedMode)
+ self._interactiveModeUpdated(self.BoundedMode)
+
+ def availableInteractionModes(self) -> List[RoiInteractionMode]:
+ """Returns the list of available interaction modes"""
+ return list(self.__availableInteractionModes)
+
+ def setAvailableInteractionModes(self, modes: Iterable[RoiInteractionMode]) -> None:
+ """Allows to restrict interaction modes of the ROI.
+
+ :param modes: Subset of BandROI interaction modes:
+ :attr:`BoundedMode` and :attr:`UnboundedMode`.
+ """
+ modes = set(modes)
+ if not modes <= set((self.BoundedMode, self.UnboundedMode)):
+ raise ValueError("Unsupported interaction modes")
+ self.__availableInteractionModes = set(modes)
+ if self.getInteractionMode() not in self.__availableInteractionModes:
+ self.setInteractionMode(self.availableInteractionModes()[0])
+
+ def _interactiveModeUpdated(self, modeId: RoiInteractionMode):
+ """Set the interaction mode."""
+ if modeId is self.BoundedMode:
+ self.__lineDown.setVisible(False)
+ self.__lineMiddle.setVisible(False)
+ self.__lineUp.setVisible(False)
+ self.__shape.setVisible(True)
+ elif modeId is self.UnboundedMode:
+ self.__lineDown.setVisible(True)
+ self.__lineMiddle.setVisible(True)
+ self.__lineUp.setVisible(True)
+ self.__shape.setVisible(False)
+ else:
+ raise RuntimeError("Unsupported interactive mode")
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ if self.isVisible():
+ self._interactiveModeUpdated(self.getInteractionMode())
+ else:
+ self.__lineDown.setVisible(False)
+ self.__lineMiddle.setVisible(False)
+ self.__lineUp.setVisible(False)
+ self.__shape.setVisible(False)
+ super()._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super()._updatedStyle(event, style)
+ for item in (self.__lineUp, self.__lineMiddle, self.__lineDown, self.__shape):
+ item.setColor(style.getColor())
+ item.setLineStyle(style.getLineStyle())
+ if item != self.__lineMiddle:
+ item.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ assert len(points) == 2
+ self.setGeometry(*points)
+
+ def _updateText(self, text):
+ self.__handleLabel.setText(text)
+
+ def getGeometry(self) -> BandGeometry:
+ """Returns the geometric description of the ROI"""
+ return self.__geometry
+
+ def setGeometry(
+ self,
+ begin: Sequence[float],
+ end: Sequence[float],
+ width: Optional[float] = None,
+ ):
+ """Set the geometry of the ROI
+
+ :param begin: Starting point as (x, y)
+ :paran end: Closing point as (x, y)
+ :param width: Width of the ROI
+ """
+ geometry = BandGeometry.create(begin, end, width)
+ if self.__geometry == geometry:
+ return
+
+ self.__geometry = geometry
+
+ with utils.blockSignals(self.__handleBegin):
+ self.__handleBegin.setPosition(*geometry.begin)
+ with utils.blockSignals(self.__handleEnd):
+ self.__handleEnd.setPosition(*geometry.end)
+ with utils.blockSignals(self.__handleCenter):
+ self.__handleCenter.setPosition(*geometry.center)
+ with utils.blockSignals(self.__handleLabel):
+ lowerCorner = geometry.corners[numpy.array(geometry.corners)[:, 1].argmin()]
+ self.__handleLabel.setPosition(*lowerCorner)
+
+ delta = 0.5 * geometry.width * numpy.array(geometry.normal)
+ with utils.blockSignals(self.__handleWidthUp):
+ self.__handleWidthUp.setPosition(*(geometry.center + delta))
+ with utils.blockSignals(self.__handleWidthDown):
+ self.__handleWidthDown.setPosition(*(geometry.center - delta))
+
+ self.__lineDown.setSlope(geometry.slope)
+ self.__lineDown.setIntercept(geometry.edgesIntercept[0])
+ self.__lineMiddle.setSlope(geometry.slope)
+ self.__lineMiddle.setIntercept(geometry.intercept)
+ self.__lineUp.setSlope(geometry.slope)
+ self.__lineUp.setIntercept(geometry.edgesIntercept[1])
+ self.__shape.setPoints(geometry.corners)
+ self.sigRegionChanged.emit()
+
+ def __updateGeometry(
+ self,
+ begin: Optional[Sequence[float]] = None,
+ end: Optional[Sequence[float]] = None,
+ width: Optional[float] = None,
+ ):
+ geometry = self.getGeometry()
+ self.setGeometry(
+ geometry.begin if begin is None else begin,
+ geometry.end if end is None else end,
+ geometry.width if width is None else width,
+ )
+
+ @staticmethod
+ def __snap(
+ point: Tuple[float, float], fixed: Tuple[float, float]
+ ) -> Tuple[float, float]:
+ """Snap point so that vector [point, fixed] snap to direction 0, 45 or 90 degrees
+
+ :return: the snapped point position.
+ """
+ vector = point[0] - fixed[0], point[1] - fixed[1]
+ angle = numpy.arctan2(vector[1], vector[0])
+ snapAngle = numpy.pi / 4 * numpy.round(angle / (numpy.pi / 4))
+ length = numpy.linalg.norm(vector)
+ return (
+ fixed[0] + length * numpy.cos(snapAngle),
+ fixed[1] + length * numpy.sin(snapAngle),
+ )
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ geometry = self.getGeometry()
+ if handle is self.__handleBegin:
+ if qt.QApplication.keyboardModifiers() & qt.Qt.ShiftModifier:
+ self.__updateGeometry(begin=self.__snap(current, geometry.end))
+ return
+ self.__updateGeometry(begin=current)
+ return
+ if handle is self.__handleEnd:
+ if qt.QApplication.keyboardModifiers() & qt.Qt.ShiftModifier:
+ self.__updateGeometry(end=self.__snap(current, geometry.begin))
+ return
+ self.__updateGeometry(end=current)
+ return
+ if handle is self.__handleCenter:
+ delta = current - previous
+ self.__updateGeometry(geometry.begin + delta, geometry.end + delta)
+ return
+ if handle in (self.__handleWidthUp, self.__handleWidthDown):
+ offset = numpy.dot(geometry.normal, current - previous)
+ if handle is self.__handleWidthDown:
+ offset *= -1
+ self.__updateGeometry(
+ geometry.begin,
+ geometry.end,
+ geometry.width + 2 * offset,
+ )
+
+ def __handleWidthUpConstraint(self, x: float, y: float) -> Tuple[float, float]:
+ geometry = self.getGeometry()
+ offset = max(
+ 0, numpy.dot(geometry.normal, numpy.array((x, y)) - geometry.center)
+ )
+ return tuple(geometry.center + offset * numpy.array(geometry.normal))
+
+ def __handleWidthDownConstraint(self, x: float, y: float) -> Tuple[float, float]:
+ geometry = self.getGeometry()
+ offset = max(
+ 0, -numpy.dot(geometry.normal, numpy.array((x, y)) - geometry.center)
+ )
+ return tuple(geometry.center - offset * numpy.array(geometry.normal))
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ return self.getGeometry().contains(position)
+
+ def __str__(self):
+ begin, end, width = self.getGeometry()
+ return f"{self.__class__.__name__}(begin=({begin[0]:g}, {begin[1]:g}), end=({end[0]:g}, {end[1]:g}), width={width:g})"
diff --git a/src/silx/gui/plot/items/_pick.py b/src/silx/gui/plot/items/_pick.py
new file mode 100644
index 0000000..631a30a
--- /dev/null
+++ b/src/silx/gui/plot/items/_pick.py
@@ -0,0 +1,71 @@
+# /*##########################################################################
+#
+# Copyright (c) 2019-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides classes supporting item picking."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "04/06/2019"
+
+import numpy
+
+
+class PickingResult(object):
+ """Class to access picking information in a :class:`PlotWidget`"""
+
+ def __init__(self, item, indices=None):
+ """Init
+
+ :param item: The picked item
+ :param numpy.ndarray indices: Array-like of indices of picked data.
+ Either 1D or 2D with dim0: data dimension and dim1: indices.
+ No copy is made.
+ """
+ self._item = item
+
+ if indices is None or len(indices) == 0:
+ self._indices = None
+ else:
+ # Indices is set to None if indices array is empty
+ indices = numpy.array(indices, copy=False, dtype=numpy.int64)
+ self._indices = None if indices.size == 0 else indices
+
+ def getItem(self):
+ """Returns the item this results corresponds to."""
+ return self._item
+
+ def getIndices(self, copy=True):
+ """Returns indices of picked data.
+
+ If data is 1D, it returns a numpy.ndarray, otherwise
+ it returns a tuple with as many numpy.ndarray as there are
+ dimensions in the data.
+
+ :param bool copy: True (default) to get a copy,
+ False to return internal arrays
+ :rtype: Union[None,numpy.ndarray,List[numpy.ndarray]]
+ """
+ if self._indices is None:
+ return None
+ indices = numpy.array(self._indices, copy=copy)
+ return indices if indices.ndim == 1 else tuple(indices)
diff --git a/src/silx/gui/plot/items/_roi_base.py b/src/silx/gui/plot/items/_roi_base.py
new file mode 100644
index 0000000..43c5381
--- /dev/null
+++ b/src/silx/gui/plot/items/_roi_base.py
@@ -0,0 +1,874 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides base components to create ROI item for
+the :class:`~silx.gui.plot.PlotWidget`.
+
+.. inheritance-diagram::
+ silx.gui.plot.items.roi
+ :parts: 1
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import logging
+import numpy
+import weakref
+import functools
+from typing import Optional
+
+from ....utils.weakref import WeakList
+from ... import qt
+from .. import items
+from ..items import core
+from ...colors import rgba
+
+
+logger = logging.getLogger(__name__)
+
+
+class _RegionOfInterestBase(qt.QObject):
+ """Base class of 1D and 2D region of interest
+
+ :param QObject parent: See QObject
+ :param str name: The name of the ROI
+ """
+
+ sigAboutToBeRemoved = qt.Signal()
+ """Signal emitted just before this ROI is removed from its manager."""
+
+ sigItemChanged = qt.Signal(object)
+ """Signal emitted when item has changed.
+
+ It provides a flag describing which property of the item has changed.
+ See :class:`ItemChangedType` for flags description.
+ """
+
+ def __init__(self, parent=None):
+ qt.QObject.__init__(self)
+ if parent is not None:
+ self.setParent(parent)
+ self.__name = ""
+
+ def getName(self):
+ """Returns the name of the ROI
+
+ :return: name of the region of interest
+ :rtype: str
+ """
+ return self.__name
+
+ def setName(self, name):
+ """Set the name of the ROI
+
+ :param str name: name of the region of interest
+ """
+ name = str(name)
+ if self.__name != name:
+ self.__name = name
+ self._updated(items.ItemChangedType.NAME)
+
+ def _updated(self, event=None, checkVisibility=True):
+ """Implement Item mix-in update method by updating the plot items
+
+ See :class:`~silx.gui.plot.items.Item._updated`
+ """
+ self.sigItemChanged.emit(event)
+
+ def contains(self, position):
+ """Returns True if the `position` is in this ROI.
+
+ :param tuple[float,float] position: position to check
+ :return: True if the value / point is consider to be in the region of
+ interest.
+ :rtype: bool
+ """
+ return False # Override in subclass to perform actual test
+
+
+class RoiInteractionMode(object):
+ """Description of an interaction mode.
+
+ An interaction mode provide a specific kind of interaction for a ROI.
+ A ROI can implement many interaction.
+ """
+
+ def __init__(self, label, description=None):
+ self._label = label
+ self._description = description
+
+ @property
+ def label(self):
+ """Short name"""
+ return self._label
+
+ @property
+ def description(self):
+ """Longer description of the interaction mode"""
+ return self._description
+
+
+class InteractionModeMixIn(object):
+ """Mix in feature which can be implemented by a ROI object.
+
+ This provides user interaction to switch between different
+ interaction mode to edit the ROI.
+
+ This ROI modes have to be described using `RoiInteractionMode`,
+ and taken into account during interation with handles.
+ """
+
+ sigInteractionModeChanged = qt.Signal(object)
+
+ def __init__(self):
+ self.__modeId = None
+
+ def _initInteractionMode(self, modeId):
+ """Set the mode without updating anything.
+
+ Must be one of the returned :meth:`availableInteractionModes`.
+
+ :param RoiInteractionMode modeId: Mode to use
+ """
+ self.__modeId = modeId
+
+ def availableInteractionModes(self):
+ """Returns the list of available interaction modes
+
+ Must be implemented when inherited to provide all available modes.
+
+ :rtype: List[RoiInteractionMode]
+ """
+ raise NotImplementedError()
+
+ def setInteractionMode(self, modeId):
+ """Set the interaction mode.
+
+ :param RoiInteractionMode modeId: Mode to use
+ """
+ self.__modeId = modeId
+ self._interactiveModeUpdated(modeId)
+ self.sigInteractionModeChanged.emit(modeId)
+
+ def _interactiveModeUpdated(self, modeId):
+ """Called directly after an update of the mode.
+
+ The signal `sigInteractionModeChanged` is triggered after this
+ call.
+
+ Must be implemented when inherited to take care of the change.
+ """
+ raise NotImplementedError()
+
+ def getInteractionMode(self):
+ """Returns the interaction mode.
+
+ Must be one of the returned :meth:`availableInteractionModes`.
+
+ :rtype: RoiInteractionMode
+ """
+ return self.__modeId
+
+ def createMenuForInteractionMode(self, parent: qt.QWidget) -> qt.QMenu:
+ """Create a menu providing access to the different interaction modes"""
+ availableModes = self.availableInteractionModes()
+ currentMode = self.getInteractionMode()
+ submenu = qt.QMenu(parent)
+ modeGroup = qt.QActionGroup(parent)
+ modeGroup.setExclusive(True)
+ for mode in availableModes:
+ action = qt.QAction(parent)
+ action.setText(mode.label)
+ action.setToolTip(mode.description)
+ action.setCheckable(True)
+ if mode is currentMode:
+ action.setChecked(True)
+ else:
+ callback = functools.partial(self.setInteractionMode, mode)
+ action.triggered.connect(callback)
+ modeGroup.addAction(action)
+ submenu.addAction(action)
+ submenu.setTitle("Interaction mode")
+ return submenu
+
+
+class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
+ """Object describing a region of interest in a plot.
+
+ :param QObject parent:
+ The RegionOfInterestManager that created this object
+ """
+
+ _DEFAULT_LINEWIDTH = 1.0
+ """Default line width of the curve"""
+
+ _DEFAULT_LINESTYLE = "-"
+ """Default line style of the curve"""
+
+ _DEFAULT_HIGHLIGHT_STYLE = items.CurveStyle(linewidth=2)
+ """Default highlight style of the item"""
+
+ ICON, NAME, SHORT_NAME = None, None, None
+ """Metadata to describe the ROI in labels, tooltips and widgets
+
+ Should be set by inherited classes to custom the ROI manager widget.
+ """
+
+ sigRegionChanged = qt.Signal()
+ """Signal emitted everytime the shape or position of the ROI changes"""
+
+ sigEditingStarted = qt.Signal()
+ """Signal emitted when the user start editing the roi"""
+
+ sigEditingFinished = qt.Signal()
+ """Signal emitted when the region edition is finished. During edition
+ sigEditionChanged will be emitted several times and
+ sigRegionEditionFinished only at end"""
+
+ def __init__(self, parent=None):
+ # Avoid circular dependency
+ from ..tools import roi as roi_tools
+
+ assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager)
+ # Must be done before _RegionOfInterestBase.__init__
+ self._child = WeakList()
+ _RegionOfInterestBase.__init__(self, parent)
+ core.HighlightedMixIn.__init__(self)
+ self.__text = None
+ self._color = rgba("red")
+ self._editable = False
+ self._selectable = False
+ self._focusProxy = None
+ self._visible = True
+
+ def _connectToPlot(self, plot):
+ """Called after connection to a plot"""
+ for item in self.getItems():
+ # This hack is needed to avoid reentrant call from _disconnectFromPlot
+ # to the ROI manager. It also speed up the item tests in _itemRemoved
+ item._roiGroup = True
+ plot.addItem(item)
+
+ def _disconnectFromPlot(self, plot):
+ """Called before disconnection from a plot"""
+ for item in self.getItems():
+ # The item could be already be removed by the plot
+ if item.getPlot() is not None:
+ del item._roiGroup
+ plot.removeItem(item)
+
+ def _setItemName(self, item):
+ """Helper to generate a unique id to a plot item"""
+ legend = "__ROI-%d__%d" % (id(self), id(item))
+ item.setName(legend)
+
+ def setParent(self, parent):
+ """Set the parent of the RegionOfInterest
+
+ :param Union[None,RegionOfInterestManager] parent: The new parent
+ """
+ # Avoid circular dependency
+ from ..tools import roi as roi_tools
+
+ if parent is not None and not isinstance(
+ parent, roi_tools.RegionOfInterestManager
+ ):
+ raise ValueError("Unsupported parent")
+
+ previousParent = self.parent()
+ if previousParent is not None:
+ previousPlot = previousParent.parent()
+ if previousPlot is not None:
+ self._disconnectFromPlot(previousPlot)
+ super(RegionOfInterest, self).setParent(parent)
+ if parent is not None:
+ plot = parent.parent()
+ if plot is not None:
+ self._connectToPlot(plot)
+
+ def addItem(self, item):
+ """Add an item to the set of this ROI children.
+
+ This item will be added and removed to the plot used by the ROI.
+
+ If the ROI is already part of a plot, the item will also be added to
+ the plot.
+
+ It the item do not have a name already, a unique one is generated to
+ avoid item collision in the plot.
+
+ :param silx.gui.plot.items.Item item: A plot item
+ """
+ assert item is not None
+ self._child.append(item)
+ if item.getName() == "":
+ self._setItemName(item)
+ manager = self.parent()
+ if manager is not None:
+ plot = manager.parent()
+ if plot is not None:
+ item._roiGroup = True
+ plot.addItem(item)
+
+ def removeItem(self, item):
+ """Remove an item from this ROI children.
+
+ If the item is part of a plot it will be removed too.
+
+ :param silx.gui.plot.items.Item item: A plot item
+ """
+ assert item is not None
+ self._child.remove(item)
+ plot = item.getPlot()
+ if plot is not None:
+ del item._roiGroup
+ plot.removeItem(item)
+
+ def getItems(self):
+ """Returns the list of PlotWidget items of this RegionOfInterest.
+
+ :rtype: List[~silx.gui.plot.items.Item]
+ """
+ return tuple(self._child)
+
+ @classmethod
+ def _getShortName(cls):
+ """Return an human readable kind of ROI
+
+ :rtype: str
+ """
+ if hasattr(cls, "SHORT_NAME"):
+ name = cls.SHORT_NAME
+ if name is None:
+ name = cls.__name__
+ return name
+
+ def getColor(self):
+ """Returns the color of this ROI
+
+ :rtype: QColor
+ """
+ return qt.QColor.fromRgbF(*self._color)
+
+ def setColor(self, color):
+ """Set the color used for this ROI.
+
+ :param color: The color to use for ROI shape as
+ either a color name, a QColor, a list of uint8 or float in [0, 1].
+ """
+ color = rgba(color)
+ if color != self._color:
+ self._color = color
+ self._updated(items.ItemChangedType.COLOR)
+
+ def isEditable(self):
+ """Returns whether the ROI is editable by the user or not.
+
+ :rtype: bool
+ """
+ return self._editable
+
+ def setEditable(self, editable):
+ """Set whether the ROI can be changed interactively.
+
+ :param bool editable: True to allow edition by the user,
+ False to disable.
+ """
+ editable = bool(editable)
+ if self._editable != editable:
+ self._editable = editable
+ self._updated(items.ItemChangedType.EDITABLE)
+
+ def isSelectable(self):
+ """Returns whether the ROI is selectable by the user or not.
+
+ :rtype: bool
+ """
+ return self._selectable
+
+ def setSelectable(self, selectable):
+ """Set whether the ROI can be selected interactively.
+
+ :param bool selectable: True to allow selection by the user,
+ False to disable.
+ """
+ selectable = bool(selectable)
+ if self._selectable != selectable:
+ self._selectable = selectable
+ self._updated(items.ItemChangedType.SELECTABLE)
+
+ def getFocusProxy(self):
+ """Returns the ROI which have to be selected when this ROI is selected,
+ else None if no proxy specified.
+
+ :rtype: RegionOfInterest
+ """
+ proxy = self._focusProxy
+ if proxy is None:
+ return None
+ proxy = proxy()
+ if proxy is None:
+ self._focusProxy = None
+ return proxy
+
+ def setFocusProxy(self, roi):
+ """Set the real ROI which will be selected when this ROI is selected,
+ else None to remove the proxy already specified.
+
+ :param RegionOfInterest roi: A ROI
+ """
+ if roi is not None:
+ self._focusProxy = weakref.ref(roi)
+ else:
+ self._focusProxy = None
+
+ def isVisible(self):
+ """Returns whether the ROI is visible in the plot.
+
+ .. note::
+ This does not take into account whether or not the plot
+ widget itself is visible (unlike :meth:`QWidget.isVisible` which
+ checks the visibility of all its parent widgets up to the window)
+
+ :rtype: bool
+ """
+ return self._visible
+
+ def setVisible(self, visible):
+ """Set whether the plot items associated with this ROI are
+ visible in the plot.
+
+ :param bool visible: True to show the ROI in the plot, False to
+ hide it.
+ """
+ visible = bool(visible)
+ if self._visible != visible:
+ self._visible = visible
+ self._updated(items.ItemChangedType.VISIBLE)
+
+ def getText(self) -> str:
+ """Returns the currently displayed text for this ROI"""
+ return self.getName() if self.__text is None else self.__text
+
+ def setText(self, text: Optional[str] = None) -> None:
+ """Set the displayed text for this ROI.
+
+ If None (the default), the ROI name is used.
+ """
+ if self.__text != text:
+ self.__text = text
+ self._updated(items.ItemChangedType.TEXT)
+
+ def _updateText(self, text: str) -> None:
+ """Update the text displayed by this ROI
+
+ Override in subclass to custom text display
+ """
+ pass
+
+ @classmethod
+ def showFirstInteractionShape(cls):
+ """Returns True if the shape created by the first interaction and
+ managed by the plot have to be visible.
+
+ :rtype: bool
+ """
+ return False
+
+ @classmethod
+ def getFirstInteractionShape(cls):
+ """Returns the shape kind which will be used by the very first
+ interaction with the plot.
+
+ This interactions are hardcoded inside the plot
+
+ :rtype: str
+ """
+ return cls._plotShape
+
+ def setFirstShapePoints(self, points):
+ """Initialize the ROI using the points from the first interaction.
+
+ This interaction is constrained by the plot API and only supports few
+ shapes.
+ """
+ raise NotImplementedError()
+
+ def creationStarted(self):
+ """Called when the ROI creation interaction was started."""
+ pass
+
+ def creationFinalized(self):
+ """Called when the ROI creation interaction was finalized."""
+ pass
+
+ def _updateItemProperty(self, event, source, destination):
+ """Update the item property of a destination from an item source.
+
+ :param items.ItemChangedType event: Property type to update
+ :param silx.gui.plot.items.Item source: The reference for the data
+ :param event Union[Item,List[Item]] destination: The item(s) to update
+ """
+ if not isinstance(destination, (list, tuple)):
+ destination = [destination]
+ if event == items.ItemChangedType.NAME:
+ value = source.getName()
+ for d in destination:
+ d.setName(value)
+ elif event == items.ItemChangedType.EDITABLE:
+ value = source.isEditable()
+ for d in destination:
+ d.setEditable(value)
+ elif event == items.ItemChangedType.SELECTABLE:
+ value = source.isSelectable()
+ for d in destination:
+ d._setSelectable(value)
+ elif event == items.ItemChangedType.COLOR:
+ value = rgba(source.getColor())
+ for d in destination:
+ d.setColor(value)
+ elif event == items.ItemChangedType.LINE_STYLE:
+ value = self.getLineStyle()
+ for d in destination:
+ d.setLineStyle(value)
+ elif event == items.ItemChangedType.LINE_WIDTH:
+ value = self.getLineWidth()
+ for d in destination:
+ d.setLineWidth(value)
+ elif event == items.ItemChangedType.SYMBOL:
+ value = self.getSymbol()
+ for d in destination:
+ d.setSymbol(value)
+ elif event == items.ItemChangedType.SYMBOL_SIZE:
+ value = self.getSymbolSize()
+ for d in destination:
+ d.setSymbolSize(value)
+ elif event == items.ItemChangedType.VISIBLE:
+ value = self.isVisible()
+ for d in destination:
+ d.setVisible(value)
+ else:
+ assert False
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.TEXT:
+ self._updateText(self.getText())
+ elif event == items.ItemChangedType.HIGHLIGHTED:
+ for item in self.getItems():
+ zoffset = 1000 if self.isHighlighted() else 0
+ item.setZValue(item._DEFAULT_Z_LAYER + zoffset)
+
+ style = self.getCurrentStyle()
+ self._updatedStyle(event, style)
+ else:
+ styleEvents = [
+ items.ItemChangedType.COLOR,
+ items.ItemChangedType.LINE_STYLE,
+ items.ItemChangedType.LINE_WIDTH,
+ items.ItemChangedType.SYMBOL,
+ items.ItemChangedType.SYMBOL_SIZE,
+ ]
+ if self.isHighlighted():
+ styleEvents.append(items.ItemChangedType.HIGHLIGHTED_STYLE)
+
+ if event in styleEvents:
+ style = self.getCurrentStyle()
+ self._updatedStyle(event, style)
+
+ super(RegionOfInterest, self)._updated(event, checkVisibility)
+
+ # Displayed text has changed, send a text event
+ if event == items.ItemChangedType.NAME and self.__text is None:
+ self._updated(items.ItemChangedType.TEXT, checkVisibility)
+
+ def _updatedStyle(self, event, style: items.CurveStyle):
+ """Called when the current displayed style of the ROI was changed.
+
+ :param event: The event responsible of the change of the style
+ :param items.CurveStyle style: The current style
+ """
+ pass
+
+ def getCurrentStyle(self) -> items.CurveStyle:
+ """Returns the current curve style.
+
+ Curve style depends on curve highlighting
+
+ :rtype: CurveStyle
+ """
+ baseColor = rgba(self.getColor())
+ if isinstance(self, core.LineMixIn):
+ baseLinestyle = self.getLineStyle()
+ baseLinewidth = self.getLineWidth()
+ else:
+ baseLinestyle = self._DEFAULT_LINESTYLE
+ baseLinewidth = self._DEFAULT_LINEWIDTH
+ if isinstance(self, core.SymbolMixIn):
+ baseSymbol = self.getSymbol()
+ baseSymbolsize = self.getSymbolSize()
+ else:
+ baseSymbol = "o"
+ baseSymbolsize = 1
+
+ if self.isHighlighted():
+ style = self.getHighlightedStyle()
+ color = style.getColor()
+ linestyle = style.getLineStyle()
+ linewidth = style.getLineWidth()
+ symbol = style.getSymbol()
+ symbolsize = style.getSymbolSize()
+
+ return items.CurveStyle(
+ color=baseColor if color is None else color,
+ linestyle=baseLinestyle if linestyle is None else linestyle,
+ linewidth=baseLinewidth if linewidth is None else linewidth,
+ symbol=baseSymbol if symbol is None else symbol,
+ symbolsize=baseSymbolsize if symbolsize is None else symbolsize,
+ )
+ else:
+ return items.CurveStyle(
+ color=baseColor,
+ linestyle=baseLinestyle,
+ linewidth=baseLinewidth,
+ symbol=baseSymbol,
+ symbolsize=baseSymbolsize,
+ )
+
+ def _editingStarted(self):
+ assert self._editable is True
+ self.sigEditingStarted.emit()
+
+ def _editingFinished(self):
+ self.sigEditingFinished.emit()
+
+ def populateContextMenu(self, menu: qt.QMenu):
+ """Populate a menu used as a context menu"""
+ pass
+
+
+class HandleBasedROI(RegionOfInterest):
+ """Manage a ROI based on a set of handles"""
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ self._handles = []
+ self._posOrigin = None
+ self._posPrevious = None
+
+ def addUserHandle(self, item=None):
+ """
+ Add a new free handle to the ROI.
+
+ This handle do nothing. It have to be managed by the ROI
+ implementing this class.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ return self.addHandle(item, role="user")
+
+ def addLabelHandle(self, item=None):
+ """
+ Add a new label handle to the ROI.
+
+ This handle is not draggable nor selectable.
+
+ It is displayed without symbol, but it is always visible anyway
+ the ROI is editable, in order to display text.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ return self.addHandle(item, role="label")
+
+ def addTranslateHandle(self, item=None):
+ """
+ Add a new translate handle to the ROI.
+
+ Dragging translate handles affect the position position of the ROI
+ but not the shape itself.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ return self.addHandle(item, role="translate")
+
+ def addHandle(self, item=None, role="default"):
+ """
+ Add a new handle to the ROI.
+
+ Dragging handles while affect the position or the shape of the
+ ROI.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ if item is None:
+ item = items.Marker()
+ color = rgba(self.getColor())
+ color = self._computeHandleColor(color)
+ item.setColor(color)
+ if role == "default":
+ item.setSymbol("s")
+ elif role == "user":
+ pass
+ elif role == "translate":
+ item.setSymbol("+")
+ elif role == "label":
+ item.setSymbol("")
+
+ if role == "user":
+ pass
+ elif role == "label":
+ item._setSelectable(False)
+ item._setDraggable(False)
+ item.setVisible(True)
+ else:
+ self.__updateEditable(item, self.isEditable(), remove=False)
+ item._setSelectable(False)
+
+ self._handles.append((item, role))
+ self.addItem(item)
+ return item
+
+ def removeHandle(self, handle):
+ data = [d for d in self._handles if d[0] is handle][0]
+ self._handles.remove(data)
+ role = data[1]
+ if role not in ["user", "label"]:
+ if self.isEditable():
+ self.__updateEditable(handle, False)
+ self.removeItem(handle)
+
+ def getHandles(self):
+ """Returns the list of handles of this HandleBasedROI.
+
+ :rtype: List[~silx.gui.plot.items.Marker]
+ """
+ return tuple(data[0] for data in self._handles)
+
+ def _updated(self, event=None, checkVisibility=True):
+ """Implement Item mix-in update method by updating the plot items
+
+ See :class:`~silx.gui.plot.items.Item._updated`
+ """
+ if event == items.ItemChangedType.VISIBLE:
+ for item, role in self._handles:
+ visible = self.isVisible()
+ editionVisible = visible and self.isEditable()
+ if role not in ["user", "label"]:
+ item.setVisible(editionVisible)
+ else:
+ item.setVisible(visible)
+ elif event == items.ItemChangedType.EDITABLE:
+ for item, role in self._handles:
+ editable = self.isEditable()
+ if role not in ["user", "label"]:
+ self.__updateEditable(item, editable)
+ super(HandleBasedROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(HandleBasedROI, self)._updatedStyle(event, style)
+
+ # Update color of shape items in the plot
+ color = rgba(self.getColor())
+ handleColor = self._computeHandleColor(color)
+ for item, role in self._handles:
+ if role == "user":
+ pass
+ elif role == "label":
+ item.setColor(color)
+ else:
+ item.setColor(handleColor)
+
+ def __updateEditable(self, handle, editable, remove=True):
+ # NOTE: visibility change emit a position update event
+ handle.setVisible(editable and self.isVisible())
+ handle._setDraggable(editable)
+ if editable:
+ handle.sigDragStarted.connect(self._handleEditingStarted)
+ handle.sigItemChanged.connect(self._handleEditingUpdated)
+ handle.sigDragFinished.connect(self._handleEditingFinished)
+ else:
+ if remove:
+ handle.sigDragStarted.disconnect(self._handleEditingStarted)
+ handle.sigItemChanged.disconnect(self._handleEditingUpdated)
+ handle.sigDragFinished.disconnect(self._handleEditingFinished)
+
+ def _handleEditingStarted(self):
+ super(HandleBasedROI, self)._editingStarted()
+ handle = self.sender()
+ self._posOrigin = numpy.array(handle.getPosition())
+ self._posPrevious = numpy.array(self._posOrigin)
+ self.handleDragStarted(handle, self._posOrigin)
+
+ def _handleEditingUpdated(self):
+ if self._posOrigin is None:
+ # Avoid to handle events when visibility change
+ return
+ handle = self.sender()
+ current = numpy.array(handle.getPosition())
+ self.handleDragUpdated(handle, self._posOrigin, self._posPrevious, current)
+ self._posPrevious = current
+
+ def _handleEditingFinished(self):
+ handle = self.sender()
+ current = numpy.array(handle.getPosition())
+ self.handleDragFinished(handle, self._posOrigin, current)
+ self._posPrevious = None
+ self._posOrigin = None
+ super(HandleBasedROI, self)._editingFinished()
+
+ def isHandleBeingDragged(self):
+ """Returns True if one of the handles is currently being dragged.
+
+ :rtype: bool
+ """
+ return self._posOrigin is not None
+
+ def handleDragStarted(self, handle, origin):
+ """Called when an handler drag started"""
+ pass
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ """Called when an handle drag position changed"""
+ pass
+
+ def handleDragFinished(self, handle, origin, current):
+ """Called when an handle drag finished"""
+ pass
+
+ def _computeHandleColor(self, color):
+ """Returns the anchor color from the base ROI color
+
+ :param Union[numpy.array,Tuple,List]: color
+ :rtype: Union[numpy.array,Tuple,List]
+ """
+ return color[:3] + (0.5,)
diff --git a/src/silx/gui/plot/items/axis.py b/src/silx/gui/plot/items/axis.py
new file mode 100644
index 0000000..1ae1ef1
--- /dev/null
+++ b/src/silx/gui/plot/items/axis.py
@@ -0,0 +1,581 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the class for axes of the :class:`PlotWidget`.
+"""
+
+from __future__ import annotations
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "22/11/2018"
+
+import datetime as dt
+import enum
+from typing import Optional
+
+import dateutil.tz
+
+from ....utils.proxy import docstring
+from ... import qt
+from .. import _utils
+
+
+class TickMode(enum.Enum):
+ """Determines if ticks are regular number or datetimes."""
+
+ DEFAULT = 0 # Ticks are regular numbers
+ TIME_SERIES = 1 # Ticks are datetime objects
+
+
+class Axis(qt.QObject):
+ """This class describes and controls a plot axis.
+
+ Note: This is an abstract class.
+ """
+
+ # States are half-stored on the backend of the plot, and half-stored on this
+ # object.
+ # TODO It would be good to store all the states of an axis in this object.
+ # i.e. vmin and vmax
+
+ LINEAR = "linear"
+ """Constant defining a linear scale"""
+
+ LOGARITHMIC = "log"
+ """Constant defining a logarithmic scale"""
+
+ _SCALES = set([LINEAR, LOGARITHMIC])
+
+ sigInvertedChanged = qt.Signal(bool)
+ """Signal emitted when axis orientation has changed"""
+
+ sigScaleChanged = qt.Signal(str)
+ """Signal emitted when axis scale has changed"""
+
+ _sigLogarithmicChanged = qt.Signal(bool)
+ """Signal emitted when axis scale has changed to or from logarithmic"""
+
+ sigAutoScaleChanged = qt.Signal(bool)
+ """Signal emitted when axis autoscale has changed"""
+
+ sigLimitsChanged = qt.Signal(float, float)
+ """Signal emitted when axis limits have changed"""
+
+ def __init__(self, plot):
+ """Constructor
+
+ :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this
+ axis
+ """
+ qt.QObject.__init__(self, parent=plot)
+ self._scale = self.LINEAR
+ self._isAutoScale = True
+ # Store default labels provided to setGraph[X|Y]Label
+ self._defaultLabel = ""
+ # Store currently displayed labels
+ # Current label can differ from input one with active curve handling
+ self._currentLabel = ""
+
+ def _getPlot(self):
+ """Returns the PlotWidget this Axis belongs to.
+
+ :rtype: PlotWidget
+ """
+ plot = self.parent()
+ if plot is None:
+ raise RuntimeError("Axis no longer attached to a PlotWidget")
+ return plot
+
+ def _getBackend(self):
+ """Returns the backend
+
+ :rtype: BackendBase
+ """
+ return self._getPlot()._backend
+
+ def getLimits(self):
+ """Get the limits of this axis.
+
+ :return: Minimum and maximum values of this axis as tuple
+ """
+ return self._internalGetLimits()
+
+ def setLimits(self, vmin, vmax):
+ """Set this axis limits.
+
+ :param float vmin: minimum axis value
+ :param float vmax: maximum axis value
+ """
+ vmin, vmax = self._checkLimits(vmin, vmax)
+ if self.getLimits() == (vmin, vmax):
+ return
+
+ self._internalSetLimits(vmin, vmax)
+ self._getPlot()._setDirtyPlot()
+
+ self._emitLimitsChanged()
+
+ def _emitLimitsChanged(self):
+ """Emit axis sigLimitsChanged and PlotWidget limitsChanged event"""
+ vmin, vmax = self.getLimits()
+ self.sigLimitsChanged.emit(vmin, vmax)
+ self._getPlot()._notifyLimitsChanged(emitSignal=False)
+
+ def _checkLimits(self, vmin, vmax):
+ """Makes sure axis range is not empty and within supported range.
+
+ :param float vmin: Min axis value
+ :param float vmax: Max axis value
+ :return: (min, max) making sure min < max
+ :rtype: 2-tuple of float
+ """
+ return _utils.checkAxisLimits(
+ vmin, vmax, isLog=self._isLogarithmic(), name=self._defaultLabel
+ )
+
+ def _getDataRange(self) -> Optional[tuple[float, float]]:
+ """Returns the range of data items over this axis as (vmin, vmax)"""
+ raise NotImplementedError()
+
+ def isInverted(self):
+ """Return True if the axis is inverted (top to bottom for the y-axis),
+ False otherwise. It is always False for the X axis.
+
+ :rtype: bool
+ """
+ return False
+
+ def setInverted(self, isInverted):
+ """Set the axis orientation.
+
+ This is only available for the Y axis.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ if isInverted == self.isInverted():
+ return
+ raise NotImplementedError()
+
+ def isVisible(self) -> bool:
+ """Returns whether the axis is displayed or not"""
+ return True
+
+ def getLabel(self):
+ """Return the current displayed label of this axis.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ :rtype: str
+ """
+ return self._currentLabel
+
+ def setLabel(self, label):
+ """Set the label displayed on the plot for this axis.
+
+ The provided label can be temporarily replaced by the label of the
+ active curve if any.
+
+ :param str label: The axis label
+ """
+ self._defaultLabel = label
+ self._setCurrentLabel(label)
+ self._getPlot()._setDirtyPlot()
+
+ def _setCurrentLabel(self, label):
+ """Define the label currently displayed.
+
+ If the label is None or empty the default label is used.
+
+ :param str label: Currently displayed label
+ """
+ if label is None or label == "":
+ label = self._defaultLabel
+ if label is None:
+ label = ""
+ self._currentLabel = label
+ self._internalSetCurrentLabel(label)
+
+ def getScale(self):
+ """Return the name of the scale used by this axis.
+
+ :rtype: str
+ """
+ return self._scale
+
+ def setScale(self, scale):
+ """Set the scale to be used by this axis.
+
+ :param str scale: Name of the scale ("log", or "linear")
+ """
+ assert scale in self._SCALES
+ if self._scale == scale:
+ return
+
+ # For the backward compatibility signal
+ emitLog = self._scale == self.LOGARITHMIC or scale == self.LOGARITHMIC
+
+ self._scale = scale
+
+ vmin, vmax = self.getLimits()
+
+ # TODO hackish way of forcing update of curves and images
+ plot = self._getPlot()
+ for item in plot.getItems():
+ item._updated()
+ plot._invalidateDataRange()
+
+ if scale == self.LOGARITHMIC:
+ self._internalSetLogarithmic(True)
+ if vmin <= 0:
+ dataRange = self._getDataRange()
+ if dataRange is None:
+ self.setLimits(1.0, 100.0)
+ else:
+ if vmax > 0 and dataRange[0] < vmax:
+ self.setLimits(dataRange[0], vmax)
+ else:
+ self.setLimits(*dataRange)
+ elif scale == self.LINEAR:
+ self._internalSetLogarithmic(False)
+ else:
+ raise ValueError("Scale %s unsupported" % scale)
+
+ self.sigScaleChanged.emit(self._scale)
+ if emitLog:
+ self._sigLogarithmicChanged.emit(self._scale == self.LOGARITHMIC)
+
+ def _isLogarithmic(self):
+ """Return True if this axis scale is logarithmic, False if linear.
+
+ :rtype: bool
+ """
+ return self._scale == self.LOGARITHMIC
+
+ def _setLogarithmic(self, flag):
+ """Set the scale of this axes (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ flag = bool(flag)
+ self.setScale(self.LOGARITHMIC if flag else self.LINEAR)
+
+ def getTimeZone(self):
+ """Sets tzinfo that is used if this axis plots date times.
+
+ None means the datetimes are interpreted as local time.
+
+ :rtype: datetime.tzinfo of None.
+ """
+ raise NotImplementedError()
+
+ def setTimeZone(self, tz):
+ """Sets tzinfo that is used if this axis' tickMode is TIME_SERIES
+
+ The tz must be a descendant of the datetime.tzinfo class, "UTC" or None.
+ Use None to let the datetimes be interpreted as local time.
+ Use the string "UTC" to let the date datetimes be in UTC time.
+
+ :param tz: datetime.tzinfo, "UTC" or None.
+ """
+ raise NotImplementedError()
+
+ def getTickMode(self):
+ """Determines if axis ticks are number or datetimes.
+
+ :rtype: TickMode enum.
+ """
+ raise NotImplementedError()
+
+ def setTickMode(self, tickMode):
+ """Determines if axis ticks are number or datetimes.
+
+ :param TickMode tickMode: tick mode enum.
+ """
+ raise NotImplementedError()
+
+ def isAutoScale(self):
+ """Return True if axis is automatically adjusting its limits.
+
+ :rtype: bool
+ """
+ return self._isAutoScale
+
+ def setAutoScale(self, flag=True):
+ """Set the axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._isAutoScale = bool(flag)
+ self.sigAutoScaleChanged.emit(self._isAutoScale)
+
+ def _setLimitsConstraints(self, minPos=None, maxPos=None):
+ raise NotImplementedError()
+
+ def setLimitsConstraints(self, minPos=None, maxPos=None):
+ """
+ Set a constraint on the position of the axes.
+
+ :param float minPos: Minimum allowed axis value.
+ :param float maxPos: Maximum allowed axis value.
+ :return: True if the constaints was updated
+ :rtype: bool
+ """
+ updated = self._setLimitsConstraints(minPos, maxPos)
+ if updated:
+ plot = self._getPlot()
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ y2Min, y2Max = plot.getYAxis("right").getLimits()
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+ return updated
+
+ def _setRangeConstraints(self, minRange=None, maxRange=None):
+ raise NotImplementedError()
+
+ def setRangeConstraints(self, minRange=None, maxRange=None):
+ """
+ Set a constraint on the position of the axes.
+
+ :param float minRange: Minimum allowed left-to-right span across the
+ view
+ :param float maxRange: Maximum allowed left-to-right span across the
+ view
+ :return: True if the constaints was updated
+ :rtype: bool
+ """
+ updated = self._setRangeConstraints(minRange, maxRange)
+ if updated:
+ plot = self._getPlot()
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ y2Min, y2Max = plot.getYAxis("right").getLimits()
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+ return updated
+
+
+class XAxis(Axis):
+ """Axis class defining primitives for the X axis"""
+
+ # TODO With some changes on the backend, it will be able to remove all this
+ # specialised implementations (prefixel by '_internal')
+
+ def getTimeZone(self):
+ return self._getBackend().getXAxisTimeZone()
+
+ def setTimeZone(self, tz):
+ if isinstance(tz, str) and tz.upper() == "UTC":
+ tz = dateutil.tz.tzutc()
+ elif not (tz is None or isinstance(tz, dt.tzinfo)):
+ raise TypeError("tz must be a dt.tzinfo object, None or 'UTC'.")
+
+ self._getBackend().setXAxisTimeZone(tz)
+ self._getPlot()._setDirtyPlot()
+
+ def getTickMode(self):
+ if self._getBackend().isXAxisTimeSeries():
+ return TickMode.TIME_SERIES
+ else:
+ return TickMode.DEFAULT
+
+ def setTickMode(self, tickMode):
+ if tickMode == TickMode.DEFAULT:
+ self._getBackend().setXAxisTimeSeries(False)
+ elif tickMode == TickMode.TIME_SERIES:
+ self._getBackend().setXAxisTimeSeries(True)
+ else:
+ raise ValueError("Unexpected TickMode: {}".format(tickMode))
+
+ def _internalSetCurrentLabel(self, label):
+ self._getBackend().setGraphXLabel(label)
+
+ def _internalGetLimits(self):
+ return self._getBackend().getGraphXLimits()
+
+ def _internalSetLimits(self, xmin, xmax):
+ self._getBackend().setGraphXLimits(xmin, xmax)
+
+ def _internalSetLogarithmic(self, flag):
+ self._getBackend().setXAxisLogarithmic(flag)
+
+ def _setLimitsConstraints(self, minPos=None, maxPos=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(xMin=minPos, xMax=maxPos)
+ return updated
+
+ def _setRangeConstraints(self, minRange=None, maxRange=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(minXRange=minRange, maxXRange=maxRange)
+ return updated
+
+ @docstring(Axis)
+ def _getDataRange(self) -> Optional[tuple[float, float]]:
+ ranges = self._getPlot().getDataRange()
+ return ranges.x
+
+
+class YAxis(Axis):
+ """Axis class defining primitives for the Y axis"""
+
+ # TODO With some changes on the backend, it will be able to remove all this
+ # specialised implementations (prefixel by '_internal')
+
+ def _internalSetCurrentLabel(self, label):
+ self._getBackend().setGraphYLabel(label, axis="left")
+
+ def _internalGetLimits(self):
+ return self._getBackend().getGraphYLimits(axis="left")
+
+ def _internalSetLimits(self, ymin, ymax):
+ self._getBackend().setGraphYLimits(ymin, ymax, axis="left")
+
+ def _internalSetLogarithmic(self, flag):
+ self._getBackend().setYAxisLogarithmic(flag)
+
+ def setInverted(self, flag=True):
+ """Set the axis orientation.
+
+ This is only available for the Y axis.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ flag = bool(flag)
+ if self.isInverted() == flag:
+ return
+ self._getBackend().setYAxisInverted(flag)
+ self._getPlot()._setDirtyPlot()
+ self.sigInvertedChanged.emit(flag)
+
+ def isInverted(self):
+ """Return True if the axis is inverted (top to bottom for the y-axis),
+ False otherwise. It is always False for the X axis.
+
+ :rtype: bool
+ """
+ return self._getBackend().isYAxisInverted()
+
+ def _setLimitsConstraints(self, minPos=None, maxPos=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(yMin=minPos, yMax=maxPos)
+ return updated
+
+ def _setRangeConstraints(self, minRange=None, maxRange=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(minYRange=minRange, maxYRange=maxRange)
+ return updated
+
+ @docstring(Axis)
+ def _getDataRange(self) -> Optional[tuple[float, float]]:
+ ranges = self._getPlot().getDataRange()
+ return ranges.y
+
+
+class YRightAxis(Axis):
+ """Proxy axis for the secondary Y axes. It manages it own label and limit
+ but share the some state like scale and direction with the main axis."""
+
+ # TODO With some changes on the backend, it will be able to remove all this
+ # specialised implementations (prefixel by '_internal')
+
+ def __init__(self, plot, mainAxis):
+ """Constructor
+
+ :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this
+ axis
+ :param Axis mainAxis: Axis which sharing state with this axis
+ """
+ Axis.__init__(self, plot)
+ self.__mainAxis = mainAxis
+ self.__mainAxis.sigInvertedChanged.connect(self.sigInvertedChanged.emit)
+ self.__mainAxis.sigScaleChanged.connect(self.sigScaleChanged.emit)
+ self.__mainAxis._sigLogarithmicChanged.connect(self._sigLogarithmicChanged.emit)
+ self.__mainAxis.sigAutoScaleChanged.connect(self.sigAutoScaleChanged.emit)
+
+ def _internalSetCurrentLabel(self, label):
+ self._getBackend().setGraphYLabel(label, axis="right")
+
+ def _internalGetLimits(self):
+ return self._getBackend().getGraphYLimits(axis="right")
+
+ def _internalSetLimits(self, ymin, ymax):
+ self._getBackend().setGraphYLimits(ymin, ymax, axis="right")
+
+ def setInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ return self.__mainAxis.setInverted(flag)
+
+ def isInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise."""
+ return self.__mainAxis.isInverted()
+
+ def isVisible(self) -> bool:
+ """Returns whether the axis is displayed or not"""
+ return self._getBackend().isYRightAxisVisible()
+
+ def getScale(self):
+ """Return the name of the scale used by this axis.
+
+ :rtype: str
+ """
+ return self.__mainAxis.getScale()
+
+ def setScale(self, scale):
+ """Set the scale to be used by this axis.
+
+ :param str scale: Name of the scale ("log", or "linear")
+ """
+ self.__mainAxis.setScale(scale)
+
+ def _isLogarithmic(self):
+ """Return True if Y axis scale is logarithmic, False if linear."""
+ return self.__mainAxis._isLogarithmic()
+
+ def _setLogarithmic(self, flag):
+ """Set the Y axes scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ return self.__mainAxis._setLogarithmic(flag)
+
+ def isAutoScale(self):
+ """Return True if Y axes are automatically adjusting its limits."""
+ return self.__mainAxis.isAutoScale()
+
+ def setAutoScale(self, flag=True):
+ """Set the Y axis limits adjusting behavior of :meth:`PlotWidget.resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ return self.__mainAxis.setAutoScale(flag)
+
+ @docstring(Axis)
+ def _getDataRange(self) -> Optional[tuple[float, float]]:
+ ranges = self._getPlot().getDataRange()
+ return ranges.y2
diff --git a/src/silx/gui/plot/items/complex.py b/src/silx/gui/plot/items/complex.py
new file mode 100644
index 0000000..d10767f
--- /dev/null
+++ b/src/silx/gui/plot/items/complex.py
@@ -0,0 +1,382 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`ImageComplexData` of the :class:`Plot`.
+"""
+
+__authors__ = ["Vincent Favre-Nicolin", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "14/06/2018"
+
+
+import logging
+
+import numpy
+
+from ....utils.proxy import docstring
+from ...colors import Colormap
+from .core import ColormapMixIn, ComplexMixIn, ItemChangedType
+from .image import ImageBase
+
+
+_logger = logging.getLogger(__name__)
+
+
+# Complex colormap functions
+
+
+def _phase2rgb(colormap, data):
+ """Creates RGBA image with colour-coded phase.
+
+ :param Colormap colormap: The colormap to use
+ :param numpy.ndarray data: The data to convert
+ :return: Array of RGBA colors
+ :rtype: numpy.ndarray
+ """
+ if data.size == 0:
+ return numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+
+ phase = numpy.angle(data)
+ return colormap.applyToData(phase)
+
+
+def _complex2rgbalog(phaseColormap, data, amin=0.0, dlogs=2, smax=None):
+ """Returns RGBA colors: colour-coded phases and log10(amplitude) in alpha.
+
+ :param Colormap phaseColormap: Colormap to use for the phase
+ :param numpy.ndarray data: the complex data array to convert to RGBA
+ :param float amin: the minimum value for the alpha channel
+ :param float dlogs: amplitude range displayed, in log10 units
+ :param float smax:
+ if specified, all values above max will be displayed with an alpha=1
+ """
+ if data.size == 0:
+ return numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+
+ rgba = _phase2rgb(phaseColormap, data)
+ sabs = numpy.absolute(data)
+ if smax is not None:
+ sabs[sabs > smax] = smax
+ a = numpy.log10(sabs + 1e-20)
+ a -= a.max() - dlogs # display dlogs orders of magnitude
+ rgba[..., 3] = 255 * (amin + a / dlogs * (1 - amin) * (a > 0))
+ return rgba
+
+
+def _complex2rgbalin(phaseColormap, data, gamma=1.0, smax=None):
+ """Returns RGBA colors: colour-coded phase and linear amplitude in alpha.
+
+ :param Colormap phaseColormap: Colormap to use for the phase
+ :param numpy.ndarray data:
+ :param float gamma: Optional exponent gamma applied to the amplitude
+ :param float smax:
+ """
+ if data.size == 0:
+ return numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+
+ rgba = _phase2rgb(phaseColormap, data)
+ a = numpy.absolute(data)
+ if smax is not None:
+ a[a > smax] = smax
+ a /= a.max()
+ rgba[..., 3] = 255 * a**gamma
+ return rgba
+
+
+class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
+ """Specific plot item to force colormap when using complex colormap.
+
+ This is returning the specific colormap when displaying
+ colored phase + amplitude.
+ """
+
+ _SUPPORTED_COMPLEX_MODES = (
+ ComplexMixIn.ComplexMode.ABSOLUTE,
+ ComplexMixIn.ComplexMode.PHASE,
+ ComplexMixIn.ComplexMode.REAL,
+ ComplexMixIn.ComplexMode.IMAGINARY,
+ ComplexMixIn.ComplexMode.AMPLITUDE_PHASE,
+ ComplexMixIn.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ComplexMixIn.ComplexMode.SQUARE_AMPLITUDE,
+ )
+ """Overrides supported ComplexMode"""
+
+ def __init__(self):
+ ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.complex64))
+ ColormapMixIn.__init__(self)
+ ComplexMixIn.__init__(self)
+ self._dataByModesCache = {}
+ self._amplitudeRangeInfo = None, 2
+
+ # Use default from ColormapMixIn
+ colormap = super(ImageComplexData, self).getColormap()
+
+ phaseColormap = Colormap(name="hsv", vmin=-numpy.pi, vmax=numpy.pi)
+
+ self._colormaps = { # Default colormaps for all modes
+ self.ComplexMode.ABSOLUTE: colormap,
+ self.ComplexMode.PHASE: phaseColormap,
+ self.ComplexMode.REAL: colormap,
+ self.ComplexMode.IMAGINARY: colormap,
+ self.ComplexMode.AMPLITUDE_PHASE: phaseColormap,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE: phaseColormap,
+ self.ComplexMode.SQUARE_AMPLITUDE: colormap,
+ }
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ mode = self.getComplexMode()
+ if mode in (
+ self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ):
+ # For those modes, compute RGBA image here
+ colormap = None
+ data = self.getRgbaImageData(copy=False)
+ else:
+ colormap = self.getColormap()
+ if colormap.isAutoscale():
+ # Avoid backend to compute autoscale: use item cache
+ colormap = colormap.copy()
+ colormap.setVRange(*colormap.getColormapRange(self))
+
+ data = self.getData(copy=False)
+
+ if data.size == 0:
+ return None # No data to display
+
+ return backend.addImage(
+ data,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=colormap,
+ alpha=self.getAlpha(),
+ )
+
+ @docstring(ComplexMixIn)
+ def setComplexMode(self, mode):
+ changed = super(ImageComplexData, self).setComplexMode(mode)
+ if changed:
+ self._valueDataChanged()
+
+ # Backward compatibility
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+
+ # Update ColormapMixIn colormap
+ colormap = self._colormaps[self.getComplexMode()]
+ if colormap is not super(ImageComplexData, self).getColormap():
+ super(ImageComplexData, self).setColormap(colormap)
+
+ # Send data updated as value returned by getData has changed
+ self._updated(ItemChangedType.DATA)
+ return changed
+
+ def _setAmplitudeRangeInfo(self, max_=None, delta=2):
+ """Set the amplitude range to display for 'log10_amplitude_phase' mode.
+
+ :param max_: Max of the amplitude range.
+ If None it autoscales to data max.
+ :param float delta: Delta range in log10 to display
+ """
+ self._amplitudeRangeInfo = max_, float(delta)
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+
+ def _getAmplitudeRangeInfo(self):
+ """Returns the amplitude range to use for 'log10_amplitude_phase' mode.
+
+ :return: (max, delta), if max is None, then it autoscales to data max
+ :rtype: 2-tuple"""
+ return self._amplitudeRangeInfo
+
+ def setColormap(self, colormap, mode=None):
+ """Set the colormap for this specific mode.
+
+ :param ~silx.gui.colors.Colormap colormap: The colormap
+ :param Union[ComplexMode,str] mode:
+ If specified, set the colormap of this specific mode.
+ Default: current mode.
+ """
+ if mode is None:
+ mode = self.getComplexMode()
+ else:
+ mode = self.ComplexMode.from_value(mode)
+
+ self._colormaps[mode] = colormap
+ if mode is self.getComplexMode():
+ super(ImageComplexData, self).setColormap(colormap)
+ else:
+ self._updated(ItemChangedType.COLORMAP)
+
+ def getColormap(self, mode=None):
+ """Get the colormap for the (current) mode.
+
+ :param Union[ComplexMode,str] mode:
+ If specified, get the colormap of this specific mode.
+ Default: current mode.
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ if mode is None:
+ mode = self.getComplexMode()
+ else:
+ mode = self.ComplexMode.from_value(mode)
+
+ return self._colormaps[mode]
+
+ def setData(self, data, copy=True):
+ """Set the image complex data
+
+ :param numpy.ndarray data: 2D array of complex with 2 dimensions (h, w)
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+ if not numpy.issubdtype(data.dtype, numpy.complexfloating):
+ _logger.warning(
+ "Image is not complex, converting it to complex to plot it."
+ )
+ data = numpy.array(data, dtype=numpy.complex64)
+
+ # Compute current mode data and set colormap data
+ mode = self.getComplexMode()
+ dataForMode = self.__convertComplexData(data, self.getComplexMode())
+ self._dataByModesCache = {mode: dataForMode}
+
+ super().setData(data)
+
+ def _updated(self, event=None, checkVisibility=True):
+ # Synchronizes colormapped data if changed
+ # ItemChangedType.COMPLEX_MODE triggers ItemChangedType.DATA
+ # No need to handle it twice.
+ if event in (ItemChangedType.DATA, ItemChangedType.MASK):
+ # Color-mapped data is NOT the `getValueData` for some modes
+ if self.getComplexMode() in (
+ self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ):
+ data = self.getData(copy=False, mode=self.ComplexMode.PHASE)
+ mask = self.getMaskData(copy=False)
+ if mask is not None:
+ data = numpy.copy(data)
+ data[mask != 0] = numpy.nan
+ else:
+ data = self.getValueData(copy=False)
+ self._setColormappedData(data, copy=False)
+ super()._updated(event=event, checkVisibility=checkVisibility)
+
+ def getComplexData(self, copy=True):
+ """Returns the image complex data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray of complex
+ """
+ return super().getData(copy=copy)
+
+ def __convertComplexData(self, data, mode):
+ """Convert complex data to given mode.
+
+ :param numpy.ndarray data:
+ :param Union[ComplexMode,str] mode:
+ :rtype: numpy.ndarray of float
+ """
+ if mode is self.ComplexMode.PHASE:
+ return numpy.angle(data)
+ elif mode is self.ComplexMode.REAL:
+ return numpy.real(data)
+ elif mode is self.ComplexMode.IMAGINARY:
+ return numpy.imag(data)
+ elif mode in (
+ self.ComplexMode.ABSOLUTE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ self.ComplexMode.AMPLITUDE_PHASE,
+ ):
+ return numpy.absolute(data)
+ elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
+ return numpy.absolute(data) ** 2
+ else:
+ _logger.error(
+ "Unsupported conversion mode: %s, fallback to absolute", str(mode)
+ )
+ return numpy.absolute(data)
+
+ def getData(self, copy=True, mode=None):
+ """Returns the image data corresponding to (current) mode.
+
+ The returned data is always floats, to get the complex data, use
+ :meth:`getComplexData`.
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param Union[ComplexMode,str] mode:
+ If specified, get data corresponding to the mode.
+ Default: Current mode.
+ :rtype: numpy.ndarray of float
+ """
+ if mode is None:
+ mode = self.getComplexMode()
+ else:
+ mode = self.ComplexMode.from_value(mode)
+
+ if mode not in self._dataByModesCache:
+ self._dataByModesCache[mode] = self.__convertComplexData(
+ self.getComplexData(copy=False), mode
+ )
+
+ return numpy.array(self._dataByModesCache[mode], copy=copy)
+
+ def getRgbaImageData(self, copy=True, mode=None):
+ """Get the displayed RGB(A) image for (current) mode
+
+ :param bool copy: Ignored for this class
+ :param Union[ComplexMode,str] mode:
+ If specified, get data corresponding to the mode.
+ Default: Current mode.
+ :rtype: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ if mode is None:
+ mode = self.getComplexMode()
+ else:
+ mode = self.ComplexMode.from_value(mode)
+
+ colormap = self.getColormap(mode=mode)
+ if mode is self.ComplexMode.AMPLITUDE_PHASE:
+ data = self.getComplexData(copy=False)
+ return _complex2rgbalin(colormap, data)
+ elif mode is self.ComplexMode.LOG10_AMPLITUDE_PHASE:
+ data = self.getComplexData(copy=False)
+ max_, delta = self._getAmplitudeRangeInfo()
+ return _complex2rgbalog(colormap, data, dlogs=delta, smax=max_)
+ else:
+ data = self.getData(copy=False, mode=mode)
+ return colormap.applyToData(data)
+
+ # Backward compatibility
+
+ Mode = ComplexMixIn.ComplexMode
diff --git a/src/silx/gui/plot/items/core.py b/src/silx/gui/plot/items/core.py
new file mode 100644
index 0000000..7d754a7
--- /dev/null
+++ b/src/silx/gui/plot/items/core.py
@@ -0,0 +1,1832 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the base class for items of the :class:`Plot`.
+"""
+from __future__ import annotations
+
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+from collections import abc
+from copy import deepcopy
+import logging
+import enum
+from typing import Optional, Tuple, Union
+import weakref
+
+import numpy
+
+from ....utils.proxy import docstring
+from ....utils.enum import Enum as _Enum
+from ....math.combo import min_max
+from ... import qt
+from ... import colors
+from ...colors import Colormap, _Colormappable
+from ._pick import PickingResult
+
+from silx import config
+
+_logger = logging.getLogger(__name__)
+
+
+@enum.unique
+class ItemChangedType(enum.Enum):
+ """Type of modification provided by :attr:`Item.sigItemChanged` signal."""
+
+ # Private setters and setInfo are not emitting sigItemChanged signal.
+ # Signals to consider:
+ # COLORMAP_SET emitted when setColormap is called but not forward colormap object signal
+ # CURRENT_COLOR_CHANGED emitted current color changed because highlight changed,
+ # highlighted color changed or color changed depending on hightlight state.
+
+ VISIBLE = "visibleChanged"
+ """Item's visibility changed flag."""
+
+ ZVALUE = "zValueChanged"
+ """Item's Z value changed flag."""
+
+ COLORMAP = (
+ "colormapChanged" # Emitted when set + forward events from the colormap object
+ )
+ """Item's colormap changed flag.
+
+ This is emitted both when setting a new colormap and
+ when the current colormap object is updated.
+ """
+
+ SYMBOL = "symbolChanged"
+ """Item's symbol changed flag."""
+
+ SYMBOL_SIZE = "symbolSizeChanged"
+ """Item's symbol size changed flag."""
+
+ LINE_WIDTH = "lineWidthChanged"
+ """Item's line width changed flag."""
+
+ LINE_STYLE = "lineStyleChanged"
+ """Item's line style changed flag."""
+
+ COLOR = "colorChanged"
+ """Item's color changed flag."""
+
+ LINE_BG_COLOR = "lineBgColorChanged" # Deprecated, use LINE_GAP_COLOR
+
+ LINE_GAP_COLOR = "lineGapColorChanged"
+ """Item's dashed line gap color changed flag."""
+
+ YAXIS = "yAxisChanged"
+ """Item's Y axis binding changed flag."""
+
+ FILL = "fillChanged"
+ """Item's fill changed flag."""
+
+ ALPHA = "alphaChanged"
+ """Item's transparency alpha changed flag."""
+
+ DATA = "dataChanged"
+ """Item's data changed flag"""
+
+ MASK = "maskChanged"
+ """Item's mask changed flag"""
+
+ HIGHLIGHTED = "highlightedChanged"
+ """Item's highlight state changed flag."""
+
+ HIGHLIGHTED_COLOR = "highlightedColorChanged"
+ """Deprecated, use HIGHLIGHTED_STYLE instead."""
+
+ HIGHLIGHTED_STYLE = "highlightedStyleChanged"
+ """Item's highlighted style changed flag."""
+
+ SCALE = "scaleChanged"
+ """Item's scale changed flag."""
+
+ TEXT = "textChanged"
+ """Item's text changed flag."""
+
+ POSITION = "positionChanged"
+ """Item's position changed flag.
+
+ This is emitted when a marker position changed and
+ when an image origin changed.
+ """
+
+ OVERLAY = "overlayChanged"
+ """Item's overlay state changed flag."""
+
+ VISUALIZATION_MODE = "visualizationModeChanged"
+ """Item's visualization mode changed flag."""
+
+ COMPLEX_MODE = "complexModeChanged"
+ """Item's complex data visualization mode changed flag."""
+
+ NAME = "nameChanged"
+ """Item's name changed flag."""
+
+ EDITABLE = "editableChanged"
+ """Item's editable state changed flags."""
+
+ SELECTABLE = "selectableChanged"
+ """Item's selectable state changed flags."""
+
+ FONT = "fontChanged"
+ """Item's text font changed flag."""
+
+ BACKGROUND_COLOR = "backgroundColorChanged"
+ """Item's text background color changed flag."""
+
+
+class Item(qt.QObject):
+ """Description of an item of the plot"""
+
+ _DEFAULT_Z_LAYER = 0
+ """Default layer for overlay rendering"""
+
+ _DEFAULT_SELECTABLE = False
+ """Default selectable state of items"""
+
+ sigItemChanged = qt.Signal(object)
+ """Signal emitted when the item has changed.
+
+ It provides a flag describing which property of the item has changed.
+ See :class:`ItemChangedType` for flags description.
+ """
+
+ _sigVisibleBoundsChanged = qt.Signal()
+ """Signal emitted when the visible extent of the item in the plot has changed.
+
+ This signal is emitted only if visible extent tracking is enabled
+ (see :meth:`_setVisibleBoundsTracking`).
+ """
+
+ def __init__(self):
+ qt.QObject.__init__(self)
+ self._dirty = True
+ self._plotRef = None
+ self._visible = True
+ self._selectable = self._DEFAULT_SELECTABLE
+ self._z = self._DEFAULT_Z_LAYER
+ self._info = None
+ self._xlabel = None
+ self._ylabel = None
+ self.__name = ""
+
+ self.__visibleBoundsTracking = False
+ self.__previousVisibleBounds = None
+
+ self._backendRenderer = None
+
+ def getPlot(self):
+ """Returns the ~silx.gui.plot.PlotWidget this item belongs to.
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def _setPlot(self, plot):
+ """Set the plot this item belongs to.
+
+ WARNING: This should only be called from the Plot.
+
+ :param Union[~silx.gui.plot.PlotWidget,None] plot: The Plot instance.
+ """
+ if plot is not None and self._plotRef is not None:
+ raise RuntimeError("Trying to add a node at two places.")
+ self.__disconnectFromPlotWidget()
+ self._plotRef = None if plot is None else weakref.ref(plot)
+ self.__connectToPlotWidget()
+ self._updated()
+
+ def getBounds(self): # TODO return a Bounds object rather than a tuple
+ """Returns the bounding box of this item in data coordinates
+
+ :returns: (xmin, xmax, ymin, ymax) or None
+ :rtype: 4-tuple of float or None
+ """
+ return self._getBounds()
+
+ def _getBounds(self):
+ """:meth:`getBounds` implementation to override by sub-class"""
+ return None
+
+ def isVisible(self):
+ """True if item is visible, False otherwise
+
+ :rtype: bool
+ """
+ return self._visible
+
+ def setVisible(self, visible):
+ """Set visibility of item.
+
+ :param bool visible: True to display it, False otherwise
+ """
+ visible = bool(visible)
+ if visible != self._visible:
+ self._visible = visible
+ # When visibility has changed, always mark as dirty
+ self._updated(ItemChangedType.VISIBLE, checkVisibility=False)
+ if visible:
+ self._visibleBoundsChanged()
+
+ def isOverlay(self):
+ """Return true if item is drawn as an overlay.
+
+ :rtype: bool
+ """
+ return False
+
+ def getName(self):
+ """Returns the name of the item which is used as legend.
+
+ :rtype: str
+ """
+ return self.__name
+
+ def setName(self, name):
+ """Set the name of the item which is used as legend.
+
+ :param str name: New name of the item
+ :raises RuntimeError: If item belongs to a PlotWidget.
+ """
+ name = str(name)
+ if self.__name != name:
+ if self.getPlot() is not None:
+ raise RuntimeError("Cannot change name while item is in a PlotWidget")
+
+ self.__name = name
+ self._updated(ItemChangedType.NAME)
+
+ def getLegend(self): # Replaced by getName for API consistency
+ return self.getName()
+
+ def isSelectable(self):
+ """Returns true if item is selectable (bool)"""
+ return self._selectable
+
+ def _setSelectable(self, selectable): # TODO support update
+ """Set whether item is selectable or not.
+
+ This is private for now as change is not handled.
+
+ :param bool selectable: True to make item selectable
+ """
+ self._selectable = bool(selectable)
+
+ def getZValue(self):
+ """Returns the layer on which to draw this item (int)"""
+ return self._z
+
+ def setZValue(self, z):
+ z = int(z) if z is not None else self._DEFAULT_Z_LAYER
+ if z != self._z:
+ self._z = z
+ self._updated(ItemChangedType.ZVALUE)
+
+ def getInfo(self, copy=True):
+ """Returns the info associated to this item
+
+ :param bool copy: True to get a deepcopy, False otherwise.
+ """
+ return deepcopy(self._info) if copy else self._info
+
+ def setInfo(self, info, copy=True):
+ if copy:
+ info = deepcopy(info)
+ self._info = info
+
+ def getVisibleBounds(self) -> Optional[Tuple[float, float, float, float]]:
+ """Returns visible bounds of the item bounding box in the plot area.
+
+ :returns:
+ (xmin, xmax, ymin, ymax) in data coordinates of the visible area or
+ None if item is not visible in the plot area.
+ :rtype: Union[List[float],None]
+ """
+ plot = self.getPlot()
+ bounds = self.getBounds()
+ if plot is None or bounds is None or not self.isVisible():
+ return None
+
+ xmin, xmax = numpy.clip(bounds[:2], *plot.getXAxis().getLimits())
+ ymin, ymax = numpy.clip(
+ bounds[2:], *plot.getYAxis(self.__getYAxis()).getLimits()
+ )
+
+ if xmin == xmax or ymin == ymax: # Outside the plot area
+ return None
+ else:
+ return xmin, xmax, ymin, ymax
+
+ def _isVisibleBoundsTracking(self) -> bool:
+ """Returns True if visible bounds changes are tracked.
+
+ When enabled, :attr:`_sigVisibleBoundsChanged` is emitted upon changes.
+ :rtype: bool
+ """
+ return self.__visibleBoundsTracking
+
+ def _setVisibleBoundsTracking(self, enable: bool) -> None:
+ """Set whether or not to track visible bounds changes.
+
+ :param bool enable:
+ """
+ if enable != self.__visibleBoundsTracking:
+ self.__disconnectFromPlotWidget()
+ self.__previousVisibleBounds = None
+ self.__visibleBoundsTracking = enable
+ self.__connectToPlotWidget()
+
+ def __getYAxis(self) -> str:
+ """Returns current Y axis ('left' or 'right')"""
+ return self.getYAxis() if isinstance(self, YAxisMixIn) else "left"
+
+ def __connectToPlotWidget(self) -> None:
+ """Connect to PlotWidget signals and install event filter"""
+ if not self._isVisibleBoundsTracking():
+ return
+
+ plot = self.getPlot()
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
+ axis.sigLimitsChanged.connect(self._visibleBoundsChanged)
+
+ plot.installEventFilter(self)
+
+ self._visibleBoundsChanged()
+
+ def __disconnectFromPlotWidget(self) -> None:
+ """Disconnect from PlotWidget signals and remove event filter"""
+ if not self._isVisibleBoundsTracking():
+ return
+
+ plot = self.getPlot()
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
+ axis.sigLimitsChanged.disconnect(self._visibleBoundsChanged)
+
+ plot.removeEventFilter(self)
+
+ def _visibleBoundsChanged(self, *args) -> None:
+ """Check if visible extent actually changed and emit signal"""
+ if not self._isVisibleBoundsTracking():
+ return # No visible extent tracking
+
+ plot = self.getPlot()
+ if plot is None or not plot.isVisible():
+ return # No plot or plot not visible
+
+ extent = self.getVisibleBounds()
+ if extent != self.__previousVisibleBounds:
+ self.__previousVisibleBounds = extent
+ self._sigVisibleBoundsChanged.emit()
+
+ def eventFilter(self, watched, event):
+ """Event filter to handle PlotWidget show events"""
+ if watched is self.getPlot() and event.type() == qt.QEvent.Show:
+ self._visibleBoundsChanged()
+ return super().eventFilter(watched, event)
+
+ def _updated(self, event=None, checkVisibility=True):
+ """Mark the item as dirty (i.e., needing update).
+
+ This also triggers Plot.replot.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ :param bool checkVisibility: True to only mark as dirty if visible,
+ False to always mark as dirty.
+ """
+ if not checkVisibility or self.isVisible():
+ if not self._dirty:
+ self._dirty = True
+ # TODO: send event instead of explicit call
+ plot = self.getPlot()
+ if plot is not None:
+ plot._itemRequiresUpdate(self)
+ if event is not None:
+ self.sigItemChanged.emit(event)
+
+ def _update(self, backend):
+ """Called by Plot to update the backend for this item.
+
+ This is meant to be called asynchronously from _updated.
+ This optimizes the number of call to _update.
+
+ :param backend: The backend to update
+ """
+ if self._dirty:
+ # Remove previous renderer from backend if any
+ self._removeBackendRenderer(backend)
+
+ # If not visible, do not add renderer to backend
+ if self.isVisible():
+ self._backendRenderer = self._addBackendRenderer(backend)
+
+ self._dirty = False
+
+ def _addBackendRenderer(self, backend):
+ """Override in subclass to add specific backend renderer.
+
+ :param BackendBase backend: The backend to update
+ :return: The renderer handle to store or None if no renderer in backend
+ """
+ return None
+
+ def _removeBackendRenderer(self, backend):
+ """Override in subclass to remove specific backend renderer.
+
+ :param BackendBase backend: The backend to update
+ """
+ if self._backendRenderer is not None:
+ backend.remove(self._backendRenderer)
+ self._backendRenderer = None
+
+ def pick(self, x, y):
+ """Run picking test on this item
+
+ :param float x: The x pixel coord where to pick.
+ :param float y: The y pixel coord where to pick.
+ :return: None if not picked, else the picked position information
+ :rtype: Union[None,PickingResult]
+ """
+ if not self.isVisible() or self._backendRenderer is None:
+ return None
+ plot = self.getPlot()
+ if plot is None:
+ return None
+
+ indices = plot._backend.pickItem(x, y, self._backendRenderer)
+ if indices is None:
+ return None
+ else:
+ return PickingResult(self, indices)
+
+
+class DataItem(Item):
+ """Item with a data extent in the plot"""
+
+ def _boundsChanged(self, checkVisibility: bool = True) -> None:
+ """Call this method in subclass when data bounds has changed.
+
+ :param bool checkVisibility:
+ """
+ if not checkVisibility or self.isVisible():
+ if self.isVisible():
+ self._visibleBoundsChanged()
+
+ # TODO hackish data range implementation
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
+
+ @docstring(Item)
+ def setVisible(self, visible: bool):
+ if visible != self.isVisible():
+ self._boundsChanged(checkVisibility=False)
+ super().setVisible(visible)
+
+
+# Mix-in classes ##############################################################
+
+
+class ItemMixInBase(object):
+ """Base class for Item mix-in"""
+
+ def _updated(self, event=None, checkVisibility=True):
+ """This is implemented in :class:`Item`.
+
+ Mark the item as dirty (i.e., needing update).
+ This also triggers Plot.replot.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ :param bool checkVisibility: True to only mark as dirty if visible,
+ False to always mark as dirty.
+ """
+ raise RuntimeError("Issue with Mix-In class inheritance order")
+
+
+class LabelsMixIn(ItemMixInBase):
+ """Mix-in class for items with x and y labels
+
+ Setters are private, otherwise it needs to check the plot
+ current active curve and access the internal current labels.
+ """
+
+ def __init__(self):
+ self._xlabel = None
+ self._ylabel = None
+
+ def getXLabel(self):
+ """Return the X axis label associated to this curve
+
+ :rtype: str or None
+ """
+ return self._xlabel
+
+ def _setXLabel(self, label):
+ """Set the X axis label associated with this curve
+
+ :param str label: The X axis label
+ """
+ self._xlabel = str(label)
+
+ def getYLabel(self):
+ """Return the Y axis label associated to this curve
+
+ :rtype: str or None
+ """
+ return self._ylabel
+
+ def _setYLabel(self, label):
+ """Set the Y axis label associated with this curve
+
+ :param str label: The Y axis label
+ """
+ self._ylabel = str(label)
+
+
+class DraggableMixIn(ItemMixInBase):
+ """Mix-in class for draggable items"""
+
+ def __init__(self):
+ self._draggable = False
+
+ def isDraggable(self):
+ """Returns true if image is draggable
+
+ :rtype: bool
+ """
+ return self._draggable
+
+ def _setDraggable(self, draggable): # TODO support update
+ """Set if image is draggable or not.
+
+ This is private for not as it does not support update.
+
+ :param bool draggable:
+ """
+ self._draggable = bool(draggable)
+
+ def drag(self, from_, to):
+ """Perform a drag of the item.
+
+ :param List[float] from_: (x, y) previous position in data coordinates
+ :param List[float] to: (x, y) current position in data coordinates
+ """
+ raise NotImplementedError("Must be implemented in subclass")
+
+
+class ColormapMixIn(_Colormappable, ItemMixInBase):
+ """Mix-in class for items with colormap"""
+
+ def __init__(self):
+ self._colormap = Colormap()
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ self.__data = None
+ self.__cacheColormapRange = {} # Store {normalization: range}
+
+ def getColormap(self):
+ """Return the used colormap"""
+ return self._colormap
+
+ def setColormap(self, colormap):
+ """Set the colormap of this item
+
+ :param silx.gui.colors.Colormap colormap: colormap description
+ """
+ if self._colormap is colormap:
+ return
+ if isinstance(colormap, dict):
+ colormap = Colormap._fromDict(colormap)
+
+ if self._colormap is not None:
+ self._colormap.sigChanged.disconnect(self._colormapChanged)
+ self._colormap = colormap
+ if self._colormap is not None:
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ self._colormapChanged()
+
+ def _colormapChanged(self):
+ """Handle updates of the colormap"""
+ self._updated(ItemChangedType.COLORMAP)
+
+ def _setColormappedData(
+ self, data, copy=True, min_=None, minPositive=None, max_=None
+ ):
+ """Set the data used to compute the colormapped display.
+
+ It also resets the cache of data ranges.
+
+ This method MUST be called by inheriting classes when data is updated.
+
+ :param Union[None,numpy.ndarray] data:
+ :param Union[None,float] min_: Minimum value of the data
+ :param Union[None,float] minPositive:
+ Minimum of strictly positive values of the data
+ :param Union[None,float] max_: Maximum value of the data
+ """
+ self.__data = None if data is None else numpy.array(data, copy=copy)
+ self.__cacheColormapRange = {} # Reset cache
+
+ # Fill-up colormap range cache if values are provided
+ if max_ is not None and numpy.isfinite(max_):
+ if min_ is not None and numpy.isfinite(min_):
+ self.__cacheColormapRange[Colormap.LINEAR, Colormap.MINMAX] = min_, max_
+ if minPositive is not None and numpy.isfinite(minPositive):
+ self.__cacheColormapRange[Colormap.LOGARITHM, Colormap.MINMAX] = (
+ minPositive,
+ max_,
+ )
+
+ colormap = self.getColormap()
+ if None in (colormap.getVMin(), colormap.getVMax()):
+ self._colormapChanged()
+
+ def getColormappedData(self, copy=True):
+ """Returns the data used to compute the displayed colors
+
+ :param bool copy: True to get a copy,
+ False to get internal data (do not modify!).
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self.__data is None:
+ return None
+ else:
+ return numpy.array(self.__data, copy=copy)
+
+ def _getColormapAutoscaleRange(self, colormap=None):
+ """Returns the autoscale range for current data and colormap.
+
+ :param Union[None,~silx.gui.colors.Colormap] colormap:
+ The colormap for which to compute the autoscale range.
+ If None, the default, the colormap of the item is used
+ :return: (vmin, vmax) range (vmin and /or vmax might be `None`)
+ """
+ if colormap is None:
+ colormap = self.getColormap()
+
+ data = self.getColormappedData(copy=False)
+ if colormap is None or data is None:
+ return None, None
+
+ normalization = colormap.getNormalization()
+ autoscaleMode = colormap.getAutoscaleMode()
+ key = normalization, autoscaleMode
+ vRange = self.__cacheColormapRange.get(key, None)
+ if vRange is None:
+ vRange = colormap._computeAutoscaleRange(data)
+ self.__cacheColormapRange[key] = vRange
+ return vRange
+
+
+class SymbolMixIn(ItemMixInBase):
+ """Mix-in class for items with symbol type"""
+
+ _DEFAULT_SYMBOL = None
+ """Default marker of the item"""
+
+ _DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE
+ """Default marker size of the item"""
+
+ _SUPPORTED_SYMBOLS = dict(
+ (
+ ("o", "Circle"),
+ ("d", "Diamond"),
+ ("s", "Square"),
+ ("+", "Plus"),
+ ("x", "Cross"),
+ (".", "Point"),
+ (",", "Pixel"),
+ ("|", "Vertical line"),
+ ("_", "Horizontal line"),
+ ("tickleft", "Tick left"),
+ ("tickright", "Tick right"),
+ ("tickup", "Tick up"),
+ ("tickdown", "Tick down"),
+ ("caretleft", "Caret left"),
+ ("caretright", "Caret right"),
+ ("caretup", "Caret up"),
+ ("caretdown", "Caret down"),
+ ("\u2665", "Heart"),
+ ("", "None"),
+ )
+ )
+ """Dict of supported symbols"""
+
+ def __init__(self):
+ if self._DEFAULT_SYMBOL is None: # Use default from config
+ self._symbol = config.DEFAULT_PLOT_SYMBOL
+ else:
+ self._symbol = self._DEFAULT_SYMBOL
+
+ if self._DEFAULT_SYMBOL_SIZE is None: # Use default from config
+ self._symbol_size = config.DEFAULT_PLOT_SYMBOL_SIZE
+ else:
+ self._symbol_size = self._DEFAULT_SYMBOL_SIZE
+
+ @classmethod
+ def getSupportedSymbols(cls):
+ """Returns the list of supported symbol names.
+
+ :rtype: tuple of str
+ """
+ return tuple(cls._SUPPORTED_SYMBOLS.keys())
+
+ @classmethod
+ def getSupportedSymbolNames(cls):
+ """Returns the list of supported symbol human-readable names.
+
+ :rtype: tuple of str
+ """
+ return tuple(cls._SUPPORTED_SYMBOLS.values())
+
+ def getSymbolName(self, symbol=None):
+ """Returns human-readable name for a symbol.
+
+ :param str symbol: The symbol from which to get the name.
+ Default: current symbol.
+ :rtype: str
+ :raise KeyError: if symbol is not in :meth:`getSupportedSymbols`.
+ """
+ if symbol is None:
+ symbol = self.getSymbol()
+ return self._SUPPORTED_SYMBOLS[symbol]
+
+ def getSymbol(self):
+ """Return the point marker type.
+
+ Marker type::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :rtype: str
+ """
+ return self._symbol
+
+ def setSymbol(self, symbol):
+ """Set the marker type
+
+ See :meth:`getSymbol`.
+
+ :param str symbol: Marker type or marker name
+ """
+ if symbol is None:
+ symbol = self._DEFAULT_SYMBOL
+
+ elif symbol not in self.getSupportedSymbols():
+ for symbolCode, name in self._SUPPORTED_SYMBOLS.items():
+ if name.lower() == symbol.lower():
+ symbol = symbolCode
+ break
+ else:
+ raise ValueError("Unsupported symbol %s" % str(symbol))
+
+ if symbol != self._symbol:
+ self._symbol = symbol
+ self._updated(ItemChangedType.SYMBOL)
+
+ def getSymbolSize(self):
+ """Return the point marker size in points.
+
+ :rtype: float
+ """
+ return self._symbol_size
+
+ def setSymbolSize(self, size):
+ """Set the point marker size in points.
+
+ See :meth:`getSymbolSize`.
+
+ :param str symbol: Marker type
+ """
+ if size is None:
+ size = self._DEFAULT_SYMBOL_SIZE
+ if size != self._symbol_size:
+ self._symbol_size = size
+ self._updated(ItemChangedType.SYMBOL_SIZE)
+
+
+LineStyleType = Union[
+ str,
+ Tuple[Union[float, int], None],
+ Tuple[Union[float, int], Tuple[Union[float, int], Union[float, int]]],
+ Tuple[Union[float, int], Tuple[Union[float, int], Union[float, int], Union[float, int], Union[float, int]]],
+]
+"""Type for :class:`LineMixIn`'s line style"""
+
+
+class LineMixIn(ItemMixInBase):
+ """Mix-in class for item with line"""
+
+ _DEFAULT_LINEWIDTH: float = 1.0
+ """Default line width"""
+
+ _DEFAULT_LINESTYLE: LineStyleType = "-"
+ """Default line style"""
+
+ _SUPPORTED_LINESTYLE = "", " ", "-", "--", "-.", ":", None
+ """Supported line styles"""
+
+ def __init__(self):
+ self._linewidth: float = self._DEFAULT_LINEWIDTH
+ self._linestyle: LineStyleType = self._DEFAULT_LINESTYLE
+
+ @classmethod
+ def getSupportedLineStyles(cls) -> tuple[str | None]:
+ """Returns list of supported constant line styles."""
+ return cls._SUPPORTED_LINESTYLE
+
+ def getLineWidth(self) -> float:
+ """Return the curve line width in pixels"""
+ return self._linewidth
+
+ def setLineWidth(self, width: float):
+ """Set the width in pixel of the curve line
+
+ See :meth:`getLineWidth`.
+ """
+ width = float(width)
+ if width != self._linewidth:
+ self._linewidth = width
+ self._updated(ItemChangedType.LINE_WIDTH)
+
+ @classmethod
+ def isValidLineStyle(cls, style: LineStyleType | None) -> bool:
+ """Returns True for valid styles"""
+ if style is None or style in cls.getSupportedLineStyles():
+ return True
+ if not isinstance(style, tuple):
+ return False
+ if (
+ len(style) == 2
+ and isinstance(style[0], (float, int))
+ and (
+ style[1] is None
+ or style[1] == ()
+ or (
+ isinstance(style[1], tuple)
+ and len(style[1]) in (2, 4)
+ and all(map(lambda item: isinstance(item, (float, int)), style[1]))
+ )
+ )
+ ):
+ return True
+ return False
+
+ def getLineStyle(self) -> LineStyleType:
+ """Return the type of the line
+
+ Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - (offset, (dash pattern))
+ """
+ return self._linestyle
+
+ def setLineStyle(self, style: LineStyleType | None):
+ """Set the style of the curve line.
+
+ See :meth:`getLineStyle`.
+
+ :param style: Line style
+ """
+ if not self.isValidLineStyle(style):
+ raise ValueError(f"No a valid line style: {style}")
+ if style is None:
+ style = self._DEFAULT_LINESTYLE
+ if style != self._linestyle:
+ self._linestyle = style
+ self._updated(ItemChangedType.LINE_STYLE)
+
+
+class ColorMixIn(ItemMixInBase):
+ """Mix-in class for item with color"""
+
+ _DEFAULT_COLOR = (0.0, 0.0, 0.0, 1.0)
+ """Default color of the item"""
+
+ def __init__(self):
+ self._color = self._DEFAULT_COLOR
+
+ def getColor(self):
+ """Returns the RGBA color of the item
+
+ :rtype: 4-tuple of float in [0, 1] or array of colors
+ """
+ return self._color
+
+ def setColor(self, color, copy=True):
+ """Set item color
+
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if isinstance(color, str):
+ color = colors.rgba(color)
+ elif isinstance(color, qt.QColor):
+ color = colors.rgba(color)
+ else:
+ color = numpy.array(color, copy=copy)
+ # TODO more checks + improve color array support
+ if color.ndim == 1: # Single RGBA color
+ color = colors.rgba(color)
+ else: # Array of colors
+ assert color.ndim == 2
+
+ self._color = color
+ self._updated(ItemChangedType.COLOR)
+
+
+class LineGapColorMixIn(ItemMixInBase):
+ """Mix-in class for dashed line gap color"""
+
+ _DEFAULT_LINE_GAP_COLOR = None
+ """Default dashed line gap color of the item"""
+
+ def __init__(self):
+ self.__lineGapColor = self._DEFAULT_LINE_GAP_COLOR
+
+ def getLineGapColor(self):
+ """Returns the RGBA color of dashed line gap of the item
+
+ :rtype: 4-tuple of float in [0, 1] or None
+ """
+ return self.__lineGapColor
+
+ def setLineGapColor(self, color):
+ """Set dashed line gap color
+
+ It supports:
+ - color names: e.g., 'green'
+ - color codes: '#RRGGBB' and '#RRGGBBAA'
+ - indexed color names: e.g., 'C0'
+ - RGB(A) sequence of uint8 in [0, 255] or float in [0, 1]
+ - QColor
+
+ :param color: line background color to be used
+ :type color: Union[str, List[int], List[float], QColor, None]
+ """
+ self.__lineGapColor = None if color is None else colors.rgba(color)
+ self._updated(ItemChangedType.LINE_GAP_COLOR)
+
+
+class YAxisMixIn(ItemMixInBase):
+ """Mix-in class for item with yaxis"""
+
+ _DEFAULT_YAXIS = "left"
+ """Default Y axis the item belongs to"""
+
+ def __init__(self):
+ self._yaxis = self._DEFAULT_YAXIS
+
+ def getYAxis(self):
+ """Returns the Y axis this curve belongs to.
+
+ Either 'left' or 'right'.
+
+ :rtype: str
+ """
+ return self._yaxis
+
+ def setYAxis(self, yaxis):
+ """Set the Y axis this curve belongs to.
+
+ :param str yaxis: 'left' or 'right'
+ """
+ yaxis = str(yaxis)
+ assert yaxis in ("left", "right")
+ if yaxis != self._yaxis:
+ self._yaxis = yaxis
+ # Handle data extent changed for DataItem
+ if isinstance(self, DataItem):
+ self._boundsChanged()
+
+ # Handle visible extent changed
+ if self._isVisibleBoundsTracking():
+ # Switch Y axis signal connection
+ plot = self.getPlot()
+ if plot is not None:
+ previousYAxis = "left" if self.getXAxis() == "right" else "right"
+ plot.getYAxis(previousYAxis).sigLimitsChanged.disconnect(
+ self._visibleBoundsChanged
+ )
+ plot.getYAxis(self.getYAxis()).sigLimitsChanged.connect(
+ self._visibleBoundsChanged
+ )
+ self._visibleBoundsChanged()
+
+ self._updated(ItemChangedType.YAXIS)
+
+
+class FillMixIn(ItemMixInBase):
+ """Mix-in class for item with fill"""
+
+ def __init__(self):
+ self._fill = False
+
+ def isFill(self):
+ """Returns whether the item is filled or not.
+
+ :rtype: bool
+ """
+ return self._fill
+
+ def setFill(self, fill):
+ """Set whether to fill the item or not.
+
+ :param bool fill:
+ """
+ fill = bool(fill)
+ if fill != self._fill:
+ self._fill = fill
+ self._updated(ItemChangedType.FILL)
+
+
+class AlphaMixIn(ItemMixInBase):
+ """Mix-in class for item with opacity"""
+
+ def __init__(self):
+ self._alpha = 1.0
+
+ def getAlpha(self):
+ """Returns the opacity of the item
+
+ :rtype: float in [0, 1.]
+ """
+ return self._alpha
+
+ def setAlpha(self, alpha):
+ """Set the opacity of the item
+
+ .. note::
+
+ If the colormap already has some transparency, this alpha
+ adds additional transparency. The alpha channel of the colormap
+ is multiplied by this value.
+
+ :param alpha: Opacity of the item, between 0 (full transparency)
+ and 1. (full opacity)
+ :type alpha: float
+ """
+ alpha = float(alpha)
+ alpha = max(0.0, min(alpha, 1.0)) # Clip alpha to [0., 1.] range
+ if alpha != self._alpha:
+ self._alpha = alpha
+ self._updated(ItemChangedType.ALPHA)
+
+
+class ComplexMixIn(ItemMixInBase):
+ """Mix-in class for complex data mode"""
+
+ _SUPPORTED_COMPLEX_MODES = None
+ """Override to only support a subset of all ComplexMode"""
+
+ class ComplexMode(_Enum):
+ """Identify available display mode for complex"""
+
+ NONE = "none"
+ ABSOLUTE = "amplitude"
+ PHASE = "phase"
+ REAL = "real"
+ IMAGINARY = "imaginary"
+ AMPLITUDE_PHASE = "amplitude_phase"
+ LOG10_AMPLITUDE_PHASE = "log10_amplitude_phase"
+ SQUARE_AMPLITUDE = "square_amplitude"
+
+ def __init__(self):
+ self.__complex_mode = self.ComplexMode.ABSOLUTE
+
+ def getComplexMode(self):
+ """Returns the current complex visualization mode.
+
+ :rtype: ComplexMode
+ """
+ return self.__complex_mode
+
+ def setComplexMode(self, mode):
+ """Set the complex visualization mode.
+
+ :param ComplexMode mode: The visualization mode in:
+ 'real', 'imaginary', 'phase', 'amplitude'
+ :return: True if value was set, False if is was already set
+ :rtype: bool
+ """
+ mode = self.ComplexMode.from_value(mode)
+ assert mode in self.supportedComplexModes()
+
+ if mode != self.__complex_mode:
+ self.__complex_mode = mode
+ self._updated(ItemChangedType.COMPLEX_MODE)
+ return True
+ else:
+ return False
+
+ def _convertComplexData(self, data, mode=None):
+ """Convert complex data to the specific mode.
+
+ :param Union[ComplexMode,None] mode:
+ The kind of value to compute.
+ If None (the default), the current complex mode is used.
+ :return: The converted dataset
+ :rtype: Union[numpy.ndarray[float],None]
+ """
+ if data is None:
+ return None
+
+ if mode is None:
+ mode = self.getComplexMode()
+
+ if mode is self.ComplexMode.REAL:
+ return numpy.real(data)
+ elif mode is self.ComplexMode.IMAGINARY:
+ return numpy.imag(data)
+ elif mode is self.ComplexMode.ABSOLUTE:
+ return numpy.absolute(data)
+ elif mode is self.ComplexMode.PHASE:
+ return numpy.angle(data)
+ elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
+ return numpy.absolute(data) ** 2
+ else:
+ raise ValueError("Unsupported conversion mode: %s", str(mode))
+
+ @classmethod
+ def supportedComplexModes(cls):
+ """Returns the list of supported complex visualization modes.
+
+ See :class:`ComplexMode` and :meth:`setComplexMode`.
+
+ :rtype: List[ComplexMode]
+ """
+ if cls._SUPPORTED_COMPLEX_MODES is None:
+ return cls.ComplexMode.members()
+ else:
+ return cls._SUPPORTED_COMPLEX_MODES
+
+
+class ScatterVisualizationMixIn(ItemMixInBase):
+ """Mix-in class for scatter plot visualization modes"""
+
+ _SUPPORTED_SCATTER_VISUALIZATION = None
+ """Allows to override supported Visualizations"""
+
+ @enum.unique
+ class Visualization(_Enum):
+ """Different modes of scatter plot visualizations"""
+
+ POINTS = "points"
+ """Display scatter plot as a point cloud"""
+
+ LINES = "lines"
+ """Display scatter plot as a wireframe.
+
+ This is based on Delaunay triangulation
+ """
+
+ SOLID = "solid"
+ """Display scatter plot as a set of filled triangles.
+
+ This is based on Delaunay triangulation
+ """
+
+ REGULAR_GRID = "regular_grid"
+ """Display scatter plot as an image.
+
+ It expects the points to be the intersection of a regular grid,
+ and the order of points following that of an image.
+ First line, then second one, and always in the same direction
+ (either all lines from left to right or all from right to left).
+ """
+
+ IRREGULAR_GRID = "irregular_grid"
+ """Display scatter plot as contiguous quadrilaterals.
+
+ It expects the points to be the intersection of an irregular grid,
+ and the order of points following that of an image.
+ First line, then second one, and always in the same direction
+ (either all lines from left to right or all from right to left).
+ """
+
+ BINNED_STATISTIC = "binned_statistic"
+ """Display scatter plot as 2D binned statistic (i.e., generalized histogram).
+ """
+
+ @enum.unique
+ class VisualizationParameter(_Enum):
+ """Different parameter names for scatter plot visualizations"""
+
+ GRID_MAJOR_ORDER = "grid_major_order"
+ """The major order of points in the regular grid.
+
+ Either 'row' (row-major, fast X) or 'column' (column-major, fast Y).
+ """
+
+ GRID_BOUNDS = "grid_bounds"
+ """The expected range in data coordinates of the regular grid.
+
+ A 2-tuple of 2-tuple: (begin (x, y), end (x, y)).
+ This provides the data coordinates of the first point and the expected
+ last on.
+ As for `GRID_SHAPE`, this can be wider than the current data.
+ """
+
+ GRID_SHAPE = "grid_shape"
+ """The expected size of the regular grid (height, width).
+
+ The given shape can be wider than the number of points,
+ in which case the grid is not fully filled.
+ """
+
+ BINNED_STATISTIC_SHAPE = "binned_statistic_shape"
+ """The number of bins in each dimension (height, width).
+ """
+
+ BINNED_STATISTIC_FUNCTION = "binned_statistic_function"
+ """The reduction function to apply to each bin (str).
+
+ Available reduction functions are: 'mean' (default), 'count', 'sum'.
+ """
+
+ DATA_BOUNDS_HINT = "data_bounds_hint"
+ """The expected bounds of the data in data coordinates.
+
+ A 2-tuple of 2-tuple: ((ymin, ymax), (xmin, xmax)).
+ This provides a hint for the data ranges in both dimensions.
+ It is eventually enlarged with actually data ranges.
+
+ WARNING: dimension 0 i.e., Y first.
+ """
+
+ _SUPPORTED_VISUALIZATION_PARAMETER_VALUES = {
+ VisualizationParameter.GRID_MAJOR_ORDER: ("row", "column"),
+ VisualizationParameter.BINNED_STATISTIC_FUNCTION: ("mean", "count", "sum"),
+ }
+ """Supported visualization parameter values.
+
+ Defined for parameters with a set of acceptable values.
+ """
+
+ def __init__(self):
+ self.__visualization = self.Visualization.POINTS
+ self.__parameters = dict( # Init parameters to None
+ (parameter, None) for parameter in self.VisualizationParameter
+ )
+ self.__parameters[
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ] = "mean"
+
+ @classmethod
+ def supportedVisualizations(cls):
+ """Returns the list of supported scatter visualization modes.
+
+ See :meth:`setVisualization`
+
+ :rtype: List[Visualization]
+ """
+ if cls._SUPPORTED_SCATTER_VISUALIZATION is None:
+ return cls.Visualization.members()
+ else:
+ return cls._SUPPORTED_SCATTER_VISUALIZATION
+
+ @classmethod
+ def supportedVisualizationParameterValues(cls, parameter):
+ """Returns the list of supported scatter visualization modes.
+
+ See :meth:`VisualizationParameters`
+
+ :param VisualizationParameter parameter:
+ This parameter for which to retrieve the supported values.
+ :returns: tuple of supported of values or None if not defined.
+ """
+ parameter = cls.VisualizationParameter(parameter)
+ return cls._SUPPORTED_VISUALIZATION_PARAMETER_VALUES.get(parameter, None)
+
+ def setVisualization(self, mode):
+ """Set the scatter plot visualization mode to use.
+
+ See :class:`Visualization` for all possible values,
+ and :meth:`supportedVisualizations` for supported ones.
+
+ :param Union[str,Visualization] mode:
+ The visualization mode to use.
+ :return: True if value was set, False if is was already set
+ :rtype: bool
+ """
+ mode = self.Visualization.from_value(mode)
+ assert mode in self.supportedVisualizations()
+
+ if mode != self.__visualization:
+ self.__visualization = mode
+
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+ return True
+ else:
+ return False
+
+ def getVisualization(self):
+ """Returns the scatter plot visualization mode in use.
+
+ :rtype: Visualization
+ """
+ return self.__visualization
+
+ def setVisualizationParameter(self, parameter, value=None):
+ """Set the given visualization parameter.
+
+ :param Union[str,VisualizationParameter] parameter:
+ The name of the parameter to set
+ :param value: The value to use for this parameter
+ Set to None to automatically set the parameter
+ :raises ValueError: If parameter is not supported
+ :return: True if parameter was set, False if is was already set
+ :rtype: bool
+ :raise ValueError: If value is not supported
+ """
+ parameter = self.VisualizationParameter.from_value(parameter)
+
+ if self.__parameters[parameter] != value:
+ validValues = self.supportedVisualizationParameterValues(parameter)
+ if validValues is not None and value not in validValues:
+ raise ValueError("Unsupported parameter value: %s" % str(value))
+
+ self.__parameters[parameter] = value
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+ return True
+ return False
+
+ def getVisualizationParameter(self, parameter):
+ """Returns the value of the given visualization parameter.
+
+ This method returns the parameter as set by
+ :meth:`setVisualizationParameter`.
+
+ :param parameter: The name of the parameter to retrieve
+ :returns: The value previously set or None if automatically set
+ :raises ValueError: If parameter is not supported
+ """
+ if parameter not in self.VisualizationParameter:
+ raise ValueError("parameter not supported: %s", parameter)
+
+ return self.__parameters[parameter]
+
+ def getCurrentVisualizationParameter(self, parameter):
+ """Returns the current value of the given visualization parameter.
+
+ If the parameter was set by :meth:`setVisualizationParameter` to
+ a value that is not None, this value is returned;
+ else the current value that is automatically computed is returned.
+
+ :param parameter: The name of the parameter to retrieve
+ :returns: The current value (either set or automatically computed)
+ :raises ValueError: If parameter is not supported
+ """
+ # Override in subclass to provide automatically computed parameters
+ return self.getVisualizationParameter(parameter)
+
+
+class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
+ """Base class for :class:`Curve` and :class:`Scatter`"""
+
+ # note: _logFilterData must be overloaded if you overload
+ # getData to change its signature
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for points,
+ on top of images."""
+
+ def __init__(self):
+ DataItem.__init__(self)
+ SymbolMixIn.__init__(self)
+ AlphaMixIn.__init__(self)
+ self._x = ()
+ self._y = ()
+ self._xerror = None
+ self._yerror = None
+
+ # Store filtered data for x > 0 and/or y > 0
+ self._filteredCache = {}
+ self._clippedCache = {}
+
+ # Store bounds depending on axes filtering >0:
+ # key is (isXPositiveFilter, isYPositiveFilter)
+ self._boundsCache = {}
+
+ @staticmethod
+ def _logFilterError(value, error):
+ """Filter/convert error values if they go <= 0.
+
+ Replace error leading to negative values by nan
+
+ :param numpy.ndarray value: 1D array of values
+ :param numpy.ndarray error:
+ Array of errors: scalar, N, Nx1 or 2xN or None.
+ :return: Filtered error so error bars are never negative
+ """
+ if error is not None:
+ # Convert Nx1 to N
+ if error.ndim == 2 and error.shape[1] == 1 and len(value) != 1:
+ error = numpy.ravel(error)
+
+ # Supports error being scalar, N or 2xN array
+ valueMinusError = value - numpy.atleast_2d(error)[0]
+ errorClipped = numpy.isnan(valueMinusError)
+ mask = numpy.logical_not(errorClipped)
+ errorClipped[mask] = valueMinusError[mask] <= 0
+
+ if numpy.any(errorClipped): # Need filtering
+ # expand errorbars to 2xN
+ if error.size == 1: # Scalar
+ error = numpy.full((2, len(value)), error, dtype=numpy.float64)
+
+ elif error.ndim == 1: # N array
+ newError = numpy.empty((2, len(value)), dtype=numpy.float64)
+ newError[0, :] = error
+ newError[1, :] = error
+ error = newError
+
+ elif error.size == 2 * len(value): # 2xN array
+ error = numpy.array(error, copy=True, dtype=numpy.float64)
+
+ else:
+ _logger.error("Unhandled error array")
+ return error
+
+ error[0, errorClipped] = numpy.nan
+
+ return error
+
+ def _getClippingBoolArray(self, xPositive, yPositive):
+ """Compute a boolean array to filter out points with negative
+ coordinates on log axes.
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :rtype: boolean numpy.ndarray
+ """
+ assert xPositive or yPositive
+ if (xPositive, yPositive) not in self._clippedCache:
+ xclipped, yclipped = False, False
+
+ if xPositive:
+ x = self.getXData(copy=False)
+ with numpy.errstate(invalid="ignore"): # Ignore NaN warnings
+ xclipped = x <= 0
+
+ if yPositive:
+ y = self.getYData(copy=False)
+ with numpy.errstate(invalid="ignore"): # Ignore NaN warnings
+ yclipped = y <= 0
+
+ self._clippedCache[(xPositive, yPositive)] = numpy.logical_or(
+ xclipped, yclipped
+ )
+ return self._clippedCache[(xPositive, yPositive)]
+
+ def _logFilterData(self, xPositive, yPositive):
+ """Filter out values with x or y <= 0 on log axes
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :return: The filter arrays or unchanged object if filtering not needed
+ :rtype: (x, y, xerror, yerror)
+ """
+ x = self.getXData(copy=False)
+ y = self.getYData(copy=False)
+ xerror = self.getXErrorData(copy=False)
+ yerror = self.getYErrorData(copy=False)
+
+ if xPositive or yPositive:
+ clipped = self._getClippingBoolArray(xPositive, yPositive)
+
+ if numpy.any(clipped):
+ # copy to keep original array and convert to float
+ x = numpy.array(x, copy=True, dtype=numpy.float64)
+ x[clipped] = numpy.nan
+ y = numpy.array(y, copy=True, dtype=numpy.float64)
+ y[clipped] = numpy.nan
+
+ if xPositive and xerror is not None:
+ xerror = self._logFilterError(x, xerror)
+
+ if yPositive and yerror is not None:
+ yerror = self._logFilterError(y, yerror)
+
+ return x, y, xerror, yerror
+
+ @staticmethod
+ def __minMaxDataWithError(
+ data: numpy.ndarray,
+ error: Optional[Union[float, numpy.ndarray]],
+ positiveOnly: bool,
+ ) -> Tuple[float]:
+ if error is None:
+ min_, max_ = min_max(data, finite=True)
+ return min_, max_
+
+ # float, 1D or 2D array
+ dataMinusError = data - numpy.atleast_2d(error)[0]
+ dataMinusError = dataMinusError[numpy.isfinite(dataMinusError)]
+ if positiveOnly:
+ dataMinusError = dataMinusError[dataMinusError > 0]
+ min_ = numpy.nan if dataMinusError.size == 0 else numpy.min(dataMinusError)
+
+ dataPlusError = data + numpy.atleast_2d(error)[-1]
+ dataPlusError = dataPlusError[numpy.isfinite(dataPlusError)]
+ if positiveOnly:
+ dataPlusError = dataPlusError[dataPlusError > 0]
+ max_ = numpy.nan if dataPlusError.size == 0 else numpy.max(dataPlusError)
+
+ return min_, max_
+
+ def _getBounds(self):
+ if self.getXData(copy=False).size == 0: # Empty data
+ return None
+
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ if (xPositive, yPositive) not in self._boundsCache:
+ # use the getData class method because instance method can be
+ # overloaded to return additional arrays
+ data = PointsBase.getData(self, copy=False, displayed=True)
+ if len(data) == 5:
+ # hack to avoid duplicating caching mechanism in Scatter
+ # (happens when cached data is used, caching done using
+ # Scatter._logFilterData)
+ x, y, xerror, yerror = data[0], data[1], data[3], data[4]
+ else:
+ x, y, xerror, yerror = data
+
+ xmin, xmax = self.__minMaxDataWithError(x, xerror, xPositive)
+ ymin, ymax = self.__minMaxDataWithError(y, yerror, yPositive)
+
+ self._boundsCache[(xPositive, yPositive)] = tuple(
+ [
+ (bound if bound is not None else numpy.nan)
+ for bound in (xmin, xmax, ymin, ymax)
+ ]
+ )
+ return self._boundsCache[(xPositive, yPositive)]
+
+ def _getCachedData(self):
+ """Return cached filtered data if applicable,
+ i.e. if any axis is in log scale.
+ Return None if caching is not applicable."""
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ if xPositive or yPositive:
+ # At least one axis has log scale, filter data
+ if (xPositive, yPositive) not in self._filteredCache:
+ self._filteredCache[(xPositive, yPositive)] = self._logFilterData(
+ xPositive, yPositive
+ )
+ return self._filteredCache[(xPositive, yPositive)]
+ return None
+
+ def getData(self, copy=True, displayed=False):
+ """Returns the x, y values of the curve points and xerror, yerror
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param bool displayed: True to only get curve points that are displayed
+ in the plot. Default: False
+ Note: If plot has log scale, negative points
+ are not displayed.
+ :returns: (x, y, xerror, yerror)
+ :rtype: 4-tuple of numpy.ndarray
+ """
+ if displayed: # filter data according to plot state
+ cached_data = self._getCachedData()
+ if cached_data is not None:
+ return cached_data
+
+ return (
+ self.getXData(copy),
+ self.getYData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy),
+ )
+
+ def getXData(self, copy=True):
+ """Returns the x coordinates of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._x, copy=copy)
+
+ def getYData(self, copy=True):
+ """Returns the y coordinates of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._y, copy=copy)
+
+ def getXErrorData(self, copy=True):
+ """Returns the x error of the points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray, float or None
+ """
+ if isinstance(self._xerror, numpy.ndarray):
+ return numpy.array(self._xerror, copy=copy)
+ else:
+ return self._xerror # float or None
+
+ def getYErrorData(self, copy=True):
+ """Returns the y error of the points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray, float or None
+ """
+ if isinstance(self._yerror, numpy.ndarray):
+ return numpy.array(self._yerror, copy=copy)
+ else:
+ return self._yerror # float or None
+
+ def setData(self, x, y, xerror=None, yerror=None, copy=True):
+ """Set the data of the curve.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for lower errors,
+ row 1 for upper errors.
+ :param yerror: Values with the uncertainties on the y values.
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ x = numpy.array(x, copy=copy)
+ y = numpy.array(y, copy=copy)
+ assert len(x) == len(y)
+ assert x.ndim == y.ndim == 1
+
+ # Convert complex data
+ if numpy.iscomplexobj(x):
+ _logger.warning("Converting x data to absolute value to plot it.")
+ x = numpy.absolute(x)
+ if numpy.iscomplexobj(y):
+ _logger.warning("Converting y data to absolute value to plot it.")
+ y = numpy.absolute(y)
+
+ if xerror is not None:
+ if isinstance(xerror, abc.Iterable):
+ xerror = numpy.array(xerror, copy=copy)
+ if numpy.iscomplexobj(xerror):
+ _logger.warning(
+ "Converting xerror data to absolute value to plot it."
+ )
+ xerror = numpy.absolute(xerror)
+ else:
+ xerror = float(xerror)
+ if yerror is not None:
+ if isinstance(yerror, abc.Iterable):
+ yerror = numpy.array(yerror, copy=copy)
+ if numpy.iscomplexobj(yerror):
+ _logger.warning(
+ "Converting yerror data to absolute value to plot it."
+ )
+ yerror = numpy.absolute(yerror)
+ else:
+ yerror = float(yerror)
+ # TODO checks on xerror, yerror
+ self._x, self._y = x, y
+ self._xerror, self._yerror = xerror, yerror
+
+ self._boundsCache = {} # Reset cached bounds
+ self._filteredCache = {} # Reset cached filtered data
+ self._clippedCache = {} # Reset cached clipped bool array
+
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+
+class BaselineMixIn(object):
+ """Base class for Baseline mix-in"""
+
+ def __init__(self, baseline=None):
+ self._baseline = baseline
+
+ def _setBaseline(self, baseline):
+ """
+ Set baseline value
+
+ :param baseline: baseline value(s)
+ :type: Union[None,float,numpy.ndarray]
+ """
+ if isinstance(baseline, abc.Iterable):
+ baseline = numpy.array(baseline)
+ self._baseline = baseline
+
+ def getBaseline(self, copy=True):
+ """
+
+ :param bool copy:
+ :return: histogram baseline
+ :rtype: Union[None,float,numpy.ndarray]
+ """
+ if isinstance(self._baseline, numpy.ndarray):
+ return numpy.array(self._baseline, copy=True)
+ else:
+ return self._baseline
+
+
+class _Style:
+ """Object which store styles"""
+
+
+class HighlightedMixIn(ItemMixInBase):
+ def __init__(self):
+ self._highlightStyle = self._DEFAULT_HIGHLIGHT_STYLE
+ self._highlighted = False
+
+ def isHighlighted(self):
+ """Returns True if curve is highlighted.
+
+ :rtype: bool
+ """
+ return self._highlighted
+
+ def setHighlighted(self, highlighted):
+ """Set the highlight state of the curve
+
+ :param bool highlighted:
+ """
+ highlighted = bool(highlighted)
+ if highlighted != self._highlighted:
+ self._highlighted = highlighted
+ # TODO inefficient: better to use backend's setCurveColor
+ self._updated(ItemChangedType.HIGHLIGHTED)
+
+ def getHighlightedStyle(self):
+ """Returns the highlighted style in use
+
+ :rtype: CurveStyle
+ """
+ return self._highlightStyle
+
+ def setHighlightedStyle(self, style):
+ """Set the style to use for highlighting
+
+ :param CurveStyle style: New style to use
+ """
+ previous = self.getHighlightedStyle()
+ if style != previous:
+ assert isinstance(style, _Style)
+ self._highlightStyle = style
+ self._updated(ItemChangedType.HIGHLIGHTED_STYLE)
+
+ # Backward compatibility event
+ if previous.getColor() != style.getColor():
+ self._updated(ItemChangedType.HIGHLIGHTED_COLOR)
diff --git a/src/silx/gui/plot/items/curve.py b/src/silx/gui/plot/items/curve.py
new file mode 100644
index 0000000..e8d0d52
--- /dev/null
+++ b/src/silx/gui/plot/items/curve.py
@@ -0,0 +1,347 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Curve` item of the :class:`Plot`.
+"""
+from __future__ import annotations
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+
+import numpy
+
+from ....utils.deprecation import deprecated_warning
+from ... import colors
+from .core import (
+ PointsBase,
+ LabelsMixIn,
+ ColorMixIn,
+ YAxisMixIn,
+ FillMixIn,
+ LineMixIn,
+ LineGapColorMixIn,
+ LineStyleType,
+ SymbolMixIn,
+ BaselineMixIn,
+ HighlightedMixIn,
+ _Style,
+)
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CurveStyle(_Style):
+ """Object storing the style of a curve.
+
+ Set a value to None to use the default
+
+ :param color: Color
+ :param linestyle: Style of the line
+ :param linewidth: Width of the line
+ :param symbol: Symbol for markers
+ :param symbolsize: Size of the markers
+ :param gapcolor: Color of gaps of dashed line
+ """
+
+ def __init__(
+ self,
+ color: colors.RGBAColorType | None = None,
+ linestyle: LineStyleType | None = None,
+ linewidth: float | None = None,
+ symbol: str | None = None,
+ symbolsize: float | None = None,
+ gapcolor: colors.RGBAColorType | None = None,
+ ):
+ if color is None:
+ self._color = None
+ else:
+ if isinstance(color, str):
+ color = colors.rgba(color)
+ else: # array-like expected
+ color = numpy.array(color, copy=False)
+ if color.ndim == 1: # Array is 1D, this is a single color
+ color = colors.rgba(color)
+ self._color = color
+
+ if not LineMixIn.isValidLineStyle(linestyle):
+ raise ValueError(f"Not a valid line style: {linestyle}")
+ self._linestyle = linestyle
+
+ self._linewidth = None if linewidth is None else float(linewidth)
+
+ if symbol is not None:
+ assert symbol in SymbolMixIn.getSupportedSymbols()
+ self._symbol = symbol
+
+ self._symbolsize = None if symbolsize is None else float(symbolsize)
+
+ self._gapcolor = None if gapcolor is None else colors.rgba(gapcolor)
+
+ def getColor(self, copy=True):
+ """Returns the color or None if not set.
+
+ :param bool copy: True to get a copy (default),
+ False to get internal representation (do not modify!)
+
+ :rtype: Union[List[float],None]
+ """
+ if isinstance(self._color, numpy.ndarray):
+ return numpy.array(self._color, copy=copy)
+ else:
+ return self._color
+
+ def getLineGapColor(self):
+ """Returns the color of dashed line gaps or None if not set.
+
+ :rtype: Union[List[float],None]
+ """
+ return self._gapcolor
+
+ def getLineStyle(self) -> LineStyleType | None:
+ """Return the type of the line or None if not set.
+
+ Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - (offset, (dash pattern))
+ """
+ return self._linestyle
+
+ def getLineWidth(self):
+ """Return the curve line width in pixels or None if not set.
+
+ :rtype: Union[float,None]
+ """
+ return self._linewidth
+
+ def getSymbol(self):
+ """Return the point marker type.
+
+ Marker type::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :rtype: Union[str,None]
+ """
+ return self._symbol
+
+ def getSymbolSize(self):
+ """Return the point marker size in points.
+
+ :rtype: Union[float,None]
+ """
+ return self._symbolsize
+
+ def __eq__(self, other):
+ if isinstance(other, CurveStyle):
+ return (
+ numpy.array_equal(self.getColor(), other.getColor())
+ and self.getLineStyle() == other.getLineStyle()
+ and self.getLineWidth() == other.getLineWidth()
+ and self.getSymbol() == other.getSymbol()
+ and self.getSymbolSize() == other.getSymbolSize()
+ and self.getLineGapColor() == other.getLineGapColor()
+ )
+ else:
+ return False
+
+
+class Curve(
+ PointsBase,
+ ColorMixIn,
+ YAxisMixIn,
+ FillMixIn,
+ LabelsMixIn,
+ LineMixIn,
+ LineGapColorMixIn,
+ BaselineMixIn,
+ HighlightedMixIn,
+):
+ """Description of a curve"""
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for curves"""
+
+ _DEFAULT_SELECTABLE = True
+ """Default selectable state for curves"""
+
+ _DEFAULT_LINEWIDTH = 1.0
+ """Default line width of the curve"""
+
+ _DEFAULT_LINESTYLE = "-"
+ """Default line style of the curve"""
+
+ _DEFAULT_HIGHLIGHT_STYLE = CurveStyle(color="black")
+ """Default highlight style of the item"""
+
+ _DEFAULT_BASELINE = None
+
+ def __init__(self):
+ PointsBase.__init__(self)
+ ColorMixIn.__init__(self)
+ YAxisMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ LabelsMixIn.__init__(self)
+ LineMixIn.__init__(self)
+ LineGapColorMixIn.__init__(self)
+ BaselineMixIn.__init__(self)
+ HighlightedMixIn.__init__(self)
+
+ self._setBaseline(Curve._DEFAULT_BASELINE)
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ # Filter-out values <= 0
+ xFiltered, yFiltered, xerror, yerror = self.getData(copy=False, displayed=True)
+
+ if len(xFiltered) == 0 or not numpy.any(numpy.isfinite(xFiltered)):
+ return None # No data to display, do not add renderer to backend
+
+ style = self.getCurrentStyle()
+
+ return backend.addCurve(
+ xFiltered,
+ yFiltered,
+ color=style.getColor(),
+ gapcolor=style.getLineGapColor(),
+ symbol=style.getSymbol(),
+ linestyle=style.getLineStyle(),
+ linewidth=style.getLineWidth(),
+ yaxis=self.getYAxis(),
+ xerror=xerror,
+ yerror=yerror,
+ fill=self.isFill(),
+ alpha=self.getAlpha(),
+ symbolsize=style.getSymbolSize(),
+ baseline=self.getBaseline(copy=False),
+ )
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ deprecated_warning(
+ "Attributes",
+ "__getitem__",
+ since_version="2.0.0",
+ replacement="Use Curve methods",
+ )
+ if isinstance(item, slice):
+ return [self[index] for index in range(*item.indices(5))]
+ elif item == 0:
+ return self.getXData(copy=False)
+ elif item == 1:
+ return self.getYData(copy=False)
+ elif item == 2:
+ return self.getName()
+ elif item == 3:
+ info = self.getInfo(copy=False)
+ return {} if info is None else info
+ elif item == 4:
+ params = {
+ "info": self.getInfo(),
+ "color": self.getColor(),
+ "symbol": self.getSymbol(),
+ "linewidth": self.getLineWidth(),
+ "linestyle": self.getLineStyle(),
+ "xlabel": self.getXLabel(),
+ "ylabel": self.getYLabel(),
+ "yaxis": self.getYAxis(),
+ "xerror": self.getXErrorData(copy=False),
+ "yerror": self.getYErrorData(copy=False),
+ "z": self.getZValue(),
+ "selectable": self.isSelectable(),
+ "fill": self.isFill(),
+ }
+ return params
+ else:
+ raise IndexError("Index out of range: %s", str(item))
+
+ def getCurrentStyle(self):
+ """Returns the current curve style.
+
+ Curve style depends on curve highlighting
+
+ :rtype: CurveStyle
+ """
+ if self.isHighlighted():
+ style = self.getHighlightedStyle()
+ color = style.getColor()
+ linestyle = style.getLineStyle()
+ linewidth = style.getLineWidth()
+ symbol = style.getSymbol()
+ symbolsize = style.getSymbolSize()
+ gapcolor = style.getLineGapColor()
+
+ return CurveStyle(
+ color=self.getColor() if color is None else color,
+ linestyle=self.getLineStyle() if linestyle is None else linestyle,
+ linewidth=self.getLineWidth() if linewidth is None else linewidth,
+ symbol=self.getSymbol() if symbol is None else symbol,
+ symbolsize=self.getSymbolSize() if symbolsize is None else symbolsize,
+ gapcolor=self.getLineGapColor() if gapcolor is None else gapcolor,
+ )
+
+ else:
+ return CurveStyle(
+ color=self.getColor(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ symbol=self.getSymbol(),
+ symbolsize=self.getSymbolSize(),
+ gapcolor=self.getLineGapColor(),
+ )
+
+ def setData(self, x, y, xerror=None, yerror=None, baseline=None, copy=True):
+ """Set the data of the curve.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values.
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param baseline: curve baseline
+ :type baseline: Union[None,float,numpy.ndarray]
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ PointsBase.setData(self, x=x, y=y, xerror=xerror, yerror=yerror, copy=copy)
+ self._setBaseline(baseline=baseline)
diff --git a/src/silx/gui/plot/items/histogram.py b/src/silx/gui/plot/items/histogram.py
new file mode 100644
index 0000000..1dc851b
--- /dev/null
+++ b/src/silx/gui/plot/items/histogram.py
@@ -0,0 +1,413 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions::t
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Histogram` item of the :class:`Plot`.
+"""
+
+__authors__ = ["H. Payno", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/08/2018"
+
+import logging
+import typing
+
+import numpy
+from collections import abc
+
+from ....utils.proxy import docstring
+from .core import (
+ DataItem,
+ AlphaMixIn,
+ BaselineMixIn,
+ ColorMixIn,
+ FillMixIn,
+ LineMixIn,
+ LineGapColorMixIn,
+ YAxisMixIn,
+ ItemChangedType,
+)
+from ._pick import PickingResult
+
+_logger = logging.getLogger(__name__)
+
+
+def _computeEdges(x, histogramType):
+ """Compute the edges from a set of xs and a rule to generate the edges
+
+ :param x: the x value of the curve to transform into an histogram
+ :param histogramType: the type of histogram we wan't to generate.
+ This define the way to center the histogram values compared to the
+ curve value. Possible values can be::
+
+ - 'left'
+ - 'right'
+ - 'center'
+
+ :return: the edges for the given x and the histogramType
+ """
+ # for now we consider that the spaces between xs are constant
+ edges = x.copy()
+ if histogramType == "left":
+ width = 1
+ if len(x) > 1:
+ width = x[1] - x[0]
+ edges = numpy.append(x[0] - width, edges)
+ if histogramType == "center":
+ edges = _computeEdges(edges, "right")
+ widths = (edges[1:] - edges[0:-1]) / 2.0
+ widths = numpy.append(widths, widths[-1])
+ edges = edges - widths
+ if histogramType == "right":
+ width = 1
+ if len(x) > 1:
+ width = x[-1] - x[-2]
+ edges = numpy.append(edges, x[-1] + width)
+
+ return edges
+
+
+def _getHistogramCurve(histogram, edges):
+ """Returns the x and y value of a curve corresponding to the histogram
+
+ :param numpy.ndarray histogram: The values of the histogram
+ :param numpy.ndarray edges: The bin edges of the histogram
+ :return: a tuple(x, y) which contains the value of the curve to use
+ to display the histogram
+ """
+ assert len(histogram) + 1 == len(edges)
+ x = numpy.empty(len(histogram) * 2, dtype=edges.dtype)
+ y = numpy.empty(len(histogram) * 2, dtype=histogram.dtype)
+ # Make a curve with stairs
+ x[:-1:2] = edges[:-1]
+ x[1::2] = edges[1:]
+ y[:-1:2] = histogram
+ y[1::2] = histogram
+
+ return x, y
+
+
+# TODO: Yerror, test log scale
+class Histogram(
+ DataItem,
+ AlphaMixIn,
+ ColorMixIn,
+ FillMixIn,
+ LineMixIn,
+ LineGapColorMixIn,
+ YAxisMixIn,
+ BaselineMixIn,
+):
+ """Description of an histogram"""
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for histograms"""
+
+ _DEFAULT_SELECTABLE = False
+ """Default selectable state for histograms"""
+
+ _DEFAULT_LINEWIDTH = 1.0
+ """Default line width of the histogram"""
+
+ _DEFAULT_LINESTYLE = "-"
+ """Default line style of the histogram"""
+
+ _DEFAULT_BASELINE = None
+
+ def __init__(self):
+ DataItem.__init__(self)
+ AlphaMixIn.__init__(self)
+ BaselineMixIn.__init__(self)
+ ColorMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ LineMixIn.__init__(self)
+ LineGapColorMixIn.__init__(self)
+ YAxisMixIn.__init__(self)
+
+ self._histogram = ()
+ self._edges = ()
+ self._setBaseline(Histogram._DEFAULT_BASELINE)
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ values, edges, baseline = self.getData(copy=False)
+
+ if values.size == 0:
+ return None # No data to display, do not add renderer
+
+ if values.size == 0:
+ return None # No data to display, do not add renderer to backend
+
+ x, y = _getHistogramCurve(values, edges)
+
+ # Filter-out values <= 0
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ if xPositive or yPositive:
+ clipped = numpy.logical_or(
+ (x <= 0) if xPositive else False, (y <= 0) if yPositive else False
+ )
+ # Make a copy and replace negative points by NaN
+ x = numpy.array(x, dtype=numpy.float64)
+ y = numpy.array(y, dtype=numpy.float64)
+ x[clipped] = numpy.nan
+ y[clipped] = numpy.nan
+
+ return backend.addCurve(
+ x,
+ y,
+ color=self.getColor(),
+ gapcolor=self.getLineGapColor(),
+ symbol="",
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ yaxis=self.getYAxis(),
+ xerror=None,
+ yerror=None,
+ fill=self.isFill(),
+ alpha=self.getAlpha(),
+ baseline=baseline,
+ symbolsize=1,
+ )
+
+ def _getBounds(self):
+ values, edges, baseline = self.getData(copy=False)
+
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ if xPositive or yPositive:
+ values = numpy.array(values, copy=True, dtype=numpy.float64)
+
+ if xPositive:
+ # Replace edges <= 0 by NaN and corresponding values by NaN
+ clipped_edges = edges <= 0
+ edges = numpy.array(edges, copy=True, dtype=numpy.float64)
+ edges[clipped_edges] = numpy.nan
+ clipped_values = numpy.logical_or(clipped_edges[:-1], clipped_edges[1:])
+ else:
+ clipped_values = numpy.zeros_like(values, dtype=bool)
+
+ if yPositive:
+ # Replace values <= 0 by NaN, do not modify edges
+ clipped_values = numpy.logical_or(clipped_values, values <= 0)
+
+ values[clipped_values] = numpy.nan
+
+ if yPositive:
+ return (
+ numpy.nanmin(edges),
+ numpy.nanmax(edges),
+ numpy.nanmin(values),
+ numpy.nanmax(values),
+ )
+
+ else: # No log scale on y axis, include 0 in bounds
+ if numpy.all(numpy.isnan(values)):
+ return None
+ return (
+ numpy.nanmin(edges),
+ numpy.nanmax(edges),
+ min(0, numpy.nanmin(values)),
+ max(0, numpy.nanmax(values)),
+ )
+
+ def __pickFilledHistogram(
+ self, x: float, y: float
+ ) -> typing.Optional[PickingResult]:
+ """Picking implementation for filled histogram
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ """
+ if not self.isFill():
+ return None
+
+ plot = self.getPlot()
+ if plot is None:
+ return None
+
+ xData, yData = plot.pixelToData(x, y, axis=self.getYAxis())
+ xmin, xmax, ymin, ymax = self.getBounds()
+ if not xmin < xData < xmax or not ymin < yData < ymax:
+ return None # Outside bounding box
+
+ # Check x
+ edges = self.getBinEdgesData(copy=False)
+ index = numpy.searchsorted(edges, (xData,), side="left")[0] - 1
+ # Safe indexing in histogram values
+ index = numpy.clip(index, 0, len(edges) - 2)
+
+ # Check y
+ baseline = self.getBaseline(copy=False)
+ if baseline is None:
+ baseline = 0 # Default value
+
+ value = self.getValueData(copy=False)[index]
+ if (baseline <= value and baseline <= yData <= value) or (
+ value < baseline and value <= yData <= baseline
+ ):
+ return PickingResult(self, numpy.array([index]))
+ else:
+ return None
+
+ @docstring(DataItem)
+ def pick(self, x, y):
+ if self.isFill():
+ return self.__pickFilledHistogram(x, y)
+ else:
+ result = super().pick(x, y)
+ if result is None:
+ return None
+ else: # Convert from curve indices to histogram indices
+ return PickingResult(self, numpy.unique(result.getIndices() // 2))
+
+ def getValueData(self, copy=True):
+ """The values of the histogram
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: The values of the histogram
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._histogram, copy=copy)
+
+ def getBinEdgesData(self, copy=True):
+ """The bin edges of the histogram (number of histogram values + 1)
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: The bin edges of the histogram
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._edges, copy=copy)
+
+ def getData(self, copy=True):
+ """Return the histogram values, bin edges and baseline
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: (N histogram value, N+1 bin edges)
+ :rtype: 2-tuple of numpy.nadarray
+ """
+ return (
+ self.getValueData(copy),
+ self.getBinEdgesData(copy),
+ self.getBaseline(copy),
+ )
+
+ def setData(self, histogram, edges, align="center", baseline=None, copy=True):
+ """Set the histogram values and bin edges.
+
+ :param numpy.ndarray histogram: The values of the histogram.
+ :param numpy.ndarray edges:
+ The bin edges of the histogram.
+ If histogram and edges have the same length, the bin edges
+ are computed according to the align parameter.
+ :param str align:
+ In case histogram values and edges have the same length N,
+ the N+1 bin edges are computed according to the alignment in:
+ 'center' (default), 'left', 'right'.
+ :param baseline: histogram baseline
+ :type baseline: Union[None,float,numpy.ndarray]
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ histogram = numpy.array(histogram, copy=copy)
+ edges = numpy.array(edges, copy=copy)
+
+ assert histogram.ndim == 1
+ assert edges.ndim == 1
+ assert edges.size in (histogram.size, histogram.size + 1)
+ assert align in ("center", "left", "right")
+
+ if histogram.size == 0: # No data
+ self._histogram = ()
+ self._edges = ()
+ else:
+ if edges.size == histogram.size: # Compute true bin edges
+ edges = _computeEdges(edges, align)
+
+ # Check that bin edges are monotonic
+ edgesDiff = numpy.diff(edges)
+ edgesDiff = edgesDiff[numpy.logical_not(numpy.isnan(edgesDiff))]
+ assert numpy.all(edgesDiff >= 0) or numpy.all(edgesDiff <= 0)
+ # manage baseline
+ if isinstance(baseline, abc.Iterable):
+ baseline = numpy.array(baseline)
+ if baseline.size == histogram.size:
+ new_baseline = numpy.empty(baseline.shape[0] * 2)
+ for i_value, value in enumerate(baseline):
+ new_baseline[i_value * 2 : i_value * 2 + 2] = value
+ baseline = new_baseline
+ self._histogram = histogram
+ self._edges = edges
+ self._alignement = align
+ self._setBaseline(baseline)
+
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ def getAlignment(self):
+ """
+
+ :return: histogram alignement. Value in ('center', 'left', 'right').
+ """
+ return self._alignement
+
+ def _revertComputeEdges(self, x, histogramType):
+ """Compute the edges from a set of xs and a rule to generate the edges
+
+ :param x: the x value of the curve to transform into an histogram
+ :param histogramType: the type of histogram we wan't to generate.
+ This define the way to center the histogram values compared to the
+ curve value. Possible values can be::
+
+ - 'left'
+ - 'right'
+ - 'center'
+
+ :return: the edges for the given x and the histogramType
+ """
+ # for now we consider that the spaces between xs are constant
+ edges = x.copy()
+ if histogramType == "left":
+ return edges[1:]
+ if histogramType == "center":
+ edges = (edges[1:] + edges[:-1]) / 2.0
+ if histogramType == "right":
+ width = 1
+ if len(x) > 1:
+ width = x[-1] + x[-2]
+ edges = edges[:-1]
+ return edges
diff --git a/src/silx/gui/plot/items/image.py b/src/silx/gui/plot/items/image.py
new file mode 100644
index 0000000..18310d9
--- /dev/null
+++ b/src/silx/gui/plot/items/image.py
@@ -0,0 +1,673 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`ImageData` and :class:`ImageRgba` items
+of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+from collections import abc
+import logging
+
+import numpy
+
+from ....utils.proxy import docstring
+from ....utils.deprecation import deprecated_warning
+from .core import (
+ DataItem,
+ LabelsMixIn,
+ DraggableMixIn,
+ ColormapMixIn,
+ AlphaMixIn,
+ ItemChangedType,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+def _convertImageToRgba32(image, copy=True):
+ """Convert an RGB or RGBA image to RGBA32.
+
+ It converts from floats in [0, 1], bool, integer and uint in [0, 255]
+
+ If the input image is already an RGBA32 image,
+ the returned image shares the same data.
+
+ :param image: Image to convert to
+ :type image: numpy.ndarray with 3 dimensions: height, width, color channels
+ :param bool copy: True (Default) to get a copy, False, avoid copy if possible
+ :return: The image converted to RGBA32 with dimension: (height, width, 4)
+ :rtype: numpy.ndarray of uint8
+ """
+ assert image.ndim == 3
+ assert image.shape[-1] in (3, 4)
+
+ # Convert type to uint8
+ if image.dtype.name != "uint8":
+ if image.dtype.kind == "f": # Float in [0, 1]
+ image = (numpy.clip(image, 0.0, 1.0) * 255).astype(numpy.uint8)
+ elif image.dtype.kind == "b": # boolean
+ image = image.astype(numpy.uint8) * 255
+ elif image.dtype.kind in ("i", "u"): # int, uint
+ image = numpy.clip(image, 0, 255).astype(numpy.uint8)
+ else:
+ raise ValueError("Unsupported image dtype: %s", image.dtype.name)
+ copy = False # A copy as already been done, avoid next one
+
+ # Convert RGB to RGBA
+ if image.shape[-1] == 3:
+ new_image = numpy.empty((image.shape[0], image.shape[1], 4), dtype=numpy.uint8)
+ new_image[:, :, :3] = image
+ new_image[:, :, 3] = 255
+ return new_image # This is a copy anyway
+ else:
+ return numpy.array(image, copy=copy)
+
+
+class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
+ """Description of an image
+
+ :param numpy.ndarray data: Initial image data
+ """
+
+ def __init__(self, data=None, mask=None):
+ DataItem.__init__(self)
+ LabelsMixIn.__init__(self)
+ DraggableMixIn.__init__(self)
+ AlphaMixIn.__init__(self)
+ if data is None:
+ data = numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+ self._data = data
+ self._mask = mask
+ self.__valueDataCache = None # Store default data
+ self._origin = (0.0, 0.0)
+ self._scale = (1.0, 1.0)
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ deprecated_warning(
+ "Attributes",
+ "__getitem__",
+ since_version="2.0.0",
+ replacement="Use ImageBase methods",
+ )
+ if isinstance(item, slice):
+ return [self[index] for index in range(*item.indices(5))]
+ elif item == 0:
+ return self.getData(copy=False)
+ elif item == 1:
+ return self.getName()
+ elif item == 2:
+ info = self.getInfo(copy=False)
+ return {} if info is None else info
+ elif item == 3:
+ return None
+ elif item == 4:
+ params = {
+ "info": self.getInfo(),
+ "origin": self.getOrigin(),
+ "scale": self.getScale(),
+ "z": self.getZValue(),
+ "selectable": self.isSelectable(),
+ "draggable": self.isDraggable(),
+ "colormap": None,
+ "xlabel": self.getXLabel(),
+ "ylabel": self.getYLabel(),
+ }
+ return params
+ else:
+ raise IndexError("Index out of range: %s" % str(item))
+
+ def _isPlotLinear(self, plot):
+ """Return True if plot only uses linear scale for both of x and y
+ axes."""
+ linear = plot.getXAxis().LINEAR
+ if plot.getXAxis().getScale() != linear:
+ return False
+ if plot.getYAxis().getScale() != linear:
+ return False
+ return True
+
+ def _getBounds(self):
+ if self.getData(copy=False).size == 0: # Empty data
+ return None
+
+ height, width = self.getData(copy=False).shape[:2]
+ origin = self.getOrigin()
+ scale = self.getScale()
+ # Taking care of scale might be < 0
+ xmin, xmax = origin[0], origin[0] + width * scale[0]
+ if xmin > xmax:
+ xmin, xmax = xmax, xmin
+ # Taking care of scale might be < 0
+ ymin, ymax = origin[1], origin[1] + height * scale[1]
+ if ymin > ymax:
+ ymin, ymax = ymax, ymin
+
+ plot = self.getPlot()
+ if plot is not None and not self._isPlotLinear(plot):
+ return None
+ else:
+ return xmin, xmax, ymin, ymax
+
+ @docstring(DraggableMixIn)
+ def drag(self, from_, to):
+ origin = self.getOrigin()
+ self.setOrigin((origin[0] + to[0] - from_[0], origin[1] + to[1] - from_[1]))
+
+ def getData(self, copy=True):
+ """Returns the image data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._data, copy=copy)
+
+ def setData(self, data):
+ """Set the image data
+
+ :param numpy.ndarray data:
+ """
+ previousShape = self._data.shape
+ self._data = data
+ self._valueDataChanged()
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ if (
+ self.getMaskData(copy=False) is not None
+ and previousShape != self._data.shape
+ ):
+ # Data shape changed, so mask shape changes.
+ # Send event, mask is lazily updated in getMaskData
+ self._updated(ItemChangedType.MASK)
+
+ def getMaskData(self, copy=True):
+ """Returns the mask data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self._mask is None:
+ return None
+
+ # Update mask if it does not match data shape
+ shape = self.getData(copy=False).shape[:2]
+ if self._mask.shape != shape:
+ # Clip/extend mask to match data
+ newMask = numpy.zeros(shape, dtype=self._mask.dtype)
+ newMask[: self._mask.shape[0], : self._mask.shape[1]] = self._mask[
+ : shape[0], : shape[1]
+ ]
+ self._mask = newMask
+
+ return numpy.array(self._mask, copy=copy)
+
+ def setMaskData(self, mask, copy=True):
+ """Set the image data
+
+ :param numpy.ndarray data:
+ :param bool copy: True (Default) to make a copy,
+ False to use as is (do not modify!)
+ """
+ if mask is not None:
+ mask = numpy.array(mask, copy=copy)
+
+ shape = self.getData(copy=False).shape[:2]
+ if mask.shape != shape:
+ _logger.warning(
+ "Inconsistent shape between mask and data %s, %s", mask.shape, shape
+ )
+ # Clip/extent is done lazily in getMaskData
+ elif self._mask is None:
+ return # No update
+
+ self._mask = mask
+ self._valueDataChanged()
+ self._updated(ItemChangedType.MASK)
+
+ def _valueDataChanged(self):
+ """Clear cache of default data array"""
+ self.__valueDataCache = None
+
+ def _getValueData(self, copy=True):
+ """Return data used by :meth:`getValueData`
+
+ :param bool copy:
+ :rtype: numpy.ndarray
+ """
+ return self.getData(copy=copy)
+
+ def getValueData(self, copy=True):
+ """Return data (converted to int or float) with mask applied.
+
+ Masked values are set to Not-A-Number.
+ It returns a 2D array of values (int or float).
+
+ :param bool copy:
+ :rtype: numpy.ndarray
+ """
+ if self.__valueDataCache is None:
+ data = self._getValueData(copy=False)
+ mask = self.getMaskData(copy=False)
+ if mask is not None:
+ if numpy.issubdtype(data.dtype, numpy.floating):
+ dtype = data.dtype
+ else:
+ dtype = numpy.float64
+ data = numpy.array(data, dtype=dtype, copy=True)
+ data[mask != 0] = numpy.NaN
+ self.__valueDataCache = data
+ return numpy.array(self.__valueDataCache, copy=copy)
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ raise NotImplementedError("This MUST be implemented in sub-class")
+
+ def getOrigin(self):
+ """Returns the offset from origin at which to display the image.
+
+ :rtype: 2-tuple of float
+ """
+ return self._origin
+
+ def setOrigin(self, origin):
+ """Set the offset from origin at which to display the image.
+
+ :param origin: (ox, oy) Offset from origin
+ :type origin: float or 2-tuple of float
+ """
+ if isinstance(origin, abc.Sequence):
+ origin = float(origin[0]), float(origin[1])
+ else: # single value origin
+ origin = float(origin), float(origin)
+ if origin != self._origin:
+ self._origin = origin
+ self._boundsChanged()
+ self._updated(ItemChangedType.POSITION)
+
+ def getScale(self):
+ """Returns the scale of the image in data coordinates.
+
+ :rtype: 2-tuple of float
+ """
+ return self._scale
+
+ def setScale(self, scale):
+ """Set the scale of the image
+
+ :param scale: (sx, sy) Scale of the image
+ :type scale: float or 2-tuple of float
+ """
+ if isinstance(scale, abc.Sequence):
+ scale = float(scale[0]), float(scale[1])
+ else: # single value scale
+ scale = float(scale), float(scale)
+
+ if scale != self._scale:
+ self._scale = scale
+ self._boundsChanged()
+ self._updated(ItemChangedType.SCALE)
+
+
+class ImageDataBase(ImageBase, ColormapMixIn):
+ """Base class for colormapped 2D data image"""
+
+ def __init__(self):
+ ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.float32))
+ ColormapMixIn.__init__(self)
+
+ def _getColormapForRendering(self):
+ colormap = self.getColormap()
+ if colormap.isAutoscale():
+ # NOTE: Make sure getColormapRange comes from the original object
+ vrange = colormap.getColormapRange(self)
+ # Avoid backend to compute autoscale: use item cache
+ colormap = colormap.copy()
+ colormap.setVRange(*vrange)
+ return colormap
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: Array of uint8 of shape (height, width, 4)
+ :rtype: numpy.ndarray
+ """
+ return self.getColormap().applyToData(self)
+
+ def setData(self, data, copy=True):
+ """Set the image data
+
+ :param numpy.ndarray data: Data array with 2 dimensions (h, w)
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+ if data.dtype.kind == "b":
+ _logger.warning("Converting boolean image to int8 to plot it.")
+ data = numpy.array(data, copy=False, dtype=numpy.int8)
+ elif numpy.iscomplexobj(data):
+ _logger.warning("Converting complex image to absolute value to plot it.")
+ data = numpy.absolute(data)
+ super().setData(data)
+
+ def _updated(self, event=None, checkVisibility=True):
+ # Synchronizes colormapped data if changed
+ if event in (ItemChangedType.DATA, ItemChangedType.MASK):
+ self._setColormappedData(self.getValueData(copy=False), copy=False)
+ super()._updated(event=event, checkVisibility=checkVisibility)
+
+
+class ImageData(ImageDataBase):
+ """Description of a data image with a colormap"""
+
+ def __init__(self):
+ ImageDataBase.__init__(self)
+ self._alternativeImage = None
+ self.__alpha = None
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ if (
+ self.getAlternativeImageData(copy=False) is not None
+ or self.getAlphaData(copy=False) is not None
+ ):
+ dataToUse = self.getRgbaImageData(copy=False)
+ else:
+ dataToUse = self.getData(copy=False)
+
+ if dataToUse.size == 0:
+ return None # No data to display
+
+ return backend.addImage(
+ dataToUse,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=self._getColormapForRendering(),
+ alpha=self.getAlpha(),
+ )
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ deprecated_warning(
+ "Attributes",
+ "__getitem__",
+ since_version="2.0.0",
+ replacement="Use ImageData methods",
+ )
+ if item == 3:
+ return self.getAlternativeImageData(copy=False)
+
+ params = ImageBase.__getitem__(self, item)
+ if item == 4:
+ params["colormap"] = self.getColormap()
+
+ return params
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: Array of uint8 of shape (height, width, 4)
+ :rtype: numpy.ndarray
+ """
+ alternative = self.getAlternativeImageData(copy=False)
+ if alternative is not None:
+ return _convertImageToRgba32(alternative, copy=copy)
+ else:
+ image = super().getRgbaImageData(copy=copy)
+ alphaImage = self.getAlphaData(copy=False)
+ if alphaImage is not None:
+ # Apply transparency
+ image[:, :, 3] = image[:, :, 3] * alphaImage
+ return image
+
+ def getAlternativeImageData(self, copy=True):
+ """Get the optional RGBA image that is displayed instead of the data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self._alternativeImage is None:
+ return None
+ else:
+ return numpy.array(self._alternativeImage, copy=copy)
+
+ def getAlphaData(self, copy=True):
+ """Get the optional transparency image applied on the data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self.__alpha is None:
+ return None
+ else:
+ return numpy.array(self.__alpha, copy=copy)
+
+ def setData(self, data, alternative=None, alpha=None, copy=True):
+ """Set the image data and optionally an alternative RGB(A) representation
+
+ :param numpy.ndarray data: Data array with 2 dimensions (h, w)
+ :param alternative: RGB(A) image to display instead of data,
+ shape: (h, w, 3 or 4)
+ :type alternative: Union[None,numpy.ndarray]
+ :param alpha: An array of transparency value in [0, 1] to use for
+ display with shape: (h, w)
+ :type alpha: Union[None,numpy.ndarray]
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+
+ if alternative is not None:
+ alternative = numpy.array(alternative, copy=copy)
+ assert alternative.ndim == 3
+ assert alternative.shape[2] in (3, 4)
+ assert alternative.shape[:2] == data.shape[:2]
+ self._alternativeImage = alternative
+
+ if alpha is not None:
+ alpha = numpy.array(alpha, copy=copy)
+ assert alpha.shape == data.shape
+ if alpha.dtype.kind != "f":
+ alpha = alpha.astype(numpy.float32)
+ if numpy.any(numpy.logical_or(alpha < 0.0, alpha > 1.0)):
+ alpha = numpy.clip(alpha, 0.0, 1.0)
+ self.__alpha = alpha
+
+ super().setData(data)
+
+
+class ImageRgba(ImageBase):
+ """Description of an RGB(A) image"""
+
+ def __init__(self):
+ ImageBase.__init__(self, numpy.zeros((0, 0, 4), dtype=numpy.uint8))
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ data = self.getData(copy=False)
+
+ if data.size == 0:
+ return None # No data to display
+
+ return backend.addImage(
+ data,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=None,
+ alpha=self.getAlpha(),
+ )
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ return _convertImageToRgba32(self.getData(copy=False), copy=copy)
+
+ def setData(self, data, copy=True):
+ """Set the image data
+
+ :param data: RGB(A) image data to set
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ if data.ndim != 3:
+ raise ValueError(
+ f"RGB(A) image is expected to be a 3D dataset. Got {data.ndim} dimensions"
+ )
+ if data.shape[-1] not in (3, 4):
+ raise ValueError(
+ f"RGB(A) image is expected to have 3 or 4 elements as last dimension. Got {data.shape[-1]}"
+ )
+ super().setData(data)
+
+ def _getValueData(self, copy=True):
+ """Compute the intensity of the RGBA image as default data.
+
+ Conversion: https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion
+
+ :param bool copy:
+ """
+ rgba = self.getRgbaImageData(copy=False).astype(numpy.float32)
+ intensity = (
+ rgba[:, :, 0] * 0.299 + rgba[:, :, 1] * 0.587 + rgba[:, :, 2] * 0.114
+ )
+ intensity *= rgba[:, :, 3] / 255.0
+ return intensity
+
+
+class MaskImageData(ImageData):
+ """Description of an image used as a mask.
+
+ This class is used to flag mask items. This information is used to improve
+ internal silx widgets.
+ """
+
+ pass
+
+
+class ImageStack(ImageData):
+ """Item to store a stack of images and to show it in the plot as one
+ of the images of the stack.
+
+ The stack is a 3D array ordered this way: `frame id, y, x`.
+ So the first image of the stack can be reached this way: `stack[0, :, :]`
+ """
+
+ def __init__(self):
+ ImageData.__init__(self)
+ self.__stack = None
+ """A 3D numpy array (or a mimic one, see ListOfImages)"""
+ self.__stackPosition = None
+ """Displayed position in the cube"""
+
+ def setStackData(self, stack, position=None, copy=True):
+ """Set the stack data
+
+ :param stack: A 3D numpy array like
+ :param int position: The position of the displayed image in the stack
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if self.__stack is stack:
+ return
+ if copy:
+ stack = numpy.array(stack)
+ assert stack.ndim == 3
+ self.__stack = stack
+ if position is not None:
+ self.__stackPosition = position
+ if self.__stackPosition is None:
+ self.__stackPosition = 0
+ self.__updateDisplayedData()
+
+ def getStackData(self, copy=True):
+ """Get the stored stack array.
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: A 3D numpy array, or numpy array like
+ """
+ if copy:
+ return numpy.array(self.__stack)
+ else:
+ return self.__stack
+
+ def setStackPosition(self, pos):
+ """Set the displayed position on the stack.
+
+ This function will clamp the stack position according to
+ the real size of the first axis of the stack.
+
+ :param int pos: A position on the first axis of the stack.
+ """
+ if self.__stackPosition == pos:
+ return
+ self.__stackPosition = pos
+ self.__updateDisplayedData()
+
+ def getStackPosition(self):
+ """Get the displayed position of the stack.
+
+ :rtype: int
+ """
+ return self.__stackPosition
+
+ def __updateDisplayedData(self):
+ """Update the displayed frame whenever the stack or the stack
+ position are updated."""
+ if self.__stack is None or self.__stackPosition is None:
+ empty = numpy.array([]).reshape(0, 0)
+ self.setData(empty, copy=False)
+ return
+ size = len(self.__stack)
+ self.__stackPosition = numpy.clip(self.__stackPosition, 0, size)
+ self.setData(self.__stack[self.__stackPosition], copy=False)
diff --git a/src/silx/gui/plot/items/image_aggregated.py b/src/silx/gui/plot/items/image_aggregated.py
new file mode 100644
index 0000000..b35e00a
--- /dev/null
+++ b/src/silx/gui/plot/items/image_aggregated.py
@@ -0,0 +1,230 @@
+# /*##########################################################################
+#
+# Copyright (c) 2021-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`ImageDataAggregated` items of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "07/07/2021"
+
+import enum
+import logging
+from typing import Tuple, Union
+import warnings
+
+import numpy
+
+from ....utils.enum import Enum as _Enum
+from ....utils.proxy import docstring
+from .axis import Axis
+from .core import ItemChangedType
+from .image import ImageDataBase
+from ._pick import PickingResult
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ImageDataAggregated(ImageDataBase):
+ """Item displaying an image as a density map."""
+
+ @enum.unique
+ class Aggregation(_Enum):
+ NONE = "none"
+ "Do not aggregate data, display as is (default)"
+
+ MAX = "max"
+ "Aggregates elements with max (ignore NaNs)"
+
+ MEAN = "mean"
+ "Aggregates elements with mean (ignore NaNs)"
+
+ MIN = "min"
+ "Aggregates elements with min (ignore NaNs)"
+
+ def __init__(self):
+ super().__init__()
+ self.__cacheLODData = {}
+ self.__currentLOD = 0, 0
+ self.__aggregationMode = self.Aggregation.NONE
+
+ def setAggregationMode(self, mode: Union[str, Aggregation]):
+ """Set the aggregation method used to reduce the data to screen resolution.
+
+ :param Aggregation mode: The aggregation method
+ """
+ aggregationMode = self.Aggregation.from_value(mode)
+ if aggregationMode != self.__aggregationMode:
+ self.__aggregationMode = aggregationMode
+ self.__cacheLODData = {} # Clear cache
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+
+ def getAggregationMode(self) -> Aggregation:
+ """Returns the currently used aggregation method."""
+ return self.__aggregationMode
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ data = self.getData(copy=False)
+ if data.size == 0:
+ return None # No data to display
+
+ aggregationMode = self.getAggregationMode()
+ if aggregationMode == self.Aggregation.NONE: # Pass data as it is
+ displayedData = data
+ scale = self.getScale()
+
+ else: # Aggregate data according to level of details
+ if aggregationMode == self.Aggregation.MAX:
+ aggregator = numpy.nanmax
+ elif aggregationMode == self.Aggregation.MEAN:
+ aggregator = numpy.nanmean
+ elif aggregationMode == self.Aggregation.MIN:
+ aggregator = numpy.nanmin
+ else:
+ _logger.error("Unsupported aggregation mode")
+ return None
+
+ lodx, lody = self._getLevelOfDetails()
+
+ if (lodx, lody) not in self.__cacheLODData:
+ height, width = data.shape
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ self.__cacheLODData[(lodx, lody)] = aggregator(
+ data[
+ : (height // lody) * lody, : (width // lodx) * lodx
+ ].reshape(height // lody, lody, width // lodx, lodx),
+ axis=(1, 3),
+ )
+
+ self.__currentLOD = lodx, lody
+ displayedData = self.__cacheLODData[self.__currentLOD]
+
+ sx, sy = self.getScale()
+ scale = sx * lodx, sy * lody
+
+ return backend.addImage(
+ displayedData,
+ origin=self.getOrigin(),
+ scale=scale,
+ colormap=self._getColormapForRendering(),
+ alpha=self.getAlpha(),
+ )
+
+ def _getPixelSizeInData(self, axis="left"):
+ """Returns the size of a pixel in plot data coordinates
+
+ :param str axis: Y axis to use in: 'left' (default), 'right'
+ :return:
+ Size (width, height) of a Qt pixel in data coordinates.
+ Size is None if it cannot be computed
+ :rtype: Union[List[float],None]
+ """
+ assert axis in ("left", "right")
+ plot = self.getPlot()
+ if plot is None:
+ return None
+
+ xaxis = plot.getXAxis()
+ yaxis = plot.getYAxis(axis)
+
+ if xaxis.getScale() != Axis.LINEAR or yaxis.getScale() != Axis.LINEAR:
+ raise RuntimeError("Only available with linear axes")
+
+ xmin, xmax = xaxis.getLimits()
+ ymin, ymax = yaxis.getLimits()
+ width, height = plot.getPlotBoundsInPixels()[2:]
+ if width == 0 or height == 0:
+ return None
+ else:
+ return (xmax - xmin) / width, (ymax - ymin) / height
+
+ def _getLevelOfDetails(self) -> Tuple[int, int]:
+ """Return current level of details the image is displayed with."""
+ plot = self.getPlot()
+ if plot is None or not self._isPlotLinear(plot):
+ return 1, 1 # Fallback to bas LOD
+
+ sx, sy = self.getScale()
+ xUnitPerPixel, yUnitPerPixel = self._getPixelSizeInData()
+ lodx = max(1, int(numpy.ceil(xUnitPerPixel / sx)))
+ lody = max(1, int(numpy.ceil(yUnitPerPixel / sy)))
+ return lodx, lody
+
+ @docstring(ImageDataBase)
+ def setData(self, data, copy=True):
+ self.__cacheLODData = {} # Reset cache
+ super().setData(data)
+
+ @docstring(ImageDataBase)
+ def _setPlot(self, plot):
+ """Refresh image when plot limits change"""
+ previousPlot = self.getPlot()
+ if previousPlot is not None:
+ for axis in (previousPlot.getXAxis(), previousPlot.getYAxis()):
+ axis.sigLimitsChanged.disconnect(self.__plotLimitsChanged)
+
+ super()._setPlot(plot)
+
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis()):
+ axis.sigLimitsChanged.connect(self.__plotLimitsChanged)
+
+ def __plotLimitsChanged(self):
+ """Trigger update if level of details has changed"""
+ if (
+ self.getAggregationMode() != self.Aggregation.NONE
+ and self.__currentLOD != self._getLevelOfDetails()
+ ):
+ self._updated()
+
+ @docstring(ImageDataBase)
+ def pick(self, x, y):
+ result = super().pick(x, y)
+ if result is None:
+ return None
+
+ # Compute indices in initial data
+ plot = self.getPlot()
+ if plot is None:
+ return None
+ dataPos = plot.pixelToData(x, y, axis="left", check=True)
+ if dataPos is None:
+ return None # Outside plot area
+
+ ox, oy = self.getOrigin()
+ sx, sy = self.getScale()
+ col = int((dataPos[0] - ox) / sx)
+ row = int((dataPos[1] - oy) / sy)
+ height, width = self.getData(copy=False).shape[:2]
+ if 0 <= col < width and 0 <= row < height:
+ return PickingResult(self, ((row,), (col,)))
+ return None
diff --git a/src/silx/gui/plot/items/marker.py b/src/silx/gui/plot/items/marker.py
new file mode 100755
index 0000000..b3da451
--- /dev/null
+++ b/src/silx/gui/plot/items/marker.py
@@ -0,0 +1,335 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides markers item of the :class:`Plot`.
+"""
+from __future__ import annotations
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2017"
+
+
+import logging
+import numpy
+
+from ....utils.proxy import docstring
+from .core import (
+ Item,
+ DraggableMixIn,
+ ColorMixIn,
+ LineMixIn,
+ SymbolMixIn,
+ ItemChangedType,
+ YAxisMixIn,
+)
+from silx import config
+from silx.gui import qt
+from silx.gui import colors
+
+
+_logger = logging.getLogger(__name__)
+
+
+class MarkerBase(Item, DraggableMixIn, ColorMixIn, YAxisMixIn):
+ """Base class for markers"""
+
+ sigDragStarted = qt.Signal()
+ """Signal emitted when the marker is pressed"""
+ sigDragFinished = qt.Signal()
+ """Signal emitted when the marker is released"""
+
+ _DEFAULT_COLOR = (0.0, 0.0, 0.0, 1.0)
+ """Default color of the markers"""
+
+ def __init__(self):
+ Item.__init__(self)
+ DraggableMixIn.__init__(self)
+ ColorMixIn.__init__(self)
+ YAxisMixIn.__init__(self)
+
+ self._text = ""
+ self._font = None
+ if config.DEFAULT_PLOT_MARKER_TEXT_FONT_SIZE is not None:
+ self._font = qt.QFont(
+ qt.QApplication.instance().font().family(),
+ config.DEFAULT_PLOT_MARKER_TEXT_FONT_SIZE,
+ )
+
+ self._x = None
+ self._y = None
+ self._bgColor: colors.RGBAColorType | None = None
+ self._constraint = self._defaultConstraint
+ self.__isBeingDragged = False
+
+ def _addRendererCall(self, backend, symbol=None, linestyle="-", linewidth=1):
+ """Perform the update of the backend renderer"""
+ return backend.addMarker(
+ x=self.getXPosition(),
+ y=self.getYPosition(),
+ text=self.getText(),
+ color=self.getColor(),
+ symbol=symbol,
+ linestyle=linestyle,
+ linewidth=linewidth,
+ constraint=self.getConstraint(),
+ yaxis=self.getYAxis(),
+ font=self._font, # Do not use getFont to spare creating a new QFont
+ bgcolor=self.getBackgroundColor(),
+ )
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ raise NotImplementedError()
+
+ @docstring(DraggableMixIn)
+ def drag(self, from_, to):
+ self.setPosition(to[0], to[1])
+
+ def isOverlay(self):
+ """Returns True: A marker is always rendered as an overlay.
+
+ :rtype: bool
+ """
+ return True
+
+ def getText(self):
+ """Returns marker text.
+
+ :rtype: str
+ """
+ return self._text
+
+ def setText(self, text):
+ """Set the text of the marker.
+
+ :param str text: The text to use
+ """
+ text = str(text)
+ if text != self._text:
+ self._text = text
+ self._updated(ItemChangedType.TEXT)
+
+ def getFont(self) -> qt.QFont | None:
+ """Returns a copy of the QFont used to render text.
+
+ To modify the text font, use :meth:`setFont`.
+ """
+ return None if self._font is None else qt.QFont(self._font)
+
+ def setFont(self, font: qt.QFont | None):
+ """Set the QFont used to render text, use None for default.
+
+ A copy is stored, so further modification of the provided font are not taken into account.
+ """
+ if font != self._font:
+ self._font = None if font is None else qt.QFont(font)
+ self._updated(ItemChangedType.FONT)
+
+ def getBackgroundColor(self) -> colors.RGBAColorType | None:
+ """Returns the RGBA background color of the item"""
+ return self._bgColor
+
+ def setBackgroundColor(self, color):
+ """Set item text background color
+
+ :param color: color(s) to be used as a str ("#RRGGBB") or (npoints, 4)
+ unsigned byte array or one of the predefined color names
+ defined in colors.py
+ """
+ if color is not None:
+ color = colors.rgba(color)
+ if self._bgColor != color:
+ self._bgColor = color
+ self._updated(ItemChangedType.BACKGROUND_COLOR)
+
+ def getXPosition(self):
+ """Returns the X position of the marker line in data coordinates
+
+ :rtype: float or None
+ """
+ return self._x
+
+ def getYPosition(self):
+ """Returns the Y position of the marker line in data coordinates
+
+ :rtype: float or None
+ """
+ return self._y
+
+ def getPosition(self) -> tuple[float | None, float | None]:
+ """Returns the (x, y) position of the marker in data coordinates
+
+ :rtype: 2-tuple of float or None
+ """
+ return self._x, self._y
+
+ def setPosition(self, x: float, y: float):
+ """Set marker position in data coordinates
+
+ Constraint are applied if any.
+
+ :param float x: X coordinates in data frame
+ :param float y: Y coordinates in data frame
+ """
+ x, y = self.getConstraint()(x, y)
+ x, y = float(x), float(y)
+ if x != self._x or y != self._y:
+ self._x, self._y = x, y
+ self._updated(ItemChangedType.POSITION)
+
+ def getConstraint(self):
+ """Returns the dragging constraint of this item"""
+ return self._constraint
+
+ def _setConstraint(self, constraint): # TODO support update
+ """Set the constraint.
+
+ This is private for now as update is not handled.
+
+ :param callable constraint:
+ :param constraint: A function filtering item displacement by
+ dragging operations or None for no filter.
+ This function is called each time the item is
+ moved.
+ This is only used if isDraggable returns True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ """
+ if constraint is None:
+ constraint = self._defaultConstraint
+ assert callable(constraint)
+ self._constraint = constraint
+
+ @staticmethod
+ def _defaultConstraint(*args):
+ """Default constraint not doing anything"""
+ return args
+
+ def _startDrag(self):
+ self.__isBeingDragged = True
+ self.sigDragStarted.emit()
+
+ def _endDrag(self):
+ self.__isBeingDragged = False
+ self.sigDragFinished.emit()
+
+ def isBeingDragged(self) -> bool:
+ """Returns whether the marker is currently dragged by the user."""
+ return self.__isBeingDragged
+
+
+class Marker(MarkerBase, SymbolMixIn):
+ """Description of a marker"""
+
+ _DEFAULT_SYMBOL = "+"
+ """Default symbol of the marker"""
+
+ def __init__(self):
+ MarkerBase.__init__(self)
+ SymbolMixIn.__init__(self)
+
+ self._x = 0.0
+ self._y = 0.0
+
+ def _addBackendRenderer(self, backend):
+ return self._addRendererCall(backend, symbol=self.getSymbol())
+
+ def _setConstraint(self, constraint):
+ """Set the constraint function of the marker drag.
+
+ It also supports 'horizontal' and 'vertical' str as constraint.
+
+ :param constraint: The constraint of the dragging of this marker
+ :type: constraint: callable or str
+ """
+ if constraint == "horizontal":
+ constraint = self._horizontalConstraint
+ elif constraint == "vertical":
+ constraint = self._verticalConstraint
+
+ super(Marker, self)._setConstraint(constraint)
+
+ def _horizontalConstraint(self, _, y):
+ return self.getXPosition(), y
+
+ def _verticalConstraint(self, x, _):
+ return x, self.getYPosition()
+
+
+class _LineMarker(MarkerBase, LineMixIn):
+ """Base class for line markers"""
+
+ def __init__(self):
+ MarkerBase.__init__(self)
+ LineMixIn.__init__(self)
+
+ def _addBackendRenderer(self, backend):
+ return self._addRendererCall(
+ backend, linestyle=self.getLineStyle(), linewidth=self.getLineWidth()
+ )
+
+
+class XMarker(_LineMarker):
+ """Description of a marker"""
+
+ def __init__(self):
+ _LineMarker.__init__(self)
+ self._x = 0.0
+
+ def setPosition(self, x, y):
+ """Set marker line position in data coordinates
+
+ Constraint are applied if any.
+
+ :param float x: X coordinates in data frame
+ :param float y: Y coordinates in data frame
+ """
+ x, _ = self.getConstraint()(x, y)
+ x = float(x)
+ if x != self._x:
+ self._x = x
+ self._updated(ItemChangedType.POSITION)
+
+
+class YMarker(_LineMarker):
+ """Description of a marker"""
+
+ def __init__(self):
+ _LineMarker.__init__(self)
+ self._y = 0.0
+
+ def setPosition(self, x, y):
+ """Set marker line position in data coordinates
+
+ Constraint are applied if any.
+
+ :param float x: X coordinates in data frame
+ :param float y: Y coordinates in data frame
+ """
+ _, y = self.getConstraint()(x, y)
+ y = float(y)
+ if y != self._y:
+ self._y = y
+ self._updated(ItemChangedType.POSITION)
diff --git a/src/silx/gui/plot/items/roi.py b/src/silx/gui/plot/items/roi.py
new file mode 100644
index 0000000..7390b88
--- /dev/null
+++ b/src/silx/gui/plot/items/roi.py
@@ -0,0 +1,1561 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides ROI item for the :class:`~silx.gui.plot.PlotWidget`.
+
+.. inheritance-diagram::
+ silx.gui.plot.items.roi
+ :parts: 1
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import logging
+import numpy
+from typing import Tuple
+
+from ... import utils
+from .. import items
+from ...colors import rgba
+from silx.image.shapes import Polygon
+from silx.image._boundingbox import _BoundingBox
+from ....utils.proxy import docstring
+from ..utils.intersections import segments_intersection
+from ._roi_base import _RegionOfInterestBase
+
+# He following imports have to be exposed by this module
+from ._roi_base import RegionOfInterest
+from ._roi_base import HandleBasedROI
+from ._arc_roi import ArcROI # noqa
+from ._band_roi import BandROI # noqa
+from ._roi_base import InteractionModeMixIn # noqa
+from ._roi_base import RoiInteractionMode # noqa
+
+
+logger = logging.getLogger(__name__)
+
+
+class PointROI(RegionOfInterest, items.SymbolMixIn):
+ """A ROI identifying a point in a 2D plot."""
+
+ ICON = "add-shape-point"
+ NAME = "point markers"
+ SHORT_NAME = "point"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "point"
+ """Plot shape which is used for the first interaction"""
+
+ _DEFAULT_SYMBOL = "+"
+ """Default symbol of the PointROI
+
+ It overwrite the `SymbolMixIn` class attribte.
+ """
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ items.SymbolMixIn.__init__(self)
+ self._marker = items.Marker()
+ self._marker.sigItemChanged.connect(self._pointPositionChanged)
+ self._marker.setSymbol(self._DEFAULT_SYMBOL)
+ self._marker.sigDragStarted.connect(self._editingStarted)
+ self._marker.sigDragFinished.connect(self._editingFinished)
+ self.addItem(self._marker)
+
+ def setFirstShapePoints(self, points):
+ self.setPosition(points[0])
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.EDITABLE:
+ self._marker._setDraggable(self.isEditable())
+ elif event in [items.ItemChangedType.VISIBLE, items.ItemChangedType.SELECTABLE]:
+ self._updateItemProperty(event, self, self._marker)
+ super(PointROI, self)._updated(event, checkVisibility)
+
+ def _updateText(self, text: str):
+ self._marker.setText(text)
+
+ def _updatedStyle(self, event, style):
+ self._marker.setColor(style.getColor())
+
+ def getPosition(self) -> Tuple[float, float]:
+ """Returns the position of this ROI"""
+ return self._marker.getPosition()
+
+ def setPosition(self, pos):
+ """Set the position of this ROI
+
+ :param pos: 2d-coordinate of this point
+ """
+ self._marker.setPosition(*pos)
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ roiPos = self.getPosition()
+ return position[0] == roiPos[0] and position[1] == roiPos[1]
+
+ def _pointPositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ self.sigRegionChanged.emit()
+
+ def __str__(self):
+ params = "%f %f" % self.getPosition()
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class CrossROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a point in a 2D plot and displayed as a cross"""
+
+ ICON = "add-shape-cross"
+ NAME = "cross marker"
+ SHORT_NAME = "cross"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "point"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._handle = self.addHandle()
+ self._handle.sigItemChanged.connect(self._handlePositionChanged)
+ self._handleLabel = self.addLabelHandle()
+ self._vmarker = self.addUserHandle(items.YMarker())
+ self._vmarker._setSelectable(False)
+ self._vmarker._setDraggable(False)
+ self._vmarker.setPosition(*self.getPosition())
+ self._hmarker = self.addUserHandle(items.XMarker())
+ self._hmarker._setSelectable(False)
+ self._hmarker._setDraggable(False)
+ self._hmarker.setPosition(*self.getPosition())
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event in [items.ItemChangedType.VISIBLE]:
+ markers = (self._vmarker, self._hmarker)
+ self._updateItemProperty(event, self, markers)
+ super(CrossROI, self)._updated(event, checkVisibility)
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def _updatedStyle(self, event, style):
+ super(CrossROI, self)._updatedStyle(event, style)
+ for marker in [self._vmarker, self._hmarker]:
+ marker.setColor(style.getColor())
+ marker.setLineStyle(style.getLineStyle())
+ marker.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ pos = points[0]
+ self.setPosition(pos)
+
+ def getPosition(self) -> Tuple[float, float]:
+ """Returns the position of this ROI"""
+ return self._handle.getPosition()
+
+ def setPosition(self, pos: Tuple[float, float]):
+ """Set the position of this ROI
+
+ :param pos: 2d-coordinate of this point
+ """
+ self._handle.setPosition(*pos)
+
+ def _handlePositionChanged(self, event):
+ """Handle center marker position updates"""
+ if event is items.ItemChangedType.POSITION:
+ position = self.getPosition()
+ self._handleLabel.setPosition(*position)
+ self._vmarker.setPosition(*position)
+ self._hmarker.setPosition(*position)
+ self.sigRegionChanged.emit()
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ roiPos = self.getPosition()
+ return position[0] == roiPos[0] or position[1] == roiPos[1]
+
+
+class LineROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a line in a 2D plot.
+
+ This ROI provides 1 anchor for each boundary of the line, plus an center
+ in the center to translate the full ROI.
+ """
+
+ ICON = "add-shape-diagonal"
+ NAME = "line ROI"
+ SHORT_NAME = "line"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._handleStart = self.addHandle()
+ self._handleEnd = self.addHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handleLabel = self.addLabelHandle()
+
+ shape = items.Shape("polylines")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setColor(rgba(self.getColor()))
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ self.__shape = shape
+ self.addItem(shape)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ self._updateItemProperty(event, self, self.__shape)
+ super(LineROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style: items.CurveStyle):
+ super(LineROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
+
+ def setFirstShapePoints(self, points):
+ assert len(points) == 2
+ self.setEndPoints(points[0], points[1])
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def setEndPoints(self, startPoint: numpy.ndarray, endPoint: numpy.ndarray):
+ """Set this line location using the ending points
+
+ :param numpy.ndarray startPoint: Staring bounding point of the line
+ :param numpy.ndarray endPoint: Ending bounding point of the line
+ """
+ if not numpy.array_equal((startPoint, endPoint), self.getEndPoints()):
+ self.__updateEndPoints(startPoint, endPoint)
+
+ def __updateEndPoints(self, startPoint: numpy.ndarray, endPoint: numpy.ndarray):
+ """Update marker and shape to match given end points
+
+ :param numpy.ndarray startPoint: Staring bounding point of the line
+ :param numpy.ndarray endPoint: Ending bounding point of the line
+ """
+ startPoint = numpy.array(startPoint)
+ endPoint = numpy.array(endPoint)
+ center = (startPoint + endPoint) * 0.5
+
+ with utils.blockSignals(self._handleStart):
+ self._handleStart.setPosition(startPoint[0], startPoint[1])
+ with utils.blockSignals(self._handleEnd):
+ self._handleEnd.setPosition(endPoint[0], endPoint[1])
+ with utils.blockSignals(self._handleCenter):
+ self._handleCenter.setPosition(center[0], center[1])
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(center[0], center[1])
+
+ line = numpy.array((startPoint, endPoint))
+ self.__shape.setPoints(line)
+ self.sigRegionChanged.emit()
+
+ def getEndPoints(self):
+ """Returns bounding points of this ROI.
+
+ :rtype: Tuple(numpy.ndarray,numpy.ndarray)
+ """
+ startPoint = numpy.array(self._handleStart.getPosition())
+ endPoint = numpy.array(self._handleEnd.getPosition())
+ return (startPoint, endPoint)
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self._handleStart:
+ _start, end = self.getEndPoints()
+ self.__updateEndPoints(current, end)
+ elif handle is self._handleEnd:
+ start, _end = self.getEndPoints()
+ self.__updateEndPoints(start, current)
+ elif handle is self._handleCenter:
+ start, end = self.getEndPoints()
+ delta = current - previous
+ start += delta
+ end += delta
+ self.setEndPoints(start, end)
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ bottom_left = position[0], position[1]
+ bottom_right = position[0] + 1, position[1]
+ top_left = position[0], position[1] + 1
+ top_right = position[0] + 1, position[1] + 1
+
+ points = self.__shape.getPoints()
+ line_pt1 = points[0]
+ line_pt2 = points[1]
+
+ bb1 = _BoundingBox.from_points(points)
+ if not bb1.contains(position):
+ return False
+
+ return (
+ segments_intersection(
+ seg1_start_pt=line_pt1,
+ seg1_end_pt=line_pt2,
+ seg2_start_pt=bottom_left,
+ seg2_end_pt=bottom_right,
+ )
+ or segments_intersection(
+ seg1_start_pt=line_pt1,
+ seg1_end_pt=line_pt2,
+ seg2_start_pt=bottom_right,
+ seg2_end_pt=top_right,
+ )
+ or segments_intersection(
+ seg1_start_pt=line_pt1,
+ seg1_end_pt=line_pt2,
+ seg2_start_pt=top_right,
+ seg2_end_pt=top_left,
+ )
+ or segments_intersection(
+ seg1_start_pt=line_pt1,
+ seg1_end_pt=line_pt2,
+ seg2_start_pt=top_left,
+ seg2_end_pt=bottom_left,
+ )
+ ) is not None
+
+ def __str__(self):
+ start, end = self.getEndPoints()
+ params = start[0], start[1], end[0], end[1]
+ params = "start: %f %f; end: %f %f" % params
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class HorizontalLineROI(RegionOfInterest, items.LineMixIn):
+ """A ROI identifying an horizontal line in a 2D plot."""
+
+ ICON = "add-shape-horizontal"
+ NAME = "horizontal line ROI"
+ SHORT_NAME = "hline"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "hline"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._marker = items.YMarker()
+ self._marker.sigItemChanged.connect(self._linePositionChanged)
+ self._marker.sigDragStarted.connect(self._editingStarted)
+ self._marker.sigDragFinished.connect(self._editingFinished)
+ self.addItem(self._marker)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.EDITABLE:
+ self._marker._setDraggable(self.isEditable())
+ elif event in [items.ItemChangedType.VISIBLE, items.ItemChangedType.SELECTABLE]:
+ self._updateItemProperty(event, self, self._marker)
+ super(HorizontalLineROI, self)._updated(event, checkVisibility)
+
+ def _updateText(self, text: str):
+ self._marker.setText(text)
+
+ def _updatedStyle(self, event, style):
+ self._marker.setColor(style.getColor())
+ self._marker.setLineStyle(style.getLineStyle())
+ self._marker.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ pos = points[0, 1]
+ if pos == self.getPosition():
+ return
+ self.setPosition(pos)
+
+ def getPosition(self) -> float:
+ """Returns the position of this line if the horizontal axis"""
+ pos = self._marker.getPosition()
+ return pos[1]
+
+ def setPosition(self, pos: float):
+ """Set the position of this ROI
+
+ :param pos: Horizontal position of this line
+ """
+ self._marker.setPosition(0, pos)
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ return position[1] == self.getPosition()
+
+ def _linePositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ self.sigRegionChanged.emit()
+
+ def __str__(self):
+ params = "y: %f" % self.getPosition()
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class VerticalLineROI(RegionOfInterest, items.LineMixIn):
+ """A ROI identifying a vertical line in a 2D plot."""
+
+ ICON = "add-shape-vertical"
+ NAME = "vertical line ROI"
+ SHORT_NAME = "vline"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "vline"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._marker = items.XMarker()
+ self._marker.sigItemChanged.connect(self._linePositionChanged)
+ self._marker.sigDragStarted.connect(self._editingStarted)
+ self._marker.sigDragFinished.connect(self._editingFinished)
+ self.addItem(self._marker)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.EDITABLE:
+ self._marker._setDraggable(self.isEditable())
+ elif event in [items.ItemChangedType.VISIBLE, items.ItemChangedType.SELECTABLE]:
+ self._updateItemProperty(event, self, self._marker)
+ super(VerticalLineROI, self)._updated(event, checkVisibility)
+
+ def _updateText(self, text: str):
+ self._marker.setText(text)
+
+ def _updatedStyle(self, event, style):
+ self._marker.setColor(style.getColor())
+ self._marker.setLineStyle(style.getLineStyle())
+ self._marker.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ pos = points[0, 0]
+ self.setPosition(pos)
+
+ def getPosition(self) -> float:
+ """Returns the position of this line if the horizontal axis"""
+ pos = self._marker.getPosition()
+ return pos[0]
+
+ def setPosition(self, pos: float):
+ """Set the position of this ROI
+
+ :param float pos: Horizontal position of this line
+ """
+ self._marker.setPosition(pos, 0)
+
+ @docstring(RegionOfInterest)
+ def contains(self, position):
+ return position[0] == self.getPosition()
+
+ def _linePositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ self.sigRegionChanged.emit()
+
+ def __str__(self):
+ params = "x: %f" % self.getPosition()
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class RectangleROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a rectangle in a 2D plot.
+
+ This ROI provides 1 anchor for each corner, plus an anchor in the
+ center to translate the full ROI.
+ """
+
+ ICON = "add-shape-rectangle"
+ NAME = "rectangle ROI"
+ SHORT_NAME = "rectangle"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "rectangle"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._handleTopLeft = self.addHandle()
+ self._handleTopRight = self.addHandle()
+ self._handleBottomLeft = self.addHandle()
+ self._handleBottomRight = self.addHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handleLabel = self.addLabelHandle()
+
+ shape = items.Shape("rectangle")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ shape.setColor(rgba(self.getColor()))
+ self.__shape = shape
+ self.addItem(shape)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event in [items.ItemChangedType.VISIBLE]:
+ self._updateItemProperty(event, self, self.__shape)
+ super(RectangleROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(RectangleROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
+
+ def setFirstShapePoints(self, points):
+ assert len(points) == 2
+ self._setBound(points)
+
+ def _setBound(self, points):
+ """Initialize the rectangle from a bunch of points"""
+ top = max(points[:, 1])
+ bottom = min(points[:, 1])
+ left = min(points[:, 0])
+ right = max(points[:, 0])
+ size = right - left, top - bottom
+ self._updateGeometry(origin=(left, bottom), size=size)
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def getCenter(self):
+ """Returns the central point of this rectangle
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ pos = self._handleCenter.getPosition()
+ return numpy.array(pos)
+
+ def getOrigin(self):
+ """Returns the corner point with the smaller coordinates
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ pos = self._handleBottomLeft.getPosition()
+ return numpy.array(pos)
+
+ def getSize(self):
+ """Returns the size of this rectangle
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ vmin = self._handleBottomLeft.getPosition()
+ vmax = self._handleTopRight.getPosition()
+ vmin, vmax = numpy.array(vmin), numpy.array(vmax)
+ return vmax - vmin
+
+ def setOrigin(self, position):
+ """Set the origin position of this ROI
+
+ :param numpy.ndarray position: Location of the smaller corner of the ROI
+ """
+ size = self.getSize()
+ self.setGeometry(origin=position, size=size)
+
+ def setSize(self, size):
+ """Set the size of this ROI
+
+ :param numpy.ndarray size: Size of the center of the ROI
+ """
+ origin = self.getOrigin()
+ self.setGeometry(origin=origin, size=size)
+
+ def setCenter(self, position):
+ """Set the size of this ROI
+
+ :param numpy.ndarray position: Location of the center of the ROI
+ """
+ size = self.getSize()
+ self.setGeometry(center=position, size=size)
+
+ def setGeometry(self, origin=None, size=None, center=None):
+ """Set the geometry of the ROI"""
+ if (
+ (origin is None or numpy.array_equal(origin, self.getOrigin()))
+ and (center is None or numpy.array_equal(center, self.getCenter()))
+ and numpy.array_equal(size, self.getSize())
+ ):
+ return # Nothing has changed
+
+ self._updateGeometry(origin, size, center)
+
+ def _updateGeometry(self, origin=None, size=None, center=None):
+ """Forced update of the geometry of the ROI"""
+ if origin is not None:
+ origin = numpy.array(origin)
+ size = numpy.array(size)
+ points = numpy.array([origin, origin + size])
+ center = origin + size * 0.5
+ elif center is not None:
+ center = numpy.array(center)
+ size = numpy.array(size)
+ points = numpy.array([center - size * 0.5, center + size * 0.5])
+ else:
+ raise ValueError("Origin or center expected")
+
+ with utils.blockSignals(self._handleBottomLeft):
+ self._handleBottomLeft.setPosition(points[0, 0], points[0, 1])
+ with utils.blockSignals(self._handleBottomRight):
+ self._handleBottomRight.setPosition(points[1, 0], points[0, 1])
+ with utils.blockSignals(self._handleTopLeft):
+ self._handleTopLeft.setPosition(points[0, 0], points[1, 1])
+ with utils.blockSignals(self._handleTopRight):
+ self._handleTopRight.setPosition(points[1, 0], points[1, 1])
+ with utils.blockSignals(self._handleCenter):
+ self._handleCenter.setPosition(center[0], center[1])
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(points[0, 0], points[0, 1])
+
+ self.__shape.setPoints(points)
+ self.sigRegionChanged.emit()
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ assert isinstance(position, (tuple, list, numpy.array))
+ points = self.__shape.getPoints()
+ bb1 = _BoundingBox.from_points(points)
+ return bb1.contains(position)
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self._handleCenter:
+ # It is the center anchor
+ size = self.getSize()
+ self._updateGeometry(center=current, size=size)
+ else:
+ opposed = {
+ self._handleBottomLeft: self._handleTopRight,
+ self._handleTopRight: self._handleBottomLeft,
+ self._handleBottomRight: self._handleTopLeft,
+ self._handleTopLeft: self._handleBottomRight,
+ }
+ handle2 = opposed[handle]
+ current2 = handle2.getPosition()
+ points = numpy.array([current, current2])
+
+ # Switch handles if they were crossed by interaction
+ if (
+ self._handleBottomLeft.getXPosition()
+ > self._handleBottomRight.getXPosition()
+ ):
+ self._handleBottomLeft, self._handleBottomRight = (
+ self._handleBottomRight,
+ self._handleBottomLeft,
+ )
+
+ if self._handleTopLeft.getXPosition() > self._handleTopRight.getXPosition():
+ self._handleTopLeft, self._handleTopRight = (
+ self._handleTopRight,
+ self._handleTopLeft,
+ )
+
+ if (
+ self._handleBottomLeft.getYPosition()
+ > self._handleTopLeft.getYPosition()
+ ):
+ self._handleBottomLeft, self._handleTopLeft = (
+ self._handleTopLeft,
+ self._handleBottomLeft,
+ )
+
+ if (
+ self._handleBottomRight.getYPosition()
+ > self._handleTopRight.getYPosition()
+ ):
+ self._handleBottomRight, self._handleTopRight = (
+ self._handleTopRight,
+ self._handleBottomRight,
+ )
+
+ self._setBound(points)
+
+ def __str__(self):
+ origin = self.getOrigin()
+ w, h = self.getSize()
+ params = origin[0], origin[1], w, h
+ params = "origin: %f %f; width: %f; height: %f" % params
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class CircleROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a circle in a 2D plot.
+
+ This ROI provides 1 anchor at the center to translate the circle,
+ and one anchor on the perimeter to change the radius.
+ """
+
+ ICON = "add-shape-circle"
+ NAME = "circle ROI"
+ SHORT_NAME = "circle"
+ """Metadata for this kind of ROI"""
+
+ _kind = "Circle"
+ """Label for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ HandleBasedROI.__init__(self, parent=parent)
+ self._handlePerimeter = self.addHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handleCenter.sigItemChanged.connect(self._centerPositionChanged)
+ self._handleLabel = self.addLabelHandle()
+
+ shape = items.Shape("polygon")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setColor(rgba(self.getColor()))
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ self.__shape = shape
+ self.addItem(shape)
+
+ self.__radius = 0
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ self._updateItemProperty(event, self, self.__shape)
+ super(CircleROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(CircleROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
+
+ def setFirstShapePoints(self, points):
+ assert len(points) == 2
+ self._setRay(points)
+
+ def _setRay(self, points):
+ """Initialize the circle from the center point and a
+ perimeter point."""
+ center = points[0]
+ radius = numpy.linalg.norm(points[0] - points[1])
+ self.setGeometry(center=center, radius=radius)
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def getCenter(self):
+ """Returns the central point of this rectangle
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ pos = self._handleCenter.getPosition()
+ return numpy.array(pos)
+
+ def getRadius(self):
+ """Returns the radius of this circle
+
+ :rtype: float
+ """
+ return self.__radius
+
+ def setCenter(self, position):
+ """Set the center point of this ROI
+
+ :param numpy.ndarray position: Location of the center of the circle
+ """
+ self._handleCenter.setPosition(*position)
+
+ def setRadius(self, radius):
+ """Set the size of this ROI
+
+ :param float size: Radius of the circle
+ """
+ radius = float(radius)
+ if radius != self.__radius:
+ self.__radius = radius
+ self._updateGeometry()
+
+ def setGeometry(self, center, radius):
+ """Set the geometry of the ROI"""
+ if numpy.array_equal(center, self.getCenter()):
+ self.setRadius(radius)
+ else:
+ self.__radius = float(radius) # Update radius directly
+ self.setCenter(center) # Calls _updateGeometry
+
+ def _updateGeometry(self):
+ """Update the handles and shape according to given parameters"""
+ center = self.getCenter()
+ perimeter_point = numpy.array([center[0] + self.__radius, center[1]])
+
+ self._handlePerimeter.setPosition(perimeter_point[0], perimeter_point[1])
+ self._handleLabel.setPosition(center[0], center[1])
+
+ nbpoints = 27
+ angles = numpy.arange(nbpoints) * 2.0 * numpy.pi / nbpoints
+ circleShape = numpy.array(
+ (numpy.cos(angles) * self.__radius, numpy.sin(angles) * self.__radius)
+ ).T
+ circleShape += center
+ self.__shape.setPoints(circleShape)
+ self.sigRegionChanged.emit()
+
+ def _centerPositionChanged(self, event):
+ """Handle position changed events of the center marker"""
+ if event is items.ItemChangedType.POSITION:
+ self._updateGeometry()
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self._handlePerimeter:
+ center = self.getCenter()
+ self.setRadius(numpy.linalg.norm(center - current))
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ return numpy.linalg.norm(self.getCenter() - position) <= self.getRadius()
+
+ def __str__(self):
+ center = self.getCenter()
+ radius = self.getRadius()
+ params = center[0], center[1], radius
+ params = "center: %f %f; radius: %f;" % params
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class EllipseROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying an oriented ellipse in a 2D plot.
+
+ This ROI provides 1 anchor at the center to translate the circle,
+ and two anchors on the perimeter to modify the major-radius and
+ minor-radius. These two anchors also allow to change the orientation.
+ """
+
+ ICON = "add-shape-ellipse"
+ NAME = "ellipse ROI"
+ SHORT_NAME = "ellipse"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ HandleBasedROI.__init__(self, parent=parent)
+ self._handleAxis0 = self.addHandle()
+ self._handleAxis1 = self.addHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handleCenter.sigItemChanged.connect(self._centerPositionChanged)
+ self._handleLabel = self.addLabelHandle()
+
+ shape = items.Shape("polygon")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setColor(rgba(self.getColor()))
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ self.__shape = shape
+ self.addItem(shape)
+
+ self._radius = 0.0, 0.0
+ self._orientation = (
+ 0.0 # angle in radians between the X-axis and the _handleAxis0
+ )
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ self._updateItemProperty(event, self, self.__shape)
+ super(EllipseROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(EllipseROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
+
+ def setFirstShapePoints(self, points):
+ assert len(points) == 2
+ self._setRay(points)
+
+ @staticmethod
+ def _calculateOrientation(p0, p1):
+ """return angle in radians between the vector p0-p1
+ and the X axis
+
+ :param p0: first point coordinates (x, y)
+ :param p1: second point coordinates
+ :return:
+ """
+ vector = (p1[0] - p0[0], p1[1] - p0[1])
+ x_unit_vector = (1, 0)
+ norm = numpy.linalg.norm(vector)
+ if norm != 0:
+ theta = numpy.arccos(numpy.dot(vector, x_unit_vector) / norm)
+ else:
+ theta = 0
+ if vector[1] < 0:
+ # arccos always returns values in range [0, pi]
+ theta = 2 * numpy.pi - theta
+ return theta
+
+ def _setRay(self, points):
+ """Initialize the circle from the center point and a
+ perimeter point."""
+ center = points[0]
+ radius = numpy.linalg.norm(points[0] - points[1])
+ orientation = self._calculateOrientation(points[0], points[1])
+ self.setGeometry(
+ center=center, radius=(radius, radius), orientation=orientation
+ )
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def getCenter(self):
+ """Returns the central point of this rectangle
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ pos = self._handleCenter.getPosition()
+ return numpy.array(pos)
+
+ def getMajorRadius(self):
+ """Returns the half-diameter of the major axis.
+
+ :rtype: float
+ """
+ return max(self._radius)
+
+ def getMinorRadius(self):
+ """Returns the half-diameter of the minor axis.
+
+ :rtype: float
+ """
+ return min(self._radius)
+
+ def getOrientation(self):
+ """Return angle in radians between the horizontal (X) axis
+ and the major axis of the ellipse in [0, 2*pi[
+
+ :rtype: float:
+ """
+ return self._orientation
+
+ def setCenter(self, center):
+ """Set the center point of this ROI
+
+ :param numpy.ndarray position: Coordinates (X, Y) of the center
+ of the ellipse
+ """
+ self._handleCenter.setPosition(*center)
+
+ def setMajorRadius(self, radius):
+ """Set the half-diameter of the major axis of the ellipse.
+
+ :param float radius:
+ Major radius of the ellipsis. Must be a positive value.
+ """
+ if self._radius[0] > self._radius[1]:
+ newRadius = radius, self._radius[1]
+ else:
+ newRadius = self._radius[0], radius
+ self.setGeometry(radius=newRadius)
+
+ def setMinorRadius(self, radius):
+ """Set the half-diameter of the minor axis of the ellipse.
+
+ :param float radius:
+ Minor radius of the ellipsis. Must be a positive value.
+ """
+ if self._radius[0] > self._radius[1]:
+ newRadius = self._radius[0], radius
+ else:
+ newRadius = radius, self._radius[1]
+ self.setGeometry(radius=newRadius)
+
+ def setOrientation(self, orientation):
+ """Rotate the ellipse
+
+ :param float orientation: Angle in radians between the horizontal and
+ the major axis.
+ :return:
+ """
+ self.setGeometry(orientation=orientation)
+
+ def setGeometry(self, center=None, radius=None, orientation=None):
+ """
+
+ :param center: (X, Y) coordinates
+ :param float majorRadius:
+ :param float minorRadius:
+ :param float orientation: angle in radians between the major axis and the
+ horizontal
+ :return:
+ """
+ if center is None:
+ center = self.getCenter()
+
+ if radius is None:
+ radius = self._radius
+ else:
+ radius = float(radius[0]), float(radius[1])
+
+ if orientation is None:
+ orientation = self._orientation
+ else:
+ # ensure that we store the orientation in range [0, 2*pi
+ orientation = numpy.mod(orientation, 2 * numpy.pi)
+
+ if (
+ numpy.array_equal(center, self.getCenter())
+ or radius != self._radius
+ or orientation != self._orientation
+ ):
+ # Update parameters directly
+ self._radius = radius
+ self._orientation = orientation
+
+ if numpy.array_equal(center, self.getCenter()):
+ self._updateGeometry()
+ else:
+ # This will call _updateGeometry
+ self.setCenter(center)
+
+ def _updateGeometry(self):
+ """Update shape and markers"""
+ center = self.getCenter()
+
+ orientation = self.getOrientation()
+ if self._radius[1] > self._radius[0]:
+ # _handleAxis1 is the major axis
+ orientation -= numpy.pi / 2
+
+ point0 = numpy.array(
+ [
+ center[0] + self._radius[0] * numpy.cos(orientation),
+ center[1] + self._radius[0] * numpy.sin(orientation),
+ ]
+ )
+ point1 = numpy.array(
+ [
+ center[0] - self._radius[1] * numpy.sin(orientation),
+ center[1] + self._radius[1] * numpy.cos(orientation),
+ ]
+ )
+ with utils.blockSignals(self._handleAxis0):
+ self._handleAxis0.setPosition(*point0)
+ with utils.blockSignals(self._handleAxis1):
+ self._handleAxis1.setPosition(*point1)
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(*center)
+
+ nbpoints = 27
+ angles = numpy.arange(nbpoints) * 2.0 * numpy.pi / nbpoints
+ X = self._radius[0] * numpy.cos(angles) * numpy.cos(orientation) - self._radius[
+ 1
+ ] * numpy.sin(angles) * numpy.sin(orientation)
+ Y = self._radius[0] * numpy.cos(angles) * numpy.sin(orientation) + self._radius[
+ 1
+ ] * numpy.sin(angles) * numpy.cos(orientation)
+
+ ellipseShape = numpy.array((X, Y)).T
+ ellipseShape += center
+ self.__shape.setPoints(ellipseShape)
+ self.sigRegionChanged.emit()
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle in (self._handleAxis0, self._handleAxis1):
+ center = self.getCenter()
+ orientation = self._calculateOrientation(center, current)
+ distance = numpy.linalg.norm(center - current)
+
+ if handle is self._handleAxis1:
+ if self._radius[0] > distance:
+ # _handleAxis1 is not the major axis, rotate -90 degrees
+ orientation -= numpy.pi / 2
+ radius = self._radius[0], distance
+
+ else: # _handleAxis0
+ if self._radius[1] > distance:
+ # _handleAxis0 is not the major axis, rotate +90 degrees
+ orientation += numpy.pi / 2
+ radius = distance, self._radius[1]
+
+ self.setGeometry(radius=radius, orientation=orientation)
+
+ def _centerPositionChanged(self, event):
+ """Handle position changed events of the center marker"""
+ if event is items.ItemChangedType.POSITION:
+ self._updateGeometry()
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ major, minor = self.getMajorRadius(), self.getMinorRadius()
+ delta = self.getOrientation()
+ x, y = position - self.getCenter()
+ return (
+ (x * numpy.cos(delta) + y * numpy.sin(delta)) ** 2 / major**2
+ + (x * numpy.sin(delta) - y * numpy.cos(delta)) ** 2 / minor**2
+ ) <= 1
+
+ def __str__(self):
+ center = self.getCenter()
+ major = self.getMajorRadius()
+ minor = self.getMinorRadius()
+ orientation = self.getOrientation()
+ params = center[0], center[1], major, minor, orientation
+ params = (
+ "center: %f %f; major radius: %f: minor radius: %f; orientation: %f"
+ % params
+ )
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class PolygonROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a closed polygon in a 2D plot.
+
+ This ROI provides 1 anchor for each point of the polygon.
+ """
+
+ ICON = "add-shape-polygon"
+ NAME = "polygon ROI"
+ SHORT_NAME = "polygon"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "polygon"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._handleLabel = self.addLabelHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handlePoints = []
+ self._points = numpy.empty((0, 2))
+ self._handleClose = None
+
+ self._polygon_shape = None
+ shape = self.__createShape()
+ self.__shape = shape
+ self.addItem(shape)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event in [items.ItemChangedType.VISIBLE]:
+ self._updateItemProperty(event, self, self.__shape)
+ super(PolygonROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(PolygonROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
+ if self._handleClose is not None:
+ color = self._computeHandleColor(style.getColor())
+ self._handleClose.setColor(color)
+
+ def __createShape(self, interaction=False):
+ kind = "polygon" if not interaction else "polylines"
+ shape = items.Shape(kind)
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setFill(False)
+ shape.setOverlay(True)
+ style = self.getCurrentStyle()
+ shape.setLineStyle(style.getLineStyle())
+ shape.setLineWidth(style.getLineWidth())
+ shape.setColor(rgba(style.getColor()))
+ return shape
+
+ def setFirstShapePoints(self, points):
+ if self._handleClose is not None:
+ self._handleClose.setPosition(*points[0])
+ self.setPoints(points)
+
+ def creationStarted(self):
+ """Called when the ROI creation interaction was started."""
+ # Handle to see where to close the polygon
+ self._handleClose = self.addUserHandle()
+ self._handleClose.setSymbol("o")
+ color = self._computeHandleColor(rgba(self.getColor()))
+ self._handleClose.setColor(color)
+
+ # Hide the center while creating the first shape
+ self._handleCenter.setSymbol("")
+
+ # In interaction replace the polygon by a line, to display something unclosed
+ self.removeItem(self.__shape)
+ self.__shape = self.__createShape(interaction=True)
+ self.__shape.setPoints(self._points)
+ self.addItem(self.__shape)
+
+ def isBeingCreated(self):
+ """Returns true if the ROI is in creation step"""
+ return self._handleClose is not None
+
+ def creationFinalized(self):
+ """Called when the ROI creation interaction was finalized."""
+ self.removeHandle(self._handleClose)
+ self._handleClose = None
+ self.removeItem(self.__shape)
+ self.__shape = self.__createShape()
+ self.__shape.setPoints(self._points)
+ self.addItem(self.__shape)
+ # Hide the center while creating the first shape
+ self._handleCenter.setSymbol("+")
+ for handle in self._handlePoints:
+ handle.setSymbol("s")
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def getPoints(self):
+ """Returns the list of the points of this polygon.
+
+ :rtype: numpy.ndarray
+ """
+ return self._points.copy()
+
+ def setPoints(self, points):
+ """Set the position of this ROI
+
+ :param numpy.ndarray pos: 2d-coordinate of this point
+ """
+ assert len(points.shape) == 2 and points.shape[1] == 2
+
+ if numpy.array_equal(points, self._points):
+ return # Nothing has changed
+
+ self._polygon_shape = None
+
+ # Update the needed handles
+ while len(self._handlePoints) != len(points):
+ if len(self._handlePoints) < len(points):
+ handle = self.addHandle()
+ self._handlePoints.append(handle)
+ if self.isBeingCreated():
+ handle.setSymbol("")
+ else:
+ handle = self._handlePoints.pop(-1)
+ self.removeHandle(handle)
+
+ for handle, position in zip(self._handlePoints, points):
+ with utils.blockSignals(handle):
+ handle.setPosition(position[0], position[1])
+
+ if len(points) > 0:
+ if not self.isHandleBeingDragged():
+ vmin = numpy.min(points, axis=0)
+ vmax = numpy.max(points, axis=0)
+ center = (vmax + vmin) * 0.5
+ with utils.blockSignals(self._handleCenter):
+ self._handleCenter.setPosition(center[0], center[1])
+
+ num = numpy.argmin(points[:, 1])
+ pos = points[num]
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(pos[0], pos[1])
+
+ if len(points) == 0:
+ self._points = numpy.empty((0, 2))
+ else:
+ self._points = points
+ self.__shape.setPoints(self._points)
+ self.sigRegionChanged.emit()
+
+ def translate(self, x, y):
+ points = self.getPoints()
+ delta = numpy.array([x, y])
+ self.setPoints(points)
+ self.setPoints(points + delta)
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self._handleCenter:
+ delta = current - previous
+ self.translate(delta[0], delta[1])
+ else:
+ points = self.getPoints()
+ num = self._handlePoints.index(handle)
+ points[num] = current
+ self.setPoints(points)
+
+ def handleDragFinished(self, handle, origin, current):
+ points = self._points
+ if len(points) > 0:
+ # Only update the center at the end
+ # To avoid to disturb the interaction
+ vmin = numpy.min(points, axis=0)
+ vmax = numpy.max(points, axis=0)
+ center = (vmax + vmin) * 0.5
+ with utils.blockSignals(self._handleCenter):
+ self._handleCenter.setPosition(center[0], center[1])
+
+ def __str__(self):
+ points = self._points
+ params = "; ".join("%f %f" % (pt[0], pt[1]) for pt in points)
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ bb1 = _BoundingBox.from_points(self.getPoints())
+ if bb1.contains(position) is False:
+ return False
+
+ if self._polygon_shape is None:
+ self._polygon_shape = Polygon(vertices=self.getPoints())
+
+ # warning: both the polygon and the value are inverted
+ return self._polygon_shape.is_inside(row=position[0], col=position[1])
+
+ def _setControlPoints(self, points):
+ RegionOfInterest._setControlPoints(self, points=points)
+ self._polygon_shape = None
+
+
+class HorizontalRangeROI(RegionOfInterest, items.LineMixIn):
+ """A ROI identifying an horizontal range in a 1D plot."""
+
+ ICON = "add-range-horizontal"
+ NAME = "horizontal range ROI"
+ SHORT_NAME = "hrange"
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._markerMin = items.XMarker()
+ self._markerMax = items.XMarker()
+ self._markerCen = items.XMarker()
+ self._markerCen.setLineStyle(" ")
+ self._markerMin._setConstraint(self.__positionMinConstraint)
+ self._markerMax._setConstraint(self.__positionMaxConstraint)
+ self._markerMin.sigDragStarted.connect(self._editingStarted)
+ self._markerMin.sigDragFinished.connect(self._editingFinished)
+ self._markerMax.sigDragStarted.connect(self._editingStarted)
+ self._markerMax.sigDragFinished.connect(self._editingFinished)
+ self._markerCen.sigDragStarted.connect(self._editingStarted)
+ self._markerCen.sigDragFinished.connect(self._editingFinished)
+ self.addItem(self._markerCen)
+ self.addItem(self._markerMin)
+ self.addItem(self._markerMax)
+ self.__filterReentrant = utils.LockReentrant()
+
+ def setFirstShapePoints(self, points):
+ vmin = min(points[:, 0])
+ vmax = max(points[:, 0])
+ self._updatePos(vmin, vmax)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.EDITABLE:
+ self._updateEditable()
+ self._updateText(self.getText())
+ elif event == items.ItemChangedType.LINE_STYLE:
+ markers = [self._markerMin, self._markerMax]
+ self._updateItemProperty(event, self, markers)
+ elif event in [items.ItemChangedType.VISIBLE, items.ItemChangedType.SELECTABLE]:
+ markers = [self._markerMin, self._markerMax, self._markerCen]
+ self._updateItemProperty(event, self, markers)
+ super(HorizontalRangeROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ markers = [self._markerMin, self._markerMax, self._markerCen]
+ for m in markers:
+ m.setColor(style.getColor())
+ m.setLineWidth(style.getLineWidth())
+
+ def _updateText(self, text: str):
+ if self.isEditable():
+ self._markerMin.setText("")
+ self._markerCen.setText(text)
+ else:
+ self._markerMin.setText(text)
+ self._markerCen.setText("")
+
+ def _updateEditable(self):
+ editable = self.isEditable()
+ self._markerMin._setDraggable(editable)
+ self._markerMax._setDraggable(editable)
+ self._markerCen._setDraggable(editable)
+ if self.isEditable():
+ self._markerMin.sigItemChanged.connect(self._minPositionChanged)
+ self._markerMax.sigItemChanged.connect(self._maxPositionChanged)
+ self._markerCen.sigItemChanged.connect(self._cenPositionChanged)
+ self._markerCen.setLineStyle(":")
+ else:
+ self._markerMin.sigItemChanged.disconnect(self._minPositionChanged)
+ self._markerMax.sigItemChanged.disconnect(self._maxPositionChanged)
+ self._markerCen.sigItemChanged.disconnect(self._cenPositionChanged)
+ self._markerCen.setLineStyle(" ")
+
+ def _updatePos(self, vmin, vmax, force=False):
+ """Update marker position and emit signal.
+
+ :param float vmin:
+ :param float vmax:
+ :param bool force:
+ True to update even if already at the right position.
+ """
+ if not force and numpy.array_equal((vmin, vmax), self.getRange()):
+ return # Nothing has changed
+
+ center = (vmin + vmax) * 0.5
+ with self.__filterReentrant:
+ with utils.blockSignals(self._markerMin):
+ self._markerMin.setPosition(vmin, 0)
+ with utils.blockSignals(self._markerCen):
+ self._markerCen.setPosition(center, 0)
+ with utils.blockSignals(self._markerMax):
+ self._markerMax.setPosition(vmax, 0)
+ self.sigRegionChanged.emit()
+
+ def setRange(self, vmin, vmax):
+ """Set the range of this ROI.
+
+ :param float vmin: Staring location of the range
+ :param float vmax: Ending location of the range
+ """
+ if vmin is None or vmax is None:
+ err = "Can't set vmin or vmax to None"
+ raise ValueError(err)
+ if vmin > vmax:
+ err = (
+ "Can't set vmin and vmax because vmin >= vmax "
+ "vmin = %s, vmax = %s" % (vmin, vmax)
+ )
+ raise ValueError(err)
+ self._updatePos(vmin, vmax)
+
+ def getRange(self):
+ """Returns the range of this ROI.
+
+ :rtype: Tuple[float,float]
+ """
+ vmin = self.getMin()
+ vmax = self.getMax()
+ return vmin, vmax
+
+ def setMin(self, vmin):
+ """Set the min of this ROI.
+
+ :param float vmin: New min
+ """
+ vmax = self.getMax()
+ self._updatePos(vmin, vmax)
+
+ def getMin(self):
+ """Returns the min value of this ROI.
+
+ :rtype: float
+ """
+ return self._markerMin.getPosition()[0]
+
+ def setMax(self, vmax):
+ """Set the max of this ROI.
+
+ :param float vmax: New max
+ """
+ vmin = self.getMin()
+ self._updatePos(vmin, vmax)
+
+ def getMax(self):
+ """Returns the max value of this ROI.
+
+ :rtype: float
+ """
+ return self._markerMax.getPosition()[0]
+
+ def setCenter(self, center):
+ """Set the center of this ROI.
+
+ :param float center: New center
+ """
+ vmin, vmax = self.getRange()
+ previousCenter = (vmin + vmax) * 0.5
+ delta = center - previousCenter
+ self._updatePos(vmin + delta, vmax + delta)
+
+ def getCenter(self):
+ """Returns the center location of this ROI.
+
+ :rtype: float
+ """
+ vmin, vmax = self.getRange()
+ return (vmin + vmax) * 0.5
+
+ def __positionMinConstraint(self, x, y):
+ """Constraint of the min marker"""
+ if self.__filterReentrant.locked():
+ # Ignore the constraint when we set an explicit value
+ return x, y
+ vmax = self.getMax()
+ if vmax is None:
+ return x, y
+ return min(x, vmax), y
+
+ def __positionMaxConstraint(self, x, y):
+ """Constraint of the max marker"""
+ if self.__filterReentrant.locked():
+ # Ignore the constraint when we set an explicit value
+ return x, y
+ vmin = self.getMin()
+ if vmin is None:
+ return x, y
+ return max(x, vmin), y
+
+ def _minPositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ marker = self.sender()
+ self._updatePos(marker.getXPosition(), self.getMax(), force=True)
+
+ def _maxPositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ marker = self.sender()
+ self._updatePos(self.getMin(), marker.getXPosition(), force=True)
+
+ def _cenPositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ marker = self.sender()
+ self.setCenter(marker.getXPosition())
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ return self.getMin() <= position[0] <= self.getMax()
+
+ def __str__(self):
+ vrange = self.getRange()
+ params = "min: %f; max: %f" % vrange
+ return "%s(%s)" % (self.__class__.__name__, params)
diff --git a/src/silx/gui/plot/items/scatter.py b/src/silx/gui/plot/items/scatter.py
new file mode 100644
index 0000000..c46b60c
--- /dev/null
+++ b/src/silx/gui/plot/items/scatter.py
@@ -0,0 +1,1060 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Scatter` item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "29/03/2017"
+
+
+from collections import namedtuple
+import logging
+import threading
+import numpy
+from matplotlib.tri import LinearTriInterpolator, Triangulation
+
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor, CancelledError
+
+from ....utils.proxy import docstring
+from ....math.combo import min_max
+from ....math.histogram import Histogramnd
+from ....utils.weakref import WeakList
+from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn
+from .axis import Axis
+from ._pick import PickingResult
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _GreedyThreadPoolExecutor(ThreadPoolExecutor):
+ """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method."""
+
+ def __init__(self, *args, **kwargs):
+ super(_GreedyThreadPoolExecutor, self).__init__(*args, **kwargs)
+ self.__futures = defaultdict(WeakList)
+ self.__lock = threading.RLock()
+
+ def submit_greedy(self, queue, fn, *args, **kwargs):
+ """Same as :meth:`submit` but cancel previous tasks in given queue.
+
+ This means that when a new task is submitted for a given queue,
+ all other pending tasks of that queue are cancelled.
+
+ :param queue: Identifier of the queue. This must be hashable.
+ :param callable fn: The callable to call with provided extra arguments
+ :return: Future corresponding to this task
+ :rtype: concurrent.futures.Future
+ """
+ with self.__lock:
+ # Cancel previous tasks in given queue
+ for future in self.__futures.pop(queue, []):
+ if not future.done():
+ future.cancel()
+
+ future = super(_GreedyThreadPoolExecutor, self).submit(fn, *args, **kwargs)
+ self.__futures[queue].append(future)
+
+ return future
+
+
+# Functions to guess grid shape from coordinates
+
+
+def _get_z_line_length(array):
+ """Return length of line if array is a Z-like 2D regular grid.
+
+ :param numpy.ndarray array: The 1D array of coordinates to check
+ :return: 0 if no line length could be found,
+ else the number of element per line.
+ :rtype: int
+ """
+ sign = numpy.sign(numpy.diff(array))
+ if len(sign) == 0 or sign[0] == 0: # We don't handle that
+ return 0
+ # Check this way to account for 0 sign (i.e., diff == 0)
+ beginnings = numpy.where(sign == -sign[0])[0] + 1
+ if len(beginnings) == 0:
+ return 0
+ length = beginnings[0]
+ if numpy.all(numpy.equal(numpy.diff(beginnings), length)):
+ return length
+ return 0
+
+
+def _guess_z_grid_shape(x, y):
+ """Guess the shape of a grid from (x, y) coordinates.
+
+ The grid might contain more elements than x and y,
+ as the last line might be partly filled.
+
+ :param numpy.ndarray x:
+ :paran numpy.ndarray y:
+ :returns: (order, (height, width)) of the regular grid,
+ or None if could not guess one.
+ 'order' is 'row' if X (i.e., column) is the fast dimension, else 'column'.
+ :rtype: Union[List(str,int),None]
+ """
+ width = _get_z_line_length(x)
+ if width != 0:
+ return "row", (int(numpy.ceil(len(x) / width)), width)
+ else:
+ height = _get_z_line_length(y)
+ if height != 0:
+ return "column", (height, int(numpy.ceil(len(y) / height)))
+ return None
+
+
+def is_monotonic(array):
+ """Returns whether array is monotonic (increasing or decreasing).
+
+ :param numpy.ndarray array: 1D array-like container.
+ :returns: 1 if array is monotonically increasing,
+ -1 if array is monotonically decreasing,
+ 0 if array is not monotonic
+ :rtype: int
+ """
+ diff = numpy.diff(numpy.ravel(array))
+ with numpy.errstate(invalid="ignore"):
+ if numpy.all(diff >= 0):
+ return 1
+ elif numpy.all(diff <= 0):
+ return -1
+ else:
+ return 0
+
+
+def _guess_grid(x, y):
+ """Guess a regular grid from the points.
+
+ Result convention is (x, y)
+
+ :param numpy.ndarray x: X coordinates of the points
+ :param numpy.ndarray y: Y coordinates of the points
+ :returns: (order, (height, width)
+ order is 'row' or 'column'
+ :rtype: Union[List[str,List[int]],None]
+ """
+ x, y = numpy.ravel(x), numpy.ravel(y)
+
+ guess = _guess_z_grid_shape(x, y)
+ if guess is not None:
+ return guess
+
+ else:
+ # Cannot guess a regular grid
+ # Let's assume it's a single line
+ order = "row" # or 'column' doesn't matter for a single line
+ y_monotonic = is_monotonic(y)
+ if is_monotonic(x) or y_monotonic: # we can guess a line
+ x_min, x_max = min_max(x)
+ y_min, y_max = min_max(y)
+
+ if not y_monotonic or x_max - x_min >= y_max - y_min:
+ # x only is monotonic or both are and X varies more
+ # line along X
+ shape = 1, len(x)
+ else:
+ # y only is monotonic or both are and Y varies more
+ # line along Y
+ shape = len(y), 1
+
+ else: # Cannot guess a line from the points
+ return None
+
+ return order, shape
+
+
+def _quadrilateral_grid_coords(points):
+ """Compute an irregular grid of quadrilaterals from a set of points
+
+ The input points are expected to lie on a grid.
+
+ :param numpy.ndarray points:
+ 3D data set of 2D input coordinates (height, width, 2)
+ height and width must be at least 2.
+ :return: 3D dataset of 2D coordinates of the grid (height+1, width+1, 2)
+ """
+ assert points.ndim == 3
+ assert points.shape[0] >= 2
+ assert points.shape[1] >= 2
+ assert points.shape[2] == 2
+
+ dim0, dim1 = points.shape[:2]
+ grid_points = numpy.zeros((dim0 + 1, dim1 + 1, 2), dtype=numpy.float64)
+
+ # Compute inner points as mean of 4 neighbours
+ neighbour_view = numpy.lib.stride_tricks.as_strided(
+ points,
+ shape=(dim0 - 1, dim1 - 1, 2, 2, points.shape[2]),
+ strides=points.strides[:2] + points.strides[:2] + points.strides[-1:],
+ writeable=False,
+ )
+ inner_points = numpy.mean(neighbour_view, axis=(2, 3))
+ grid_points[1:-1, 1:-1] = inner_points
+
+ # Compute 'vertical' sides
+ # Alternative: grid_points[1:-1, [0, -1]] = points[:-1, [0, -1]] + points[1:, [0, -1]] - inner_points[:, [0, -1]]
+ grid_points[1:-1, [0, -1], 0] = (
+ points[:-1, [0, -1], 0] + points[1:, [0, -1], 0] - inner_points[:, [0, -1], 0]
+ )
+ grid_points[1:-1, [0, -1], 1] = inner_points[:, [0, -1], 1]
+
+ # Compute 'horizontal' sides
+ grid_points[[0, -1], 1:-1, 0] = inner_points[[0, -1], :, 0]
+ grid_points[[0, -1], 1:-1, 1] = (
+ points[[0, -1], :-1, 1] + points[[0, -1], 1:, 1] - inner_points[[0, -1], :, 1]
+ )
+
+ # Compute corners
+ d0, d1 = [0, 0, -1, -1], [0, -1, -1, 0]
+ grid_points[d0, d1] = 2 * points[d0, d1] - inner_points[d0, d1]
+ return grid_points
+
+
+def _quadrilateral_grid_as_triangles(points):
+ """Returns the points and indices to make a grid of quadirlaterals
+
+ :param numpy.ndarray points:
+ 3D array of points (height, width, 2)
+ :return: triangle corners (4 * N, 2), triangle indices (2 * N, 3)
+ With N = height * width, the number of input points
+ """
+ nbpoints = numpy.prod(points.shape[:2])
+
+ grid = _quadrilateral_grid_coords(points)
+ coords = numpy.empty((4 * nbpoints, 2), dtype=grid.dtype)
+ coords[::4] = grid[:-1, :-1].reshape(-1, 2)
+ coords[1::4] = grid[1:, :-1].reshape(-1, 2)
+ coords[2::4] = grid[:-1, 1:].reshape(-1, 2)
+ coords[3::4] = grid[1:, 1:].reshape(-1, 2)
+
+ indices = numpy.empty((2 * nbpoints, 3), dtype=numpy.uint32)
+ indices[::2, 0] = numpy.arange(0, 4 * nbpoints, 4)
+ indices[::2, 1] = numpy.arange(1, 4 * nbpoints, 4)
+ indices[::2, 2] = numpy.arange(2, 4 * nbpoints, 4)
+ indices[1::2, 0] = indices[::2, 1]
+ indices[1::2, 1] = indices[::2, 2]
+ indices[1::2, 2] = numpy.arange(3, 4 * nbpoints, 4)
+
+ return coords, indices
+
+
+_RegularGridInfo = namedtuple(
+ "_RegularGridInfo", ["bounds", "origin", "scale", "shape", "order"]
+)
+
+
+_HistogramInfo = namedtuple(
+ "_HistogramInfo", ["mean", "count", "sum", "origin", "scale", "shape"]
+)
+
+
+class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
+ """Description of a scatter"""
+
+ _DEFAULT_SELECTABLE = True
+ """Default selectable state for scatter plots"""
+
+ _SUPPORTED_SCATTER_VISUALIZATION = (
+ ScatterVisualizationMixIn.Visualization.POINTS,
+ ScatterVisualizationMixIn.Visualization.SOLID,
+ ScatterVisualizationMixIn.Visualization.REGULAR_GRID,
+ ScatterVisualizationMixIn.Visualization.IRREGULAR_GRID,
+ ScatterVisualizationMixIn.Visualization.BINNED_STATISTIC,
+ )
+ """Overrides supported Visualizations"""
+
+ def __init__(self):
+ PointsBase.__init__(self)
+ ColormapMixIn.__init__(self)
+ ScatterVisualizationMixIn.__init__(self)
+ self._value = ()
+ self.__alpha = None
+ # Cache Delaunay triangulation future object
+ self.__triangulationFuture = None
+ # Cache interpolator future object
+ self.__interpolatorFuture = None
+ self.__executor = None
+
+ # Cache triangles: x, y, indices
+ self.__cacheTriangles = None, None, None
+
+ # Cache regular grid and histogram info
+ self.__cacheRegularGridInfo = None
+ self.__cacheHistogramInfo = None
+
+ def _updateColormappedData(self):
+ """Update the colormapped data, to be called when changed"""
+ if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ data = None
+ else:
+ data = getattr(
+ histoInfo,
+ self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ),
+ )
+ else:
+ data = self.getValueData(copy=False)
+ self._setColormappedData(data, copy=False)
+
+ @docstring(ScatterVisualizationMixIn)
+ def setVisualization(self, mode):
+ previous = self.getVisualization()
+ if super().setVisualization(mode):
+ if bool(mode is self.Visualization.BINNED_STATISTIC) ^ bool(
+ previous is self.Visualization.BINNED_STATISTIC
+ ):
+ self._updateColormappedData()
+ return True
+ else:
+ return False
+
+ @docstring(ScatterVisualizationMixIn)
+ def setVisualizationParameter(self, parameter, value):
+ parameter = self.VisualizationParameter.from_value(parameter)
+
+ if super(Scatter, self).setVisualizationParameter(parameter, value):
+ if parameter in (
+ self.VisualizationParameter.GRID_BOUNDS,
+ self.VisualizationParameter.GRID_MAJOR_ORDER,
+ self.VisualizationParameter.GRID_SHAPE,
+ ):
+ self.__cacheRegularGridInfo = None
+
+ if parameter in (
+ self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
+ self.VisualizationParameter.DATA_BOUNDS_HINT,
+ ):
+ if parameter in (
+ self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.DATA_BOUNDS_HINT,
+ ):
+ self.__cacheHistogramInfo = None # Clean-up cache
+ if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
+ self._updateColormappedData()
+ return True
+ else:
+ return False
+
+ @docstring(ScatterVisualizationMixIn)
+ def getCurrentVisualizationParameter(self, parameter):
+ value = self.getVisualizationParameter(parameter)
+ if (
+ parameter is self.VisualizationParameter.DATA_BOUNDS_HINT
+ or value is not None
+ ):
+ return value # Value has been set, return it
+
+ elif parameter is self.VisualizationParameter.GRID_BOUNDS:
+ grid = self.__getRegularGridInfo()
+ return None if grid is None else grid.bounds
+
+ elif parameter is self.VisualizationParameter.GRID_MAJOR_ORDER:
+ grid = self.__getRegularGridInfo()
+ return None if grid is None else grid.order
+
+ elif parameter is self.VisualizationParameter.GRID_SHAPE:
+ grid = self.__getRegularGridInfo()
+ return None if grid is None else grid.shape
+
+ elif parameter is self.VisualizationParameter.BINNED_STATISTIC_SHAPE:
+ info = self.__getHistogramInfo()
+ return None if info is None else info.shape
+
+ else:
+ raise NotImplementedError()
+
+ def __getRegularGridInfo(self):
+ """Get grid info"""
+ if self.__cacheRegularGridInfo is None:
+ shape = self.getVisualizationParameter(
+ self.VisualizationParameter.GRID_SHAPE
+ )
+ order = self.getVisualizationParameter(
+ self.VisualizationParameter.GRID_MAJOR_ORDER
+ )
+ if shape is None or order is None:
+ guess = _guess_grid(
+ self.getXData(copy=False), self.getYData(copy=False)
+ )
+ if guess is None:
+ _logger.warning(
+ "Cannot guess a grid: Cannot display as regular grid image"
+ )
+ return None
+ if shape is None:
+ shape = guess[1]
+ if order is None:
+ order = guess[0]
+
+ nbpoints = len(self.getXData(copy=False))
+ if nbpoints > shape[0] * shape[1]:
+ # More data points that provided grid shape: enlarge grid
+ _logger.warning(
+ "More data points than provided grid shape size: extends grid"
+ )
+ dim0, dim1 = shape
+ if order == "row": # keep dim1, enlarge dim0
+ dim0 = nbpoints // dim1 + (1 if nbpoints % dim1 else 0)
+ else: # keep dim0, enlarge dim1
+ dim1 = nbpoints // dim0 + (1 if nbpoints % dim0 else 0)
+ shape = dim0, dim1
+
+ bounds = self.getVisualizationParameter(
+ self.VisualizationParameter.GRID_BOUNDS
+ )
+ if bounds is None:
+ x, y = self.getXData(copy=False), self.getYData(copy=False)
+ min_, max_ = min_max(x)
+ xRange = (min_, max_) if (x[0] - min_) < (max_ - x[0]) else (max_, min_)
+ min_, max_ = min_max(y)
+ yRange = (min_, max_) if (y[0] - min_) < (max_ - y[0]) else (max_, min_)
+ bounds = (xRange[0], yRange[0]), (xRange[1], yRange[1])
+
+ begin, end = bounds
+ scale = (
+ (end[0] - begin[0]) / max(1, shape[1] - 1),
+ (end[1] - begin[1]) / max(1, shape[0] - 1),
+ )
+ if scale[0] == 0 and scale[1] == 0:
+ scale = 1.0, 1.0
+ elif scale[0] == 0:
+ scale = scale[1], scale[1]
+ elif scale[1] == 0:
+ scale = scale[0], scale[0]
+
+ origin = begin[0] - 0.5 * scale[0], begin[1] - 0.5 * scale[1]
+
+ self.__cacheRegularGridInfo = _RegularGridInfo(
+ bounds=bounds, origin=origin, scale=scale, shape=shape, order=order
+ )
+
+ return self.__cacheRegularGridInfo
+
+ def __getHistogramInfo(self):
+ """Get histogram info"""
+ if self.__cacheHistogramInfo is None:
+ shape = self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_SHAPE
+ )
+ if shape is None:
+ shape = 100, 100 # TODO compute auto shape
+
+ x, y, values = self.getData(copy=False)[:3]
+ if len(x) == 0: # No histogram
+ return None
+
+ if not numpy.issubdtype(x.dtype, numpy.floating):
+ x = x.astype(numpy.float64)
+ if not numpy.issubdtype(y.dtype, numpy.floating):
+ y = y.astype(numpy.float64)
+ if not numpy.issubdtype(values.dtype, numpy.floating):
+ values = values.astype(numpy.float64)
+
+ ranges = (tuple(min_max(y, finite=True)), tuple(min_max(x, finite=True)))
+ rangesHint = self.getVisualizationParameter(
+ self.VisualizationParameter.DATA_BOUNDS_HINT
+ )
+ if rangesHint is not None:
+ ranges = tuple(
+ (min(dataMin, hintMin), max(dataMax, hintMax))
+ for (dataMin, dataMax), (hintMin, hintMax) in zip(
+ ranges, rangesHint
+ )
+ )
+
+ points = numpy.transpose(numpy.array((y, x)))
+ counts, sums, bin_edges = Histogramnd(
+ points, histo_range=ranges, n_bins=shape, weights=values
+ )
+ yEdges, xEdges = bin_edges
+ origin = xEdges[0], yEdges[0]
+ scale = (
+ (xEdges[-1] - xEdges[0]) / (len(xEdges) - 1),
+ (yEdges[-1] - yEdges[0]) / (len(yEdges) - 1),
+ )
+
+ with numpy.errstate(divide="ignore", invalid="ignore"):
+ histo = sums / counts
+
+ self.__cacheHistogramInfo = _HistogramInfo(
+ mean=histo,
+ count=counts,
+ sum=sums,
+ origin=origin,
+ scale=scale,
+ shape=shape,
+ )
+
+ return self.__cacheHistogramInfo
+
+ def __applyColormapToData(self):
+ """Compute colors by applying colormap to values.
+
+ :returns: Array of RGBA colors
+ """
+ cmap = self.getColormap()
+ rgbacolors = cmap.applyToData(self)
+
+ if self.__alpha is not None:
+ rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8)
+ return rgbacolors
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ # Filter-out values <= 0
+ xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData(
+ copy=False, displayed=True
+ )
+
+ # Remove not finite numbers (this includes filtered out x, y <= 0)
+ mask = numpy.logical_and(numpy.isfinite(xFiltered), numpy.isfinite(yFiltered))
+ xFiltered = xFiltered[mask]
+ yFiltered = yFiltered[mask]
+
+ if len(xFiltered) == 0:
+ return None # No data to display, do not add renderer to backend
+
+ visualization = self.getVisualization()
+
+ if visualization is self.Visualization.BINNED_STATISTIC:
+ plot = self.getPlot()
+ if (
+ plot is None
+ or plot.getXAxis().getScale() != Axis.LINEAR
+ or plot.getYAxis().getScale() != Axis.LINEAR
+ ):
+ # Those visualizations are not available with log scaled axes
+ return None
+
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ return None
+ data = getattr(
+ histoInfo,
+ self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ),
+ )
+
+ return backend.addImage(
+ data=data,
+ origin=histoInfo.origin,
+ scale=histoInfo.scale,
+ colormap=self.getColormap(),
+ alpha=self.getAlpha(),
+ )
+
+ elif visualization is self.Visualization.POINTS:
+ rgbacolors = self.__applyColormapToData()
+ return backend.addCurve(
+ xFiltered,
+ yFiltered,
+ color=rgbacolors[mask],
+ gapcolor=None,
+ symbol=self.getSymbol(),
+ linewidth=0,
+ linestyle="",
+ yaxis="left",
+ xerror=xerror,
+ yerror=yerror,
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=self.getSymbolSize(),
+ baseline=None,
+ )
+
+ else:
+ plot = self.getPlot()
+ if (
+ plot is None
+ or plot.getXAxis().getScale() != Axis.LINEAR
+ or plot.getYAxis().getScale() != Axis.LINEAR
+ ):
+ # Those visualizations are not available with log scaled axes
+ return None
+
+ if visualization is self.Visualization.SOLID:
+ try:
+ triangulation = self._getTriangulationFuture().result()
+ except (RuntimeError, ValueError):
+ _logger.warning(
+ "Cannot get a triangulation: Cannot display as solid surface"
+ )
+ return None
+ else:
+ rgbacolors = self.__applyColormapToData()
+ triangles = triangulation.triangles.astype(numpy.int32)
+ return backend.addTriangles(
+ xFiltered,
+ yFiltered,
+ triangles,
+ color=rgbacolors[mask],
+ alpha=self.getAlpha(),
+ )
+
+ elif visualization is self.Visualization.REGULAR_GRID:
+ gridInfo = self.__getRegularGridInfo()
+ if gridInfo is None:
+ return None
+
+ dim0, dim1 = gridInfo.shape
+ if gridInfo.order == "column": # transposition needed
+ dim0, dim1 = dim1, dim0
+
+ values = self.getValueData(copy=False)
+ if self.__alpha is None and len(values) == dim0 * dim1:
+ image = values.reshape(dim0, dim1)
+ else:
+ # The points do not fill the whole image
+ if self.__alpha is None and numpy.issubdtype(
+ values.dtype, numpy.floating
+ ):
+ image = numpy.empty(dim0 * dim1, dtype=values.dtype)
+ image[: len(values)] = values
+ image[len(values) :] = float("nan") # Transparent pixels
+ image.shape = dim0, dim1
+ else: # Per value alpha or no NaN, so convert to RGBA
+ rgbacolors = self.__applyColormapToData()
+ image = numpy.empty((dim0 * dim1, 4), dtype=numpy.uint8)
+ image[: len(rgbacolors)] = rgbacolors
+ image[len(rgbacolors) :] = (0, 0, 0, 0) # Transparent pixels
+ image.shape = dim0, dim1, 4
+
+ if gridInfo.order == "column":
+ if image.ndim == 2:
+ image = numpy.transpose(image)
+ else:
+ image = numpy.transpose(image, axes=(1, 0, 2))
+
+ if image.ndim == 2:
+ colormap = self.getColormap()
+ if colormap.isAutoscale():
+ # Avoid backend to compute autoscale: use item cache
+ colormap = colormap.copy()
+ colormap.setVRange(*colormap.getColormapRange(self))
+ else:
+ colormap = None
+
+ return backend.addImage(
+ data=image,
+ origin=gridInfo.origin,
+ scale=gridInfo.scale,
+ colormap=colormap,
+ alpha=self.getAlpha(),
+ )
+
+ elif visualization is self.Visualization.IRREGULAR_GRID:
+ gridInfo = self.__getRegularGridInfo()
+ if gridInfo is None:
+ return None
+
+ shape = gridInfo.shape
+ if shape is None: # No shape, no display
+ return None
+
+ rgbacolors = self.__applyColormapToData()
+
+ nbpoints = len(xFiltered)
+ if nbpoints == 1:
+ # single point, render as a square points
+ return backend.addCurve(
+ xFiltered,
+ yFiltered,
+ color=rgbacolors[mask],
+ gapcolor=None,
+ symbol="s",
+ linewidth=0,
+ linestyle="",
+ yaxis="left",
+ xerror=None,
+ yerror=None,
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=7,
+ baseline=None,
+ )
+
+ # Make shape include all points
+ gridOrder = gridInfo.order
+ if nbpoints != numpy.prod(shape):
+ if gridOrder == "row":
+ shape = int(numpy.ceil(nbpoints / shape[1])), shape[1]
+ else: # column-major order
+ shape = shape[0], int(numpy.ceil(nbpoints / shape[0]))
+
+ if shape[0] < 2 or shape[1] < 2: # Single line, at least 2 points
+ points = numpy.ones((2, nbpoints, 2), dtype=numpy.float64)
+ # Use row/column major depending on shape, not on info value
+ gridOrder = "row" if shape[0] == 1 else "column"
+
+ if gridOrder == "row":
+ points[0, :, 0] = xFiltered
+ points[0, :, 1] = yFiltered
+ else: # column-major order
+ points[0, :, 0] = yFiltered
+ points[0, :, 1] = xFiltered
+
+ # Add a second line that will be clipped in the end
+ points[1, :-1] = (
+ points[0, :-1]
+ + numpy.cross(points[0, 1:] - points[0, :-1], (0.0, 0.0, 1.0))[
+ :, :2
+ ]
+ )
+ points[1, -1] = (
+ points[0, -1]
+ + numpy.cross(points[0, -1] - points[0, -2], (0.0, 0.0, 1.0))[
+ :2
+ ]
+ )
+
+ points.shape = 2, nbpoints, 2 # Use same shape for both orders
+ coords, indices = _quadrilateral_grid_as_triangles(points)
+
+ elif gridOrder == "row": # row-major order
+ if nbpoints != numpy.prod(shape):
+ points = numpy.empty(
+ (numpy.prod(shape), 2), dtype=numpy.float64
+ )
+ points[:nbpoints, 0] = xFiltered
+ points[:nbpoints, 1] = yFiltered
+ # Index of last element of last fully filled row
+ index = (nbpoints // shape[1]) * shape[1]
+ points[nbpoints:, 0] = xFiltered[
+ index - (numpy.prod(shape) - nbpoints) : index
+ ]
+ points[nbpoints:, 1] = yFiltered[-1]
+ else:
+ points = numpy.transpose((xFiltered, yFiltered))
+ points.shape = shape[0], shape[1], 2
+
+ else: # column-major order
+ if nbpoints != numpy.prod(shape):
+ points = numpy.empty(
+ (numpy.prod(shape), 2), dtype=numpy.float64
+ )
+ points[:nbpoints, 0] = yFiltered
+ points[:nbpoints, 1] = xFiltered
+ # Index of last element of last fully filled column
+ index = (nbpoints // shape[0]) * shape[0]
+ points[nbpoints:, 0] = yFiltered[
+ index - (numpy.prod(shape) - nbpoints) : index
+ ]
+ points[nbpoints:, 1] = xFiltered[-1]
+ else:
+ points = numpy.transpose((yFiltered, xFiltered))
+ points.shape = shape[1], shape[0], 2
+
+ coords, indices = _quadrilateral_grid_as_triangles(points)
+
+ # Remove unused extra triangles
+ coords = coords[: 4 * nbpoints]
+ indices = indices[: 2 * nbpoints]
+
+ if gridOrder == "row":
+ x, y = coords[:, 0], coords[:, 1]
+ else: # column-major order
+ y, x = coords[:, 0], coords[:, 1]
+
+ rgbacolors = rgbacolors[mask] # Filter-out not finite points
+ gridcolors = numpy.empty(
+ (4 * nbpoints, rgbacolors.shape[-1]), dtype=rgbacolors.dtype
+ )
+ for first in range(4):
+ gridcolors[first::4] = rgbacolors[:nbpoints]
+
+ return backend.addTriangles(
+ x, y, indices, color=gridcolors, alpha=self.getAlpha()
+ )
+
+ else:
+ _logger.error("Unhandled visualization %s", visualization)
+ return None
+
+ @docstring(PointsBase)
+ def pick(self, x, y):
+ result = super(Scatter, self).pick(x, y)
+
+ if result is not None:
+ visualization = self.getVisualization()
+
+ if visualization is self.Visualization.IRREGULAR_GRID:
+ # Specific handling of picking for the irregular grid mode
+ index = result.getIndices(copy=False)[0] // 4
+ result = PickingResult(self, (index,))
+
+ elif visualization is self.Visualization.REGULAR_GRID:
+ # Specific handling of picking for the regular grid mode
+ picked = result.getIndices(copy=False)
+ if picked is None:
+ return None
+ row, column = picked[0][0], picked[1][0]
+
+ gridInfo = self.__getRegularGridInfo()
+ if gridInfo is None:
+ return None
+
+ if gridInfo.order == "row":
+ index = row * gridInfo.shape[1] + column
+ else:
+ index = row + column * gridInfo.shape[0]
+ if index >= len(
+ self.getXData(copy=False)
+ ): # OK as long as not log scale
+ return None # Image can be larger than scatter
+
+ result = PickingResult(self, (index,))
+
+ elif visualization is self.Visualization.BINNED_STATISTIC:
+ picked = result.getIndices(copy=False)
+ if picked is None or len(picked) == 0 or len(picked[0]) == 0:
+ return None
+ row, col = picked[0][0], picked[1][0]
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ return None
+ sx, sy = histoInfo.scale
+ ox, oy = histoInfo.origin
+ xdata = self.getXData(copy=False)
+ ydata = self.getYData(copy=False)
+ indices = numpy.nonzero(
+ numpy.logical_and(
+ numpy.logical_and(
+ xdata >= ox + sx * col, xdata < ox + sx * (col + 1)
+ ),
+ numpy.logical_and(
+ ydata >= oy + sy * row, ydata < oy + sy * (row + 1)
+ ),
+ )
+ )[0]
+ result = None if len(indices) == 0 else PickingResult(self, indices)
+
+ return result
+
+ def __getExecutor(self):
+ """Returns async greedy executor
+
+ :rtype: _GreedyThreadPoolExecutor
+ """
+ if self.__executor is None:
+ self.__executor = _GreedyThreadPoolExecutor(max_workers=2)
+ return self.__executor
+
+ def _getTriangulationFuture(self):
+ """Returns a :class:`Future` which result is the Triangulation object.
+
+ :rtype: concurrent.futures.Future
+ """
+ if self.__triangulationFuture is None or self.__triangulationFuture.cancelled():
+ # Need to init a new delaunay
+ x, y = self.getData(copy=False)[:2]
+ # Remove not finite points
+ mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
+
+ self.__triangulationFuture = self.__getExecutor().submit_greedy(
+ "Triangulation", Triangulation, x[mask], y[mask]
+ )
+
+ return self.__triangulationFuture
+
+ @staticmethod
+ def __initInterpolator(triangulationFuture, values):
+ """Returns an interpolator for the given data points
+
+ :param concurrent.futures.Future triangulationFuture:
+ Future object which result is a Triangulation object
+ :param numpy.ndarray values: The data value of valid points.
+ :rtype: Union[callable,None]
+ """
+ # Wait for Triangulation to complete
+ try:
+ triangulation = triangulationFuture.result()
+ except (RuntimeError, ValueError):
+ return None # triangulation failed
+ except CancelledError:
+ return None
+
+ return LinearTriInterpolator(triangulation, values)
+
+ def _getInterpolatorFuture(self):
+ """Returns a :class:`Future` which result is the interpolator.
+
+ The interpolator is a callable taking an array Nx2 of points
+ as a single argument.
+ The :class:`Future` result is None in case the interpolator cannot
+ be initialized.
+
+ :rtype: concurrent.futures.Future
+ """
+ if self.__interpolatorFuture is None or self.__interpolatorFuture.cancelled():
+ # Need to init a new interpolator
+ x, y, values = self.getData(copy=False)[:3]
+ # Remove not finite points
+ mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
+ x, y, values = x[mask], y[mask], values[mask]
+
+ self.__interpolatorFuture = self.__getExecutor().submit_greedy(
+ "interpolator",
+ self.__initInterpolator,
+ self._getTriangulationFuture(),
+ values,
+ )
+ return self.__interpolatorFuture
+
+ def _logFilterData(self, xPositive, yPositive):
+ """Filter out values with x or y <= 0 on log axes
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :return: The filtered arrays or unchanged object if not filtering needed
+ :rtype: (x, y, value, xerror, yerror)
+ """
+ # overloaded from PointsBase to filter also value.
+ value = self.getValueData(copy=False)
+
+ if xPositive or yPositive:
+ clipped = self._getClippingBoolArray(xPositive, yPositive)
+
+ if numpy.any(clipped):
+ # copy to keep original array and convert to float
+ value = numpy.array(value, copy=True, dtype=numpy.float64)
+ value[clipped] = numpy.nan
+
+ x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive)
+
+ return x, y, value, xerror, yerror
+
+ def getValueData(self, copy=True):
+ """Returns the value assigned to the scatter data points.
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._value, copy=copy)
+
+ def getAlphaData(self, copy=True):
+ """Returns the alpha (transparency) assigned to the scatter data points.
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.__alpha, copy=copy)
+
+ def getData(self, copy=True, displayed=False):
+ """Returns the x, y coordinates and the value of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param bool displayed: True to only get curve points that are displayed
+ in the plot. Default: False.
+ Note: If plot has log scale, negative points
+ are not displayed.
+ :returns: (x, y, value, xerror, yerror)
+ :rtype: 5-tuple of numpy.ndarray
+ """
+ if displayed:
+ data = self._getCachedData()
+ if data is not None:
+ assert len(data) == 5
+ return data
+
+ return (
+ self.getXData(copy),
+ self.getYData(copy),
+ self.getValueData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy),
+ )
+
+ # reimplemented from PointsBase to handle `value`
+ def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
+ """Set the data of the scatter.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param numpy.ndarray value: The data corresponding to the value of
+ the data points.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for lower errors,
+ row 1 for upper errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param alpha: Values with the transparency (between 0 and 1)
+ :type alpha: A float, or a numpy.ndarray of float32
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ value = numpy.array(value, copy=copy)
+ assert value.ndim == 1
+ assert len(x) == len(value)
+
+ # Convert complex data
+ if numpy.iscomplexobj(value):
+ _logger.warning("Converting value data to absolute value to plot it.")
+ value = numpy.absolute(value)
+
+ # Reset triangulation and interpolator
+ if self.__triangulationFuture is not None:
+ self.__triangulationFuture.cancel()
+ self.__triangulationFuture = None
+ if self.__interpolatorFuture is not None:
+ self.__interpolatorFuture.cancel()
+ self.__interpolatorFuture = None
+
+ # Data changed, this needs update
+ self.__cacheRegularGridInfo = None
+ self.__cacheHistogramInfo = None
+
+ self._value = value
+
+ if alpha is not None:
+ # Make sure alpha is an array of float in [0, 1]
+ alpha = numpy.array(alpha, copy=copy)
+ assert alpha.ndim == 1
+ assert len(x) == len(alpha)
+ if alpha.dtype.kind != "f":
+ alpha = alpha.astype(numpy.float32)
+ if numpy.any(numpy.logical_or(alpha < 0.0, alpha > 1.0)):
+ alpha = numpy.clip(alpha, 0.0, 1.0)
+ self.__alpha = alpha
+
+ # set x, y, xerror, yerror
+
+ # call self._updated + plot._invalidateDataRange()
+ PointsBase.setData(self, x, y, xerror, yerror, copy)
+
+ self._updateColormappedData()
diff --git a/src/silx/gui/plot/items/shape.py b/src/silx/gui/plot/items/shape.py
new file mode 100644
index 0000000..c911924
--- /dev/null
+++ b/src/silx/gui/plot/items/shape.py
@@ -0,0 +1,393 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Shape` item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+
+import logging
+
+import numpy
+
+from .core import (
+ Item,
+ DataItem,
+ AlphaMixIn,
+ ColorMixIn,
+ FillMixIn,
+ ItemChangedType,
+ LineMixIn,
+ LineGapColorMixIn,
+ YAxisMixIn,
+)
+from ....utils.deprecation import deprecated
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _OverlayItem(Item):
+ """Item with settable overlay"""
+
+ def __init__(self):
+ self.__overlay = False
+ Item.__init__(self)
+
+ def isOverlay(self) -> bool:
+ """Return true if shape is drawn as an overlay"""
+ return self.__overlay
+
+ def setOverlay(self, overlay: bool):
+ """Set the overlay state of the shape
+
+ :param overlay: True to make it an overlay
+ """
+ overlay = bool(overlay)
+ if overlay != self.__overlay:
+ self.__overlay = overlay
+ self._updated(ItemChangedType.OVERLAY)
+
+
+class _TwoColorsLineMixIn(LineMixIn, LineGapColorMixIn):
+ """Mix-in class for items with a background color for dashes"""
+
+ def __init__(self):
+ LineMixIn.__init__(self)
+ LineGapColorMixIn.__init__(self)
+
+ @deprecated(replacement="getLineGapColor", since_version="2.0.0")
+ def getLineBgColor(self):
+ return self.getLineGapColor()
+
+ @deprecated(replacement="setLineGapColor", since_version="2.0.0")
+ def setLineBgColor(self, color, copy: bool = True):
+ self.setLineGapColor(color)
+ self._updated(ItemChangedType.LINE_BG_COLOR)
+
+
+# TODO probably make one class for each kind of shape
+# TODO check fill:polygon/polyline + fill = duplicated
+class Shape(_OverlayItem, ColorMixIn, FillMixIn, _TwoColorsLineMixIn):
+ """Description of a shape item
+
+ :param str type_: The type of shape in:
+ 'hline', 'polygon', 'rectangle', 'vline', 'polylines'
+ """
+
+ def __init__(self, type_):
+ _OverlayItem.__init__(self)
+ ColorMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ _TwoColorsLineMixIn.__init__(self)
+ assert type_ in ("hline", "polygon", "rectangle", "vline", "polylines")
+ self._type = type_
+ self._points = ()
+ self._handle = None
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ points = self.getPoints(copy=False)
+ x, y = points.T[0], points.T[1]
+ return backend.addShape(
+ x,
+ y,
+ shape=self.getType(),
+ color=self.getColor(),
+ fill=self.isFill(),
+ overlay=self.isOverlay(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ gapcolor=self.getLineGapColor(),
+ )
+
+ def getType(self):
+ """Returns the type of shape to draw.
+
+ One of: 'hline', 'polygon', 'rectangle', 'vline', 'polylines'
+
+ :rtype: str
+ """
+ return self._type
+
+ def getPoints(self, copy=True):
+ """Get the control points of the shape.
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :return: Array of point coordinates
+ :rtype: numpy.ndarray with 2 dimensions
+ """
+ return numpy.array(self._points, copy=copy)
+
+ def setPoints(self, points, copy=True):
+ """Set the point coordinates
+
+ :param numpy.ndarray points: Array of point coordinates
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :return:
+ """
+ self._points = numpy.array(points, copy=copy)
+ self._updated(ItemChangedType.DATA)
+
+
+class BoundingRect(DataItem, YAxisMixIn):
+ """An invisible shape which enforce the plot view to display the defined
+ space on autoscale.
+
+ This item do not display anything. But if the visible property is true,
+ this bounding box is used by the plot, if not, the bounding box is
+ ignored. That's the default behaviour for plot items.
+
+ It can be applied on the "left" or "right" axes. Not both at the same time.
+ """
+
+ def __init__(self):
+ DataItem.__init__(self)
+ YAxisMixIn.__init__(self)
+ self.__bounds = None
+
+ def setBounds(self, rect):
+ """Set the bounding box of this item in data coordinates
+
+ :param Union[None,List[float]] rect: (xmin, xmax, ymin, ymax) or None
+ """
+ if rect is not None:
+ rect = float(rect[0]), float(rect[1]), float(rect[2]), float(rect[3])
+ assert rect[0] <= rect[1]
+ assert rect[2] <= rect[3]
+
+ if rect != self.__bounds:
+ self.__bounds = rect
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ def _getBounds(self):
+ if self.__bounds is None:
+ return None
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ if xPositive or yPositive:
+ bounds = list(self.__bounds)
+ if xPositive and bounds[1] <= 0:
+ return None
+ if xPositive and bounds[0] <= 0:
+ bounds[0] = bounds[1]
+ if yPositive and bounds[3] <= 0:
+ return None
+ if yPositive and bounds[2] <= 0:
+ bounds[2] = bounds[3]
+ return tuple(bounds)
+
+ return self.__bounds
+
+
+class _BaseExtent(DataItem):
+ """Base class for :class:`XAxisExtent` and :class:`YAxisExtent`.
+
+ :param str axis: Either 'x' or 'y'.
+ """
+
+ def __init__(self, axis="x"):
+ assert axis in ("x", "y")
+ DataItem.__init__(self)
+ self.__axis = axis
+ self.__range = 1.0, 100.0
+
+ def setRange(self, min_, max_):
+ """Set the range of the extent of this item in data coordinates.
+
+ :param float min_: Lower bound of the extent
+ :param float max_: Upper bound of the extent
+ :raises ValueError: If min > max or not finite bounds
+ """
+ range_ = float(min_), float(max_)
+ if not numpy.all(numpy.isfinite(range_)):
+ raise ValueError("min_ and max_ must be finite numbers.")
+ if range_[0] > range_[1]:
+ raise ValueError("min_ must be lesser or equal to max_")
+
+ if range_ != self.__range:
+ self.__range = range_
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ def getRange(self):
+ """Returns the range (min, max) of the extent in data coordinates.
+
+ :rtype: List[float]
+ """
+ return self.__range
+
+ def _getBounds(self):
+ min_, max_ = self.getRange()
+
+ plot = self.getPlot()
+ if plot is not None:
+ axis = plot.getXAxis() if self.__axis == "x" else plot.getYAxis()
+ if axis._isLogarithmic():
+ if max_ <= 0:
+ return None
+ if min_ <= 0:
+ min_ = max_
+
+ if self.__axis == "x":
+ return min_, max_, float("nan"), float("nan")
+ else:
+ return float("nan"), float("nan"), min_, max_
+
+
+class XAxisExtent(_BaseExtent):
+ """Invisible item with a settable horizontal data extent.
+
+ This item do not display anything, but it behaves as a data
+ item with a horizontal extent regarding plot data bounds, i.e.,
+ :meth:`PlotWidget.resetZoom` will take this horizontal extent into account.
+ """
+
+ def __init__(self):
+ _BaseExtent.__init__(self, axis="x")
+
+
+class YAxisExtent(_BaseExtent, YAxisMixIn):
+ """Invisible item with a settable vertical data extent.
+
+ This item do not display anything, but it behaves as a data
+ item with a vertical extent regarding plot data bounds, i.e.,
+ :meth:`PlotWidget.resetZoom` will take this vertical extent into account.
+ """
+
+ def __init__(self):
+ _BaseExtent.__init__(self, axis="y")
+ YAxisMixIn.__init__(self)
+
+
+class Line(_OverlayItem, AlphaMixIn, ColorMixIn, _TwoColorsLineMixIn):
+ """Description of a infinite line item as y = slope * x + interecpt
+
+ Warning: If slope is not finite, then the line is x = intercept.
+ """
+
+ def __init__(self, slope: float = 0, intercept: float = 0):
+ assert numpy.isfinite(intercept)
+
+ _OverlayItem.__init__(self)
+ AlphaMixIn.__init__(self)
+ ColorMixIn.__init__(self)
+ _TwoColorsLineMixIn.__init__(self)
+ self.__slope = float(slope)
+ self.__intercept = float(intercept)
+ self.__coordinates = None
+ self._setVisibleBoundsTracking(True)
+
+ def __updatePoints(self):
+ if not self.isVisible():
+ return
+
+ plot = self.getPlot()
+ if plot is None or not plot.isVisible():
+ return
+
+ xmin, xmax = plot.getXAxis().getLimits()
+ ymin, ymax = plot.getYAxis().getLimits()
+
+ slope = self.getSlope()
+ intercept = self.getIntercept()
+
+ if not numpy.isfinite(slope):
+ if not xmin <= intercept <= xmax:
+ coordinates = None
+ else:
+ coordinates = (intercept, intercept), (ymin, ymax)
+ else:
+ ycoords = slope * xmin + intercept, slope * xmax + intercept
+
+ if min(ycoords) < ymax and max(ycoords) > ymin:
+ coordinates = (xmin, xmax), ycoords
+ else:
+ coordinates = None
+
+ if coordinates != self.__coordinates:
+ self.__coordinates = coordinates
+ self._updated()
+
+ def _visibleBoundsChanged(self, *args) -> None:
+ """Override method to benefit from bounds tracking"""
+ self.__updatePoints()
+ return super()._visibleBoundsChanged(*args)
+
+ def setSlope(self, slope: float):
+ slope = float(slope)
+ if slope != self.__slope:
+ self.__slope = slope
+ self.__updatePoints()
+ self._updated(ItemChangedType.DATA)
+
+ def getSlope(self) -> float:
+ return self.__slope
+
+ def setIntercept(self, intercept: float):
+ intercept = float(intercept)
+ assert numpy.isfinite(intercept)
+ if intercept != self.__intercept:
+ self.__intercept = intercept
+ self.__updatePoints()
+ self._updated(ItemChangedType.DATA)
+
+ def getIntercept(self) -> float:
+ return self.__intercept
+
+ def setSlopeInterceptFromPoints(self, point0, point1):
+ """Set slope and intercept from 2 (x, y) points"""
+ x0, y0 = point0
+ x1, y1 = point1
+ if x0 == x1: # Special case: vertical line
+ self.setSlope(float("inf"))
+ self.setIntercept(x0)
+ return
+
+ slope = (y1 - y0) / (x1 - x0)
+ self.setSlope(slope)
+ self.setIntercept(y0 - x0 * slope)
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ if self.__coordinates is None:
+ return None
+
+ return backend.addShape(
+ *self.__coordinates,
+ shape="polylines",
+ color=self.getColor(),
+ fill=False,
+ overlay=self.isOverlay(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ gapcolor=self.getLineGapColor(),
+ )
diff --git a/src/silx/gui/plot/stats/__init__.py b/src/silx/gui/plot/stats/__init__.py
new file mode 100644
index 0000000..dfaa865
--- /dev/null
+++ b/src/silx/gui/plot/stats/__init__.py
@@ -0,0 +1,32 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "07/03/2018"
+
+
+from .stats import *
diff --git a/src/silx/gui/plot/stats/stats.py b/src/silx/gui/plot/stats/stats.py
new file mode 100644
index 0000000..d575e3f
--- /dev/null
+++ b/src/silx/gui/plot/stats/stats.py
@@ -0,0 +1,911 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides mechanism relative to stats calculation within a
+:class:`PlotWidget`.
+It also include the implementation of the statistics themselves.
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "06/06/2018"
+
+
+from functools import lru_cache
+import logging
+
+import numpy
+import numpy.ma
+
+from .. import items
+from ..CurvesROIWidget import ROI
+from ..items.roi import RegionOfInterest
+
+from ....math.combo import min_max
+from silx.utils.proxy import docstring
+
+logger = logging.getLogger(__name__)
+
+
+class Stats(dict):
+ """Class to define a set of statistic relative to a dataset
+ (image, curve...).
+
+ The goal of this class is to avoid multiple recalculation of some
+ basic operations such as filtering data area where the statistics has to
+ be apply.
+ Min and max are also stored because they can be used several time.
+
+ :param List statslist: List of the :class:`Stat` object to be computed.
+ """
+
+ def __init__(self, statslist=None):
+ super().__init__()
+ _statslist = statslist if not None else []
+ if statslist is not None:
+ for stat in _statslist:
+ self.add(stat)
+
+ def calculate(
+ self, item, plot, onlimits, roi, data_changed=False, roi_changed=False
+ ):
+ """
+ Call all :class:`Stat` object registered and return the result of the
+ computation.
+
+ :param item: the item for which we want statistics
+ :param plot: plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: region of interest for statistic calculation. Incompatible
+ with the `onlimits` option.
+ :type roi: Union[None, :class:`~_RegionOfInterestBase`]
+ :param bool data_changed: did the data changed since last calculation.
+ :param bool roi_changed: did the associated roi (if any) has changed
+ since last calculation.
+ :return dict: dictionary with :class:`Stat` name as ket and result
+ of the calculation as value
+ """
+ res = {}
+ context = self._getContext(item=item, plot=plot, onlimits=onlimits, roi=roi)
+ for statName, stat in list(self.items()):
+ if context.kind not in stat.compatibleKinds:
+ logger.debug(
+ "kind %s not managed by statistic %s" % (context.kind, stat.name)
+ )
+ res[statName] = None
+ else:
+ if roi_changed is True:
+ context.clear_mask()
+ if data_changed is True or roi_changed is True:
+ # if data changed or mask changed
+ context.clipData(item=item, plot=plot, onlimits=onlimits, roi=roi)
+ # init roi and data
+ res[statName] = stat.calculate(context)
+ return res
+
+ def __setitem__(self, key, value):
+ assert isinstance(value, StatBase)
+ super().__setitem__(key, value)
+
+ def add(self, stat):
+ """Add a :class:`Stat` to the set
+
+ :param Stat stat: stat to add to the set
+ """
+ self.__setitem__(key=stat.name, value=stat)
+
+ @staticmethod
+ @lru_cache(maxsize=50)
+ def _getContext(item, plot, onlimits, roi):
+ context = None
+ # Check for PlotWidget items
+ if isinstance(item, items.Curve):
+ context = _CurveContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, items.ImageData):
+ context = _ImageContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, items.Scatter):
+ context = _ScatterContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, items.Histogram):
+ context = _HistogramContext(item, plot, onlimits, roi=roi)
+ else:
+ # Check for SceneWidget items
+ from ...plot3d import items as items3d # Lazy import
+
+ if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)):
+ context = _plot3DScatterContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, (items3d.ImageData, items3d.ScalarField3D)):
+ context = _plot3DArrayContext(item, plot, onlimits, roi=roi)
+ if context is None:
+ raise ValueError("Item type not managed")
+ return context
+
+
+class _StatsContext(object):
+ """
+ The context is designed to be a simple buffer and avoid repetition of
+ calculations that can appear during stats evaluation.
+
+ .. warning:: this class gives access to the data to be used for computation
+ . It deal with filtering data visible by the user on plot.
+ The filtering is a simple data sub-sampling. No interpolation
+ is made to fit data to boundaries.
+
+ :param item: the item for which we want to compute the context
+ :param str kind: the kind of the item
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlimits` calculation
+ :type roi: Union[None,:class:`_RegionOfInterestBase`]
+ """
+
+ def __init__(self, item, kind, plot, onlimits, roi):
+ assert item
+ assert plot
+ assert type(onlimits) is bool
+ self.kind = kind
+ self.min = None
+ self.max = None
+ self.data = None
+ self.roi = None
+ self.onlimits = onlimits
+
+ self.values = None
+ """The array of data with limit filtering if any. Is a numpy.ma.array,
+ meaning that it embed the mask applied by the roi if any"""
+
+ self.axes = None
+ """A list of array of position on each axis.
+
+ If the signal is an array,
+ then each axis has the length of that dimension,
+ and the order is (z, y, x) (i.e., as the array shape).
+ If the signal is not an array,
+ then each axis has the same length as the signal,
+ and the order is (x, y, z).
+ """
+
+ self.clipData(item, plot, onlimits, roi=roi)
+
+ def clear_mask(self):
+ """
+ Remove the mask to force recomputation of it on next iteration
+ :return:
+ """
+ raise NotImplementedError()
+
+ @property
+ def mask(self):
+ if self.values is not None:
+ assert isinstance(self.values, numpy.ma.MaskedArray)
+ return self.values.mask
+ else:
+ return None
+
+ @property
+ def is_mask_valid(self, **kwargs):
+ """Return if the mask is valid for the data or need to be recomputed"""
+ raise NotImplementedError("Base class")
+
+ def _set_mask_validity(self, **kwargs):
+ """User to set some values that allows to define the mask properties
+ and boundaries"""
+ raise NotImplementedError("Base class")
+
+ def clipData(self, item, plot, onlimits, roi):
+ """Clip the data to the current mask to have accurate statistics
+
+ Function called before computing each statistics associated to this
+ context. It will insure the context for the (item, plot, onlimits, roi)
+ is created.
+
+ :param item: item for which we want statistics
+ :param plot: plot containing the statistics
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlimits` calculation
+ :type roi: Union[None,:class:`_RegionOfInterestBase`]
+ """
+ raise NotImplementedError("Base class")
+
+ def isStructuredData(self):
+ """Returns True if data as an array-like structure.
+
+ :rtype: bool
+ """
+ if self.values is None or self.axes is None:
+ return False
+
+ if numpy.prod([len(axis) for axis in self.axes]) == self.values.size:
+ return True
+ else:
+ # Make sure there is the right number of value in axes
+ for axis in self.axes:
+ assert len(axis) == self.values.size
+ return False
+
+ def isScalarData(self):
+ """Returns True if data is a scalar.
+
+ :rtype: bool
+ """
+ if self.values is None or self.axes is None:
+ return False
+ if self.isStructuredData():
+ return len(self.axes) == self.values.ndim
+ else:
+ return self.values.ndim == 1
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ if roi is not None and onlimits is True:
+ raise ValueError(
+ "Stats context is unable to manage both a ROI"
+ "and the `onlimits` option"
+ )
+
+
+class _ScatterCurveHistoMixInContext(_StatsContext):
+ def __init__(self, kind, item, plot, onlimits, roi):
+ self.clear_mask()
+ _StatsContext.__init__(
+ self, item=item, kind=kind, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ def _set_mask_validity(self, onlimits, from_, to_):
+ self._onlimits = onlimits
+ self._from_ = from_
+ self._to_ = to_
+
+ def clear_mask(self):
+ self._onlimits = None
+ self._from_ = None
+ self._to_ = None
+
+ def is_mask_valid(self, onlimits, from_, to_):
+ return onlimits == self.onlimits and from_ == self._from_ and to_ == self._to_
+
+
+class _CurveContext(_ScatterCurveHistoMixInContext):
+ """
+ StatsContext for :class:`Curve`
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+
+ def __init__(self, item, plot, onlimits, roi):
+ _ScatterCurveHistoMixInContext.__init__(
+ self, kind="curve", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
+ self.roi = roi
+ self.onlimits = onlimits
+ xData, yData = item.getData(copy=True)[0:2]
+
+ if onlimits:
+ minX, maxX = plot.getXAxis().getLimits()
+ if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
+ mask = self.mask
+ else:
+ mask = (minX <= xData) & (xData <= maxX)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
+ elif roi:
+ minX, maxX = roi.getFrom(), roi.getTo()
+ if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
+ mask = self.mask
+ else:
+ mask = (minX <= xData) & (xData <= maxX)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
+ else:
+ mask = numpy.zeros_like(yData)
+
+ mask = mask.astype(numpy.uint32)
+ self.xData = xData
+ self.yData = yData
+ self.values = numpy.ma.array(yData, mask=mask)
+ unmasked_data = self.values.compressed()
+ if len(unmasked_data) > 0:
+ self.min, self.max = min_max(unmasked_data)
+ else:
+ self.min, self.max = None, None
+ self.data = (xData, yData)
+ self.axes = (xData,)
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+ if roi is not None and not isinstance(roi, ROI):
+ raise TypeError("curve `context` can ony manage 1D roi")
+
+
+class _HistogramContext(_ScatterCurveHistoMixInContext):
+ """
+ StatsContext for :class:`Histogram`
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+
+ def __init__(self, item, plot, onlimits, roi):
+ _ScatterCurveHistoMixInContext.__init__(
+ self, kind="histogram", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
+ yData, edges = item.getData(copy=True)[0:2]
+ xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment())
+
+ if onlimits:
+ minX, maxX = plot.getXAxis().getLimits()
+ if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
+ mask = self.mask
+ else:
+ mask = (minX <= xData) & (xData <= maxX)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
+ elif roi:
+ if self.is_mask_valid(
+ onlimits=onlimits, from_=roi._fromdata, to_=roi._todata
+ ):
+ mask = self.mask
+ else:
+ mask = (roi._fromdata <= xData) & (xData <= roi._todata)
+ mask = mask == 0
+ self._set_mask_validity(
+ onlimits=onlimits, from_=roi._fromdata, to_=roi._todata
+ )
+ else:
+ mask = numpy.zeros_like(yData)
+ mask = mask.astype(numpy.uint32)
+ self.xData = xData
+ self.yData = yData
+ self.values = numpy.ma.array(yData, mask=(mask))
+ unmasked_data = self.values.compressed()
+ if len(unmasked_data) > 0:
+ self.min, self.max = min_max(unmasked_data)
+ else:
+ self.min, self.max = None, None
+ self.data = (self.xData, self.yData)
+ self.axes = (self.xData,)
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ if roi is not None and not isinstance(roi, ROI):
+ raise TypeError("curve `context` can ony manage 1D roi")
+
+
+class _ScatterContext(_ScatterCurveHistoMixInContext):
+ """StatsContext scatter plots.
+
+ It supports :class:`~silx.gui.plot.items.Scatter`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+
+ def __init__(self, item, plot, onlimits, roi):
+ _ScatterCurveHistoMixInContext.__init__(
+ self, kind="scatter", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ @docstring(_ScatterCurveHistoMixInContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
+ valueData = item.getValueData(copy=True)
+ xData = item.getXData(copy=True)
+ yData = item.getYData(copy=True)
+
+ if onlimits:
+ minX, maxX = plot.getXAxis().getLimits()
+ minY, maxY = plot.getYAxis().getLimits()
+
+ # filter on X axis
+ valueData = valueData[(minX <= xData) & (xData <= maxX)]
+ yData = yData[(minX <= xData) & (xData <= maxX)]
+ xData = xData[(minX <= xData) & (xData <= maxX)]
+ # filter on Y axis
+ valueData = valueData[(minY <= yData) & (yData <= maxY)]
+ xData = xData[(minY <= yData) & (yData <= maxY)]
+ yData = yData[(minY <= yData) & (yData <= maxY)]
+
+ if roi:
+ if self.is_mask_valid(
+ onlimits=onlimits, from_=roi.getFrom(), to_=roi.getTo()
+ ):
+ mask = self.mask
+ else:
+ mask = (xData < roi.getFrom()) | (xData > roi.getTo())
+ else:
+ mask = numpy.zeros_like(xData)
+
+ self.data = (xData, yData, valueData)
+ self.values = numpy.ma.array(valueData, mask=mask)
+ self.axes = (xData, yData)
+
+ unmasked_values = self.values.compressed()
+ if len(unmasked_values) > 0:
+ self.min, self.max = min_max(unmasked_values)
+ else:
+ self.min, self.max = None, None
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ if roi is not None and not isinstance(roi, ROI):
+ raise TypeError("curve `context` can ony manage 1D roi")
+
+
+class _ImageContext(_StatsContext):
+ """StatsContext for images.
+
+ It supports :class:`~silx.gui.plot.items.ImageData`.
+
+ :warning: behaviour of scale images: now the statistics are computed on
+ the entire data array (there is no sampling in the array or
+ interpolation regarding the scale).
+ This also mean that the result can differ from what is displayed.
+ But I guess there is no perfect behaviour.
+
+ :warning: `isIn` functions for image context: for now have basically a
+ binary approach, the pixel is in a roi or not. To have a fully
+ 'correct behaviour' we should add a weight on stats calculation
+ to moderate the pixel value.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+
+ def __init__(self, item, plot, onlimits, roi):
+ self.clear_mask()
+ _StatsContext.__init__(
+ self, kind="image", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ def _set_mask_validity(self, xmin: float, xmax: float, ymin: float, ymax: float):
+ self._mask_x_min = xmin
+ self._mask_x_max = xmax
+ self._mask_y_min = ymin
+ self._mask_y_max = ymax
+
+ def clear_mask(self):
+ self._mask_x_min = None
+ self._mask_x_max = None
+ self._mask_y_min = None
+ self._mask_y_max = None
+
+ def is_mask_valid(self, xmin, xmax, ymin, ymax):
+ return (
+ xmin == self._mask_x_min
+ and xmax == self._mask_x_max
+ and ymin == self._mask_y_min
+ and ymax == self._mask_y_max
+ )
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
+ self.origin = item.getOrigin()
+ self.scale = item.getScale()
+
+ self.data = item.getData(copy=True)
+ mask = numpy.zeros_like(self.data)
+ """mask use to know of the stat should be count in or not"""
+
+ if onlimits:
+ minX, maxX = plot.getXAxis().getLimits()
+ minY, maxY = plot.getYAxis().getLimits()
+
+ XMinBound = int((minX - self.origin[0]) / self.scale[0])
+ YMinBound = int((minY - self.origin[1]) / self.scale[1])
+ XMaxBound = int((maxX - self.origin[0]) / self.scale[0])
+ YMaxBound = int((maxY - self.origin[1]) / self.scale[1])
+
+ XMinBound = max(XMinBound, 0)
+ YMinBound = max(YMinBound, 0)
+
+ if onlimits:
+ if XMaxBound <= XMinBound or YMaxBound <= YMinBound:
+ self.data = None
+ else:
+ self.data = self.data[
+ YMinBound : YMaxBound + 1, XMinBound : XMaxBound + 1
+ ]
+ mask = numpy.zeros_like(self.data)
+ elif roi:
+ minX, maxX = 0, self.data.shape[1]
+ minY, maxY = 0, self.data.shape[0]
+
+ XMinBound = max(minX, 0)
+ YMinBound = max(minY, 0)
+ XMaxBound = min(maxX, self.data.shape[1])
+ YMaxBound = min(maxY, self.data.shape[0])
+
+ if self.is_mask_valid(
+ xmin=XMinBound, xmax=XMaxBound, ymin=YMinBound, ymax=YMaxBound
+ ):
+ mask = self.mask
+ else:
+ for x in range(XMinBound, XMaxBound):
+ for y in range(YMinBound, YMaxBound):
+ _x = (x * self.scale[0]) + self.origin[0]
+ _y = (y * self.scale[1]) + self.origin[1]
+ mask[y, x] = not roi.contains((_x, _y))
+ self._set_mask_validity(
+ xmin=XMinBound, xmax=XMaxBound, ymin=YMinBound, ymax=YMaxBound
+ )
+ self.values = numpy.ma.array(self.data, mask=mask)
+ if self.values.compressed().size > 0:
+ self.min, self.max = min_max(self.values.compressed())
+ else:
+ self.min, self.max = None, None
+
+ if self.values is not None:
+ self.axes = (
+ self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]),
+ self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1]),
+ )
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ if roi is not None and not isinstance(roi, RegionOfInterest):
+ raise TypeError("curve `context` can ony manage 2D roi")
+
+
+class _plot3DScatterContext(_StatsContext):
+ """StatsContext for 3D scatter plots.
+
+ It supports :class:`~silx.gui.plot3d.items.Scatter2D` and
+ :class:`~silx.gui.plot3d.items.Scatter3D`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+
+ def __init__(self, item, plot, onlimits, roi):
+ _StatsContext.__init__(
+ self, kind="scatter", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
+ if onlimits:
+ raise RuntimeError("Unsupported plot %s" % str(plot))
+ values = item.getValueData(copy=False)
+ if roi:
+ logger.warning("Roi are unsupported on volume for now")
+ mask = numpy.zeros_like(values)
+ else:
+ mask = numpy.zeros_like(values)
+
+ if values is not None and len(values) > 0:
+ self.values = values
+ axes = [item.getXData(copy=False), item.getYData(copy=False)]
+ if self.values.ndim == 3:
+ axes.append(item.getZData(copy=False))
+ self.axes = tuple(axes)
+ self.min, self.max = min_max(self.values)
+ self.values = numpy.ma.array(self.values, mask=mask)
+ else:
+ self.values = None
+ self.axes = None
+ self.min, self.max = None, None
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ if roi is not None and not isinstance(roi, RegionOfInterest):
+ raise TypeError("curve `context` can ony manage 2D roi")
+
+
+class _plot3DArrayContext(_StatsContext):
+ """StatsContext for 3D scalar field and data image.
+
+ It supports :class:`~silx.gui.plot3d.items.ScalarField3D` and
+ :class:`~silx.gui.plot3d.items.ImageData`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+
+ def __init__(self, item, plot, onlimits, roi):
+ _StatsContext.__init__(
+ self, kind="image", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
+ if onlimits:
+ raise RuntimeError("Unsupported plot %s" % str(plot))
+
+ values = item.getData(copy=False)
+ if roi:
+ logger.warning("Roi are unsuported on volume for now")
+ mask = numpy.zeros_like(values)
+ else:
+ mask = numpy.zeros_like(values)
+
+ if values is not None and len(values) > 0:
+ self.values = values
+ self.axes = tuple([numpy.arange(size) for size in self.values.shape])
+ self.min, self.max = min_max(self.values)
+ self.values = numpy.ma.array(self.values, mask=mask)
+ else:
+ self.values = None
+ self.axes = None
+ self.min, self.max = None, None
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
+
+ if roi is not None and not isinstance(roi, RegionOfInterest):
+ raise TypeError("curve `context` can ony manage 2D roi")
+
+
+BASIC_COMPATIBLE_KINDS = "curve", "image", "scatter", "histogram"
+
+
+class StatBase(object):
+ """
+ Base class for defining a statistic.
+
+ :param str name: the name of the statistic. Must be unique.
+ :param List[str] compatibleKinds:
+ The kind of items (curve, scatter...) for which the statistic apply.
+ """
+
+ def __init__(self, name, compatibleKinds=BASIC_COMPATIBLE_KINDS, description=None):
+ self.name = name
+ self.compatibleKinds = compatibleKinds
+ self.description = description
+
+ def calculate(self, context):
+ """
+ compute the statistic for the given :class:`StatsContext`
+
+ :param _StatsContext context:
+ :return dict: key is stat name, statistic computed is the dict value
+ """
+ raise NotImplementedError("Base class")
+
+ def getToolTip(self, kind):
+ """
+ If necessary add a tooltip for a stat kind
+
+ :param str kind: the kind of item the statistic is compute for.
+ :return: tooltip or None if no tooltip
+ """
+ return None
+
+
+class Stat(StatBase):
+ """
+ Create a StatBase class based on a function pointer.
+
+ :param str name: name of the statistic. Used as id
+ :param fct: function which should have as unique mandatory parameter the
+ data. Should be able to adapt to all `kinds` defined as
+ compatible
+ :param tuple kinds: the compatible item kinds of the function (curve,
+ image...)
+ """
+
+ def __init__(self, name, fct, kinds=BASIC_COMPATIBLE_KINDS):
+ StatBase.__init__(self, name, kinds)
+ self._fct = fct
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ if context.values is not None:
+ if context.kind in self.compatibleKinds:
+ return self._fct(context.values)
+ else:
+ raise ValueError(
+ "Kind %s not managed by %s" "" % (context.kind, self.name)
+ )
+ else:
+ return None
+
+
+class StatMin(StatBase):
+ """Compute the minimal value on data"""
+
+ def __init__(self):
+ StatBase.__init__(self, name="min")
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ return context.min
+
+
+class StatMax(StatBase):
+ """Compute the maximal value on data"""
+
+ def __init__(self):
+ StatBase.__init__(self, name="max")
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ return context.max
+
+
+class StatDelta(StatBase):
+ """Compute the delta between minimal and maximal on data"""
+
+ def __init__(self):
+ StatBase.__init__(self, name="delta")
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ return context.max - context.min
+
+
+class _StatCoord(StatBase):
+ """Base class for argmin and argmax stats"""
+
+ def _indexToCoordinates(self, context, index):
+ """Returns the coordinates of data point at given index
+
+ If data is an array, coordinates are in reverse order from data shape.
+
+ :param _StatsContext context:
+ :param int index: Index in the flattened data array
+ :rtype: List[int]
+ """
+
+ axes = context.axes
+
+ if context.isStructuredData() or context.roi:
+ coordinates = []
+ for axis in reversed(axes):
+ coordinates.append(axis[index % len(axis)])
+ index = index // len(axis)
+ return tuple(coordinates)
+ else:
+ return tuple(axis[index] for axis in axes)
+
+
+class StatCoordMin(_StatCoord):
+ """Compute the coordinates of the first minimum value of the data"""
+
+ def __init__(self):
+ _StatCoord.__init__(self, name="coords min")
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ if context.values is None or not context.isScalarData():
+ return None
+
+ index = context.values.argmin()
+ return self._indexToCoordinates(context, index)
+
+ @docstring(StatBase)
+ def getToolTip(self, kind):
+ return "Coordinates of the first minimum value of the data"
+
+
+class StatCoordMax(_StatCoord):
+ """Compute the coordinates of the first maximum value of the data"""
+
+ def __init__(self):
+ _StatCoord.__init__(self, name="coords max")
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ if context.values is None or not context.isScalarData():
+ return None
+
+ # TODO: the values should be a mask array by default, will be simpler
+ # if possible
+ index = context.values.argmax()
+ return self._indexToCoordinates(context, index)
+
+ @docstring(StatBase)
+ def getToolTip(self, kind):
+ return "Coordinates of the first maximum value of the data"
+
+
+class StatCOM(StatBase):
+ """Compute data center of mass"""
+
+ def __init__(self):
+ StatBase.__init__(self, name="COM", description="Center of mass")
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ if context.values is None or not context.isScalarData():
+ return None
+
+ values = numpy.ma.array(context.values, mask=context.mask, dtype=numpy.float64)
+ sum_ = numpy.sum(values)
+ if sum_ == 0.0 or numpy.ma.is_masked(sum_):
+ return (numpy.nan,) * len(context.axes)
+
+ if context.isStructuredData():
+ centerofmass = []
+ for index, axis in enumerate(context.axes):
+ axes = tuple([i for i in range(len(context.axes)) if i != index])
+ centerofmass.append(
+ numpy.sum(axis * numpy.sum(values, axis=axes)) / sum_
+ )
+ return tuple(reversed(centerofmass))
+ else:
+ return tuple(numpy.sum(axis * values) / sum_ for axis in context.axes)
+
+ @docstring(StatBase)
+ def getToolTip(self, kind):
+ return "Compute the center of mass of the dataset"
diff --git a/src/silx/gui/plot/stats/statshandler.py b/src/silx/gui/plot/stats/statshandler.py
new file mode 100644
index 0000000..8e7e08b
--- /dev/null
+++ b/src/silx/gui/plot/stats/statshandler.py
@@ -0,0 +1,222 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module containts the classes relative to the management of statistics
+display.
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "05/06/2018"
+
+
+import logging
+import numbers
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.plot import stats as statsmdl
+
+logger = logging.getLogger(__name__)
+
+
+class _FloatItem(qt.QTableWidgetItem):
+ """Simple QTableWidgetItem allowing ordering on floats"""
+
+ def __init__(self, type=qt.QTableWidgetItem.Type):
+ qt.QTableWidgetItem.__init__(self, type=type)
+
+ def __lt__(self, other):
+ self_values = self.text().lstrip("(").rstrip(")").split(",")
+ other_values = other.text().lstrip("(").rstrip(")").split(",")
+ for self_value, other_value in zip(self_values, other_values):
+ f_self_value = float(self_value)
+ f_other_value = float(other_value)
+ if f_self_value != f_other_value:
+ return f_self_value < f_other_value
+ return False
+
+
+class StatFormatter(object):
+ """
+ Class used to apply format on :class:`Stat`
+
+ :param formatter: the formatter. Defined as str.format()
+ :param qItemClass: the class inheriting from :class:`QTableWidgetItem`
+ which will be used to display the result of the
+ statistic computation.
+ """
+
+ DEFAULT_FORMATTER = "{0:.3f}"
+
+ def __init__(self, formatter=DEFAULT_FORMATTER, qItemClass=_FloatItem):
+ self.formatter = formatter
+ self.tabWidgetItemClass = qItemClass
+
+ def format(self, val):
+ if val is None or numpy.ma.is_masked(val):
+ return "--"
+
+ if self.formatter is not None and isinstance(val, numbers.Number):
+ return self.formatter.format(val)
+
+ return str(val)
+
+
+class StatsHandler(object):
+ """
+ Give
+ create:
+
+ * Stats object which will manage the statistic computation
+ * Associate formatter and :class:`Stat`
+
+ :param statFormatters: Stat and optional formatter.
+ If elements are given as a tuple, elements
+ should be (:class:`Stat`, formatter).
+ Otherwise should be :class:`Stat` elements.
+ :rtype: List or tuple
+ """
+
+ def __init__(self, statFormatters):
+ self.stats = statsmdl.Stats()
+ self.formatters = {}
+ for elmt in statFormatters:
+ stat, formatter = self._processStatArgument(elmt)
+ self.add(stat=stat, formatter=formatter)
+
+ @staticmethod
+ def _processStatArgument(arg):
+ """Process an element of the init arguments
+
+ :param arg: The argument to process
+ :return: Corresponding (StatBase, StatFormatter)
+ """
+ stat, formatter = None, None
+
+ if isinstance(arg, statsmdl.StatBase):
+ stat = arg
+ else:
+ assert len(arg) > 0
+ if isinstance(arg[0], statsmdl.StatBase):
+ stat = arg[0]
+ if len(arg) > 2:
+ raise ValueError(
+ "To many argument with %s. At most one "
+ "argument can be associated with the "
+ "BaseStat (the `StatFormatter`"
+ )
+ if len(arg) == 2:
+ assert arg[1] is None or isinstance(arg[1], (StatFormatter, str))
+ formatter = arg[1]
+ else:
+ if isinstance(arg[0], tuple):
+ if len(arg) > 1:
+ formatter = arg[1]
+ arg = arg[0]
+
+ if type(arg[0]) is not str:
+ raise ValueError(
+ "first element of the tuple should be a string"
+ " or a StatBase instance"
+ )
+ if len(arg) == 1:
+ raise ValueError(
+ "A function should be associated with the" "stat name"
+ )
+ if len(arg) > 3:
+ raise ValueError(
+ "Two much argument given for defining statistic."
+ "Take at most three arguments (name, function, "
+ "kinds)"
+ )
+ if len(arg) == 2:
+ stat = statsmdl.Stat(name=arg[0], fct=arg[1])
+ else:
+ stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2])
+
+ return stat, formatter
+
+ def add(self, stat, formatter=None):
+ """Add a stat to the list.
+
+ :param StatBase stat:
+ :param Union[None,StatFormatter] formatter:
+ """
+ assert isinstance(stat, statsmdl.StatBase)
+ self.stats.add(stat)
+ _formatter = formatter
+ if type(_formatter) is str:
+ _formatter = StatFormatter(formatter=_formatter)
+ self.formatters[stat.name] = _formatter
+
+ def format(self, name, val):
+ """Apply the format for the `name` statistic and the given value
+
+ :param str name: the name of the associated statistic
+ :param val: value before formatting
+ :return: formatted value
+ """
+ if name not in self.formatters:
+ logger.warning("statistic %s haven't been registred" % name)
+ return val
+ else:
+ if self.formatters[name] is None:
+ return str(val)
+ else:
+ if isinstance(val, (tuple, list)):
+ res = []
+ [res.append(self.formatters[name].format(_val)) for _val in val]
+ return ", ".join(res)
+ else:
+ return self.formatters[name].format(val)
+
+ def calculate(
+ self, item, plot, onlimits, roi=None, data_changed=False, roi_changed=False
+ ):
+ """
+ compute all statistic registered and return the list of formatted
+ statistics result.
+
+ :param item: item for which we want to compute statistics
+ :param plot: plot containing the item
+ :param onlimits: True if we want to compute statistics on visible data
+ only
+ :type: bool
+ :param roi: region of interest for statistic calculation
+ :type: Union[None,:class:`_RegionOfInterestBase`]
+ :return: list of formatted statistics (as str)
+ :rtype: dict
+ """
+ res = self.stats.calculate(
+ item,
+ plot,
+ onlimits,
+ roi,
+ data_changed=data_changed,
+ roi_changed=roi_changed,
+ )
+ for resName, resValue in list(res.items()):
+ res[resName] = self.format(resName, res[resName])
+ return res
diff --git a/src/silx/gui/plot/test/__init__.py b/src/silx/gui/plot/test/__init__.py
new file mode 100644
index 0000000..78821ec
--- /dev/null
+++ b/src/silx/gui/plot/test/__init__.py
@@ -0,0 +1,23 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot/test/conftest.py b/src/silx/gui/plot/test/conftest.py
new file mode 100644
index 0000000..78475fb
--- /dev/null
+++ b/src/silx/gui/plot/test/conftest.py
@@ -0,0 +1,43 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test PlotWidget active item"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "13/12/2023"
+
+
+import pytest
+from silx.gui.plot import PlotWidget
+
+
+@pytest.fixture
+def plotWidget(qWidgetFactory, request):
+ try:
+ backend = request.param
+ except AttributeError:
+ backend = "mpl" # Backend was not defined
+ if backend == "gl":
+ request.getfixturevalue("use_opengl") # Skip test if OpenGL test disabled
+ yield qWidgetFactory(PlotWidget, backend=backend)
diff --git a/src/silx/gui/plot/test/testAlphaSlider.py b/src/silx/gui/plot/test/testAlphaSlider.py
new file mode 100644
index 0000000..e9ccb45
--- /dev/null
+++ b/src/silx/gui/plot/test/testAlphaSlider.py
@@ -0,0 +1,189 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ImageAlphaSlider"""
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/03/2017"
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import PlotWidget
+from silx.gui.plot import AlphaSlider
+
+
+class TestActiveImageAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestActiveImageAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.ActiveImageAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestActiveImageAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no active image initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
+ # now we have an active image
+ self.assertTrue(self.aslider.isEnabled())
+
+ self.plot.setActiveImage(None)
+ self.assertFalse(self.aslider.isEnabled())
+
+ def testGetImage(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
+ self.assertEqual(self.plot.getActiveImage(), self.aslider.getItem())
+
+ self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
+ self.plot.setActiveImage("2")
+ self.assertEqual(self.plot.getImage("2"), self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setValue(137)
+ self.assertAlmostEqual(self.aslider.getAlpha(), 137.0 / 255)
+
+
+class TestNamedImageAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestNamedImageAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.NamedImageAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestNamedImageAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no image set initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setLegend("1")
+ # now we have an image set
+ self.assertTrue(self.aslider.isEnabled())
+
+ def testGetImage(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
+ self.aslider.setLegend("1")
+ self.assertEqual(self.plot.getImage("1"), self.aslider.getItem())
+
+ self.aslider.setLegend("2")
+ self.assertEqual(self.plot.getImage("2"), self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setLegend("1")
+ self.aslider.setValue(128)
+ self.assertAlmostEqual(self.aslider.getAlpha(), 128.0 / 255)
+
+
+class TestNamedScatterAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestNamedScatterAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.NamedScatterAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestNamedScatterAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no Scatter set initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7], legend="1")
+ self.aslider.setLegend("1")
+ # now we have an image set
+ self.assertTrue(self.aslider.isEnabled())
+
+ def testGetScatter(self):
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7], legend="1")
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70], legend="2")
+ self.aslider.setLegend("1")
+ self.assertEqual(self.plot.getScatter("1"), self.aslider.getItem())
+
+ self.aslider.setLegend("2")
+ self.assertEqual(self.plot.getScatter("2"), self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70], legend="1")
+ self.aslider.setLegend("1")
+ self.aslider.setValue(128)
+ self.assertAlmostEqual(self.aslider.getAlpha(), 128.0 / 255)
diff --git a/src/silx/gui/plot/test/testAxis.py b/src/silx/gui/plot/test/testAxis.py
new file mode 100644
index 0000000..dcf2f06
--- /dev/null
+++ b/src/silx/gui/plot/test/testAxis.py
@@ -0,0 +1,147 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests of PlotWidget Axis items"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/06/2023"
+
+
+from silx.gui.plot import PlotWidget
+
+
+def testAxisIsVisible(qapp, qWidgetFactory):
+ """Test Axis.isVisible method"""
+ plotWidget = qWidgetFactory(PlotWidget)
+
+ assert plotWidget.getXAxis().isVisible()
+ assert plotWidget.getYAxis().isVisible()
+ assert not plotWidget.getYAxis("right").isVisible()
+
+ # Add curve on right axis
+ plotWidget.addCurve((0, 1, 2), (1, 2, 3), yaxis="right")
+ qapp.processEvents()
+
+ assert plotWidget.getYAxis("right").isVisible()
+
+ # hide curve on right axis
+ curve = plotWidget.getItems()[0]
+ curve.setVisible(False)
+ qapp.processEvents()
+
+ assert not plotWidget.getYAxis("right").isVisible()
+
+ # show curve on right axis
+ curve.setVisible(True)
+ qapp.processEvents()
+
+ assert plotWidget.getYAxis("right").isVisible()
+
+ # Move curve to left axis
+ curve.setYAxis("left")
+ qapp.processEvents()
+
+ assert not plotWidget.getYAxis("right").isVisible()
+
+
+def testAxisSetScaleLogNoData(qapp, qWidgetFactory):
+ """Test Axis.setScale('log') method with an empty plot
+
+ Limits are reset only when negative
+ """
+ plotWidget = qWidgetFactory(PlotWidget)
+ xaxis = plotWidget.getXAxis()
+ yaxis = plotWidget.getYAxis()
+ y2axis = plotWidget.getYAxis("right")
+
+ xaxis.setLimits(-1.0, 1.0)
+ yaxis.setLimits(2.0, 3.0)
+ y2axis.setLimits(-2.0, -1.0)
+
+ xaxis.setScale("log")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (1.0, 100.0)
+ assert yaxis.getLimits() == (2.0, 3.0)
+ assert y2axis.getLimits() == (-2.0, -1.0)
+
+ xaxis.setLimits(10.0, 20.0)
+
+ yaxis.setScale("log")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (10.0, 20.0)
+ assert yaxis.getLimits() == (2.0, 3.0) # Positive range is preserved
+ assert y2axis.getLimits() == (1.0, 100.0) # Negative min is reset
+
+
+def testAxisSetScaleLogWithData(qapp, qWidgetFactory):
+ """Test Axis.setScale('log') method with data
+
+ Limits are reset only when negative and takes the data range into account
+ """
+ plotWidget = qWidgetFactory(PlotWidget)
+ xaxis = plotWidget.getXAxis()
+ yaxis = plotWidget.getYAxis()
+ plotWidget.addCurve((-1, 1, 2, 3), (-1, 1, 2, 3))
+
+ xaxis.setLimits(-1.0, 0.5) # Limits contains no positive data
+ yaxis.setLimits(-1.0, 2.0) # Limits contains positive data
+
+ xaxis.setScale("log")
+ yaxis.setScale("log")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (1.0, 3.0) # Reset to positive data range
+ assert yaxis.getLimits() == (1.0, 2.0) # Keep max limit
+
+
+def testAxisSetScaleLinear(qapp, qWidgetFactory):
+ """Test Axis.setScale('linear') method: Limits are not changed"""
+ plotWidget = qWidgetFactory(PlotWidget)
+ xaxis = plotWidget.getXAxis()
+ yaxis = plotWidget.getYAxis()
+ y2axis = plotWidget.getYAxis("right")
+ xaxis.setScale("log")
+ yaxis.setScale("log")
+ plotWidget.resetZoom()
+ qapp.processEvents()
+
+ xaxis.setLimits(10.0, 1000.0)
+ yaxis.setLimits(20.0, 2000.0)
+ y2axis.setLimits(30.0, 3000.0)
+
+ xaxis.setScale("linear")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (10.0, 1000.0)
+ assert yaxis.getLimits() == (20.0, 2000.0)
+ assert y2axis.getLimits() == (30.0, 3000.0)
+
+ yaxis.setScale("linear")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (10.0, 1000.0)
+ assert yaxis.getLimits() == (20.0, 2000.0)
+ assert y2axis.getLimits() == (30.0, 3000.0)
diff --git a/src/silx/gui/plot/test/testColorBar.py b/src/silx/gui/plot/test/testColorBar.py
new file mode 100644
index 0000000..7202bc2
--- /dev/null
+++ b/src/silx/gui/plot/test/testColorBar.py
@@ -0,0 +1,317 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ColorBar featues and sub widgets of Colorbar module"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.ColorBar import _ColorScale
+from silx.gui.plot.ColorBar import ColorBarWidget
+from silx.gui.colors import Colormap
+from silx.math.colormap import LinearNormalization, LogarithmicNormalization
+from silx.gui.plot import Plot2D
+from silx.gui import qt
+import numpy
+
+
+class TestColorScale(TestCaseQt):
+ """Test that interaction with the colorScale is correct"""
+
+ def setUp(self):
+ super(TestColorScale, self).setUp()
+ self.colorScaleWidget = _ColorScale(colormap=None, parent=None)
+ self.colorScaleWidget.show()
+ self.qWaitForWindowExposed(self.colorScaleWidget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.colorScaleWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.colorScaleWidget.close()
+ del self.colorScaleWidget
+ super(TestColorScale, self).tearDown()
+
+ def testNoColormap(self):
+ """Test _ColorScale without a colormap"""
+ colormap = self.colorScaleWidget.getColormap()
+ self.assertIsNone(colormap)
+
+ def testRelativePositionLinear(self):
+ self.colorMapLin1 = Colormap(
+ name="gray", normalization=Colormap.LINEAR, vmin=0.0, vmax=1.0
+ )
+ self.colorScaleWidget.setColormap(self.colorMapLin1)
+
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25
+ )
+ self.assertTrue(self.colorScaleWidget.getValueFromRelativePosition(0.5) == 0.5)
+ self.assertTrue(self.colorScaleWidget.getValueFromRelativePosition(1.0) == 1.0)
+
+ self.colorMapLin2 = Colormap(
+ name="viridis", normalization=Colormap.LINEAR, vmin=-10, vmax=0
+ )
+ self.colorScaleWidget.setColormap(self.colorMapLin2)
+
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5
+ )
+ self.assertTrue(self.colorScaleWidget.getValueFromRelativePosition(0.5) == -5.0)
+ self.assertTrue(self.colorScaleWidget.getValueFromRelativePosition(1.0) == 0.0)
+
+ def testRelativePositionLog(self):
+ self.colorMapLog1 = Colormap(
+ name="temperature", normalization=Colormap.LOGARITHM, vmin=1.0, vmax=100.0
+ )
+
+ self.colorScaleWidget.setColormap(self.colorMapLog1)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(1.0)
+ self.assertAlmostEqual(val, 100.0)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(0.5)
+ self.assertAlmostEqual(val, 10.0)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == 1.0)
+
+
+class TestNoAutoscale(TestCaseQt):
+ """Test that ticks and color displayed are correct in the case of a colormap
+ with no autoscale
+ """
+
+ def setUp(self):
+ super(TestNoAutoscale, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = self.plot.getColorBarWidget()
+ self.colorBar.setVisible(True) # Makes sure the colormap is visible
+ self.tickBar = self.colorBar.getColorScaleBar().getTickBar()
+ self.colorScale = self.colorBar.getColorScaleBar().getColorScale()
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.tickBar = None
+ self.colorScale = None
+ del self.colorBar
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestNoAutoscale, self).tearDown()
+
+ def testLogNormNoAutoscale(self):
+ colormapLog = Colormap(
+ name="gray", normalization=Colormap.LOGARITHM, vmin=1.0, vmax=100.0
+ )
+
+ data = numpy.linspace(10, 1e10, 9).reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend="toto")
+ self.plot.setActiveImage("toto")
+
+ # test Ticks
+ self.tickBar.setTicksNumber(10)
+ self.tickBar.computeTicks()
+
+ ticksTh = numpy.linspace(1.0, 100.0, 10)
+ ticksTh = 10**ticksTh
+ numpy.array_equal(self.tickBar.ticks, ticksTh)
+
+ # test ColorScale
+ val = self.colorScale.getValueFromRelativePosition(1.0)
+ self.assertAlmostEqual(val, 100.0)
+
+ val = self.colorScale.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == 1.0)
+
+ def testLinearNormNoAutoscale(self):
+ colormapLog = Colormap(
+ name="gray", normalization=Colormap.LINEAR, vmin=-4, vmax=5
+ )
+
+ data = numpy.linspace(1, 9, 9).reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend="toto")
+ self.plot.setActiveImage("toto")
+
+ # test Ticks
+ self.tickBar.setTicksNumber(10)
+ self.tickBar.computeTicks()
+
+ numpy.array_equal(self.tickBar.ticks, numpy.linspace(-4, 5, 10))
+
+ # test ColorScale
+ val = self.colorScale.getValueFromRelativePosition(1.0)
+ self.assertTrue(val == 5.0)
+
+ val = self.colorScale.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == -4.0)
+
+
+class TestColorBarWidget(TestCaseQt):
+ """Test interaction with the ColorBarWidget"""
+
+ def setUp(self):
+ super(TestColorBarWidget, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = self.plot.getColorBarWidget()
+ self.colorBar.setVisible(True) # Makes sure the colormap is visible
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ del self.colorBar
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestColorBarWidget, self).tearDown()
+
+ def testEmptyColorBar(self):
+ colorBar = ColorBarWidget(parent=None)
+ colorBar.show()
+ self.qWaitForWindowExposed(colorBar)
+
+ def testNegativeColormaps(self):
+ """test the behavior of the ColorBarWidget in the case of negative
+ values
+
+ Note : colorbar is modified by the Plot directly not ColorBarWidget
+ """
+ colormapLog = Colormap(
+ name="gray", normalization=Colormap.LOGARITHM, vmin=None, vmax=None
+ )
+
+ data = numpy.array([-5, -4, 0, 2, 3, 5, 10, 20, 30])
+ data = data.reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend="toto")
+ self.plot.setActiveImage("toto")
+
+ # default behavior when with log and negative values: should set vmin
+ # to 1 and vmax to 10
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == 2)
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 30)
+
+ # if data is positive
+ data[data < 1] = data.max()
+ self.plot.addImage(data=data, colormap=colormapLog, legend="toto", replace=True)
+ self.plot.setActiveImage("toto")
+
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == data.min())
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == data.max())
+
+ def testPlotAssocation(self):
+ """Make sure the ColorBarWidget is properly connected with the plot"""
+ colormap = Colormap(
+ name="gray", normalization=Colormap.LINEAR, vmin=None, vmax=None
+ )
+
+ # make sure that default settings are the same (but a copy of the
+ self.colorBar.setPlot(self.plot)
+ self.assertTrue(self.colorBar.getColormap() is self.plot.getDefaultColormap())
+
+ data = numpy.linspace(0, 10, 100).reshape(10, 10)
+ self.plot.addImage(data=data, colormap=colormap, legend="toto")
+ self.plot.setActiveImage("toto")
+
+ # make sure the modification of the colormap has been done
+ self.assertFalse(self.colorBar.getColormap() is self.plot.getDefaultColormap())
+ self.assertTrue(self.colorBar.getColormap() is colormap)
+
+ # test that colorbar is updated when default plot colormap changes
+ self.plot.clear()
+ plotColormap = Colormap(
+ name="gray", normalization=Colormap.LOGARITHM, vmin=None, vmax=None
+ )
+ self.plot.setDefaultColormap(plotColormap)
+ self.assertTrue(self.colorBar.getColormap() is plotColormap)
+
+ def testColormapWithoutRange(self):
+ """Test with a colormap with vmin==vmax"""
+ colormap = Colormap(
+ name="gray", normalization=Colormap.LINEAR, vmin=1.0, vmax=1.0
+ )
+ self.colorBar.setColormap(colormap)
+
+
+class TestColorBarUpdate(TestCaseQt):
+ """Test that the ColorBar is correctly updated when the signal 'sigChanged'
+ of the colormap is emitted
+ """
+
+ def setUp(self):
+ super(TestColorBarUpdate, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = self.plot.getColorBarWidget()
+ self.colorBar.setVisible(True) # Makes sure the colormap is visible
+ self.colorBar.setPlot(self.plot)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ self.data = numpy.random.rand(9).reshape(3, 3)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ del self.colorBar
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestColorBarUpdate, self).tearDown()
+
+ def testUpdateColorMap(self):
+ colormap = Colormap(name="gray", normalization="linear", vmin=0, vmax=1)
+
+ # check inital state
+ self.plot.addImage(data=self.data, colormap=colormap, legend="toto")
+ self.plot.setActiveImage("toto")
+
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0)
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 1)
+ self.assertTrue(self.colorBar.getColorScaleBar().getTickBar()._vmin == 0)
+ self.assertTrue(self.colorBar.getColorScaleBar().getTickBar()._vmax == 1)
+ self.assertIsInstance(
+ self.colorBar.getColorScaleBar().getTickBar()._normalizer,
+ LinearNormalization,
+ )
+
+ # update colormap
+ colormap.setVMin(0.5)
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0.5)
+ self.assertTrue(self.colorBar.getColorScaleBar().getTickBar()._vmin == 0.5)
+
+ colormap.setVMax(0.8)
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 0.8)
+ self.assertTrue(self.colorBar.getColorScaleBar().getTickBar()._vmax == 0.8)
+
+ colormap.setNormalization("log")
+ self.assertIsInstance(
+ self.colorBar.getColorScaleBar().getTickBar()._normalizer,
+ LogarithmicNormalization,
+ )
+
+ # TODO : should also check that if the colormap is changing then values (especially in log scale)
+ # should be coherent if in autoscale
diff --git a/src/silx/gui/plot/test/testCompareImages.py b/src/silx/gui/plot/test/testCompareImages.py
new file mode 100644
index 0000000..4bc52b4
--- /dev/null
+++ b/src/silx/gui/plot/test/testCompareImages.py
@@ -0,0 +1,236 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for CompareImages widget"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "23/07/2018"
+
+import pytest
+import numpy
+import weakref
+
+from silx.gui import qt
+from silx.gui.plot.CompareImages import CompareImages
+
+
+@pytest.fixture
+def compareImages(qapp, qapp_utils):
+ widget = CompareImages()
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ yield widget
+ widget.close()
+ ref = weakref.ref(widget)
+ widget = None
+ qapp_utils.qWaitForDestroy(ref)
+
+
+def testIntensityImage(compareImages):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(10, 10)
+ compareImages.setData(image1, image2)
+
+
+def testRgbImage(compareImages):
+ image1 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ compareImages.setData(image1, image2)
+
+
+def testRgbaImage(compareImages):
+ image1 = numpy.random.randint(0, 255, size=(10, 10, 4))
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 4))
+ compareImages.setData(image1, image2)
+
+
+def testAlignemnt(compareImages):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(5, 5)
+ compareImages.setData(image1, image2)
+ for mode in CompareImages.AlignmentMode:
+ compareImages.setAlignmentMode(mode)
+
+
+def testGetPixel(compareImages):
+ image1 = numpy.random.rand(11, 11)
+ image2 = numpy.random.rand(5, 5)
+ image1[5, 5] = 111.111
+ image2[2, 2] = 222.222
+ compareImages.setData(image1, image2)
+ expectedValue = {}
+ expectedValue[CompareImages.AlignmentMode.CENTER] = 222.222
+ expectedValue[CompareImages.AlignmentMode.STRETCH] = 222.222
+ expectedValue[CompareImages.AlignmentMode.ORIGIN] = None
+ for mode in expectedValue.keys():
+ compareImages.setAlignmentMode(mode)
+ data = compareImages.getRawPixelData(11 / 2.0, 11 / 2.0)
+ data1, data2 = data
+ assert data1 == 111.111
+ assert data2 == expectedValue[mode]
+
+
+def testImageEmpty(compareImages):
+ compareImages.setData(image1=None, image2=None)
+
+
+def testSetImageSeparately(compareImages):
+ compareImages.setImage1(numpy.random.rand(10, 10))
+ compareImages.setImage2(numpy.random.rand(10, 10))
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationMode(compareImages, data):
+ (visualizationMode,) = data
+ compareImages.setImage1(numpy.random.rand(10, 10))
+ compareImages.setImage2(numpy.random.rand(10, 10))
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationModeWithoutImage(compareImages, data):
+ (visualizationMode,) = data
+ compareImages.setImage1(None)
+ compareImages.setImage2(None)
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationModeWithOnlyImage1(compareImages, data):
+ (visualizationMode,) = data
+ compareImages.setImage1(numpy.random.rand(10, 10))
+ compareImages.setImage2(None)
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationModeWithOnlyImage2(compareImages, data):
+ (visualizationMode,) = data
+ compareImages.setImage1(None)
+ compareImages.setImage2(numpy.random.rand(10, 10))
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationModeWithRGBImage(compareImages, data):
+ (visualizationMode,) = data
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ compareImages.setData(image1, image2)
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.AlignmentMode.STRETCH,),
+ (CompareImages.AlignmentMode.AUTO,),
+ (CompareImages.AlignmentMode.CENTER,),
+ (CompareImages.AlignmentMode.ORIGIN,),
+ ],
+)
+def testAlignemntModeWithoutImages(compareImages, data):
+ (alignmentMode,) = data
+ compareImages.setAlignmentMode(alignmentMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.AlignmentMode.STRETCH,),
+ (CompareImages.AlignmentMode.AUTO,),
+ (CompareImages.AlignmentMode.CENTER,),
+ (CompareImages.AlignmentMode.ORIGIN,),
+ ],
+)
+def testAlignemntModeWithSingleImage(compareImages, data):
+ (alignmentMode,) = data
+ compareImages.setImage1(numpy.arange(9).reshape(3, 3))
+ compareImages.setAlignmentMode(alignmentMode)
+
+
+def testTooltip(compareImages):
+ compareImages.setImage1(numpy.arange(9).reshape(3, 3))
+ compareImages.setImage2(numpy.arange(9).reshape(3, 3))
+ compareImages.getRawPixelData(1.5, 1.5)
+
+
+def testTooltipWithoutImage(compareImages):
+ compareImages.setImage1(numpy.arange(9).reshape(3, 3))
+ compareImages.setImage2(numpy.arange(9).reshape(3, 3))
+ compareImages.getRawPixelData(1.5, 1.5)
+
+
+def testTooltipWithSingleImage(compareImages):
+ compareImages.setImage1(numpy.arange(9).reshape(3, 3))
+ compareImages.getRawPixelData(1.5, 1.5)
diff --git a/src/silx/gui/plot/test/testComplexImageView.py b/src/silx/gui/plot/test/testComplexImageView.py
new file mode 100644
index 0000000..f8b331b
--- /dev/null
+++ b/src/silx/gui/plot/test/testComplexImageView.py
@@ -0,0 +1,82 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test suite for :class:`ComplexImageView`"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.plot import ComplexImageView
+
+from .utils import PlotWidgetTestCase
+
+
+logger = logging.getLogger(__name__)
+
+
+class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase):
+ """Test suite of ComplexImageView widget"""
+
+ def _createPlot(self):
+ return ComplexImageView.ComplexImageView()
+
+ def testPlot2DComplex(self):
+ """Test API of ComplexImageView widget"""
+ data = numpy.array(((0, 1j), (1, 1 + 1j)), dtype=numpy.complex64)
+ self.plot.setData(data)
+ self.plot.setKeepDataAspectRatio(True)
+ self.plot.getPlot().resetZoom()
+ self.qWait(100)
+
+ # Test colormap API
+ colormap = self.plot.getColormap().copy()
+ colormap.setName("magma")
+ self.plot.setColormap(colormap)
+ self.qWait(100)
+
+ # Test all modes
+ modes = self.plot.supportedComplexModes()
+ for mode in modes:
+ with self.subTest(mode=mode):
+ self.plot.setComplexMode(mode)
+ self.qWait(100)
+
+ # Test origin and scale API
+ self.plot.setScale((2, 1))
+ self.qWait(100)
+ self.plot.setOrigin((1, 1))
+ self.qWait(100)
+
+ # Test no data
+ self.plot.setData(numpy.zeros((0, 0), dtype=numpy.complex64))
+ self.qWait(100)
+
+ # Test float data
+ self.plot.setData(numpy.arange(100, dtype=numpy.float64).reshape(10, 10))
+ self.qWait(100)
diff --git a/src/silx/gui/plot/test/testCurvesROIWidget.py b/src/silx/gui/plot/test/testCurvesROIWidget.py
new file mode 100644
index 0000000..05acd36
--- /dev/null
+++ b/src/silx/gui/plot/test/testCurvesROIWidget.py
@@ -0,0 +1,489 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for CurvesROIWidget"""
+
+__authors__ = ["T. Vincent", "P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "16/11/2017"
+
+
+import logging
+import os.path
+import numpy
+
+from silx.gui import qt
+from silx.gui.plot import items
+from silx.gui.plot import Plot1D
+from silx.test.utils import temp_dir
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+from silx.gui.plot import PlotWindow, CurvesROIWidget
+from silx.gui.utils.testutils import getQToolButtonFromAction
+
+_logger = logging.getLogger(__name__)
+
+
+class TestCurvesROIWidget(TestCaseQt):
+ """Basic test for CurvesROIWidget"""
+
+ def setUp(self):
+ super(TestCurvesROIWidget, self).setUp()
+ self.plot = PlotWindow()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.widget = self.plot.getCurvesRoiDockWidget()
+
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+
+ super(TestCurvesROIWidget, self).tearDown()
+
+ def testDummyAPI(self):
+ """Simple test of the getRois and setRois API"""
+ roi_neg = CurvesROIWidget.ROI(
+ name="negative", fromdata=-20, todata=-10, type_="X"
+ )
+ roi_pos = CurvesROIWidget.ROI(
+ name="positive", fromdata=10, todata=20, type_="X"
+ )
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ rois_defs = self.widget.roiWidget.getRois()
+ self.widget.roiWidget.setRois(rois=rois_defs)
+
+ def testWithCurves(self):
+ """Plot with curves: test all ROI widget buttons"""
+ for offset in range(2):
+ self.plot.addCurve(
+ numpy.arange(1000),
+ offset + numpy.random.random(1000),
+ legend=str(offset),
+ )
+
+ # Add two ROI
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
+
+ # Change active curve
+ self.plot.setActiveCurve(str(1))
+
+ # Delete a ROI
+ self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton)
+ self.qWait(200)
+
+ with temp_dir() as tmpDir:
+ self.tmpFile = os.path.join(tmpDir, "test.ini")
+
+ # Save ROIs
+ self.widget.roiWidget.save(self.tmpFile)
+ self.assertTrue(os.path.isfile(self.tmpFile))
+ self.assertEqual(len(self.widget.getRois()), 2)
+
+ # Reset ROIs
+ self.mouseClick(self.widget.roiWidget.resetButton, qt.Qt.LeftButton)
+ self.qWait(200)
+ rois = self.widget.getRois()
+ self.assertEqual(len(rois), 1)
+ roiID = list(rois.keys())[0]
+ self.assertEqual(rois[roiID].getName(), "ICR")
+
+ # Load ROIs
+ self.widget.roiWidget.load(self.tmpFile)
+ self.assertEqual(len(self.widget.getRois()), 2)
+
+ del self.tmpFile
+
+ def testMiddleMarker(self):
+ """Test with middle marker enabled"""
+ self.widget.roiWidget.roiTable.setMiddleROIMarkerFlag(True)
+
+ # Add a ROI
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+
+ for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers:
+ handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[
+ roiID
+ ]
+ assert handler.getMarker("min")
+ xleftMarker = handler.getMarker("min").getXPosition()
+ xMiddleMarker = handler.getMarker("middle").getXPosition()
+ xRightMarker = handler.getMarker("max").getXPosition()
+ thValue = xleftMarker + (xRightMarker - xleftMarker) / 2.0
+ self.assertAlmostEqual(xMiddleMarker, thValue)
+
+ def testAreaCalculation(self):
+ """Test result of area calculation"""
+ x = numpy.arange(100.0)
+ y = numpy.arange(100.0)
+
+ # Add two curves
+ self.plot.addCurve(x, y, legend="positive")
+ self.plot.addCurve(-x, y, legend="negative")
+
+ # Make sure there is an active curve and it is the positive one
+ self.plot.setActiveCurve("positive")
+
+ # Add two ROIs
+ roi_neg = CurvesROIWidget.ROI(
+ name="negative", fromdata=-20, todata=-10, type_="X"
+ )
+ roi_pos = CurvesROIWidget.ROI(
+ name="positive", fromdata=10, todata=20, type_="X"
+ )
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ posCurve = self.plot.getCurve("positive")
+ negCurve = self.plot.getCurve("negative")
+
+ self.assertEqual(
+ roi_pos.computeRawAndNetArea(posCurve),
+ (numpy.trapz(y=[10, 20], x=[10, 20]), 0.0),
+ )
+ self.assertEqual(roi_pos.computeRawAndNetArea(negCurve), (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(posCurve), ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(negCurve), ((-150.0), 0.0))
+
+ def testCountsCalculation(self):
+ """Test result of count calculation"""
+ x = numpy.arange(100.0)
+ y = numpy.arange(100.0)
+
+ # Add two curves
+ self.plot.addCurve(x, y, legend="positive")
+ self.plot.addCurve(-x, y, legend="negative")
+
+ # Make sure there is an active curve and it is the positive one
+ self.plot.setActiveCurve("positive")
+
+ # Add two ROIs
+ roi_neg = CurvesROIWidget.ROI(
+ name="negative", fromdata=-20, todata=-10, type_="X"
+ )
+ roi_pos = CurvesROIWidget.ROI(
+ name="positive", fromdata=10, todata=20, type_="X"
+ )
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ posCurve = self.plot.getCurve("positive")
+ negCurve = self.plot.getCurve("negative")
+
+ self.assertEqual(
+ roi_pos.computeRawAndNetCounts(posCurve), (y[10:21].sum(), 0.0)
+ )
+ self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve), (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve), ((0.0), 0.0))
+ self.assertEqual(
+ roi_neg.computeRawAndNetCounts(negCurve), (y[10:21].sum(), 0.0)
+ )
+
+ def testDeferedInit(self):
+ """Test behavior of the deferedInit"""
+ x = numpy.arange(100.0)
+ y = numpy.arange(100.0)
+ self.plot.addCurve(x=x, y=y, legend="name", replace="True")
+ roisDefs = dict(
+ [
+ ["range1", dict([["from", 20], ["to", 200], ["type", "energy"]])],
+ ["range2", dict([["from", 300], ["to", 500], ["type", "energy"]])],
+ ]
+ )
+
+ roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
+ self.plot.getCurvesRoiDockWidget().setRois(roisDefs)
+ self.assertEqual(len(roiWidget.getRois()), len(roisDefs))
+ self.plot.getCurvesRoiDockWidget().setVisible(True)
+ self.assertEqual(len(roiWidget.getRois()), len(roisDefs))
+
+ def testDictCompatibility(self):
+ """Test that ROI api is valid with dict and not information is lost"""
+ roiDict = {
+ "from": 20,
+ "to": 200,
+ "type": "energy",
+ "comment": "no",
+ "name": "myROI",
+ "calibration": [1, 2, 3],
+ }
+ roi = CurvesROIWidget.ROI._fromDict(roiDict)
+ self.assertEqual(roi.toDict(), roiDict)
+
+ def testShowAllROI(self):
+ """Test the show allROI action"""
+ x = numpy.arange(100.0)
+ y = numpy.arange(100.0)
+ self.plot.addCurve(x=x, y=y, legend="name", replace="True")
+
+ roisDefsDict = {
+ "range1": {"from": 20, "to": 200, "type": "energy"},
+ "range2": {"from": 300, "to": 500, "type": "energy"},
+ }
+
+ roisDefsObj = (
+ CurvesROIWidget.ROI(name="range3", fromdata=20, todata=200, type_="energy"),
+ CurvesROIWidget.ROI(
+ name="range4", fromdata=300, todata=500, type_="energy"
+ ),
+ )
+ self.widget.roiWidget.showAllMarkers(True)
+ roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
+ roiWidget.setRois(roisDefsDict)
+ markers = [
+ item for item in self.plot.getItems() if isinstance(item, items.MarkerBase)
+ ]
+ self.assertEqual(len(markers), 2 * 3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 1)
+
+ roiWidget.setRois(roisDefsObj)
+ self.qapp.processEvents()
+ markers = [
+ item for item in self.plot.getItems() if isinstance(item, items.MarkerBase)
+ ]
+ self.assertEqual(len(markers), 2 * 3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 1)
+
+ def testRoiEdition(self):
+ """Make sure if the ROI object is edited the ROITable will be updated"""
+ roi = CurvesROIWidget.ROI(name="linear", fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+
+ x = (0, 1, 1, 2, 2, 3)
+ y = (1, 1, 2, 2, 1, 1)
+ self.plot.addCurve(x=x, y=y, legend="linearCurve")
+ self.plot.setActiveCurve(legend="linearCurve")
+ self.widget.calculateROIs()
+
+ roiTable = self.widget.roiWidget.roiTable
+ indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX
+ itemRawCounts = roiTable.item(0, indexesColumns["Raw Counts"])
+ itemNetCounts = roiTable.item(0, indexesColumns["Net Counts"])
+
+ self.assertTrue(itemRawCounts.text() == "8.0")
+ self.assertTrue(itemNetCounts.text() == "2.0")
+
+ itemRawArea = roiTable.item(0, indexesColumns["Raw Area"])
+ itemNetArea = roiTable.item(0, indexesColumns["Net Area"])
+
+ self.assertTrue(itemRawArea.text() == "4.0")
+ self.assertTrue(itemNetArea.text() == "1.0")
+
+ roi.setTo(2)
+ itemRawArea = roiTable.item(0, indexesColumns["Raw Area"])
+ self.assertTrue(itemRawArea.text() == "3.0")
+ roi.setFrom(1)
+ itemRawArea = roiTable.item(0, indexesColumns["Raw Area"])
+ self.assertTrue(itemRawArea.text() == "2.0")
+
+ def testRemoveActiveROI(self):
+ """Test widget behavior when removing the active ROI"""
+ roi = CurvesROIWidget.ROI(name="linear", fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+
+ self.widget.roiWidget.roiTable.setActiveRoi(None)
+ self.assertEqual(len(self.widget.roiWidget.roiTable.selectedItems()), 0)
+ self.widget.roiWidget.setRois((roi,))
+ self.plot.setActiveCurve(legend="linearCurve")
+ self.widget.calculateROIs()
+
+ def testEmitCurrentROI(self):
+ """Test behavior of the CurvesROIWidget.sigROISignal"""
+ roi = CurvesROIWidget.ROI(name="linear", fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+ signalListener = SignalListener()
+ self.widget.roiWidget.sigROISignal.connect(signalListener.partial())
+ self.widget.show()
+ self.qapp.processEvents()
+ self.assertEqual(signalListener.callCount(), 0)
+ self.assertIs(self.widget.roiWidget.roiTable.activeRoi, roi)
+ roi.setFrom(0.0)
+ self.qapp.processEvents()
+ self.assertEqual(signalListener.callCount(), 0)
+ roi.setFrom(0.3)
+ self.qapp.processEvents()
+ self.assertEqual(signalListener.callCount(), 1)
+
+
+class TestRoiWidgetSignals(TestCaseQt):
+ """Test Signals emitted by the RoiWidgetSignals"""
+
+ def setUp(self):
+ super().setUp()
+
+ self.plot = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot.addCurve(x, y, legend="curve0")
+ self.listener = SignalListener()
+ self.curves_roi_widget = self.plot.getCurvesRoiWidget()
+ self.curves_roi_widget.sigROISignal.connect(self.listener)
+ assert self.curves_roi_widget.isVisible() is False
+ assert self.listener.callCount() == 0
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ toolButton = getQToolButtonFromAction(self.plot.getRoiAction())
+ self.qapp.processEvents()
+ self.mouseClick(widget=toolButton, button=qt.Qt.LeftButton)
+
+ self.curves_roi_widget.show()
+ self.qWaitForWindowExposed(self.curves_roi_widget)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+
+ self.curves_roi_widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.curves_roi_widget.close()
+ del self.curves_roi_widget
+
+ super().tearDown()
+
+ def testSigROISignalAddRmRois(self):
+ """Test SigROISignal when adding and removing ROIS"""
+ self.listener.clear()
+
+ roi1 = CurvesROIWidget.ROI(name="linear", fromdata=0, todata=5)
+ self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "linear")
+ self.listener.clear()
+
+ roi2 = CurvesROIWidget.ROI(name="linear2", fromdata=0, todata=5)
+ self.curves_roi_widget.roiTable.addRoi(roi2)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "linear2")
+ self.listener.clear()
+
+ self.curves_roi_widget.roiTable.removeROI(roi2)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "linear")
+ self.listener.clear()
+
+ self.curves_roi_widget.roiTable.deleteActiveRoi()
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.curves_roi_widget.roiTable.activeRoi is None)
+ self.assertTrue(self.listener.arguments()[0][0]["current"] is None)
+ self.listener.clear()
+
+ self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "linear")
+ self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
+ self.listener.clear()
+ self.qapp.processEvents()
+
+ self.curves_roi_widget.roiTable.removeROI(roi1)
+ self.qapp.processEvents()
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "ICR")
+ self.listener.clear()
+
+ def testSigROISignalModifyROI(self):
+ """Test SigROISignal when modifying it"""
+ self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True)
+ roi1 = CurvesROIWidget.ROI(name="linear", fromdata=2, todata=5)
+ self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.curves_roi_widget.roiTable.setActiveRoi(roi1)
+
+ # test modify the roi2 object
+ self.listener.clear()
+ roi1.setFrom(0.56)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.listener.clear()
+ roi1.setTo(2.56)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.listener.clear()
+ roi1.setName("linear2")
+ self.assertEqual(self.listener.callCount(), 1)
+ self.listener.clear()
+ roi1.setType("new type")
+ self.assertEqual(self.listener.callCount(), 1)
+
+ widget = self.plot.getWidgetHandle()
+ widget.setFocus(qt.Qt.OtherFocusReason)
+ self.plot.raise_()
+ self.qapp.processEvents()
+
+ # modify roi limits (from the gui)
+ roi_marker_handler = (
+ self.curves_roi_widget.roiTable._markersHandler.getMarkerHandler(
+ roi1.getID()
+ )
+ )
+ for marker_type in ("min", "max", "middle"):
+ with self.subTest(marker_type=marker_type):
+ self.listener.clear()
+ marker = roi_marker_handler.getMarker(marker_type)
+ x_pix, y_pix = self.plot.dataToPixel(
+ marker.getXPosition(), marker.getYPosition()
+ )
+ self.mouseMove(widget, pos=(x_pix, y_pix))
+ self.qWait(100)
+ self.mousePress(widget, qt.Qt.LeftButton, pos=(x_pix, y_pix))
+ self.mouseMove(widget, pos=(x_pix + 20, y_pix))
+ self.qWait(100)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=(x_pix + 20, y_pix))
+ self.qWait(100)
+ self.mouseMove(widget, pos=(x_pix, y_pix))
+ self.qapp.processEvents()
+ self.assertEqual(self.listener.callCount(), 1)
+
+ def testSetActiveCurve(self):
+ """Test sigRoiSignal when set an active curve"""
+ roi1 = CurvesROIWidget.ROI(name="linear", fromdata=2, todata=5)
+ self.curves_roi_widget.roiTable.setActiveRoi(roi1)
+ self.listener.clear()
+ self.plot.setActiveCurve("curve0")
+ self.assertEqual(self.listener.callCount(), 0)
diff --git a/src/silx/gui/plot/test/testImageStack.py b/src/silx/gui/plot/test/testImageStack.py
new file mode 100644
index 0000000..482cdfd
--- /dev/null
+++ b/src/silx/gui/plot/test/testImageStack.py
@@ -0,0 +1,234 @@
+# /*##########################################################################
+#
+# Copyright (c) 2020-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ImageStack"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "15/01/2020"
+
+
+import tempfile
+import numpy
+import h5py
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.io.url import DataUrl
+from silx.gui.plot.ImageStack import ImageStack
+from silx.gui.utils.testutils import SignalListener
+import os
+import time
+import shutil
+
+
+class TestImageStack(TestCaseQt):
+ """Simple test of the Image stack"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.urls = {}
+ self._raw_data = {}
+ self._folder = tempfile.mkdtemp()
+ self._n_urls = 10
+ file_name = os.path.join(self._folder, "test_inage_stack_file.h5")
+ with h5py.File(file_name, "w") as h5f:
+ for i in range(self._n_urls):
+ width = numpy.random.randint(10, 40)
+ height = numpy.random.randint(10, 40)
+ raw_data = numpy.random.random((width, height))
+ self._raw_data[i] = raw_data
+ h5f[str(i)] = raw_data
+ self.urls[i] = DataUrl(
+ file_path=file_name, data_path=str(i), scheme="silx"
+ )
+ self.widget = ImageStack()
+
+ self.urlLoadedListener = SignalListener()
+ self.widget.sigLoaded.connect(self.urlLoadedListener)
+
+ self.currentUrlChangedListener = SignalListener()
+ self.widget.sigCurrentUrlChanged.connect(self.currentUrlChangedListener)
+
+ def tearDown(self):
+ shutil.rmtree(self._folder)
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.widget.close()
+ TestCaseQt.setUp(self)
+
+ def testControls(self):
+ """Test that selection using the url table and the slider are working"""
+ self.widget.show()
+ self.assertEqual(self.widget.getCurrentUrl(), None)
+ self.assertEqual(self.widget.getCurrentUrlIndex(), None)
+ self.widget.setUrls(list(self.urls.values()))
+
+ # wait for image to be loaded
+ self._waitUntilUrlLoaded()
+
+ self.assertEqual(self.widget.getCurrentUrl(), self.urls[0])
+
+ # make sure all image are loaded
+ self.assertEqual(self.urlLoadedListener.callCount(), self._n_urls)
+ numpy.testing.assert_array_equal(
+ self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
+ self._raw_data[0],
+ )
+ self.assertEqual(self.widget._slider.value(), 0)
+
+ self.widget._urlsTable.setUrl(self.urls[4])
+ numpy.testing.assert_array_equal(
+ self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
+ self._raw_data[4],
+ )
+ self.assertEqual(self.widget._slider.value(), 4)
+ self.assertEqual(self.widget.getCurrentUrl(), self.urls[4])
+ self.assertEqual(self.widget.getCurrentUrlIndex(), 4)
+
+ self.widget._slider.setUrlIndex(6)
+ numpy.testing.assert_array_equal(
+ self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
+ self._raw_data[6],
+ )
+ self.assertEqual(
+ self.widget._urlsTable.currentItem().text(), self.urls[6].path()
+ )
+
+ def testCurrentUrlSignals(self):
+ """Test emission of 'currentUrlChangedListener'"""
+ # check initialization
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 0)
+ self.widget.setUrls(list(self.urls.values()))
+ self.qapp.processEvents()
+ time.sleep(0.5)
+ self.qapp.processEvents()
+ # once loaded the two signals should have been sended
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 1)
+ # if the slider is stuck to the same position no signal should be
+ # emitted
+ self.qapp.processEvents()
+ time.sleep(0.5)
+ self.qapp.processEvents()
+ self.assertEqual(self.widget._slider.value(), 0)
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 1)
+ # if slider position is changed, one of each signal should have been
+ # emitted
+ self.widget._urlsTable.setUrl(self.urls[4])
+ self.qapp.processEvents()
+ time.sleep(1.5)
+ self.qapp.processEvents()
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 2)
+
+ def testUtils(self):
+ """Test that some utils functions are working"""
+ self.widget.show()
+ self.widget.setUrls(list(self.urls.values()))
+ self.assertEqual(len(self.widget.getUrls()), len(self.urls))
+
+ # wait for image to be loaded
+ self._waitUntilUrlLoaded()
+
+ urls_values = list(self.urls.values())
+ self.assertEqual(urls_values[0], self.urls[0])
+ self.assertEqual(urls_values[7], self.urls[7])
+
+ self.assertEqual(
+ self.widget._getNextUrl(urls_values[2]).path(), urls_values[3].path()
+ )
+ self.assertEqual(self.widget._getPreviousUrl(urls_values[0]), None)
+ self.assertEqual(
+ self.widget._getPreviousUrl(urls_values[6]).path(), urls_values[5].path()
+ )
+
+ self.assertEqual(self.widget._getNNextUrls(2, urls_values[0]), urls_values[1:3])
+ self.assertEqual(self.widget._getNNextUrls(5, urls_values[7]), urls_values[8:])
+ self.assertEqual(
+ self.widget._getNPreviousUrls(3, urls_values[2]), urls_values[:2]
+ )
+ self.assertEqual(
+ self.widget._getNPreviousUrls(5, urls_values[8]), urls_values[3:8]
+ )
+
+ def testRemoveUrlFromList(self):
+ """
+ Test behavior when some item (url) are removed from the list
+ """
+ self.widget.setUrlsEditable(True)
+ self.widget.show()
+ self.widget.setUrls(list(self.urls.values()))
+ self.assertEqual(len(self.widget.getUrls()), len(self.urls))
+
+ # wait for image to be loaded
+ self._waitUntilUrlLoaded()
+ ll_slider = self.widget._slider._slider
+ assert ll_slider.maximum() - ll_slider.minimum() + 1 == len(self.urls)
+
+ # remove some urls from the list (~ simulating behavior with a right click)
+ urlsTable = self.widget._urlsTable._urlsTable
+ urlsTable.clearSelection()
+ urlsTable.item(1).setSelected(True)
+ urlsTable.item(2).setSelected(True)
+ urlsTable._removeSelectedItems()
+ self.qapp.processEvents()
+
+ # make sure slider has been updated
+ assert ll_slider.maximum() - ll_slider.minimum() + 1 == len(self.urls) - 2
+ # as the ImageStack widget
+ assert len(self.widget._urls) == len(self.urls) - 2
+ removed_urls = list(self.urls.values())[1:3]
+
+ existing_urls_as_str = [url.path() for url in self.widget._urls.values()]
+ for removed_url in removed_urls:
+ assert type(removed_url) == type(tuple(self.widget._urls.values())[0])
+ assert removed_url.path() not in existing_urls_as_str
+ # make sure we have some data plot
+ self.widget.getPlotWidget().getActiveImage() is not None
+
+ # test removing remaining urls
+ urlsTable.selectAll()
+ urlsTable._removeSelectedItems()
+ self.qapp.processEvents()
+ assert len(self.widget._urls) == 0
+ assert ll_slider.maximum() - ll_slider.minimum() == 0
+ # make sure if all urls are removed nothing is plot anymore
+ self.widget.getPlotWidget().getActiveImage() is None
+
+ def _waitUntilUrlLoaded(self, timeout=2.0):
+ """Wait until all image urls are loaded"""
+ loop_duration = 0.2
+ remaining_duration = timeout
+ while len(self.widget._loadingThreads) > 0 and remaining_duration > 0:
+ remaining_duration -= loop_duration
+ time.sleep(loop_duration)
+ self.qapp.processEvents()
+
+ if remaining_duration <= 0.0:
+ remaining_urls = []
+ for thread_ in self.widget._loadingThreads:
+ remaining_urls.append(thread_.url.path())
+ mess = (
+ "All images are not loaded after the time out. "
+ "Remaining urls are: " + str(remaining_urls)
+ )
+ raise TimeoutError(mess)
+ return True
diff --git a/src/silx/gui/plot/test/testImageView.py b/src/silx/gui/plot/test/testImageView.py
new file mode 100644
index 0000000..df19ab7
--- /dev/null
+++ b/src/silx/gui/plot/test/testImageView.py
@@ -0,0 +1,201 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWindow"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import items
+
+from silx.gui.plot.ImageView import ImageView
+from silx.gui.colors import Colormap
+
+
+class TestImageView(TestCaseQt):
+ """Tests of ImageView widget."""
+
+ def setUp(self):
+ super(TestImageView, self).setUp()
+ self.plot = ImageView()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ self.qapp.processEvents()
+ super(TestImageView, self).tearDown()
+
+ def testSetImage(self):
+ """Test setImage"""
+ image = numpy.arange(100).reshape(10, 10)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ # With reset=False
+ self.plot.setImage(image[::2, ::2], reset=False)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ self.plot.setImage(image, origin=(10, 20), scale=(2, 4), reset=False)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ # With reset=True
+ self.plot.setImage(image, origin=(1, 2), scale=(1, 0.5), reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (1, 11))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (2, 7))
+
+ self.plot.setImage(image[::2, ::2], reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 5))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 5))
+
+ def testColormap(self):
+ """Test get|setColormap"""
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image)
+
+ # Colormap as dict
+ self.plot.setColormap(
+ {
+ "name": "viridis",
+ "normalization": "log",
+ "autoscale": False,
+ "vmin": 0,
+ "vmax": 1,
+ }
+ )
+ colormap = self.plot.getColormap()
+ self.assertEqual(colormap.getName(), "viridis")
+ self.assertEqual(colormap.getNormalization(), "log")
+ self.assertEqual(colormap.getVMin(), 0)
+ self.assertEqual(colormap.getVMax(), 1)
+
+ # Colormap as keyword arguments
+ self.plot.setColormap(
+ colormap="magma", normalization="linear", autoscale=True, vmin=1, vmax=2
+ )
+ self.assertEqual(colormap.getName(), "magma")
+ self.assertEqual(colormap.getNormalization(), "linear")
+ self.assertEqual(colormap.getVMin(), None)
+ self.assertEqual(colormap.getVMax(), None)
+
+ # Update colormap with keyword argument
+ self.plot.setColormap(normalization="log")
+ self.assertEqual(colormap.getNormalization(), "log")
+
+ # Colormap as Colormap object
+ cmap = Colormap()
+ self.plot.setColormap(cmap)
+ self.assertIs(self.plot.getColormap(), cmap)
+
+ def testSetProfileWindowBehavior(self):
+ """Test change of profile window display behavior"""
+ self.assertIs(
+ self.plot.getProfileWindowBehavior(),
+ ImageView.ProfileWindowBehavior.POPUP,
+ )
+
+ self.plot.setProfileWindowBehavior("embedded")
+ self.assertIs(
+ self.plot.getProfileWindowBehavior(),
+ ImageView.ProfileWindowBehavior.EMBEDDED,
+ )
+
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image)
+
+ self.plot.setProfileWindowBehavior(ImageView.ProfileWindowBehavior.POPUP)
+ self.assertIs(
+ self.plot.getProfileWindowBehavior(),
+ ImageView.ProfileWindowBehavior.POPUP,
+ )
+
+ def testRGBImage(self):
+ """Test setImage"""
+ image = numpy.arange(100 * 3, dtype=numpy.uint8).reshape(10, 10, 3)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ def testRGBAImage(self):
+ """Test setImage"""
+ image = numpy.arange(100 * 4, dtype=numpy.uint8).reshape(10, 10, 4)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ def testImageAggregationMode(self):
+ """Test setImage"""
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(
+ items.ImageDataAggregated.Aggregation.MAX
+ )
+ self.qWait(100)
+
+ def testImageAggregationModeBackToNormalMode(self):
+ """Test setImage"""
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(
+ items.ImageDataAggregated.Aggregation.MAX
+ )
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(
+ items.ImageDataAggregated.Aggregation.NONE
+ )
+ self.qWait(100)
+
+ def testRGBAInAggregationMode(self):
+ """Test setImage"""
+ image = numpy.arange(100 * 3, dtype=numpy.uint8).reshape(10, 10, 3)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(
+ items.ImageDataAggregated.Aggregation.MAX
+ )
+ self.qWait(100)
diff --git a/src/silx/gui/plot/test/testInteraction.py b/src/silx/gui/plot/test/testInteraction.py
new file mode 100644
index 0000000..b031454
--- /dev/null
+++ b/src/silx/gui/plot/test/testInteraction.py
@@ -0,0 +1,79 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests from interaction state machines"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import unittest
+
+from silx.gui.plot import Interaction
+
+
+class TestInteraction(unittest.TestCase):
+ def testClickOrDrag(self):
+ """Minimalistic test for click or drag state machine."""
+ events = []
+
+ class TestClickOrDrag(Interaction.ClickOrDrag):
+ def click(self, x, y, btn):
+ events.append(("click", x, y, btn))
+
+ def beginDrag(self, x, y, btn):
+ events.append(("beginDrag", x, y, btn))
+
+ def drag(self, x, y, btn):
+ events.append(("drag", x, y, btn))
+
+ def endDrag(self, start, end, btn):
+ events.append(("endDrag", start, end, btn))
+
+ clickOrDrag = TestClickOrDrag()
+
+ # click
+ clickOrDrag.handleEvent("press", 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 0)
+
+ clickOrDrag.handleEvent("release", 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 1)
+ self.assertEqual(events[0], ("click", 10, 10, Interaction.LEFT_BTN))
+
+ # drag
+ events = []
+ clickOrDrag.handleEvent("press", 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 0)
+ clickOrDrag.handleEvent("move", 15, 10)
+ self.assertEqual(len(events), 2) # Received beginDrag and drag
+ self.assertEqual(events[0], ("beginDrag", 10, 10, Interaction.LEFT_BTN))
+ self.assertEqual(events[1], ("drag", 15, 10, Interaction.LEFT_BTN))
+ clickOrDrag.handleEvent("move", 20, 10)
+ self.assertEqual(len(events), 3)
+ self.assertEqual(events[-1], ("drag", 20, 10, Interaction.LEFT_BTN))
+ clickOrDrag.handleEvent("release", 20, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 4)
+ self.assertEqual(
+ events[-1], ("endDrag", (10, 10), (20, 10), Interaction.LEFT_BTN)
+ )
diff --git a/src/silx/gui/plot/test/testItem.py b/src/silx/gui/plot/test/testItem.py
new file mode 100644
index 0000000..8a6db40
--- /dev/null
+++ b/src/silx/gui/plot/test/testItem.py
@@ -0,0 +1,565 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for PlotWidget items."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/09/2017"
+
+
+import numpy
+import pytest
+
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.plot.items.roi import RegionOfInterest
+from silx.gui.plot.items import ItemChangedType
+from silx.gui.plot import items
+from .utils import PlotWidgetTestCase
+
+
+class TestSigItemChangedSignal(PlotWidgetTestCase):
+ """Test item's sigItemChanged signal"""
+
+ def testCurveChanged(self):
+ """Test sigItemChanged for curve"""
+ self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend="test")
+ curve = self.plot.getCurve("test")
+
+ listener = SignalListener()
+ curve.sigItemChanged.connect(listener)
+
+ # Test for signal in Item class
+ curve.setVisible(False)
+ curve.setVisible(True)
+ curve.setZValue(100)
+
+ # Test for signals in PointsBase class
+ curve.setData(numpy.arange(100), numpy.arange(100))
+
+ # SymbolMixIn
+ curve.setSymbol("Circle")
+ curve.setSymbol("d")
+ curve.setSymbolSize(20)
+
+ # AlphaMixIn
+ curve.setAlpha(0.5)
+
+ # Test for signals in Curve class
+ # ColorMixIn
+ curve.setColor("yellow")
+ # YAxisMixIn
+ curve.setYAxis("right")
+ # FillMixIn
+ curve.setFill(True)
+ # LineMixIn
+ curve.setLineStyle(":")
+ curve.setLineStyle(":") # Not sending event
+ curve.setLineWidth(2)
+
+ self.assertEqual(
+ listener.arguments(argumentIndex=0),
+ [
+ ItemChangedType.VISIBLE,
+ ItemChangedType.VISIBLE,
+ ItemChangedType.ZVALUE,
+ ItemChangedType.DATA,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL_SIZE,
+ ItemChangedType.ALPHA,
+ ItemChangedType.COLOR,
+ ItemChangedType.YAXIS,
+ ItemChangedType.FILL,
+ ItemChangedType.LINE_STYLE,
+ ItemChangedType.LINE_WIDTH,
+ ],
+ )
+
+ def testHistogramChanged(self):
+ """Test sigItemChanged for Histogram"""
+ self.plot.addHistogram(numpy.arange(10), edges=numpy.arange(11), legend="test")
+ histogram = self.plot.getHistogram("test")
+ listener = SignalListener()
+ histogram.sigItemChanged.connect(listener)
+
+ # Test signals in Histogram class
+ histogram.setData(numpy.zeros(10), numpy.arange(11))
+
+ self.assertEqual(listener.arguments(argumentIndex=0), [ItemChangedType.DATA])
+
+ def testImageDataChanged(self):
+ """Test sigItemChanged for ImageData"""
+ self.plot.addImage(numpy.arange(100).reshape(10, 10), legend="test")
+ image = self.plot.getImage("test")
+
+ listener = SignalListener()
+ image.sigItemChanged.connect(listener)
+
+ # ColormapMixIn
+ colormap = self.plot.getDefaultColormap().copy()
+ image.setColormap(colormap)
+ image.getColormap().setName("viridis")
+
+ # Test of signals in ImageBase class
+ image.setOrigin(10)
+ image.setScale(2)
+
+ # Test of signals in ImageData class
+ image.setData(numpy.ones((10, 10)))
+
+ self.assertEqual(
+ listener.arguments(argumentIndex=0),
+ [
+ ItemChangedType.COLORMAP,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.POSITION,
+ ItemChangedType.SCALE,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.DATA,
+ ],
+ )
+
+ def testImageRgbaChanged(self):
+ """Test sigItemChanged for ImageRgba"""
+ self.plot.addImage(numpy.ones((10, 10, 3)), legend="rgb")
+ image = self.plot.getImage("rgb")
+
+ listener = SignalListener()
+ image.sigItemChanged.connect(listener)
+
+ # Test of signals in ImageRgba class
+ image.setData(numpy.zeros((10, 10, 3)))
+
+ self.assertEqual(listener.arguments(argumentIndex=0), [ItemChangedType.DATA])
+
+ def testMarkerChanged(self):
+ """Test sigItemChanged for markers"""
+ self.plot.addMarker(10, 20, legend="test")
+ marker = self.plot._getMarker("test")
+
+ listener = SignalListener()
+ marker.sigItemChanged.connect(listener)
+
+ # Test signals in _BaseMarker
+ marker.setPosition(10, 10)
+ marker.setPosition(10, 10) # Not sending event
+ marker.setText("toto")
+ self.assertEqual(
+ listener.arguments(argumentIndex=0),
+ [ItemChangedType.POSITION, ItemChangedType.TEXT],
+ )
+
+ # XMarker
+ self.plot.addXMarker(10, legend="x")
+ marker = self.plot._getMarker("x")
+
+ listener = SignalListener()
+ marker.sigItemChanged.connect(listener)
+ marker.setPosition(20, 20)
+ self.assertEqual(
+ listener.arguments(argumentIndex=0), [ItemChangedType.POSITION]
+ )
+
+ # YMarker
+ self.plot.addYMarker(10, legend="x")
+ marker = self.plot._getMarker("x")
+
+ listener = SignalListener()
+ marker.sigItemChanged.connect(listener)
+ marker.setPosition(20, 20)
+ self.assertEqual(
+ listener.arguments(argumentIndex=0), [ItemChangedType.POSITION]
+ )
+
+ def testScatterChanged(self):
+ """Test sigItemChanged for scatter"""
+ data = numpy.arange(10)
+ self.plot.addScatter(data, data, data, legend="test")
+ scatter = self.plot.getScatter("test")
+
+ listener = SignalListener()
+ scatter.sigItemChanged.connect(listener)
+
+ # ColormapMixIn
+ scatter.getColormap().setName("viridis")
+
+ # Test of signals in Scatter class
+ scatter.setData((0, 1, 2), (1, 0, 2), (0, 1, 2))
+
+ # Visualization mode changed
+ scatter.setVisualization(scatter.Visualization.SOLID)
+
+ self.assertEqual(
+ listener.arguments(),
+ [
+ (ItemChangedType.COLORMAP,),
+ (ItemChangedType.DATA,),
+ (ItemChangedType.COLORMAP,),
+ (ItemChangedType.VISUALIZATION_MODE,),
+ ],
+ )
+
+ def testShapeChanged(self):
+ """Test sigItemChanged for shape"""
+ data = numpy.array((1.0, 10.0))
+ self.plot.addShape(data, data, legend="test", shape="rectangle")
+ shape = self.plot._getItem(kind="item", legend="test")
+
+ listener = SignalListener()
+ shape.sigItemChanged.connect(listener)
+
+ shape.setOverlay(True)
+ shape.setPoints(((2.0, 2.0), (3.0, 3.0)))
+
+ self.assertEqual(
+ listener.arguments(), [(ItemChangedType.OVERLAY,), (ItemChangedType.DATA,)]
+ )
+
+
+class TestSymbol(PlotWidgetTestCase):
+ """Test item's symbol"""
+
+ def test(self):
+ """Test sigItemChanged for curve"""
+ self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend="test")
+ curve = self.plot.getCurve("test")
+
+ # SymbolMixIn
+ curve.setSymbol("o")
+ name = curve.getSymbolName()
+ self.assertEqual("Circle", name)
+
+ name = curve.getSymbolName("d")
+ self.assertEqual("Diamond", name)
+
+
+class TestVisibleExtent(PlotWidgetTestCase):
+ """Test item's visible extent feature"""
+
+ def testGetVisibleBounds(self):
+ """Test Item.getVisibleBounds"""
+
+ # Create test items (with a bounding box of x: [1,3], y: [0,2])
+ curve = items.Curve()
+ curve.setData((1, 2, 3), (0, 1, 2))
+
+ histogram = items.Histogram()
+ histogram.setData((0, 1, 2), (1, 5 / 3, 7 / 3, 3))
+
+ image = items.ImageData()
+ image.setOrigin((1, 0))
+ image.setData(numpy.arange(4).reshape(2, 2))
+
+ scatter = items.Scatter()
+ scatter.setData((1, 2, 3), (0, 1, 2), (1, 2, 3))
+
+ bbox = items.BoundingRect()
+ bbox.setBounds((1, 3, 0, 2))
+
+ xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis()
+ for item in (curve, histogram, image, scatter, bbox):
+ with self.subTest(item=item):
+ xaxis.setLimits(0, 100)
+ yaxis.setLimits(0, 100)
+ self.plot.addItem(item)
+ self.assertEqual(item.getVisibleBounds(), (1.0, 3.0, 0.0, 2.0))
+
+ xaxis.setLimits(0.5, 2.5)
+ self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.0, 2.0))
+
+ yaxis.setLimits(0.5, 1.5)
+ self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.5, 1.5))
+
+ item.setVisible(False)
+ self.assertIsNone(item.getVisibleBounds())
+
+ self.plot.clear()
+
+ def testVisibleExtentTracking(self):
+ """Test Item's visible extent tracking"""
+ image = items.ImageData()
+ image.setData(numpy.arange(6).reshape(2, 3))
+
+ listener = SignalListener()
+ image._sigVisibleBoundsChanged.connect(listener)
+ image._setVisibleBoundsTracking(True)
+ self.assertTrue(image._isVisibleBoundsTracking())
+
+ self.plot.addItem(image)
+ self.assertEqual(listener.callCount(), 1)
+
+ self.plot.getXAxis().setLimits(0, 1)
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.hide()
+ self.qapp.processEvents()
+ # No event here
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.getXAxis().setLimits(1, 2)
+ # No event since PlotWidget is hidden, delayed to PlotWidget show
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.show()
+ self.qapp.processEvents()
+ # Receives delayed event now
+ self.assertEqual(listener.callCount(), 3)
+
+ image.setOrigin((-1, -1))
+ self.assertEqual(listener.callCount(), 4)
+
+ image.setVisible(False)
+ image.setOrigin((0, 0))
+ # No event since item is not visible
+ self.assertEqual(listener.callCount(), 4)
+
+ image.setVisible(True)
+ # Receives delayed event now
+ self.assertEqual(listener.callCount(), 5)
+
+
+class TestImageDataAggregated(PlotWidgetTestCase):
+ """Test ImageDataAggregated item"""
+
+ def test(self):
+ data = numpy.random.random(1024**2).reshape(1024, 1024)
+
+ item = items.ImageDataAggregated()
+ item.setData(data)
+ self.assertEqual(item.getAggregationMode(), item.Aggregation.NONE)
+ self.plot.addItem(item)
+
+ for mode in item.Aggregation.members():
+ with self.subTest(mode=mode):
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ item.setAggregationMode(mode)
+ self.qapp.processEvents()
+
+ # Zoom-out
+ for i in range(4):
+ xmin, xmax = self.plot.getXAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
+ self.plot.setLimits(
+ xmin - (xmax - xmin) / 2,
+ xmax + (xmax - xmin) / 2,
+ ymin - (ymax - ymin) / 2,
+ ymax + (ymax - ymin) / 2,
+ )
+ self.qapp.processEvents()
+
+
+def testRegionOfInterestText():
+ roi = RegionOfInterest()
+
+ listener = SignalListener()
+ roi.sigItemChanged.connect(listener)
+
+ assert roi.getName() == roi.getText()
+
+ roi.setText("some text")
+ assert listener.arguments(argumentIndex=0) == [ItemChangedType.TEXT]
+ listener.clear()
+ assert roi.getText() == "some text"
+
+ roi.setName("new_name")
+ assert listener.arguments(argumentIndex=0) == [ItemChangedType.NAME]
+ listener.clear()
+ assert roi.getText() == "some text"
+
+ roi.setText(None)
+ assert listener.arguments(argumentIndex=0) == [ItemChangedType.TEXT]
+ listener.clear()
+ assert roi.getText() == "new_name"
+
+ roi.setName("even_newer_name")
+ assert listener.arguments(argumentIndex=0) == [
+ ItemChangedType.NAME,
+ ItemChangedType.TEXT,
+ ]
+ assert roi.getText() == "even_newer_name"
+
+
+def testPlotAddItemsWithoutLegend(plotWidget):
+ curve1 = items.Curve()
+ curve1.setData([0, 10], [0, 20])
+ plotWidget.addItem(curve1)
+
+ curve2 = items.Curve()
+ curve2.setData([0, -10], [0, -20])
+ plotWidget.addItem(curve2)
+
+ assert plotWidget.getItems() == (curve1, curve2)
+
+ datarange = plotWidget.getDataRange()
+ assert datarange.x == (-10, 10)
+ assert datarange.y == (-20, 20)
+
+ plotWidget.resetZoom()
+ assert plotWidget.getXAxis().getLimits() == (-10, 10)
+ assert plotWidget.getYAxis().getLimits() == (-20, 20)
+
+
+def testPlotWidgetAddCurve(plotWidget):
+ curve = plotWidget.addCurve(x=(0, 1), y=(1, 0), legend="test", symbol="s")
+ assert isinstance(curve, items.Curve)
+ assert numpy.array_equal(curve.getXData(copy=False), (0, 1))
+ assert numpy.array_equal(curve.getYData(copy=False), (1, 0))
+ assert curve.getName() == "test"
+ assert curve.getSymbol() == "s"
+
+ curveUpdated = plotWidget.addCurve(
+ x=(0, 1, 2), y=(1, 0, 1), legend="test", symbol="o"
+ )
+ assert curveUpdated is curve
+ assert numpy.array_equal(curveUpdated.getXData(copy=False), (0, 1, 2))
+ assert numpy.array_equal(curveUpdated.getYData(copy=False), (1, 0, 1))
+ assert curveUpdated.getName() == "test"
+ assert curveUpdated.getSymbol() == "o"
+
+
+def testPlotWidgetAddImage(plotWidget):
+ image = plotWidget.addImage(((0, 1), (2, 3)), legend="test")
+ assert isinstance(image, items.ImageData)
+ assert numpy.array_equal(image.getData(copy=False), ((0, 1), (2, 3)))
+ assert image.getName() == "test"
+
+ imageUpdated = plotWidget.addImage([(0, 1)], legend="test")
+ assert imageUpdated is image
+ assert numpy.array_equal(image.getData(copy=False), [(0, 1)])
+ assert image.getName() == "test"
+
+ # Update with a 1pixel RGB image
+ imageRgb = plotWidget.addImage([[(0.0, 0.0, 1.0)]], legend="test")
+ assert isinstance(imageRgb, items.ImageRgba)
+ assert numpy.array_equal(imageRgb.getData(copy=False), [[(0.0, 0.0, 1.0)]])
+ assert imageRgb.getName() == "test"
+
+ # Update with a 1pixel RGB image
+ imageRgbUpdated = plotWidget.addImage([[(1.0, 0.0, 0.0)]], legend="test")
+ assert imageRgbUpdated is imageRgb
+ assert numpy.array_equal(imageRgbUpdated.getData(copy=False), [[(1.0, 0.0, 0.0)]])
+ assert imageRgbUpdated.getName() == "test"
+
+
+def testPlotWidgetAddScatter(plotWidget):
+ scatter = plotWidget.addScatter(
+ x=(0, 1), y=(0, 1), value=(0, 1), legend="test", symbol="s"
+ )
+ assert isinstance(scatter, items.Scatter)
+ assert numpy.array_equal(scatter.getXData(copy=False), (0, 1))
+ assert numpy.array_equal(scatter.getYData(copy=False), (0, 1))
+ assert numpy.array_equal(scatter.getValueData(copy=False), (0, 1))
+ assert scatter.getName() == "test"
+ assert scatter.getSymbol() == "s"
+
+
+def testPlotWidgetAddHistogram(plotWidget):
+ histogram = plotWidget.addHistogram(
+ histogram=[1], edges=(0, 1), legend="test", fill=True
+ )
+ assert isinstance(histogram, items.Histogram)
+ assert numpy.array_equal(histogram.getBinEdgesData(copy=False), (0, 1))
+ assert numpy.array_equal(histogram.getValueData(copy=False), [1])
+ assert histogram.getName() == "test"
+ assert histogram.isFill()
+
+
+def testPlotWidgetAddMarker(plotWidget):
+ marker = plotWidget.addMarker(x=0, y=1, legend="test")
+ assert isinstance(marker, items.Marker)
+ assert marker.getPosition() == (0, 1)
+ assert marker.getName() == "test"
+ assert plotWidget.getItems() == (marker,)
+
+ xmarker = plotWidget.addXMarker(1, legend="test")
+ assert isinstance(xmarker, items.XMarker)
+ assert xmarker.getPosition() == (1, None)
+ assert xmarker.getName() == "test"
+ assert plotWidget.getItems() == (xmarker,)
+
+ ymarker = plotWidget.addYMarker(2, legend="test")
+ assert isinstance(ymarker, items.YMarker)
+ assert ymarker.getPosition() == (None, 2)
+ assert ymarker.getName() == "test"
+ assert plotWidget.getItems() == (ymarker,)
+
+
+def testPlotWidgetAddShape(plotWidget):
+ shape = plotWidget.addShape(
+ xdata=(0, 1), ydata=(0, 1), legend="test", shape="polygon"
+ )
+ assert isinstance(shape, items.Shape)
+ assert numpy.array_equal(shape.getPoints(copy=False), ((0, 0), (1, 1)))
+ assert shape.getName() == "test"
+ assert shape.getType() == "polygon"
+
+
+@pytest.mark.parametrize(
+ "linestyle",
+ (
+ "",
+ "-",
+ "--",
+ "-.",
+ ":",
+ (0.0, None),
+ (0.5, ()),
+ (0.0, (5.0, 5.0)),
+ (4.0, (8.0, 4.0, 4.0, 4.0)),
+ ),
+)
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testLineStyle(qapp_utils, plotWidget, linestyle):
+ """Test different line styles for LineMixIn items"""
+ plotWidget.setGraphTitle(f"Line style: {linestyle}")
+
+ curve = plotWidget.addCurve((0, 1), (0, 1), linestyle=linestyle)
+ assert curve.getLineStyle() == linestyle
+
+ histogram = plotWidget.addHistogram((0.25, 0.75, 0.25), (0.0, 0.33, 0.66, 1.0))
+ histogram.setLineStyle(linestyle)
+ assert histogram.getLineStyle() == linestyle
+
+ polylines = plotWidget.addShape(
+ (0, 1), (1, 0), shape="polylines", linestyle=linestyle
+ )
+ assert polylines.getLineStyle() == linestyle
+
+ rectangle = plotWidget.addShape(
+ (0.4, 0.6), (0.4, 0.6), shape="rectangle", linestyle=linestyle
+ )
+ assert rectangle.getLineStyle() == linestyle
+
+ xmarker = plotWidget.addXMarker(0.5)
+ xmarker.setLineStyle(linestyle)
+ assert xmarker.getLineStyle() == linestyle
+
+ ymarker = plotWidget.addYMarker(0.5)
+ ymarker.setLineStyle(linestyle)
+ assert ymarker.getLineStyle() == linestyle
+
+ plotWidget.replot()
+ qapp_utils.qWait(100)
diff --git a/src/silx/gui/plot/test/testLegendSelector.py b/src/silx/gui/plot/test/testLegendSelector.py
new file mode 100644
index 0000000..a1f000a
--- /dev/null
+++ b/src/silx/gui/plot/test/testLegendSelector.py
@@ -0,0 +1,139 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Rueter", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/05/2017"
+
+
+import logging
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import LegendSelector
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestLegendSelector(TestCaseQt):
+ """Basic test for LegendSelector"""
+
+ def testLegendSelector(self):
+ """Test copied from __main__ of LegendSelector in PyMca"""
+
+ class Notifier(qt.QObject):
+ def __init__(self):
+ qt.QObject.__init__(self)
+ self.chk = True
+
+ def signalReceived(self, **kw):
+ obj = self.sender()
+ _logger.info("NOTIFIER -- signal received\n\tsender: %s", str(obj))
+
+ notifier = Notifier()
+
+ legends = [
+ "Legend0",
+ "Legend1",
+ "Long Legend 2",
+ "Foo Legend 3",
+ "Even Longer Legend 4",
+ "Short Leg 5",
+ "Dot symbol 6",
+ "Comma symbol 7",
+ ]
+ colors = [
+ qt.Qt.darkRed,
+ qt.Qt.green,
+ qt.Qt.yellow,
+ qt.Qt.darkCyan,
+ qt.Qt.blue,
+ qt.Qt.darkBlue,
+ qt.Qt.red,
+ qt.Qt.darkYellow,
+ ]
+ symbols = ["o", "t", "+", "x", "s", "d", ".", ","]
+
+ win = LegendSelector.LegendListView()
+ # win = LegendListContextMenu()
+ # win = qt.QWidget()
+ # layout = qt.QVBoxLayout()
+ # layout.setContentsMargins(0,0,0,0)
+ llist = []
+
+ for _idx, (l, c, s) in enumerate(zip(legends, colors, symbols)):
+ ddict = {
+ "color": qt.QColor(c),
+ "linewidth": 4,
+ "symbol": s,
+ }
+ legend = l
+ llist.append((legend, ddict))
+ # item = qt.QListWidgetItem(win)
+ # legendWidget = LegendListItemWidget(l)
+ # legendWidget.icon.setSymbol(s)
+ # legendWidget.icon.setColor(qt.QColor(c))
+ # layout.addWidget(legendWidget)
+ # win.setItemWidget(item, legendWidget)
+
+ # win = LegendListItemWidget('Some Legend 1')
+ # print(llist)
+ model = LegendSelector.LegendModel(legendList=llist)
+ win.setModel(model)
+ win.setSelectionModel(qt.QItemSelectionModel(model))
+ win.setContextMenu()
+ # print('Edit triggers: %d'%win.editTriggers())
+
+ # win = LegendListWidget(None, legends)
+ # win[0].updateItem(ddict)
+ # win.setLayout(layout)
+ win.sigLegendSignal.connect(notifier.signalReceived)
+ win.show()
+
+ win.clear()
+ win.setLegendList(llist)
+
+ self.qWaitForWindowExposed(win)
+
+
+class TestRenameCurveDialog(TestCaseQt):
+ """Basic test for RenameCurveDialog"""
+
+ def testDialog(self):
+ """Create dialog, change name and press OK"""
+ self.dialog = LegendSelector.RenameCurveDialog(
+ None, "curve1", ["curve1", "curve2", "curve3"]
+ )
+ self.dialog.open()
+ self.qWaitForWindowExposed(self.dialog)
+ self.keyClicks(self.dialog.lineEdit, "changed")
+ self.mouseClick(self.dialog.okButton, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ ret = self.dialog.result()
+ self.assertEqual(ret, qt.QDialog.Accepted)
+ newName = self.dialog.getText()
+ self.assertEqual(newName, "curve1changed")
+ del self.dialog
diff --git a/src/silx/gui/plot/test/testLimitConstraints.py b/src/silx/gui/plot/test/testLimitConstraints.py
new file mode 100644
index 0000000..04a53e1
--- /dev/null
+++ b/src/silx/gui/plot/test/testLimitConstraints.py
@@ -0,0 +1,113 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test setLimitConstaints on the PlotWidget"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "30/08/2017"
+
+
+import unittest
+from silx.gui.plot import PlotWidget
+
+
+class TestLimitConstaints(unittest.TestCase):
+ """Tests setLimitConstaints class"""
+
+ def setUp(self):
+ self.plot = PlotWidget()
+
+ def tearDown(self):
+ self.plot = None
+
+ def testApi(self):
+ """Test availability of the API"""
+ self.plot.getXAxis().setLimitsConstraints(minPos=1, maxPos=10)
+ self.plot.getXAxis().setRangeConstraints(minRange=1, maxRange=1)
+ self.plot.getYAxis().setLimitsConstraints(minPos=1, maxPos=10)
+ self.plot.getYAxis().setRangeConstraints(minRange=1, maxRange=1)
+
+ def testXMinMax(self):
+ """Test limit constains on x-axis"""
+ self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (-1, 101))
+
+ def testYMinMax(self):
+ """Test limit constains on y-axis"""
+ self.plot.getYAxis().setLimitsConstraints(minPos=0, maxPos=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 100))
+
+ def testMinXRange(self):
+ """Test min range constains on x-axis"""
+ self.plot.getXAxis().setRangeConstraints(minRange=100)
+ self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+
+ def testMaxXRange(self):
+ """Test max range constains on x-axis"""
+ self.plot.getXAxis().setRangeConstraints(maxRange=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+
+ def testMinYRange(self):
+ """Test min range constains on y-axis"""
+ self.plot.getYAxis().setRangeConstraints(minRange=100)
+ self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+
+ def testMaxYRange(self):
+ """Test max range constains on y-axis"""
+ self.plot.getYAxis().setRangeConstraints(maxRange=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+
+ def testChangeOfConstraints(self):
+ """Test changing of the constraints"""
+ self.plot.getXAxis().setRangeConstraints(minRange=10, maxRange=10)
+ # There is no more constraints on the range
+ self.plot.getXAxis().setRangeConstraints(minRange=None, maxRange=None)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101))
+
+ def testSettingConstraints(self):
+ """Test setting a constaint (setLimits first then the constaint)"""
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100))
diff --git a/src/silx/gui/plot/test/testMaskToolsWidget.py b/src/silx/gui/plot/test/testMaskToolsWidget.py
new file mode 100644
index 0000000..1428687
--- /dev/null
+++ b/src/silx/gui/plot/test/testMaskToolsWidget.py
@@ -0,0 +1,321 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for MaskToolsWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import os.path
+
+import numpy
+
+from silx.gui import qt
+from silx.test.utils import temp_dir
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import getQToolButtonFromAction
+from silx.gui.plot import PlotWindow, MaskToolsWidget
+from .utils import PlotWidgetTestCase
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
+ """Basic test for MaskToolsWidget"""
+
+ def _createPlot(self):
+ return PlotWindow()
+
+ def setUp(self):
+ super(TestMaskToolsWidget, self).setUp()
+ self.widget = MaskToolsWidget.MaskToolsDockWidget(plot=self.plot, name="TEST")
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
+ self.maskWidget = self.widget.widget()
+
+ def tearDown(self):
+ del self.maskWidget
+ del self.widget
+ super(TestMaskToolsWidget, self).tearDown()
+
+ def testEmptyPlot(self):
+ """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
+ self.maskWidget.setMultipleMasks("single")
+ self.qapp.processEvents()
+
+ self.maskWidget.setMultipleMasks("exclusive")
+ self.qapp.processEvents()
+
+ def _drag(self):
+ """Drag from plot center to offset position"""
+ plot = self.plot.getWidgetHandle()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ pos0 = xCenter, yCenter
+ pos1 = xCenter + offset, yCenter + offset
+
+ self.mouseMove(plot, pos=(0, 0))
+ self.mouseMove(plot, pos=pos0)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
+ self.mouseMove(plot, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(0, 0))
+
+ def _drawPolygon(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [
+ (x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ (x, y + offset),
+ ] # Close polygon
+
+ self.mouseMove(plot, pos=(0, 0))
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+
+ def _drawPencil(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [
+ (x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ ]
+
+ self.mouseMove(plot, pos=(0, 0))
+ for start, end in zip(star[:-1], star[1:]):
+ self.mouseMove(plot, pos=start)
+ self.mousePress(plot, qt.Qt.LeftButton, pos=start)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=end)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=end)
+ self.qapp.processEvents()
+
+ def _isMaskItemSync(self):
+ """Check if masks from item and tools are sync or not"""
+ if self.maskWidget.isItemMaskUpdated():
+ return numpy.all(
+ numpy.equal(
+ self.maskWidget.getSelectionMask(),
+ self.plot.getActiveImage().getMaskData(copy=False),
+ )
+ )
+ else:
+ return True
+
+ def testWithAnImage(self):
+ """Plot with an image: test MaskToolsWidget interactions"""
+
+ # Add and remove a image (this should enable/disable GUI + change mask)
+ self.plot.addImage(
+ numpy.random.random(1024**2).reshape(1024, 1024), legend="test"
+ )
+ self.qapp.processEvents()
+
+ self.plot.remove("test", kind="image")
+ self.qapp.processEvents()
+
+ tests = [
+ ((0, 0), (1, 1)),
+ ((1000, 1000), (1, 1)),
+ ((0, 0), (-1, -1)),
+ ((1000, 1000), (-1, -1)),
+ ]
+
+ for itemMaskUpdated in (False, True):
+ for origin, scale in tests:
+ with self.subTest(origin=origin, scale=scale):
+ self.maskWidget.setItemMaskUpdated(itemMaskUpdated)
+ self.plot.addImage(
+ numpy.arange(1024**2).reshape(1024, 1024),
+ legend="test",
+ origin=origin,
+ scale=scale,
+ )
+ self.qapp.processEvents()
+
+ self.assertEqual(
+ self.maskWidget.isItemMaskUpdated(), itemMaskUpdated
+ )
+
+ # Test draw rectangle #
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test draw polygon #
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test draw pencil #
+ toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.maskWidget.pencilSpinBox.setValue(30)
+ self.qapp.processEvents()
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test no draw tool #
+ toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.plot.clear()
+
+ def __loadSave(self, file_format):
+ """Plot with an image: test MaskToolsWidget operations"""
+ self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024), legend="test")
+ self.qapp.processEvents()
+
+ # Draw a polygon mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self._drawPolygon()
+
+ ref_mask = self.maskWidget.getSelectionMask()
+ self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
+
+ with temp_dir() as tmp:
+ mask_filename = os.path.join(tmp, "mask." + file_format)
+ self.maskWidget.save(mask_filename, file_format)
+
+ self.maskWidget.resetSelectionMask()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
+
+ self.maskWidget.load(mask_filename)
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), ref_mask))
+ )
+
+ def testLoadSaveNpy(self):
+ self.__loadSave("npy")
+
+ def testLoadSaveFit2D(self):
+ self.__loadSave("msk")
+
+ def testSigMaskChangedEmitted(self):
+ self.plot.addImage(numpy.arange(512**2).reshape(512, 512), legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ l = []
+
+ def slot():
+ l.append(1)
+
+ self.maskWidget.sigMaskChanged.connect(slot)
+
+ # rectangle mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertGreater(len(l), 0)
diff --git a/src/silx/gui/plot/test/testPixelIntensityHistoAction.py b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py
new file mode 100644
index 0000000..7fd87e8
--- /dev/null
+++ b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py
@@ -0,0 +1,149 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PixelIntensitiesHistoAction"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/03/2018"
+
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
+from silx.gui import qt
+from silx.gui.plot import Plot2D
+
+
+class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
+ """Tests for PixelIntensitiesHistoAction widget."""
+
+ def setUp(self):
+ super(TestPixelIntensitiesHisto, self).setUp()
+ self.image = numpy.random.rand(10, 10)
+ self.plotImage = Plot2D()
+ self.plotImage.getIntensityHistogramAction().setVisible(True)
+
+ def tearDown(self):
+ del self.plotImage
+ super(TestPixelIntensitiesHisto, self).tearDown()
+
+ def testShowAndHide(self):
+ """Simple test that the plot is showing and hiding when activating the
+ action"""
+ self.plotImage.addImage(self.image, origin=(0, 0), legend="sino")
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ self.assertTrue(histoAction.getHistogramWidget().isVisible())
+
+ # test the pixel intensity diagram is hiding
+ self.plotImage.activateWindow()
+ self.qapp.processEvents()
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ self.assertFalse(histoAction.getHistogramWidget().isVisible())
+
+ def testImageFormatInput(self):
+ """Test multiple type as image input"""
+ typesToTest = [
+ numpy.uint8,
+ numpy.int8,
+ numpy.int16,
+ numpy.int32,
+ numpy.float32,
+ numpy.float64,
+ ]
+ self.plotImage.addImage(self.image, origin=(0, 0), legend="sino")
+ self.plotImage.show()
+ button = getQToolButtonFromAction(self.plotImage.getIntensityHistogramAction())
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ for typeToTest in typesToTest:
+ with self.subTest(typeToTest=typeToTest):
+ self.plotImage.addImage(
+ self.image.astype(typeToTest), origin=(0, 0), legend="sino"
+ )
+
+ def testScatter(self):
+ """Test that an histogram from a scatter is displayed"""
+ xx = numpy.arange(10)
+ yy = numpy.arange(10)
+ value = numpy.sin(xx)
+ self.plotImage.addScatter(xx, yy, value)
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+
+ widget = histoAction.getHistogramWidget()
+ self.assertTrue(widget.isVisible())
+ items = widget.getPlotWidget().getItems()
+ self.assertEqual(len(items), 1)
+
+ def testChangeItem(self):
+ """Test that histogram changes it the item changes"""
+ xx = numpy.arange(10)
+ yy = numpy.arange(10)
+ value = numpy.sin(xx)
+ self.plotImage.addScatter(xx, yy, value)
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+
+ # Reach histogram from the first item
+ widget = histoAction.getHistogramWidget()
+ self.assertTrue(widget.isVisible())
+ items = widget.getPlotWidget().getItems()
+ data1 = items[0].getValueData(copy=False)
+
+ # Set another item to the plot
+ self.plotImage.addImage(self.image, origin=(0, 0), legend="sino")
+ self.qapp.processEvents()
+ data2 = items[0].getValueData(copy=False)
+
+ # Histogram is not the same
+ self.assertFalse(numpy.array_equal(data1, data2))
diff --git a/src/silx/gui/plot/test/testPlotActions.py b/src/silx/gui/plot/test/testPlotActions.py
new file mode 100644
index 0000000..9f56aad
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotActions.py
@@ -0,0 +1,106 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of actions integrated in the plot window"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "09/11/2018"
+
+
+import pytest
+import weakref
+
+from silx.gui import qt
+from silx.gui.colors import Colormap
+from silx.gui.plot.PlotWindow import PlotWindow
+
+import numpy
+
+
+@pytest.fixture
+def colormap1():
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ yield colormap
+
+
+@pytest.fixture
+def colormap2():
+ colormap = Colormap(name="red", vmin=10.0, vmax=20.0, normalization="linear")
+ yield colormap
+
+
+@pytest.fixture
+def plot(qapp):
+ plot = PlotWindow()
+ plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ yield weakref.proxy(plot)
+ plot.close()
+ qapp.processEvents()
+
+
+def test_action_active_colormap(qapp_utils, plot, colormap1, colormap2):
+ plot.getColormapAction()._actionTriggered(checked=True)
+ colormapDialog = plot.getColormapAction()._dialog
+
+ defaultColormap = plot.getDefaultColormap()
+ assert colormapDialog.getColormap() is defaultColormap
+
+ plot.addImage(
+ data=numpy.random.rand(10, 10), legend="img1", origin=(0, 0), colormap=colormap1
+ )
+ plot.setActiveImage("img1")
+ assert colormapDialog.getColormap() is colormap1
+
+ plot.addImage(
+ data=numpy.random.rand(10, 10), legend="img2", origin=(0, 0), colormap=colormap2
+ )
+ plot.addImage(data=numpy.random.rand(10, 10), legend="img3", origin=(0, 0))
+
+ plot.setActiveImage("img3")
+ assert colormapDialog.getColormap() is defaultColormap
+ plot.getActiveImage().setColormap(colormap2)
+ assert colormapDialog.getColormap() is colormap2
+
+ plot.remove("img2")
+ plot.remove("img3")
+ plot.remove("img1")
+ assert colormapDialog.getColormap() is defaultColormap
+
+
+def test_action_show_hide_colormap_dialog(qapp_utils, plot, colormap1):
+ plot.getColormapAction()._actionTriggered(checked=True)
+ colormapDialog = plot.getColormapAction()._dialog
+
+ plot.getColormapAction()._actionTriggered(checked=False)
+ assert not plot.getColormapAction().isChecked()
+ plot.getColormapAction()._actionTriggered(checked=True)
+ assert plot.getColormapAction().isChecked()
+ plot.addImage(
+ data=numpy.random.rand(10, 10), legend="img1", origin=(0, 0), colormap=colormap1
+ )
+ colormap1.setName("red")
+ plot.getColormapAction()._actionTriggered()
+ colormap1.setName("blue")
+ colormapDialog.close()
+ assert not plot.getColormapAction().isChecked()
diff --git a/src/silx/gui/plot/test/testPlotInteraction.py b/src/silx/gui/plot/test/testPlotInteraction.py
new file mode 100644
index 0000000..a97a694
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotInteraction.py
@@ -0,0 +1,232 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests of plot interaction, through a PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/09/2017"
+
+import pytest
+
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+from .utils import PlotWidgetTestCase
+
+
+class _SignalDump(object):
+ """Callable object that store passed arguments in a list"""
+
+ def __init__(self):
+ self._received = []
+
+ def __call__(self, *args):
+ self._received.append(args)
+
+ @property
+ def received(self):
+ """Return a shallow copy of the list of received arguments"""
+ return list(self._received)
+
+
+class TestSelectPolygon(PlotWidgetTestCase):
+ """Test polygon selection interaction"""
+
+ def _interactionModeChanged(self, source):
+ """Check that source received in event is the correct one"""
+ self.assertEqual(source, self)
+
+ def _draw(self, polygon):
+ """Draw a polygon in the plot
+
+ :param polygon: List of points (x, y) of the polygon (closed)
+ """
+ plot = self.plot.getWidgetHandle()
+
+ dump = _SignalDump()
+ self.plot.sigPlotSignal.connect(dump)
+
+ for pos in polygon:
+ self.mouseMove(plot, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+
+ self.plot.sigPlotSignal.disconnect(dump)
+ return [args[0] for args in dump.received]
+
+ def test(self):
+ """Test draw polygons + events"""
+ self.plot.sigInteractiveModeChanged.connect(self._interactionModeChanged)
+
+ self.plot.setInteractiveMode("draw", shape="polygon", label="test", source=self)
+ interaction = self.plot.getInteractiveMode()
+
+ self.assertEqual(interaction["mode"], "draw")
+ self.assertEqual(interaction["shape"], "polygon")
+
+ self.plot.sigInteractiveModeChanged.disconnect(self._interactionModeChanged)
+
+ plot = self.plot.getWidgetHandle()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ # Star polygon
+ star = [
+ (xCenter, yCenter + offset),
+ (xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter),
+ (xCenter - offset, yCenter),
+ (xCenter + offset, yCenter - offset),
+ (xCenter, yCenter + offset),
+ ] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(star)
+
+ # Test last event
+ drawEvents = [event for event in events if event["event"].startswith("drawing")]
+ self.assertEqual(drawEvents[-1]["event"], "drawingFinished")
+ self.assertEqual(len(drawEvents[-1]["points"]), 6)
+
+ # Large square
+ largeSquare = [
+ (xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter - offset),
+ (xCenter + offset, yCenter + offset),
+ (xCenter - offset, yCenter + offset),
+ (xCenter - offset, yCenter - offset),
+ ] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(largeSquare)
+
+ # Test last event
+ drawEvents = [event for event in events if event["event"].startswith("drawing")]
+ self.assertEqual(drawEvents[-1]["event"], "drawingFinished")
+ self.assertEqual(len(drawEvents[-1]["points"]), 5)
+
+ # Rectangle too thin along X: Some points are ignored
+ thinRectX = [
+ (xCenter, yCenter - offset),
+ (xCenter, yCenter + offset),
+ (xCenter + 1, yCenter + offset),
+ (xCenter + 1, yCenter - offset),
+ ] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(thinRectX)
+
+ # Test last event
+ drawEvents = [event for event in events if event["event"].startswith("drawing")]
+ self.assertEqual(drawEvents[-1]["event"], "drawingFinished")
+ self.assertEqual(len(drawEvents[-1]["points"]), 3)
+
+ # Rectangle too thin along Y: Some points are ignored
+ thinRectY = [
+ (xCenter - offset, yCenter),
+ (xCenter + offset, yCenter),
+ (xCenter + offset, yCenter + 1),
+ (xCenter - offset, yCenter + 1),
+ ] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(thinRectY)
+
+ # Test last event
+ drawEvents = [event for event in events if event["event"].startswith("drawing")]
+ self.assertEqual(drawEvents[-1]["event"], "drawingFinished")
+ self.assertEqual(len(drawEvents[-1]["points"]), 3)
+
+
+@pytest.mark.parametrize("scale", ["linear", "log"])
+@pytest.mark.parametrize("xaxis", [True, False])
+@pytest.mark.parametrize("yaxis", [True, False])
+@pytest.mark.parametrize("y2axis", [True, False])
+def testZoomEnabledAxes(qapp, qWidgetFactory, scale, xaxis, yaxis, y2axis):
+ """Test PlotInteraction.setZoomEnabledAxes effect on zoom interaction"""
+ plotWidget = qWidgetFactory(PlotWidget)
+ plotWidget.getXAxis().setScale(scale)
+ plotWidget.getYAxis("left").setScale(scale)
+ plotWidget.getYAxis("right").setScale(scale)
+ qapp.processEvents()
+
+ xLimits = plotWidget.getXAxis().getLimits()
+ yLimits = plotWidget.getYAxis("left").getLimits()
+ y2Limits = plotWidget.getYAxis("right").getLimits()
+
+ interaction = plotWidget.interaction()
+
+ assert interaction.getZoomEnabledAxes() == (True, True, True)
+
+ enabledAxes = xaxis, yaxis, y2axis
+ interaction.setZoomEnabledAxes(*enabledAxes)
+ assert interaction.getZoomEnabledAxes() == enabledAxes
+
+ cx, cy = plotWidget.width() // 2, plotWidget.height() // 2
+ plotWidget.onMouseWheel(cx, cy, 10)
+ qapp.processEvents()
+
+ xZoomed = plotWidget.getXAxis().getLimits() != xLimits
+ yZoomed = plotWidget.getYAxis("left").getLimits() != yLimits
+ y2Zoomed = plotWidget.getYAxis("right").getLimits() != y2Limits
+
+ assert xZoomed == enabledAxes[0]
+ assert yZoomed == enabledAxes[1]
+ assert y2Zoomed == enabledAxes[2]
+
+
+@pytest.mark.parametrize("scale", ["linear", "log"])
+@pytest.mark.parametrize("zoomOnWheel", [True, False])
+def testZoomOnWheelEnabled(qapp, qWidgetFactory, zoomOnWheel, scale):
+ """Test PlotInteraction.setZoomOnWheelEnabled"""
+ plotWidget = qWidgetFactory(PlotWidget)
+ plotWidget.getXAxis().setScale(scale)
+ plotWidget.getYAxis("left").setScale(scale)
+ plotWidget.getYAxis("right").setScale(scale)
+ qapp.processEvents()
+
+ xLimits = plotWidget.getXAxis().getLimits()
+ yLimits = plotWidget.getYAxis("left").getLimits()
+ y2Limits = plotWidget.getYAxis("right").getLimits()
+
+ interaction = plotWidget.interaction()
+
+ assert interaction.isZoomOnWheelEnabled()
+
+ interaction.setZoomOnWheelEnabled(zoomOnWheel)
+ assert interaction.isZoomOnWheelEnabled() == zoomOnWheel
+
+ cx, cy = plotWidget.width() // 2, plotWidget.height() // 2
+ plotWidget.onMouseWheel(cx, cy, 10)
+ qapp.processEvents()
+
+ xZoomed = plotWidget.getXAxis().getLimits() != xLimits
+ yZoomed = plotWidget.getYAxis("left").getLimits() != yLimits
+ y2Zoomed = plotWidget.getYAxis("right").getLimits() != y2Limits
+
+ assert xZoomed == zoomOnWheel
+ assert yZoomed == zoomOnWheel
+ assert y2Zoomed == zoomOnWheel
diff --git a/src/silx/gui/plot/test/testPlotWidget.py b/src/silx/gui/plot/test/testPlotWidget.py
new file mode 100755
index 0000000..842e880
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidget.py
@@ -0,0 +1,2049 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/01/2019"
+
+
+import unittest
+import numpy
+import pytest
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent, Axis
+from silx.gui.colors import Colormap
+
+from .utils import PlotWidgetTestCase
+
+
+SIZE = 1024
+"""Size of the test image"""
+
+DATA_2D = numpy.arange(SIZE**2).reshape(SIZE, SIZE)
+"""Image data set"""
+
+
+class TestSpecialBackend(PlotWidgetTestCase, ParametricTestCase):
+ def __init__(self, methodName="runTest", backend=None):
+ TestCaseQt.__init__(self, methodName=methodName)
+ self.__backend = backend
+
+ def _createPlot(self):
+ return PlotWidget(backend=self.__backend)
+
+ def testPlot(self):
+ self.assertIsNotNone(self.plot)
+
+
+class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for PlotWidget"""
+
+ def testShow(self):
+ """Most basic test"""
+ pass
+
+ def testSetTitleLabels(self):
+ """Set title and axes labels"""
+
+ title, xlabel, ylabel = "the title", "x label", "y label"
+ self.plot.setGraphTitle(title)
+ self.plot.getXAxis().setLabel(xlabel)
+ self.plot.getYAxis().setLabel(ylabel)
+ self.qapp.processEvents()
+
+ self.assertEqual(self.plot.getGraphTitle(), title)
+ self.assertEqual(self.plot.getXAxis().getLabel(), xlabel)
+ self.assertEqual(self.plot.getYAxis().getLabel(), ylabel)
+
+ def _checkLimits(self, expectedXLim=None, expectedYLim=None, expectedRatio=None):
+ """Assert that limits are as expected"""
+ xlim = self.plot.getXAxis().getLimits()
+ ylim = self.plot.getYAxis().getLimits()
+ ratio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
+
+ if expectedXLim is not None:
+ self.assertEqual(expectedXLim, xlim)
+
+ if expectedYLim is not None:
+ self.assertEqual(expectedYLim, ylim)
+
+ if expectedRatio is not None:
+ self.assertTrue(numpy.allclose(expectedRatio, ratio, atol=0.01))
+
+ def testChangeLimitsWithAspectRatio(self):
+ self.plot.setKeepDataAspectRatio()
+ self.qapp.processEvents()
+ xlim = self.plot.getXAxis().getLimits()
+ ylim = self.plot.getYAxis().getLimits()
+ defaultRatio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
+
+ self.plot.getXAxis().setLimits(1.0, 10.0)
+ self._checkLimits(expectedXLim=(1.0, 10.0), expectedRatio=defaultRatio)
+ self.qapp.processEvents()
+ self._checkLimits(expectedXLim=(1.0, 10.0), expectedRatio=defaultRatio)
+
+ self.plot.getYAxis().setLimits(1.0, 10.0)
+ self._checkLimits(expectedYLim=(1.0, 10.0), expectedRatio=defaultRatio)
+ self.qapp.processEvents()
+ self._checkLimits(expectedYLim=(1.0, 10.0), expectedRatio=defaultRatio)
+
+ def testResizeWidget(self):
+ """Test resizing the widget and receiving limitsChanged events"""
+ self.plot.resize(200, 200)
+ self.qapp.processEvents()
+ self.qWait(100)
+
+ xlim = self.plot.getXAxis().getLimits()
+ ylim = self.plot.getYAxis().getLimits()
+
+ listener = SignalListener()
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial("x"))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial("y"))
+
+ # Resize without aspect ratio
+ self.plot.resize(200, 300)
+ self.qapp.processEvents()
+ self.qWait(100)
+ self._checkLimits(expectedXLim=xlim, expectedYLim=ylim)
+ self.assertEqual(listener.callCount(), 0)
+
+ # Resize with aspect ratio
+ self.plot.setKeepDataAspectRatio(True)
+ self.qapp.processEvents()
+ self.qWait(1000)
+ listener.clear() # Clean-up received signal
+
+ self.plot.resize(200, 200)
+ self.qapp.processEvents()
+ self.qWait(100)
+ self.assertNotEqual(listener.callCount(), 0)
+
+ def testAddRemoveItemSignals(self):
+ """Test sigItemAdded and sigItemAboutToBeRemoved"""
+ listener = SignalListener()
+ self.plot.sigItemAdded.connect(listener.partial("add"))
+ self.plot.sigItemAboutToBeRemoved.connect(listener.partial("remove"))
+
+ self.plot.addCurve((1, 2, 3), (3, 2, 1), legend="curve")
+ self.assertEqual(listener.callCount(), 1)
+
+ curve = self.plot.getCurve("curve")
+ self.plot.remove("curve")
+ self.assertEqual(listener.callCount(), 2)
+ self.assertEqual(listener.arguments(callIndex=0), ("add", curve))
+ self.assertEqual(listener.arguments(callIndex=1), ("remove", curve))
+
+ def testGetItems(self):
+ """Test getItems method"""
+ curve_x = 1, 2
+ self.plot.addCurve(curve_x, (3, 4))
+ image = (0, 1), (2, 3)
+ self.plot.addImage(image)
+ scatter_x = 10, 11
+ self.plot.addScatter(scatter_x, (12, 13), (0, 1))
+ marker_pos = 5, 5
+ self.plot.addMarker(*marker_pos)
+ marker_x = 6
+ self.plot.addXMarker(marker_x)
+ self.plot.addShape((0, 5), (2, 10), shape="rectangle")
+
+ items = self.plot.getItems()
+ self.assertEqual(len(items), 6)
+ self.assertTrue(numpy.all(numpy.equal(items[0].getXData(), curve_x)))
+ self.assertTrue(numpy.all(numpy.equal(items[1].getData(), image)))
+ self.assertTrue(numpy.all(numpy.equal(items[2].getXData(), scatter_x)))
+ self.assertTrue(numpy.all(numpy.equal(items[3].getPosition(), marker_pos)))
+ self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x)))
+ self.assertEqual(items[5].getType(), "rectangle")
+
+ def testRemoveDiscardItem(self):
+ """Test removeItem and discardItem"""
+ self.plot.addCurve((1, 2, 3), (1, 2, 3))
+ curve = self.plot.getItems()[0]
+ self.plot.removeItem(curve)
+ with self.assertRaises(ValueError):
+ self.plot.removeItem(curve)
+
+ self.plot.addCurve((1, 2, 3), (1, 2, 3))
+ curve = self.plot.getItems()[0]
+ result = self.plot.discardItem(curve)
+ self.assertTrue(result)
+ result = self.plot.discardItem(curve)
+ self.assertFalse(result)
+
+ def testBackGroundColors(self):
+ self.plot.setVisible(True)
+ self.qWaitForWindowExposed(self.plot)
+ self.qapp.processEvents()
+
+ # Custom the full background
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ self.plot.setBackgroundColor("red")
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Custom the data background
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.plot.setDataBackgroundColor("red")
+ color = self.plot.getDataBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Back to default
+ self.plot.setBackgroundColor("white")
+ self.plot.setDataBackgroundColor(None)
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.qapp.processEvents()
+
+
+class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for addImage"""
+
+ def setUp(self):
+ super(TestPlotImage, self).setUp()
+
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
+
+ def testPlotColormapTemperature(self):
+ self.plot.setGraphTitle("Temp. Linear")
+
+ colormap = Colormap(
+ name="temperature", normalization="linear", vmin=None, vmax=None
+ )
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotColormapGray(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle("Gray Linear")
+
+ colormap = Colormap(name="gray", normalization="linear", vmin=None, vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotColormapTemperatureLog(self):
+ self.plot.setGraphTitle("Temp. Log")
+
+ colormap = Colormap(
+ name="temperature", normalization=Colormap.LOGARITHM, vmin=None, vmax=None
+ )
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotRgbRgba(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle("RGB + RGBA")
+
+ rgb = numpy.array(
+ (
+ ((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 255)),
+ ),
+ dtype=numpy.uint8,
+ )
+
+ self.plot.addImage(
+ rgb, legend="rgb_uint8", origin=(0, 0), scale=(1, 1), resetzoom=False
+ )
+
+ rgb = numpy.array(
+ (
+ ((0, 0, 0), (32768, 0, 0), (65535, 0, 0)),
+ ((0, 32768, 0), (0, 32768, 32768), (0, 32768, 65535)),
+ ),
+ dtype=numpy.uint16,
+ )
+
+ self.plot.addImage(
+ rgb, legend="rgb_uint16", origin=(3, 2), scale=(2, 2), resetzoom=False
+ )
+
+ rgba = numpy.array(
+ (
+ ((0, 0, 0, 0.5), (0.5, 0, 0, 1), (1, 0, 0, 0.5)),
+ ((0, 0.5, 0, 1), (0, 0.5, 0.5, 1), (0, 1, 1, 0.5)),
+ ),
+ dtype=numpy.float32,
+ )
+
+ self.plot.addImage(
+ rgba, legend="rgba_float32", origin=(9, 6), scale=(1, 1), resetzoom=False
+ )
+
+ self.plot.resetZoom()
+
+ def testPlotColormapCustom(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle("Custom colormap")
+
+ colormap = Colormap(
+ name=None,
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+ colors=((0.0, 0.0, 0.0), (1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0)),
+ )
+ self.plot.addImage(
+ DATA_2D, legend="image 1", colormap=colormap, resetzoom=False
+ )
+
+ colormap = Colormap(
+ name=None,
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+ colors=numpy.array(
+ (
+ (0, 0, 0, 0),
+ (0, 0, 0, 128),
+ (128, 128, 128, 128),
+ (255, 255, 255, 255),
+ ),
+ dtype=numpy.uint8,
+ ),
+ )
+ self.plot.addImage(
+ DATA_2D,
+ legend="image 2",
+ colormap=colormap,
+ origin=(DATA_2D.shape[0], 0),
+ resetzoom=False,
+ )
+ self.plot.resetZoom()
+
+ def testPlotColormapNaNColor(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle("Colormap with NaN color")
+
+ colormap = Colormap()
+ colormap.setNaNColor("red")
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
+ data = DATA_2D.astype(numpy.float32)
+ data[len(data) // 2 :] = numpy.nan
+ self.plot.addImage(data, legend="image 1", colormap=colormap, resetzoom=False)
+ self.plot.resetZoom()
+
+ colormap.setNaNColor((0.0, 1.0, 0.0, 1.0))
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(0, 255, 0))
+ self.qapp.processEvents()
+
+ def testImageOriginScale(self):
+ """Test of image with different origin and scale"""
+ self.plot.setGraphTitle("origin and scale")
+
+ tests = [ # (origin, scale)
+ ((10, 20), (1, 1)),
+ ((10, 20), (-1, -1)),
+ ((-10, 20), (2, 1)),
+ ((10, -20), (-1, -2)),
+ (100, 2),
+ (-100, (1, 1)),
+ ((10, 20), 2),
+ ]
+
+ for origin, scale in tests:
+ with self.subTest(origin=origin, scale=scale):
+ self.plot.addImage(DATA_2D, origin=origin, scale=scale)
+
+ try:
+ ox, oy = origin
+ except TypeError:
+ ox, oy = origin, origin
+ try:
+ sx, sy = scale
+ except TypeError:
+ sx, sy = scale, scale
+ xbounds = ox, ox + DATA_2D.shape[1] * sx
+ ybounds = oy, oy + DATA_2D.shape[0] * sy
+
+ # Check limits without aspect ratio
+ xmin, xmax = self.plot.getXAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
+ self.assertEqual(xmin, min(xbounds))
+ self.assertEqual(xmax, max(xbounds))
+ self.assertEqual(ymin, min(ybounds))
+ self.assertEqual(ymax, max(ybounds))
+
+ # Check limits with aspect ratio
+ self.plot.setKeepDataAspectRatio(True)
+ xmin, xmax = self.plot.getXAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
+ self.assertTrue(round(xmin, 7) <= min(xbounds))
+ self.assertTrue(round(xmax, 7) >= max(xbounds))
+ self.assertTrue(round(ymin, 7) <= min(ybounds))
+ self.assertTrue(round(ymax, 7) >= max(ybounds))
+
+ self.plot.setKeepDataAspectRatio(False) # Reset aspect ratio
+ self.plot.clear()
+ self.plot.resetZoom()
+
+ def testPlotColormapDictAPI(self):
+ """Test that the addImage API using a colormap dictionary is still
+ working"""
+ self.plot.setGraphTitle("Temp. Log")
+
+ colormap = {
+ "name": "temperature",
+ "normalization": "log",
+ "vmin": None,
+ "vmax": None,
+ }
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotComplexImage(self):
+ """Test that a complex image is displayed as its absolute value."""
+ data = numpy.linspace(1, 1j, 100).reshape(10, 10)
+ self.plot.addImage(data, legend="complex")
+
+ image = self.plot.getActiveImage()
+ retrievedData = image.getData(copy=False)
+ self.assertTrue(numpy.all(numpy.equal(retrievedData, numpy.absolute(data))))
+
+ def testPlotBooleanImage(self):
+ """Test that a boolean image is displayed and converted to int8."""
+ data = numpy.zeros((10, 10), dtype=bool)
+ data[::2, ::2] = True
+ self.plot.addImage(data, legend="boolean")
+
+ image = self.plot.getActiveImage()
+ retrievedData = image.getData(copy=False)
+ self.assertTrue(numpy.all(numpy.equal(retrievedData, data)))
+ self.assertIs(retrievedData.dtype.type, numpy.int8)
+
+ def testPlotAlphaImage(self):
+ """Test with an alpha image layer"""
+ data = numpy.random.random((10, 10))
+ alpha = numpy.linspace(0, 1, 100).reshape(10, 10)
+ self.plot.addImage(data, legend="image")
+ image = self.plot.getActiveImage()
+ image.setData(data, alpha=alpha)
+ self.qapp.processEvents()
+ self.assertTrue(numpy.array_equal(alpha, image.getAlphaData()))
+
+
+class TestPlotCurve(PlotWidgetTestCase):
+ """Basic tests for addCurve."""
+
+ # Test data sets
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ xData2 = xData + 1000
+ yData2 = xData - 1000 + 200 * numpy.random.random(1000)
+
+ def setUp(self):
+ super(TestPlotCurve, self).setUp()
+ self.plot.setGraphTitle("Curve")
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
+
+ self.plot.setActiveCurveHandling(False)
+
+ def testPlotCurveInfinite(self):
+ """Test plot curves with not finite data"""
+ tests = {
+ "y all not finite": ([0, 1, 2], [numpy.inf, numpy.nan, -numpy.inf]),
+ "x all not finite": ([numpy.inf, numpy.nan, -numpy.inf], [0, 1, 2]),
+ "x some inf": ([0, numpy.inf, 2], [0, 1, 2]),
+ "y some inf": ([0, 1, 2], [0, numpy.inf, 2]),
+ }
+ for name, args in tests.items():
+ with self.subTest(name):
+ self.plot.addCurve(*args)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+ self.plot.clear()
+
+ def testPlotCurveColorFloat(self):
+ color = numpy.array(numpy.random.random(3 * 1000), dtype=numpy.float32).reshape(
+ 1000, 3
+ )
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 1",
+ replace=False,
+ resetzoom=False,
+ color=color,
+ linestyle="",
+ symbol="s",
+ )
+ self.plot.addCurve(
+ self.xData2,
+ self.yData2,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ self.plot.resetZoom()
+
+ def testPlotCurveColorByte(self):
+ color = numpy.array(
+ 255 * numpy.random.random(3 * 1000), dtype=numpy.uint8
+ ).reshape(1000, 3)
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 1",
+ replace=False,
+ resetzoom=False,
+ color=color,
+ linestyle="",
+ symbol="s",
+ )
+ self.plot.addCurve(
+ self.xData2,
+ self.yData2,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ self.plot.resetZoom()
+
+ def testPlotCurveColors(self):
+ color = numpy.array(numpy.random.random(3 * 1000), dtype=numpy.float32).reshape(
+ 1000, 3
+ )
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color=color,
+ linestyle="-",
+ symbol="o",
+ )
+ self.plot.resetZoom()
+
+ # Test updating color array
+
+ # From array to array
+ newColors = numpy.ones((len(self.xData), 3), dtype=numpy.float32)
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color=newColors,
+ symbol="o",
+ )
+
+ # Array to single color
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ symbol="o",
+ )
+
+ # single color to array
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color=color,
+ symbol="o",
+ )
+
+ def testPlotBaselineNumpyArray(self):
+ """simple test of the API with baseline as a numpy array"""
+ x = numpy.arange(0, 10, step=0.1)
+ my_sin = numpy.sin(x)
+ y = numpy.arange(-4, 6, step=0.1) + my_sin
+ baseline = y - 1.0
+
+ self.plot.addCurve(
+ x=x, y=y, color="grey", legend="curve1", fill=True, baseline=baseline
+ )
+
+ def testPlotBaselineScalar(self):
+ """simple test of the API with baseline as an int"""
+ x = numpy.arange(0, 10, step=0.1)
+ my_sin = numpy.sin(x)
+ y = numpy.arange(-4, 6, step=0.1) + my_sin
+
+ self.plot.addCurve(
+ x=x, y=y, color="grey", legend="curve1", fill=True, baseline=0
+ )
+
+ def testPlotBaselineList(self):
+ """simple test of the API with baseline as an int"""
+ x = numpy.arange(0, 10, step=0.1)
+ my_sin = numpy.sin(x)
+ y = numpy.arange(-4, 6, step=0.1) + my_sin
+
+ self.plot.addCurve(
+ x=x,
+ y=y,
+ color="grey",
+ legend="curve1",
+ fill=True,
+ baseline=list(range(0, 100, 1)),
+ )
+
+ def testPlotCurveComplexData(self):
+ """Test curve with complex data"""
+ data = numpy.arange(100.0) + 1j
+ self.plot.addCurve(x=data, y=data, xerror=data, yerror=data)
+
+ def testPlotCurveGapColor(self):
+ """Test dashed curve with gap color"""
+ data = numpy.arange(100)
+ self.plot.addCurve(
+ x=data, y=data, legend="curve1", linestyle="--", color="blue"
+ )
+ curve = self.plot.getCurve("curve1")
+ assert curve.getLineGapColor() is None
+ curve.setLineGapColor("red")
+ assert curve.getLineGapColor() == (1.0, 0.0, 0.0, 1.0)
+
+
+class TestPlotHistogram(PlotWidgetTestCase):
+ """Basic tests for add Histogram"""
+
+ def setUp(self):
+ super(TestPlotHistogram, self).setUp()
+ self.edges = numpy.arange(0, 10, step=1)
+ self.histogram = numpy.random.random(len(self.edges))
+
+ def testPlot(self):
+ self.plot.addHistogram(
+ histogram=self.histogram, edges=self.edges, legend="histogram1"
+ )
+
+ def testPlotBaseline(self):
+ self.plot.addHistogram(
+ histogram=self.histogram,
+ edges=self.edges,
+ legend="histogram1",
+ color="blue",
+ baseline=-2,
+ z=2,
+ fill=True,
+ )
+
+ def testPlotGapColor(self):
+ """Test dashed histogram with gap color"""
+ data = numpy.arange(100)
+ self.plot.addHistogram(
+ histogram=self.histogram,
+ edges=self.edges,
+ legend="histogram1",
+ color="blue",
+ )
+ histogram = self.plot.getItems()[0]
+ assert histogram.getLineGapColor() is None
+ histogram.setLineGapColor("red")
+ assert histogram.getLineGapColor() == (1.0, 0.0, 0.0, 1.0)
+ histogram.setLineStyle(":")
+
+
+class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for addScatter"""
+
+ def testScatter(self):
+ x = numpy.arange(100)
+ y = numpy.arange(100)
+ value = numpy.arange(100)
+ self.plot.addScatter(x, y, value)
+ self.plot.resetZoom()
+
+ def testScatterComplexData(self):
+ """Test scatter item with complex data"""
+ data = numpy.arange(100.0) + 1j
+ self.plot.addScatter(x=data, y=data, value=data, xerror=data, yerror=data)
+ self.plot.resetZoom()
+
+ def testScatterVisualization(self):
+ self.plot.addScatter((0, 1, 0, 1), (0, 0, 2, 2), (0, 1, 2, 3))
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ scatter = self.plot.getItems()[0]
+
+ for visualization in (
+ "solid",
+ "points",
+ "regular_grid",
+ "irregular_grid",
+ "binned_statistic",
+ scatter.Visualization.SOLID,
+ scatter.Visualization.POINTS,
+ scatter.Visualization.REGULAR_GRID,
+ scatter.Visualization.IRREGULAR_GRID,
+ scatter.Visualization.BINNED_STATISTIC,
+ ):
+ with self.subTest(visualization=visualization):
+ scatter.setVisualization(visualization)
+ self.qapp.processEvents()
+
+ def testGridVisualization(self):
+ """Test regular and irregular grid mode with different points"""
+ points = { # name: (x, y, order)
+ "single point": ((1.0,), (1.0,), "row"),
+ "horizontal line": ((0, 1, 2), (0, 0, 0), "row"),
+ "horizontal line backward": ((2, 1, 0), (0, 0, 0), "row"),
+ "vertical line": ((0, 0, 0), (0, 1, 2), "row"),
+ "vertical line backward": ((0, 0, 0), (2, 1, 0), "row"),
+ "grid fast x, +x +y": ((0, 1, 2, 0, 1, 2), (0, 0, 0, 1, 1, 1), "row"),
+ "grid fast x, +x -y": ((0, 1, 2, 0, 1, 2), (1, 1, 1, 0, 0, 0), "row"),
+ "grid fast x, -x -y": ((2, 1, 0, 2, 1, 0), (1, 1, 1, 0, 0, 0), "row"),
+ "grid fast x, -x +y": ((2, 1, 0, 2, 1, 0), (0, 0, 0, 1, 1, 1), "row"),
+ "grid fast y, +x +y": ((0, 0, 0, 1, 1, 1), (0, 1, 2, 0, 1, 2), "column"),
+ "grid fast y, +x -y": ((0, 0, 0, 1, 1, 1), (2, 1, 0, 2, 1, 0), "column"),
+ "grid fast y, -x -y": ((1, 1, 1, 0, 0, 0), (2, 1, 0, 2, 1, 0), "column"),
+ "grid fast y, -x +y": ((1, 1, 1, 0, 0, 0), (0, 1, 2, 0, 1, 2), "column"),
+ }
+
+ self.plot.addScatter((), (), ())
+ scatter = self.plot.getItems()[0]
+
+ self.qapp.processEvents()
+
+ for visualization in (
+ scatter.Visualization.REGULAR_GRID,
+ scatter.Visualization.IRREGULAR_GRID,
+ ):
+ scatter.setVisualization(visualization)
+ self.assertIs(scatter.getVisualization(), visualization)
+
+ for name, (x, y, ref_order) in points.items():
+ with self.subTest(name=name, visualization=visualization.name):
+ scatter.setData(x, y, numpy.arange(len(x)))
+ self.plot.setGraphTitle(name)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ order = scatter.getCurrentVisualizationParameter(
+ scatter.VisualizationParameter.GRID_MAJOR_ORDER
+ )
+ self.assertEqual(ref_order, order)
+
+ ref_bounds = (x[0], y[0]), (x[-1], y[-1])
+ bounds = scatter.getCurrentVisualizationParameter(
+ scatter.VisualizationParameter.GRID_BOUNDS
+ )
+ self.assertEqual(ref_bounds, bounds)
+
+ shape = scatter.getCurrentVisualizationParameter(
+ scatter.VisualizationParameter.GRID_SHAPE
+ )
+
+ self.plot.getXAxis().setLimits(numpy.min(x) - 1, numpy.max(x) + 1)
+ self.plot.getYAxis().setLimits(numpy.min(y) - 1, numpy.max(y) + 1)
+ self.qapp.processEvents()
+
+ for index, position in enumerate(zip(x, y)):
+ xpixel, ypixel = self.plot.dataToPixel(*position)
+ result = scatter.pick(xpixel, ypixel)
+ self.assertIsNotNone(result)
+ self.assertIs(result.getItem(), scatter)
+ self.assertEqual(result.getIndices(), (index,))
+
+ def testBinnedStatisticVisualization(self):
+ """Test binned display"""
+ self.plot.addScatter((), (), ())
+ scatter = self.plot.getItems()[0]
+ scatter.setVisualization(scatter.Visualization.BINNED_STATISTIC)
+ self.assertIs(
+ scatter.getVisualization(), scatter.Visualization.BINNED_STATISTIC
+ )
+ self.assertEqual(
+ scatter.getVisualizationParameter(
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ),
+ "mean",
+ )
+
+ self.qapp.processEvents()
+
+ scatter.setData(*numpy.random.random(300).reshape(3, -1))
+ self.qapp.processEvents()
+
+ # Update data
+ scatter.setData(*numpy.random.random(3000).reshape(3, -1))
+ self.qapp.processEvents()
+
+ for reduction in ("count", "sum", "mean"):
+ with self.subTest(reduction=reduction):
+ scatter.setVisualizationParameter(
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION, reduction
+ )
+ self.assertEqual(
+ scatter.getVisualizationParameter(
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ),
+ reduction,
+ )
+
+ self.qapp.processEvents()
+
+
+class TestPlotMarker(PlotWidgetTestCase):
+ """Basic tests for add*Marker"""
+
+ def setUp(self):
+ super(TestPlotMarker, self).setUp()
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
+
+ self.plot.getXAxis().setAutoScale(False)
+ self.plot.getYAxis().setAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(0.0, 100.0, -100.0, 100.0)
+
+ def testPlotMarkerX(self):
+ self.plot.setGraphTitle("Markers X")
+
+ markers = [
+ (10.0, "blue", False, False),
+ (20.0, "red", False, False),
+ (40.0, "green", True, False),
+ (60.0, "gray", True, True),
+ (80.0, "black", False, True),
+ ]
+
+ for x, color, select, drag in markers:
+ name = str(x)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addXMarker(x, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerY(self):
+ self.plot.setGraphTitle("Markers Y")
+
+ markers = [
+ (-50.0, "blue", False, False),
+ (-30.0, "red", False, False),
+ (0.0, "green", True, False),
+ (10.0, "gray", True, True),
+ (80.0, "black", False, True),
+ ]
+
+ for y, color, select, drag in markers:
+ name = str(y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addYMarker(y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerPt(self):
+ self.plot.setGraphTitle("Markers Pt")
+
+ markers = [
+ (10.0, -50.0, "blue", False, False),
+ (40.0, -30.0, "red", False, False),
+ (50.0, 0.0, "green", True, False),
+ (50.0, 20.0, "gray", True, True),
+ (70.0, 50.0, "black", False, True),
+ ]
+ for x, y, color, select, drag in markers:
+ name = "{0},{1}".format(x, y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addMarker(x, y, name, name, color, select, drag)
+
+ self.plot.resetZoom()
+
+ def testPlotMarkerWithoutLegend(self):
+ self.plot.setGraphTitle("Markers without legend")
+ self.plot.getYAxis().setInverted(True)
+
+ # Markers without legend
+ self.plot.addMarker(10, 10)
+ self.plot.addMarker(10, 20)
+ self.plot.addMarker(40, 50, text="test", symbol=None)
+ self.plot.addMarker(40, 50, text="test", symbol="+")
+ self.plot.addXMarker(25)
+ self.plot.addXMarker(35)
+ self.plot.addXMarker(45, text="test")
+ self.plot.addYMarker(55)
+ self.plot.addYMarker(65)
+ self.plot.addYMarker(75, text="test")
+
+ self.plot.resetZoom()
+
+ def testPlotMarkerYAxis(self):
+ # Check only the API
+
+ item = self.plot.addMarker(10, 10)
+ self.assertEqual(item.getYAxis(), "left")
+
+ item = self.plot.addMarker(10, 10, yaxis="right")
+ self.assertEqual(item.getYAxis(), "right")
+
+ item = self.plot.addMarker(10, 10, yaxis="left")
+ self.assertEqual(item.getYAxis(), "left")
+
+ item = self.plot.addXMarker(10, yaxis="right")
+ self.assertEqual(item.getYAxis(), "right")
+
+ item = self.plot.addXMarker(10, yaxis="left")
+ self.assertEqual(item.getYAxis(), "left")
+
+ item = self.plot.addYMarker(10, yaxis="right")
+ self.assertEqual(item.getYAxis(), "right")
+
+ item = self.plot.addYMarker(10, yaxis="left")
+ self.assertEqual(item.getYAxis(), "left")
+
+ self.plot.resetZoom()
+
+
+# TestPlotItem ################################################################
+
+
+class TestPlotItem(PlotWidgetTestCase):
+ """Basic tests for addItem."""
+
+ # Polygon coordinates and color
+ POLYGONS = [ # legend, x coords, y coords, color
+ ("triangle", numpy.array((10, 30, 50)), numpy.array((55, 70, 55)), "red"),
+ (
+ "square",
+ numpy.array((10, 10, 50, 50)),
+ numpy.array((10, 50, 50, 10)),
+ "green",
+ ),
+ (
+ "star",
+ numpy.array((60, 70, 80, 60, 80)),
+ numpy.array((25, 50, 25, 40, 40)),
+ "blue",
+ ),
+ (
+ "2 triangles-simple",
+ numpy.array((90.0, 95.0, 100.0, numpy.nan, 90.0, 95.0, 100.0)),
+ numpy.array((25.0, 5.0, 25.0, numpy.nan, 30.0, 50.0, 30.0)),
+ "pink",
+ ),
+ (
+ "2 triangles-extra NaN",
+ numpy.array(
+ (
+ numpy.nan,
+ 90.0,
+ 95.0,
+ 100.0,
+ numpy.nan,
+ 0.0,
+ 90.0,
+ 95.0,
+ 100.0,
+ numpy.nan,
+ )
+ ),
+ numpy.array(
+ (
+ 0.0,
+ 55.0,
+ 70.0,
+ 55.0,
+ numpy.nan,
+ numpy.nan,
+ 75.0,
+ 90.0,
+ 75.0,
+ numpy.nan,
+ )
+ ),
+ "black",
+ ),
+ ]
+
+ # Rectangle coordinantes and color
+ RECTANGLES = [ # legend, x coords, y coords, color
+ ("square 1", numpy.array((1.0, 10.0)), numpy.array((1.0, 10.0)), "red"),
+ ("square 2", numpy.array((10.0, 20.0)), numpy.array((10.0, 20.0)), "green"),
+ ("square 3", numpy.array((20.0, 30.0)), numpy.array((20.0, 30.0)), "blue"),
+ ("rect 1", numpy.array((1.0, 30.0)), numpy.array((35.0, 40.0)), "black"),
+ ("line h", numpy.array((1.0, 30.0)), numpy.array((45.0, 45.0)), "darkRed"),
+ ]
+
+ SCALES = Axis.LINEAR, Axis.LOGARITHMIC
+
+ def setUp(self):
+ super(TestPlotItem, self).setUp()
+
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
+ self.plot.getXAxis().setAutoScale(False)
+ self.plot.getYAxis().setAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(0.0, 100.0, -100.0, 100.0)
+
+ def testPlotItemPolygonFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle("Item Fill %s" % scale)
+
+ for legend, xList, yList, color in self.POLYGONS:
+ self.plot.addShape(
+ xList,
+ yList,
+ legend=legend,
+ replace=False,
+ linestyle="--",
+ shape="polygon",
+ fill=True,
+ color=color,
+ )
+ self.plot.resetZoom()
+
+ def testPlotItemPolygonNoFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle("Item No Fill %s" % scale)
+
+ for legend, xList, yList, color in self.POLYGONS:
+ self.plot.addShape(
+ xList,
+ yList,
+ legend=legend,
+ replace=False,
+ linestyle="--",
+ shape="polygon",
+ fill=False,
+ color=color,
+ )
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle("Rectangle Fill %s" % scale)
+
+ for legend, xList, yList, color in self.RECTANGLES:
+ self.plot.addShape(
+ xList,
+ yList,
+ legend=legend,
+ replace=False,
+ shape="rectangle",
+ fill=True,
+ color=color,
+ )
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleNoFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle("Rectangle No Fill %s" % scale)
+
+ for legend, xList, yList, color in self.RECTANGLES:
+ self.plot.addShape(
+ xList,
+ yList,
+ legend=legend,
+ replace=False,
+ shape="rectangle",
+ fill=False,
+ color=color,
+ )
+ self.plot.resetZoom()
+
+
+##############################################################################
+# Log
+##############################################################################
+
+
+class TestPlotEmptyLog(PlotWidgetTestCase):
+ """Basic tests for log plot"""
+
+ def testEmptyPlotTitleLabelsLog(self):
+ self.plot.setGraphTitle("Empty Log Log")
+ self.plot.getXAxis().setLabel("X")
+ self.plot.getYAxis().setLabel("Y")
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.resetZoom()
+
+
+class TestPlotAxes(TestCaseQt, ParametricTestCase):
+ # Test data
+ xData = numpy.arange(1, 10)
+ yData = xData**2
+
+ def __init__(self, methodName="runTest", backend=None):
+ unittest.TestCase.__init__(self, methodName)
+ self.__backend = backend
+
+ def setUp(self):
+ super(TestPlotAxes, self).setUp()
+ self.plot = PlotWidget(backend=self.__backend)
+ # It is not needed to display the plot
+ # It saves a lot of time
+ # self.plot.show()
+ # self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestPlotAxes, self).tearDown()
+
+ def testDefaultAxes(self):
+ axis = self.plot.getXAxis()
+ self.assertEqual(axis.getScale(), axis.LINEAR)
+ axis = self.plot.getYAxis()
+ self.assertEqual(axis.getScale(), axis.LINEAR)
+ axis = self.plot.getYAxis(axis="right")
+ self.assertEqual(axis.getScale(), axis.LINEAR)
+
+ def testOldPlotAxis_getterSetter(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ p = self.plot
+
+ tests = [
+ # setters
+ (p.setGraphXLimits, (10, 20), x.getLimits, (10, 20)),
+ (p.setGraphYLimits, (10, 20), y.getLimits, (10, 20)),
+ (p.setGraphXLabel, "foox", x.getLabel, "foox"),
+ (p.setGraphYLabel, "fooy", y.getLabel, "fooy"),
+ (p.setYAxisInverted, True, y.isInverted, True),
+ (p.setXAxisLogarithmic, True, x.getScale, x.LOGARITHMIC),
+ (p.setYAxisLogarithmic, True, y.getScale, y.LOGARITHMIC),
+ (p.setXAxisAutoScale, False, x.isAutoScale, False),
+ (p.setYAxisAutoScale, False, y.isAutoScale, False),
+ # getters
+ (x.setLimits, (11, 20), p.getGraphXLimits, (11, 20)),
+ (y.setLimits, (11, 20), p.getGraphYLimits, (11, 20)),
+ (x.setLabel, "fooxx", p.getGraphXLabel, "fooxx"),
+ (y.setLabel, "fooyy", p.getGraphYLabel, "fooyy"),
+ (y.setInverted, False, p.isYAxisInverted, False),
+ (x.setScale, x.LINEAR, p.isXAxisLogarithmic, False),
+ (y.setScale, y.LINEAR, p.isYAxisLogarithmic, False),
+ (x.setAutoScale, True, p.isXAxisAutoScale, True),
+ (y.setAutoScale, True, p.isYAxisAutoScale, True),
+ ]
+ for testCase in tests:
+ setter, value, getter, expected = testCase
+ with self.subTest():
+ if setter is not None:
+ if not isinstance(value, tuple):
+ value = (value,)
+ setter(*value)
+ if getter is not None:
+ self.assertEqual(getter(), expected)
+
+ def testOldPlotAxis_Logarithmic(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+
+ self.assertEqual(x.getScale(), x.LINEAR)
+ self.assertEqual(y.getScale(), x.LINEAR)
+ self.assertEqual(yright.getScale(), x.LINEAR)
+
+ self.plot.setXAxisLogarithmic(True)
+ self.assertEqual(x.getScale(), x.LOGARITHMIC)
+ self.assertEqual(y.getScale(), x.LINEAR)
+ self.assertEqual(yright.getScale(), x.LINEAR)
+ self.assertEqual(self.plot.isXAxisLogarithmic(), True)
+ self.assertEqual(self.plot.isYAxisLogarithmic(), False)
+
+ self.plot.setYAxisLogarithmic(True)
+ self.assertEqual(x.getScale(), x.LOGARITHMIC)
+ self.assertEqual(y.getScale(), x.LOGARITHMIC)
+ self.assertEqual(yright.getScale(), x.LOGARITHMIC)
+ self.assertEqual(self.plot.isXAxisLogarithmic(), True)
+ self.assertEqual(self.plot.isYAxisLogarithmic(), True)
+
+ yright.setScale(yright.LINEAR)
+ self.assertEqual(x.getScale(), x.LOGARITHMIC)
+ self.assertEqual(y.getScale(), x.LINEAR)
+ self.assertEqual(yright.getScale(), x.LINEAR)
+ self.assertEqual(self.plot.isXAxisLogarithmic(), True)
+ self.assertEqual(self.plot.isYAxisLogarithmic(), False)
+
+ def testOldPlotAxis_AutoScale(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+
+ self.assertEqual(x.isAutoScale(), True)
+ self.assertEqual(y.isAutoScale(), True)
+ self.assertEqual(yright.isAutoScale(), True)
+
+ self.plot.setXAxisAutoScale(False)
+ self.assertEqual(x.isAutoScale(), False)
+ self.assertEqual(y.isAutoScale(), True)
+ self.assertEqual(yright.isAutoScale(), True)
+ self.assertEqual(self.plot.isXAxisAutoScale(), False)
+ self.assertEqual(self.plot.isYAxisAutoScale(), True)
+
+ self.plot.setYAxisAutoScale(False)
+ self.assertEqual(x.isAutoScale(), False)
+ self.assertEqual(y.isAutoScale(), False)
+ self.assertEqual(yright.isAutoScale(), False)
+ self.assertEqual(self.plot.isXAxisAutoScale(), False)
+ self.assertEqual(self.plot.isYAxisAutoScale(), False)
+
+ yright.setAutoScale(True)
+ self.assertEqual(x.isAutoScale(), False)
+ self.assertEqual(y.isAutoScale(), True)
+ self.assertEqual(yright.isAutoScale(), True)
+ self.assertEqual(self.plot.isXAxisAutoScale(), False)
+ self.assertEqual(self.plot.isYAxisAutoScale(), True)
+
+ def testOldPlotAxis_Inverted(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+
+ self.assertEqual(x.isInverted(), False)
+ self.assertEqual(y.isInverted(), False)
+ self.assertEqual(yright.isInverted(), False)
+
+ self.plot.setYAxisInverted(True)
+ self.assertEqual(x.isInverted(), False)
+ self.assertEqual(y.isInverted(), True)
+ self.assertEqual(yright.isInverted(), True)
+ self.assertEqual(self.plot.isYAxisInverted(), True)
+
+ yright.setInverted(False)
+ self.assertEqual(x.isInverted(), False)
+ self.assertEqual(y.isInverted(), False)
+ self.assertEqual(yright.isInverted(), False)
+ self.assertEqual(self.plot.isYAxisInverted(), False)
+
+ def testLogXWithData(self):
+ self.plot.setGraphTitle("Curve X: Log Y: Linear")
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ axis = self.plot.getXAxis()
+ axis.setScale(axis.LOGARITHMIC)
+
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+
+ def testLogYWithData(self):
+ self.plot.setGraphTitle("Curve X: Linear Y: Log")
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ axis = self.plot.getYAxis()
+ axis.setScale(axis.LOGARITHMIC)
+
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+ axis = self.plot.getYAxis(axis="right")
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+
+ def testLogYRightWithData(self):
+ self.plot.setGraphTitle("Curve X: Linear Y: Log")
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ axis = self.plot.getYAxis(axis="right")
+ axis.setScale(axis.LOGARITHMIC)
+
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+ axis = self.plot.getYAxis()
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+
+ def testLimitsChanged_setLimits(self):
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ listener = SignalListener()
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
+ self.plot.getYAxis(axis="right").sigLimitsChanged.connect(
+ listener.partial(axis="y2")
+ )
+ self.plot.setLimits(0, 1, 0, 1, 0, 1)
+ # at least one event per axis
+ self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
+
+ def testLimitsChanged_resetZoom(self):
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ listener = SignalListener()
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
+ self.plot.getYAxis(axis="right").sigLimitsChanged.connect(
+ listener.partial(axis="y2")
+ )
+ self.plot.resetZoom()
+ # at least one event per axis
+ self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
+
+ def testLimitsChanged_setXLimit(self):
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ listener = SignalListener()
+ axis = self.plot.getXAxis()
+ axis.sigLimitsChanged.connect(listener)
+ axis.setLimits(20, 30)
+ # at least one event per axis
+ self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
+ self.assertEqual(axis.getLimits(), (20.0, 30.0))
+
+ def testLimitsChanged_setYLimit(self):
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ listener = SignalListener()
+ axis = self.plot.getYAxis()
+ axis.sigLimitsChanged.connect(listener)
+ axis.setLimits(20, 30)
+ # at least one event per axis
+ self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
+ self.assertEqual(axis.getLimits(), (20.0, 30.0))
+
+ def testLimitsChanged_setYRightLimit(self):
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+ listener = SignalListener()
+ axis = self.plot.getYAxis(axis="right")
+ axis.sigLimitsChanged.connect(listener)
+ axis.setLimits(20, 30)
+ # at least one event per axis
+ self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
+ self.assertEqual(axis.getLimits(), (20.0, 30.0))
+
+ def testScaleProxy(self):
+ listener = SignalListener()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+ y.sigScaleChanged.connect(listener.partial("left"))
+ yright.sigScaleChanged.connect(listener.partial("right"))
+ yright.setScale(yright.LOGARITHMIC)
+
+ self.assertEqual(y.getScale(), y.LOGARITHMIC)
+ events = listener.arguments()
+ self.assertEqual(len(events), 2)
+ self.assertIn(("left", y.LOGARITHMIC), events)
+ self.assertIn(("right", y.LOGARITHMIC), events)
+
+ def testAutoScaleProxy(self):
+ listener = SignalListener()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+ y.sigAutoScaleChanged.connect(listener.partial("left"))
+ yright.sigAutoScaleChanged.connect(listener.partial("right"))
+ yright.setAutoScale(False)
+
+ self.assertEqual(y.isAutoScale(), False)
+ events = listener.arguments()
+ self.assertEqual(len(events), 2)
+ self.assertIn(("left", False), events)
+ self.assertIn(("right", False), events)
+
+ def testInvertedProxy(self):
+ listener = SignalListener()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+ y.sigInvertedChanged.connect(listener.partial("left"))
+ yright.sigInvertedChanged.connect(listener.partial("right"))
+ yright.setInverted(True)
+
+ self.assertEqual(y.isInverted(), True)
+ events = listener.arguments()
+ self.assertEqual(len(events), 2)
+ self.assertIn(("left", True), events)
+ self.assertIn(("right", True), events)
+
+ def testAxesDisplayedFalse(self):
+ """Test coverage on setAxesDisplayed(False)"""
+ self.plot.setAxesDisplayed(False)
+
+ def testAxesDisplayedTrue(self):
+ """Test coverage on setAxesDisplayed(True)"""
+ self.plot.setAxesDisplayed(True)
+
+ def testAxesMargins(self):
+ """Test PlotWidget's getAxesMargins and setAxesMargins"""
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ margins = self.plot.getAxesMargins()
+ self.assertEqual(margins, (0.15, 0.1, 0.1, 0.15))
+
+ for margins in ((0.0, 0.0, 0.0, 0.0), (0.15, 0.1, 0.1, 0.15)):
+ with self.subTest(margins=margins):
+ self.plot.setAxesMargins(*margins)
+ self.qapp.processEvents()
+ self.assertEqual(self.plot.getAxesMargins(), margins)
+
+ def testBoundingRectItem(self):
+ item = BoundingRect()
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.addItem(item)
+ self.plot.resetZoom()
+ limits = numpy.array(self.plot.getXAxis().getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000]))
+ limits = numpy.array(self.plot.getYAxis().getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000]))
+
+ def testBoundingRectRightItem(self):
+ item = BoundingRect()
+ item.setYAxis("right")
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.addItem(item)
+ self.plot.resetZoom()
+ limits = numpy.array(self.plot.getXAxis().getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000]))
+ limits = numpy.array(self.plot.getYAxis("right").getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000]))
+
+ def testBoundingRectArguments(self):
+ item = BoundingRect()
+ with self.assertRaises(Exception):
+ item.setBounds((1000, -1000, -2000, 2000))
+ with self.assertRaises(Exception):
+ item.setBounds((-1000, 1000, 2000, -2000))
+
+ def testBoundingRectWithLog(self):
+ item = BoundingRect()
+ self.plot.addItem(item)
+
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(False)
+ self.assertEqual(item.getBounds(), (1000, 1000, -2000, 2000))
+
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.getXAxis()._setLogarithmic(False)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.assertEqual(item.getBounds(), (-1000, 1000, 2000, 2000))
+
+ item.setBounds((-1000, 0, -2000, 2000))
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(False)
+ self.assertIsNone(item.getBounds())
+
+ def testAxisExtent(self):
+ """Test XAxisExtent and yAxisExtent"""
+ for cls, axis in (
+ (XAxisExtent, self.plot.getXAxis()),
+ (YAxisExtent, self.plot.getYAxis()),
+ ):
+ for range_, logRange in (
+ ((2, 3), (2, 3)),
+ ((-2, -1), (1, 100)),
+ ((-1, 3), (3.0 * 0.9, 3.0 * 1.1)),
+ ):
+ extent = cls()
+ extent.setRange(*range_)
+ self.plot.addItem(extent)
+
+ for isLog, plotRange in ((False, range_), (True, logRange)):
+ with self.subTest(cls=cls.__name__, range=range_, isLog=isLog):
+ axis._setLogarithmic(isLog)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+ self.assertEqual(axis.getLimits(), plotRange)
+
+ axis._setLogarithmic(False)
+ self.plot.clear()
+
+ def testAxisLimitOverflow(self):
+ """Test setting limis beyond supported range"""
+ xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis()
+ for scale in ("linear", "log"):
+ xaxis.setScale(scale)
+ yaxis.setScale(scale)
+ for limits in ((1e300, 1e308), (-1e308, 1e308), (1e-300, 2e-300)):
+ with self.subTest(scale=scale, limits=limits):
+ xaxis.setLimits(*limits)
+ self.qapp.processEvents()
+ self.assertNotEqual(xaxis.getLimits(), limits)
+ yaxis.setLimits(*limits)
+ self.qapp.processEvents()
+ self.assertNotEqual(yaxis.getLimits(), limits)
+
+
+class TestPlotCurveLog(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for addCurve with log scale axes"""
+
+ # Test data
+ xData = numpy.arange(1000) + 1
+ yData = xData**2
+
+ def _setLabels(self):
+ self.plot.getXAxis().setLabel("X")
+ self.plot.getYAxis().setLabel("X * X")
+
+ def testPlotCurveLogX(self):
+ self._setLabels()
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle("Curve X: Log Y: Linear")
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+
+ def testPlotCurveLogY(self):
+ self._setLabels()
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ self.plot.setGraphTitle("Curve X: Linear Y: Log")
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+
+ def testPlotCurveLogXY(self):
+ self._setLabels()
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ self.plot.setGraphTitle("Curve X: Log Y: Log")
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+
+ def testPlotCurveErrorLogXY(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ # Every second error leads to negative number
+ errors = numpy.ones_like(self.xData)
+ errors[::2] = self.xData[::2] + 1
+
+ tests = [ # name, xerror, yerror
+ ("xerror=3", 3, None),
+ ("xerror=N array", errors, None),
+ ("xerror=Nx1 array", errors.reshape(len(errors), 1), None),
+ ("xerror=2xN array", numpy.array((errors, errors)), None),
+ ("yerror=6", None, 6),
+ ("yerror=N array", None, errors**2),
+ ("yerror=Nx1 array", None, (errors**2).reshape(len(errors), 1)),
+ ("yerror=2xN array", None, numpy.array((errors, errors)) ** 2),
+ ]
+
+ for name, xError, yError in tests:
+ with self.subTest(name):
+ self.plot.setGraphTitle(name)
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend=name,
+ xerror=xError,
+ yerror=yError,
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
+
+ self.qapp.processEvents()
+
+ if xError is None:
+ dataMin, dataMax = numpy.min(self.xData), numpy.max(self.xData)
+ else:
+ xMinusError = self.xData - numpy.atleast_2d(xError)[0]
+ dataMin = numpy.min(xMinusError[xMinusError > 0])
+ xPlusError = self.xData + numpy.atleast_2d(xError)[-1]
+ dataMax = numpy.max(xPlusError[xPlusError > 0])
+ plotMin, plotMax = self.plot.getXAxis().getLimits()
+ assert numpy.allclose((dataMin, dataMax), (plotMin, plotMax))
+
+ if yError is None:
+ dataMin, dataMax = numpy.min(self.yData), numpy.max(self.yData)
+ else:
+ yMinusError = self.yData - numpy.atleast_2d(yError)[0]
+ dataMin = numpy.min(yMinusError[yMinusError > 0])
+ yPlusError = self.yData + numpy.atleast_2d(yError)[-1]
+ dataMax = numpy.max(yPlusError[yPlusError > 0])
+ plotMin, plotMax = self.plot.getYAxis().getLimits()
+ assert numpy.allclose((dataMin, dataMax), (plotMin, plotMax))
+
+ self.plot.clear()
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ def testPlotCurveToggleLog(self):
+ """Add a curve with negative data and toggle log axis"""
+ arange = numpy.arange(1000) + 1
+ tests = [ # name, xData, yData
+ ("x>0, some negative y", arange, arange - 500),
+ ("x>0, y<0", arange, -arange),
+ ("some negative x, y>0", arange - 500, arange),
+ ("x<0, y>0", -arange, arange),
+ ("some negative x and y", arange - 500, arange - 500),
+ ("x<0, y<0", -arange, -arange),
+ ]
+
+ for name, xData, yData in tests:
+ with self.subTest(name):
+ self.plot.addCurve(xData, yData, resetzoom=True)
+ self.qapp.processEvents()
+
+ # no log axis
+ xLim = self.plot.getXAxis().getLimits()
+ self.assertEqual(xLim, (min(xData), max(xData)))
+ yLim = self.plot.getYAxis().getLimits()
+ self.assertEqual(yLim, (min(yData), max(yData)))
+
+ # x axis log
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ yLim = self.plot.getYAxis().getLimits()
+ positives = xData > 0
+ if numpy.any(positives):
+ self.assertTrue(
+ numpy.allclose(
+ xLim, (min(xData[positives]), max(xData[positives]))
+ )
+ )
+ else: # No positive x in the curve
+ self.assertEqual(xLim, (1.0, 100.0))
+ self.assertEqual(yLim, (min(yData), max(yData)))
+
+ # x axis and y axis log
+ previousXLim = self.plot.getXAxis().getLimits()
+ previousYLim = self.plot.getYAxis().getLimits()
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ yLim = self.plot.getYAxis().getLimits()
+
+ self.assertEqual(xLim, previousXLim)
+ positives = numpy.logical_and(xData > 0, yData > 0)
+ if previousYLim[0] > 0:
+ self.assertEqual(yLim, previousYLim)
+ elif numpy.any(positives):
+ expectedLimits = min(yData[positives]), max(yData[positives])
+ self.assertTrue(
+ numpy.allclose(yLim, expectedLimits),
+ f"{yLim} != {expectedLimits}",
+ )
+ else: # No positive x and y in the curve
+ self.assertEqual(yLim, (1.0, 100.0))
+
+ # y axis log
+ previousXLim = self.plot.getXAxis().getLimits()
+ self.plot.getXAxis()._setLogarithmic(False)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ yLim = self.plot.getYAxis().getLimits()
+ self.assertEqual(xLim, previousXLim)
+ positives = yData > 0
+ if numpy.any(positives):
+ self.assertTrue(
+ numpy.allclose(
+ yLim, (min(yData[positives]), max(yData[positives]))
+ )
+ )
+ else: # No positive y in the curve
+ self.assertEqual(yLim, (1.0, 100.0))
+
+ # no log axis
+ previousXLim = self.plot.getXAxis().getLimits()
+ previousYLim = self.plot.getYAxis().getLimits()
+ self.plot.getYAxis()._setLogarithmic(False)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ self.assertEqual(xLim, previousXLim)
+ yLim = self.plot.getYAxis().getLimits()
+ self.assertEqual(yLim, previousYLim)
+
+ self.plot.clear()
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+
+class TestPlotImageLog(PlotWidgetTestCase):
+ """Basic tests for addImage with log scale axes."""
+
+ def setUp(self):
+ super(TestPlotImageLog, self).setUp()
+
+ self.plot.getXAxis().setLabel("Columns")
+ self.plot.getYAxis().setLabel("Rows")
+
+ def testPlotColormapGrayLogX(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle("CMap X: Log Y: Linear")
+
+ colormap = Colormap(name="gray", normalization="linear", vmin=None, vmax=None)
+ self.plot.addImage(
+ DATA_2D,
+ legend="image 1",
+ origin=(1.0, 1.0),
+ scale=(1.0, 1.0),
+ resetzoom=False,
+ colormap=colormap,
+ )
+ self.plot.resetZoom()
+
+ def testPlotColormapGrayLogY(self):
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle("CMap X: Linear Y: Log")
+
+ colormap = Colormap(name="gray", normalization="linear", vmin=None, vmax=None)
+ self.plot.addImage(
+ DATA_2D,
+ legend="image 1",
+ origin=(1.0, 1.0),
+ scale=(1.0, 1.0),
+ resetzoom=False,
+ colormap=colormap,
+ )
+ self.plot.resetZoom()
+
+ def testPlotColormapGrayLogXY(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle("CMap X: Log Y: Log")
+
+ colormap = Colormap(name="gray", normalization="linear", vmin=None, vmax=None)
+ self.plot.addImage(
+ DATA_2D,
+ legend="image 1",
+ origin=(1.0, 1.0),
+ scale=(1.0, 1.0),
+ resetzoom=False,
+ colormap=colormap,
+ )
+ self.plot.resetZoom()
+
+ def testPlotRgbRgbaLogXY(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle("RGB + RGBA X: Log Y: Log")
+
+ rgb = numpy.array(
+ (
+ ((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 255)),
+ ),
+ dtype=numpy.uint8,
+ )
+
+ self.plot.addImage(
+ rgb, legend="rgb", origin=(1, 1), scale=(10, 10), resetzoom=False
+ )
+
+ rgba = numpy.array(
+ (
+ ((0, 0, 0, 0.5), (0.5, 0, 0, 1), (1, 0, 0, 0.5)),
+ ((0, 0.5, 0, 1), (0, 0.5, 0.5, 1), (0, 1, 1, 0.5)),
+ ),
+ dtype=numpy.float32,
+ )
+
+ self.plot.addImage(
+ rgba, legend="rgba", origin=(5.0, 5.0), scale=(10.0, 10.0), resetzoom=False
+ )
+ self.plot.resetZoom()
+
+
+class TestPlotMarkerLog(PlotWidgetTestCase):
+ """Basic tests for markers on log scales"""
+
+ # Test marker parameters
+ markers = [ # x, y, color, selectable, draggable
+ (10.0, 10.0, "blue", False, False),
+ (20.0, 20.0, "red", False, False),
+ (40.0, 100.0, "green", True, False),
+ (40.0, 500.0, "gray", True, True),
+ (60.0, 800.0, "black", False, True),
+ ]
+
+ def setUp(self):
+ super(TestPlotMarkerLog, self).setUp()
+
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
+ self.plot.getXAxis().setAutoScale(False)
+ self.plot.getYAxis().setAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(1.0, 100.0, 1.0, 1000.0)
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ def testPlotMarkerXLog(self):
+ self.plot.setGraphTitle("Markers X, Log axes")
+
+ for x, _, color, select, drag in self.markers:
+ name = str(x)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addXMarker(x, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerYLog(self):
+ self.plot.setGraphTitle("Markers Y, Log axes")
+
+ for _, y, color, select, drag in self.markers:
+ name = str(y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addYMarker(y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerPtLog(self):
+ self.plot.setGraphTitle("Markers Pt, Log axes")
+
+ for x, y, color, select, drag in self.markers:
+ name = "{0},{1}".format(x, y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addMarker(x, y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+
+@pytest.mark.usefixtures("test_options_class_attr")
+class TestPlotWidgetSwitchBackend(PlotWidgetTestCase):
+ """Test [get|set]Backend to switch backend"""
+
+ @pytest.mark.usefixtures("test_options")
+ def testSwitchBackend(self):
+ """Test switching a plot with a few items"""
+ backends = {"none": "BackendBase", "mpl": "BackendMatplotlibQt"}
+ if self.test_options.WITH_GL_TEST:
+ backends["gl"] = "BackendOpenGL"
+
+ self.plot.addImage(numpy.arange(100).reshape(10, 10))
+ self.plot.addCurve((-3, -2, -1), (1, 2, 3))
+ self.plot.resetZoom()
+ xlimits = self.plot.getXAxis().getLimits()
+ ylimits = self.plot.getYAxis().getLimits()
+ items = self.plot.getItems()
+ self.assertEqual(len(items), 2)
+
+ for backend, className in backends.items():
+ with self.subTest(backend=backend):
+ self.plot.setBackend(backend)
+ self.plot.replot()
+
+ retrievedBackend = self.plot.getBackend()
+ self.assertEqual(type(retrievedBackend).__name__, className)
+ self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
+ self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
+ self.assertEqual(self.plot.getItems(), items)
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotWidget_Gl(TestPlotWidget):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotImage_Gl(TestPlotImage):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotCurve_Gl(TestPlotCurve):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotHistogram_Gl(TestPlotHistogram):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotScatter_Gl(TestPlotScatter):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotMarker_Gl(TestPlotMarker):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotItem_Gl(TestPlotItem):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotAxes_Gl(TestPlotAxes):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotEmptyLog_Gl(TestPlotEmptyLog):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotCurveLog_Gl(TestPlotCurveLog):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotImageLog_Gl(TestPlotImageLog):
+ backend = "gl"
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotMarkerLog_Gl(TestPlotMarkerLog):
+ backend = "gl"
+
+
+class TestSpecial_ExplicitMplBackend(TestSpecialBackend):
+ backend = "mpl"
diff --git a/src/silx/gui/plot/test/testPlotWidgetActiveItem.py b/src/silx/gui/plot/test/testPlotWidgetActiveItem.py
new file mode 100755
index 0000000..99285a8
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidgetActiveItem.py
@@ -0,0 +1,416 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test PlotWidget active item"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/12/2023"
+
+
+import numpy
+import pytest
+
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.plot.items.curve import CurveStyle
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testActiveCurveAndLabels(plotWidget):
+ # Active curve handling off, no label change
+ plotWidget.setActiveCurveHandling(False)
+ plotWidget.getXAxis().setLabel("XLabel")
+ plotWidget.getYAxis().setLabel("YLabel")
+ plotWidget.addCurve((1, 2), (1, 2))
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.addCurve((1, 2), (2, 3), xlabel="x1", ylabel="y1")
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.clear()
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ # Active curve handling on, label changes
+ plotWidget.setActiveCurveHandling(True)
+ plotWidget.getXAxis().setLabel("XLabel")
+ plotWidget.getYAxis().setLabel("YLabel")
+
+ # labels changed as active curve
+ plotWidget.addCurve((1, 2), (1, 2), legend="1", xlabel="x1", ylabel="y1")
+ plotWidget.setActiveCurve("1")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ # labels not changed as not active curve
+ plotWidget.addCurve((1, 2), (2, 3), legend="2")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ # labels changed
+ plotWidget.setActiveCurve("2")
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.setActiveCurve("1")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ plotWidget.clear()
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.setActiveCurveHandling(False)
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testPlotActiveCurveSelectionMode(plotWidget):
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ xData2 = xData + 1000
+ yData2 = xData - 1000 + 200 * numpy.random.random(1000)
+
+ plotWidget.clear()
+ plotWidget.setActiveCurveHandling(True)
+ legend = "curve 1"
+ plotWidget.addCurve(xData, yData, legend=legend, color="green")
+
+ # active curve should be None
+ assert plotWidget.getActiveCurve(just_legend=True) is None
+
+ # active curve should be None when None is set as active curve
+ plotWidget.setActiveCurve(legend)
+ current = plotWidget.getActiveCurve(just_legend=True)
+ assert current == legend
+ plotWidget.setActiveCurve(None)
+ current = plotWidget.getActiveCurve(just_legend=True)
+ assert current is None
+
+ # testing it automatically toggles if there is only one
+ plotWidget.setActiveCurveSelectionMode("legacy")
+ current = plotWidget.getActiveCurve(just_legend=True)
+ assert current == legend
+
+ # active curve should not change when None set as active curve
+ assert plotWidget.getActiveCurveSelectionMode() == "legacy"
+ plotWidget.setActiveCurve(None)
+ current = plotWidget.getActiveCurve(just_legend=True)
+ assert current == legend
+
+ # situation where no curve is active
+ plotWidget.clear()
+ plotWidget.setActiveCurveHandling(True)
+ assert plotWidget.getActiveCurveSelectionMode() == "atmostone"
+ plotWidget.addCurve(xData, yData, legend=legend, color="green")
+ assert plotWidget.getActiveCurve(just_legend=True) is None
+ plotWidget.addCurve(xData2, yData2, legend="curve 2", color="red")
+ assert plotWidget.getActiveCurve(just_legend=True) is None
+ plotWidget.setActiveCurveSelectionMode("legacy")
+ assert plotWidget.getActiveCurve(just_legend=True) is None
+
+ # the first curve added should be active
+ plotWidget.clear()
+ plotWidget.addCurve(xData, yData, legend=legend, color="green")
+ assert plotWidget.getActiveCurve(just_legend=True) == legend
+ plotWidget.addCurve(xData2, yData2, legend="curve 2", color="red")
+ assert plotWidget.getActiveCurve(just_legend=True) == legend
+
+ plotWidget.setActiveCurveHandling(False)
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testActiveCurveStyle(plotWidget):
+ """Test change of active curve style"""
+ plotWidget.setActiveCurveHandling(True)
+ plotWidget.setActiveCurveStyle(color="black")
+ style = plotWidget.getActiveCurveStyle()
+ assert style.getColor() == (0.0, 0.0, 0.0, 1.0)
+ assert style.getLineStyle() is None
+ assert style.getLineWidth() is None
+ assert style.getSymbol() is None
+ assert style.getSymbolSize() is None
+
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ plotWidget.addCurve(x=xData, y=yData, legend="curve1")
+ curve = plotWidget.getCurve("curve1")
+ curve.setColor("blue")
+ curve.setLineStyle("-")
+ curve.setLineWidth(1)
+ curve.setSymbol("o")
+ curve.setSymbolSize(5)
+
+ # Check default current style
+ defaultStyle = curve.getCurrentStyle()
+ assert defaultStyle == CurveStyle(
+ color="blue", linestyle="-", linewidth=1, symbol="o", symbolsize=5
+ )
+
+ # Activate curve with highlight color=black
+ plotWidget.setActiveCurve("curve1")
+ style = curve.getCurrentStyle()
+ assert style.getColor() == (0.0, 0.0, 0.0, 1.0)
+ assert style.getLineStyle() == "-"
+ assert style.getLineWidth() == 1
+ assert style.getSymbol() == "o"
+ assert style.getSymbolSize() == 5
+
+ # Change highlight to linewidth=2
+ plotWidget.setActiveCurveStyle(linewidth=2)
+ style = curve.getCurrentStyle()
+ assert style.getColor() == (0.0, 0.0, 1.0, 1.0)
+ assert style.getLineStyle() == "-"
+ assert style.getLineWidth() == 2
+ assert style.getSymbol() == "o"
+ assert style.getSymbolSize() == 5
+
+ plotWidget.setActiveCurve(None)
+ assert curve.getCurrentStyle() == defaultStyle
+
+ plotWidget.setActiveCurveHandling(False)
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testActiveImageAndLabels(plotWidget):
+ # Active image handling always on, no API for toggling it
+ plotWidget.getXAxis().setLabel("XLabel")
+ plotWidget.getYAxis().setLabel("YLabel")
+
+ # labels changed as active curve
+ plotWidget.addImage(
+ numpy.arange(100).reshape(10, 10), legend="1", xlabel="x1", ylabel="y1"
+ )
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ # labels not changed as not active curve
+ plotWidget.addImage(numpy.arange(100).reshape(10, 10), legend="2")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ # labels changed
+ plotWidget.setActiveImage("2")
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.setActiveImage("1")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ plotWidget.clear()
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.setActiveCurveHandling(False)
+
+
+def _checkSelection(selection, current=None, selected=()):
+ """Check current item and selected items."""
+ assert selection.getCurrentItem() is current
+ assert selection.getSelectedItems() == selected
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testSelectionSyncWithActiveItems(plotWidget):
+ """Test update of PlotWidgetSelection according to active items"""
+ listener = SignalListener()
+
+ selection = plotWidget.selection()
+ selection.sigCurrentItemChanged.connect(listener)
+ _checkSelection(selection)
+
+ # Active item is current
+ plotWidget.addImage(((0, 1), (2, 3)), legend="image")
+ image = plotWidget.getActiveImage()
+ assert listener.callCount() == 1
+ _checkSelection(selection, image, (image,))
+
+ # No active = no current
+ plotWidget.setActiveImage(None)
+ assert listener.callCount() == 2
+ _checkSelection(selection)
+
+ # Active item is current
+ plotWidget.setActiveImage("image")
+ assert listener.callCount() == 3
+ _checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ plotWidget.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend="scatter")
+ scatter = plotWidget.getActiveScatter()
+ assert listener.callCount() == 4
+ _checkSelection(selection, scatter, (scatter, image))
+
+ # Previously mosted recently "actived" item is current
+ plotWidget.setActiveScatter(None)
+ assert listener.callCount() == 5
+ _checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ plotWidget.setActiveScatter("scatter")
+ assert listener.callCount() == 6
+ _checkSelection(selection, scatter, (scatter, image))
+
+ # No active = no current
+ plotWidget.setActiveImage(None)
+ plotWidget.setActiveScatter(None)
+ assert listener.callCount() == 7
+ _checkSelection(selection)
+
+ # Mosted recently "actived" item is current
+ plotWidget.setActiveScatter("scatter")
+ assert listener.callCount() == 8
+ plotWidget.setActiveImage("image")
+ assert listener.callCount() == 9
+ _checkSelection(selection, image, (image, scatter))
+
+ # Add a curve which is not active by default
+ plotWidget.addCurve((0, 1, 2), (0, 1, 2), legend="curve")
+ curve = plotWidget.getCurve("curve")
+ assert listener.callCount() == 9
+ _checkSelection(selection, image, (image, scatter))
+
+ # Mosted recently "actived" item is current
+ plotWidget.setActiveCurve("curve")
+ assert listener.callCount() == 10
+ _checkSelection(selection, curve, (curve, image, scatter))
+
+ # Add a curve which is not active by default
+ plotWidget.addCurve((0, 1, 2), (0, 1, 2), legend="curve2")
+ curve2 = plotWidget.getCurve("curve2")
+ assert listener.callCount() == 10
+ _checkSelection(selection, curve, (curve, image, scatter))
+
+ # Mosted recently "actived" item is current, previous curve is removed
+ plotWidget.setActiveCurve("curve2")
+ assert listener.callCount() == 11
+ _checkSelection(selection, curve2, (curve2, image, scatter))
+
+ # No items = no current
+ plotWidget.clear()
+ assert listener.callCount() == 12
+ _checkSelection(selection)
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testSelectionWithItems(plotWidget):
+ """Test init of selection on a plot with items"""
+ plotWidget.addImage(((0, 1), (2, 3)), legend="image")
+ plotWidget.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend="scatter")
+ plotWidget.addCurve((0, 1, 2), (0, 1, 2), legend="curve")
+ plotWidget.setActiveCurve("curve")
+
+ selection = plotWidget.selection()
+ assert selection.getCurrentItem() is not None
+ selected = selection.getSelectedItems()
+ assert len(selected) == 3
+ assert plotWidget.getActiveCurve() in selected
+ assert plotWidget.getActiveImage() in selected
+ assert plotWidget.getActiveScatter() in selected
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testSelectionSetCurrentItem(plotWidget):
+ """Test setCurrentItem"""
+ # Add items to the plot
+ plotWidget.addImage(((0, 1), (2, 3)), legend="image")
+ image = plotWidget.getActiveImage()
+ plotWidget.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend="scatter")
+ scatter = plotWidget.getActiveScatter()
+ plotWidget.addCurve((0, 1, 2), (0, 1, 2), legend="curve")
+ plotWidget.setActiveCurve("curve")
+ curve = plotWidget.getActiveCurve()
+
+ selection = plotWidget.selection()
+ assert selection.getCurrentItem() is not None
+ assert len(selection.getSelectedItems()) == 3
+
+ # Set current to None reset all active items
+ selection.setCurrentItem(None)
+ _checkSelection(selection)
+ assert plotWidget.getActiveCurve() is None
+ assert plotWidget.getActiveImage() is None
+ assert plotWidget.getActiveScatter() is None
+
+ # Set current to an item makes it active
+ selection.setCurrentItem(image)
+ _checkSelection(selection, image, (image,))
+ assert plotWidget.getActiveCurve() is None
+ assert plotWidget.getActiveImage() is image
+ assert plotWidget.getActiveScatter() is None
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(curve)
+ _checkSelection(selection, curve, (curve, image))
+ assert plotWidget.getActiveCurve() is curve
+ assert plotWidget.getActiveImage() is image
+ assert plotWidget.getActiveScatter() is None
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(scatter)
+ _checkSelection(selection, scatter, (scatter, curve, image))
+ assert plotWidget.getActiveCurve() is curve
+ assert plotWidget.getActiveImage() is image
+ assert plotWidget.getActiveScatter() is scatter
+
+
+def testSetActiveCurveWithInstance(plotWidget):
+ """Test setting the active curve with a curve item instance"""
+ plotWidget.addCurve((0, 1), (0, 1), legend="curve0")
+ plotWidget.addCurve((0, 1), (1, 0), legend="curve1")
+ curve0, curve1 = plotWidget.getItems()
+
+ plotWidget.setActiveCurve(curve0)
+ assert plotWidget.getActiveCurve() is curve0
+
+ plotWidget.setActiveCurve(curve1)
+ assert plotWidget.getActiveCurve() is curve1
+
+ plotWidget.setActiveCurve(None)
+ assert plotWidget.getActiveCurve() is None
+
+
+def testSetActiveImageWithInstance(plotWidget):
+ """Test setting the active image with an image item instance"""
+ plotWidget.addImage(((0, 1), (2, 3)), legend="image")
+ image = plotWidget.getItems()[0]
+
+ plotWidget.setActiveImage(None)
+ assert plotWidget.getActiveImage() is None
+
+ plotWidget.setActiveImage(image)
+ assert plotWidget.getActiveImage() is image
+
+
+def testSetActiveScatterWithInstance(plotWidget):
+ """Test setting the active scatter with a scatter item instance"""
+ plotWidget.addScatter((0, 1), (0, 1), (0, 1), legend="scatter")
+ scatter = plotWidget.getItems()[0]
+
+ plotWidget.setActiveScatter(None)
+ assert plotWidget.getActiveScatter() is None
+
+ plotWidget.setActiveScatter(scatter)
+ assert plotWidget.getActiveScatter() is scatter
diff --git a/src/silx/gui/plot/test/testPlotWidgetDataMargins.py b/src/silx/gui/plot/test/testPlotWidgetDataMargins.py
new file mode 100644
index 0000000..4eb5134
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidgetDataMargins.py
@@ -0,0 +1,135 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test PlotWidget features related to data margins"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/05/2023"
+
+import numpy
+import pytest
+
+
+def testDefaultDataMargins(plotWidget):
+ """Test default PlotWidget data margins: No margins"""
+ assert plotWidget.getDataMargins() == (0, 0, 0, 0)
+
+
+def testResetZoomDataMarginsLinearAxes(qapp, plotWidget):
+ """Test PlotWidget.setDataMargins effect on resetZoom with linear axis scales"""
+
+ margins = 0.1, 0.2, 0.3, 0.4
+ plotWidget.setDataMargins(*margins)
+
+ plotWidget.resetZoom()
+ qapp.processEvents()
+
+ retrievedMargins = plotWidget.getDataMargins()
+ assert retrievedMargins == margins
+
+ dataRange = 100 - 1
+ expectedXLimits = 1 - 0.1 * dataRange, 100 + 0.2 * dataRange
+ expectedYLimits = 1 - 0.3 * dataRange, 100 + 0.4 * dataRange
+
+ assert plotWidget.getXAxis().getLimits() == expectedXLimits
+ assert plotWidget.getYAxis().getLimits() == expectedYLimits
+ assert plotWidget.getYAxis(axis="right").getLimits() == expectedYLimits
+
+
+def testResetZoomDataMarginsLogAxes(qapp, plotWidget):
+ """Test PlotWidget.setDataMargins effect on resetZoom with log axis scales"""
+ plotWidget.getXAxis().setScale("log")
+ plotWidget.getYAxis().setScale("log")
+
+ dataMargins = 0.1, 0.2, 0.3, 0.4
+ plotWidget.setDataMargins(*dataMargins)
+
+ plotWidget.resetZoom()
+ qapp.processEvents()
+
+ retrievedMargins = plotWidget.getDataMargins()
+ assert retrievedMargins == dataMargins
+
+ logMin, logMax = numpy.log10(1), numpy.log10(100)
+ logRange = logMax - logMin
+ expectedXLimits = pow(10.0, logMin - 0.1 * logRange), pow(
+ 10.0, logMax + 0.2 * logRange
+ )
+ expectedYLimits = pow(10.0, logMin - 0.3 * logRange), pow(
+ 10.0, logMax + 0.4 * logRange
+ )
+
+ assert plotWidget.getXAxis().getLimits() == expectedXLimits
+ assert plotWidget.getYAxis().getLimits() == expectedYLimits
+ assert plotWidget.getYAxis(axis="right").getLimits() == expectedYLimits
+
+
+@pytest.mark.parametrize("margins", [False, True, (0, 0, 0, 0)])
+def testSetLimitsNoDataMargins(plotWidget, margins):
+ """Test PlotWidget.setLimits without data margins"""
+ xlimits = 1, 2
+ ylimits = 3, 4
+ y2limits = 5, 6
+ plotWidget.setLimits(*xlimits, *ylimits, *y2limits, margins=margins)
+
+ assert plotWidget.getXAxis().getLimits() == xlimits
+ assert plotWidget.getYAxis().getLimits() == ylimits
+ assert plotWidget.getYAxis(axis="right").getLimits() == y2limits
+
+
+@pytest.mark.parametrize(
+ "margins,expectedLimits",
+ [
+ # margins=False: use limits as is
+ (
+ False,
+ (1, 2, 3, 4, 5, 6),
+ ),
+ # margins=True: apply data margins
+ (
+ True,
+ (1 - 0.1, 2 + 0.2, 3 - 0.3, 4 + 0.4, 5 - 0.3, 6 + 0.4),
+ ),
+ # margins=tuple: apply provided margins
+ (
+ (0.4, 0.3, 0.2, 0.1),
+ (1 - 0.4, 2 + 0.3, 3 - 0.2, 4 + 0.1, 5 - 0.2, 6 + 0.1),
+ ),
+ ],
+)
+def testSetLimitsWithDataMargins(qapp, plotWidget, margins, expectedLimits):
+ """Test PlotWidget.setLimits with data margins"""
+ dataMargins = 0.1, 0.2, 0.3, 0.4
+ limits = 1, 2, 3, 4, 5, 6
+
+ plotWidget.setDataMargins(*dataMargins)
+ plotWidget.setLimits(*limits, margins=margins)
+ qapp.processEvents()
+
+ retrievedLimits = (
+ *plotWidget.getXAxis().getLimits(),
+ *plotWidget.getYAxis().getLimits(),
+ *plotWidget.getYAxis(axis="right").getLimits(),
+ )
+ assert retrievedLimits == expectedLimits
diff --git a/src/silx/gui/plot/test/testPlotWidgetNoBackend.py b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py
new file mode 100644
index 0000000..d9d5706
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py
@@ -0,0 +1,685 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget with 'none' backend"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+from functools import reduce
+from silx.utils.testutils import ParametricTestCase
+
+import numpy
+
+import silx
+from silx.gui.colors import rgba
+from silx.gui.plot.PlotWidget import PlotWidget
+from silx.gui.plot.items.histogram import _getHistogramCurve, _computeEdges
+
+
+class TestPlot(unittest.TestCase):
+ """Basic tests of Plot without backend"""
+
+ def testPlotTitleLabels(self):
+ """Create a Plot and set the labels"""
+
+ plot = PlotWidget(backend="none")
+
+ title, xlabel, ylabel = "the title", "x label", "y label"
+ plot.setGraphTitle(title)
+ plot.getXAxis().setLabel(xlabel)
+ plot.getYAxis().setLabel(ylabel)
+
+ self.assertEqual(plot.getGraphTitle(), title)
+ self.assertEqual(plot.getXAxis().getLabel(), xlabel)
+ self.assertEqual(plot.getYAxis().getLabel(), ylabel)
+
+ def testAddNoRemove(self):
+ """add objects to the Plot"""
+
+ plot = PlotWidget(backend="none")
+ plot.addCurve(x=(1, 2, 3), y=(3, 2, 1))
+ plot.addImage(numpy.arange(100.0).reshape(10, -1))
+ plot.addShape(
+ numpy.array((1.0, 10.0)), numpy.array((10.0, 10.0)), shape="rectangle"
+ )
+ plot.addXMarker(10.0)
+
+
+class TestPlotRanges(ParametricTestCase):
+ """Basic tests of Plot data ranges without backend"""
+
+ _getValidValues = {
+ True: lambda ar: ar > 0,
+ False: lambda ar: numpy.ones(shape=ar.shape, dtype=bool),
+ }
+
+ @staticmethod
+ def _getRanges(arrays, are_logs):
+ gen = (
+ TestPlotRanges._getValidValues[is_log](ar)
+ for (ar, is_log) in zip(arrays, are_logs)
+ )
+ indices = numpy.where(reduce(numpy.logical_and, gen))[0]
+ if len(indices) > 0:
+ ranges = [(ar[indices[0]], ar[indices[-1]]) for ar in arrays]
+ else:
+ ranges = [None] * len(arrays)
+
+ return ranges
+
+ @staticmethod
+ def _getRangesMinmax(ranges):
+ # TODO : error if None in ranges.
+ rangeMin = numpy.min([rng[0] for rng in ranges])
+ rangeMax = numpy.max([rng[1] for rng in ranges])
+ return rangeMin, rangeMax
+
+ def testDataRangeNoPlot(self):
+ """empty plot data range"""
+
+ plot = PlotWidget(backend="none")
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ self.assertIsNone(dataRange.x)
+ self.assertIsNone(dataRange.y)
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeLeft(self):
+ """left axis range"""
+
+ plot = PlotWidget(backend="none")
+
+ xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+
+ plot.addCurve(x=xData, y=yData, legend="plot_0", yaxis="left")
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = self._getRanges([xData, yData], [logX, logY])
+ self.assertSequenceEqual(dataRange.x, xRange)
+ self.assertSequenceEqual(dataRange.y, yRange)
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeRight(self):
+ """right axis range"""
+
+ plot = PlotWidget(backend="none")
+ xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+ plot.addCurve(x=xData, y=yData, legend="plot_0", yaxis="right")
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = self._getRanges([xData, yData], [logX, logY])
+ self.assertSequenceEqual(dataRange.x, xRange)
+ self.assertIsNone(dataRange.y)
+ self.assertSequenceEqual(dataRange.yright, yRange)
+
+ def testDataRangeImage(self):
+ """image data range"""
+
+ origin = (-10, 25)
+ scale = (3.0, 8.0)
+ image = numpy.arange(100.0).reshape(20, 5)
+
+ plot = PlotWidget(backend="none")
+ plot.addImage(image, origin=origin, scale=scale)
+
+ xRange = numpy.array([0.0, image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0.0, image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {
+ (False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None),
+ }
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(
+ numpy.array_equal(dataRange.x, xRange),
+ msg="{0} != {1}".format(dataRange.x, xRange),
+ )
+ self.assertTrue(
+ numpy.array_equal(dataRange.y, yRange),
+ msg="{0} != {1}".format(dataRange.y, yRange),
+ )
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeLeftRight(self):
+ """right+left axis range"""
+
+ plot = PlotWidget(backend="none")
+
+ xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
+ plot.addCurve(x=xData_l, y=yData_l, legend="plot_l", yaxis="left")
+
+ xData_r = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData_r = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+ plot.addCurve(x=xData_r, y=yData_r, legend="plot_r", yaxis="right")
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l], [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r], [logX, logY])
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
+ self.assertSequenceEqual(dataRange.x, xRangeLR)
+ self.assertSequenceEqual(dataRange.y, yRangeL)
+ self.assertSequenceEqual(dataRange.yright, yRangeR)
+
+ def testDataRangeCurveImage(self):
+ """right+left+image axis range"""
+
+ # overlapping ranges :
+ # image sets x min and y max
+ # plot_left sets y min
+ # plot_right sets x max (and yright)
+ plot = PlotWidget(backend="none")
+
+ origin = (-10, 5)
+ scale = (3.0, 8.0)
+ image = numpy.arange(100.0).reshape(20, 5)
+
+ plot.addImage(image, origin=origin, scale=scale, legend="image")
+
+ xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
+ plot.addCurve(x=xData_l, y=yData_l, legend="plot_l", yaxis="left")
+
+ xData_r = numpy.arange(10) + 4.1 # range : 4.1 , 13.1
+ yData_r = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ plot.addCurve(x=xData_r, y=yData_r, legend="plot_r", yaxis="right")
+
+ imgXRange = numpy.array([0.0, image.shape[1] * scale[0]]) + origin[0]
+ imgYRange = numpy.array([0.0, image.shape[0] * scale[1]]) + origin[1]
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l], [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r], [logX, logY])
+ if logX or logY:
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
+ else:
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR, imgXRange])
+ yRangeL = self._getRangesMinmax([yRangeL, imgYRange])
+ self.assertSequenceEqual(dataRange.x, xRangeLR)
+ self.assertSequenceEqual(dataRange.y, yRangeL)
+ self.assertSequenceEqual(dataRange.yright, yRangeR)
+
+ def testDataRangeImageNegativeScaleX(self):
+ """image data range, negative scale"""
+
+ origin = (-10, 25)
+ scale = (-3.0, 8.0)
+ image = numpy.arange(100.0).reshape(20, 5)
+
+ plot = PlotWidget(backend="none")
+ plot.addImage(image, origin=origin, scale=scale)
+
+ xRange = numpy.array([0.0, image.shape[1] * scale[0]]) + origin[0]
+ xRange.sort() # negative scale!
+ yRange = numpy.array([0.0, image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {
+ (False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None),
+ }
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(
+ numpy.array_equal(dataRange.x, xRange),
+ msg="{0} != {1}".format(dataRange.x, xRange),
+ )
+ self.assertTrue(
+ numpy.array_equal(dataRange.y, yRange),
+ msg="{0} != {1}".format(dataRange.y, yRange),
+ )
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeImageNegativeScaleY(self):
+ """image data range, negative scale"""
+
+ origin = (-10, 25)
+ scale = (3.0, -8.0)
+ image = numpy.arange(100.0).reshape(20, 5)
+
+ plot = PlotWidget(backend="none")
+ plot.addImage(image, origin=origin, scale=scale)
+
+ xRange = numpy.array([0.0, image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0.0, image.shape[0] * scale[1]]) + origin[1]
+ yRange.sort() # negative scale!
+
+ ranges = {
+ (False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None),
+ }
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(
+ numpy.array_equal(dataRange.x, xRange),
+ msg="{0} != {1}".format(dataRange.x, xRange),
+ )
+ self.assertTrue(
+ numpy.array_equal(dataRange.y, yRange),
+ msg="{0} != {1}".format(dataRange.y, yRange),
+ )
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeHiddenCurve(self):
+ """curves with a hidden curve"""
+ plot = PlotWidget(backend="none")
+ plot.addCurve((0, 1), (0, 1), legend="shown")
+ plot.addCurve((0, 1, 2), (5, 5, 5), legend="hidden")
+ range1 = plot.getDataRange()
+ self.assertEqual(range1.x, (0, 2))
+ self.assertEqual(range1.y, (0, 5))
+ plot.hideCurve("hidden")
+ range2 = plot.getDataRange()
+ self.assertEqual(range2.x, (0, 1))
+ self.assertEqual(range2.y, (0, 1))
+
+
+class TestPlotGetCurveImage(unittest.TestCase):
+ """Test of plot getCurve and getImage methods"""
+
+ def testGetCurve(self):
+ """PlotWidget.getCurve and Plot.getActiveCurve tests"""
+
+ plot = PlotWidget(backend="none")
+
+ # No curve
+ curve = plot.getCurve()
+ self.assertIsNone(curve) # No curve
+
+ plot.setActiveCurveHandling(True)
+ plot.addCurve(x=(0, 1), y=(0, 1), legend="curve 0")
+ plot.addCurve(x=(0, 1), y=(0, 1), legend="curve 1")
+ plot.addCurve(x=(0, 1), y=(0, 1), legend="curve 2")
+ plot.setActiveCurve("curve 0")
+
+ # Active curve
+ active = plot.getActiveCurve()
+ self.assertEqual(active.getName(), "curve 0")
+ curve = plot.getCurve()
+ self.assertEqual(curve.getName(), "curve 0")
+
+ # No active curve and curves
+ plot.setActiveCurveHandling(False)
+ active = plot.getActiveCurve()
+ self.assertIsNone(active) # No active curve
+ curve = plot.getCurve()
+ self.assertEqual(curve.getName(), "curve 2") # Last added curve
+
+ # Last curve hidden
+ plot.hideCurve("curve 2", True)
+ curve = plot.getCurve()
+ self.assertEqual(curve.getName(), "curve 1") # Last added curve
+
+ # All curves hidden
+ plot.hideCurve("curve 1", True)
+ plot.hideCurve("curve 0", True)
+ curve = plot.getCurve()
+ self.assertIsNone(curve)
+
+ def testGetCurveOldApi(self):
+ """old API PlotWidget.getCurve and Plot.getActiveCurve tests"""
+
+ plot = PlotWidget(backend="none")
+
+ # No curve
+ curve = plot.getCurve()
+ self.assertIsNone(curve) # No curve
+
+ plot.setActiveCurveHandling(True)
+ x = numpy.arange(10.0).astype(numpy.float32)
+ y = x * x
+ plot.addCurve(x=x, y=y, legend="curve 0", info=["whatever"])
+ plot.addCurve(x=x, y=2 * x, legend="curve 1", info="anything")
+ plot.setActiveCurve("curve 0")
+
+ # Active curve (4 elements)
+ xOut, yOut, legend, info = plot.getActiveCurve()[:4]
+ self.assertEqual(legend, "curve 0")
+ self.assertTrue(numpy.allclose(xOut, x), "curve 0 wrong x data")
+ self.assertTrue(numpy.allclose(yOut, y), "curve 0 wrong y data")
+
+ # Active curve (5 elements)
+ xOut, yOut, legend, info, params = plot.getCurve("curve 1")
+ self.assertEqual(legend, "curve 1")
+ self.assertEqual(info, "anything")
+ self.assertTrue(numpy.allclose(xOut, x), "curve 1 wrong x data")
+ self.assertTrue(numpy.allclose(yOut, 2 * x), "curve 1 wrong y data")
+
+ def testGetImage(self):
+ """PlotWidget.getImage and PlotWidget.getActiveImage tests"""
+
+ plot = PlotWidget(backend="none")
+
+ # No image
+ image = plot.getImage()
+ self.assertIsNone(image)
+
+ plot.addImage(((0, 1), (2, 3)), legend="image 0")
+ plot.addImage(((0, 1), (2, 3)), legend="image 1")
+
+ # Active image
+ active = plot.getActiveImage()
+ self.assertEqual(active.getName(), "image 0")
+ image = plot.getImage()
+ self.assertEqual(image.getName(), "image 0")
+
+ # No active image
+ plot.addImage(((0, 1), (2, 3)), legend="image 2")
+ plot.setActiveImage(None)
+ active = plot.getActiveImage()
+ self.assertIsNone(active)
+ image = plot.getImage()
+ self.assertEqual(image.getName(), "image 2")
+
+ # Active image
+ plot.setActiveImage("image 1")
+ active = plot.getActiveImage()
+ self.assertEqual(active.getName(), "image 1")
+ image = plot.getImage()
+ self.assertEqual(image.getName(), "image 1")
+
+ def testGetImageOldApi(self):
+ """PlotWidget.getImage and PlotWidget.getActiveImage old API tests"""
+
+ plot = PlotWidget(backend="none")
+
+ # No image
+ image = plot.getImage()
+ self.assertIsNone(image)
+
+ image = numpy.arange(10).astype(numpy.float32)
+ image.shape = 5, 2
+
+ plot.addImage(image, legend="image 0", info=["Hi!"])
+
+ # Active image
+ data, legend, info, something, params = plot.getActiveImage()
+ self.assertEqual(legend, "image 0")
+ self.assertEqual(info, ["Hi!"])
+ self.assertTrue(numpy.allclose(data, image), "image 0 data not correct")
+
+ def testGetAllImages(self):
+ """PlotWidget.getAllImages test"""
+
+ plot = PlotWidget(backend="none")
+
+ # No image
+ images = plot.getAllImages()
+ self.assertEqual(len(images), 0)
+
+ # 2 images
+ data = numpy.arange(100).reshape(10, 10)
+ plot.addImage(data, legend="1")
+ plot.addImage(data, origin=(10, 10), legend="2")
+ images = plot.getAllImages(just_legend=True)
+ self.assertEqual(list(images), ["1", "2"])
+ images = plot.getAllImages(just_legend=False)
+ self.assertEqual(len(images), 2)
+ self.assertEqual(images[0].getName(), "1")
+ self.assertEqual(images[1].getName(), "2")
+
+
+class TestPlotAddScatter(unittest.TestCase):
+ """Test of plot addScatter"""
+
+ def testAddGetScatter(self):
+ plot = PlotWidget(backend="none")
+
+ # No curve
+ scatter = plot._getItem(kind="scatter")
+ self.assertIsNone(scatter) # No curve
+
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 0")
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 1")
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 2")
+ plot.setActiveScatter("scatter 0")
+
+ # Active scatter
+ active = plot.getActiveScatter()
+ self.assertEqual(active.getName(), "scatter 0")
+
+ # check default values
+ self.assertAlmostEqual(active.getSymbolSize(), active._DEFAULT_SYMBOL_SIZE)
+ self.assertEqual(active.getSymbol(), "o")
+ self.assertAlmostEqual(active.getAlpha(), 1.0)
+
+ # modify parameters
+ active.setSymbolSize(20.5)
+ active.setSymbol("d")
+ active.setAlpha(0.777)
+
+ s0 = plot.getScatter("scatter 0")
+
+ self.assertAlmostEqual(s0.getSymbolSize(), 20.5)
+ self.assertEqual(s0.getSymbol(), "d")
+ self.assertAlmostEqual(s0.getAlpha(), 0.777)
+
+ scatter1 = plot._getItem(kind="scatter", legend="scatter 1")
+ self.assertEqual(scatter1.getName(), "scatter 1")
+
+ def testGetAllScatters(self):
+ """PlotWidget.getAllImages test"""
+
+ plot = PlotWidget(backend="none")
+
+ items = plot.getItems()
+ self.assertEqual(len(items), 0)
+
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 0")
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 1")
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 2")
+
+ items = plot.getItems()
+ self.assertEqual(len(items), 3)
+ self.assertEqual(items[0].getName(), "scatter 0")
+ self.assertEqual(items[1].getName(), "scatter 1")
+ self.assertEqual(items[2].getName(), "scatter 2")
+
+
+class TestPlotHistogram(unittest.TestCase):
+ """Basic tests for histogram."""
+
+ def testEdges(self):
+ x = numpy.array([0, 1, 2])
+ edgesRight = numpy.array([0, 1, 2, 3])
+ edgesLeft = numpy.array([-1, 0, 1, 2])
+ edgesCenter = numpy.array([-0.5, 0.5, 1.5, 2.5])
+
+ # testing x values for right
+ edges = _computeEdges(x, "right")
+ numpy.testing.assert_array_equal(edges, edgesRight)
+
+ edges = _computeEdges(x, "center")
+ numpy.testing.assert_array_equal(edges, edgesCenter)
+
+ edges = _computeEdges(x, "left")
+ numpy.testing.assert_array_equal(edges, edgesLeft)
+
+ def testHistogramCurve(self):
+ y = numpy.array([3, 2, 5])
+ edges = numpy.array([0, 1, 2, 3])
+
+ xHisto, yHisto = _getHistogramCurve(y, edges)
+ numpy.testing.assert_array_equal(yHisto, numpy.array([3, 3, 2, 2, 5, 5]))
+
+ y = numpy.array([-3, 2, 5, 0])
+ edges = numpy.array([-2, -1, 0, 1, 2])
+ xHisto, yHisto = _getHistogramCurve(y, edges)
+ numpy.testing.assert_array_equal(
+ yHisto, numpy.array([-3, -3, 2, 2, 5, 5, 0, 0])
+ )
+
+
+def testSetDefaultColors(qWidgetFactory):
+ """Basic test of PlotWidget.get|setDefaultColors"""
+ plot = qWidgetFactory(PlotWidget)
+
+ # By default using config
+ assert numpy.array_equal(
+ plot.getDefaultColors(), silx.config.DEFAULT_PLOT_CURVE_COLORS
+ )
+
+ # Use own colors
+ colors = "red", "green", "blue"
+ plot.setDefaultColors(colors)
+ assert plot.getDefaultColors() == colors
+
+ # Reset to default
+ plot.setDefaultColors(None)
+ assert numpy.array_equal(
+ plot.getDefaultColors(), silx.config.DEFAULT_PLOT_CURVE_COLORS
+ )
+
+
+def testSetDefaultColorsAddCurve(qWidgetFactory):
+ """Test that PlotWidget.setDefaultColors reset color index"""
+ plot = qWidgetFactory(PlotWidget)
+
+ plot.addCurve((0, 1), (0, 0), legend="curve0")
+ plot.addCurve((0, 1), (1, 1), legend="curve1")
+ plot.addCurve((0, 1), (2, 2), legend="curve2")
+
+ colors = "#123456", "#abcdef"
+ plot.setDefaultColors(colors)
+ assert plot.getDefaultColors() == colors
+
+ # Check that the color index is reset
+ curve = plot.addCurve((1, 2), (0, 1), legend="newcurve")
+ assert curve.getColor() == rgba(colors[0])
+
+
+def testDefaultColorsUpdateConfig(qWidgetFactory):
+ """Test that color index is reset if needed when default colors config is updated"""
+ plot = qWidgetFactory(PlotWidget)
+
+ plot.addCurve((0, 1), (0, 0), legend="curve0")
+ plot.addCurve((0, 1), (1, 1), legend="curve1")
+ plot.addCurve((0, 1), (2, 2), legend="curve2")
+
+ previous_colors = silx.config.DEFAULT_PLOT_CURVE_COLORS
+ try:
+ colors = "#123456", "#abcdef"
+ silx.config.DEFAULT_PLOT_CURVE_COLORS = colors
+ assert plot.getDefaultColors() == colors
+
+ # Check that the color index is reset
+ curve = plot.addCurve((1, 2), (0, 1), legend="newcurve")
+ assert curve.getColor() == rgba(colors[0])
+
+ finally:
+ silx.config.DEFAULT_PLOT_CURVE_COLORS = previous_colors
diff --git a/src/silx/gui/plot/test/testPlotWindow.py b/src/silx/gui/plot/test/testPlotWindow.py
new file mode 100644
index 0000000..8f17bf1
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWindow.py
@@ -0,0 +1,175 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWindow"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "27/06/2017"
+
+
+import numpy
+import pytest
+
+from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
+
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.colors import Colormap
+
+
+class TestPlotWindow(TestCaseQt):
+ """Base class for tests of PlotWindow."""
+
+ def setUp(self):
+ super(TestPlotWindow, self).setUp()
+ self.plot = PlotWindow()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestPlotWindow, self).tearDown()
+
+ def testActions(self):
+ """Test the actions QToolButtons"""
+ self.plot.setLimits(1, 100, 1, 100)
+
+ checkList = [ # QAction, Plot state getter
+ (self.plot.xAxisAutoScaleAction, self.plot.getXAxis().isAutoScale),
+ (self.plot.yAxisAutoScaleAction, self.plot.getYAxis().isAutoScale),
+ (self.plot.xAxisLogarithmicAction, self.plot.getXAxis()._isLogarithmic),
+ (self.plot.yAxisLogarithmicAction, self.plot.getYAxis()._isLogarithmic),
+ (self.plot.gridAction, self.plot.getGraphGrid),
+ ]
+
+ for action, getter in checkList:
+ self.mouseMove(self.plot)
+ initialState = getter()
+ toolButton = getQToolButtonFromAction(action)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.assertNotEqual(
+ getter(), initialState, msg='"%s" state not changed' % action.text()
+ )
+
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.assertEqual(
+ getter(), initialState, msg='"%s" state not changed' % action.text()
+ )
+
+ # Trigger a zoom reset
+ self.mouseMove(self.plot)
+ resetZoomAction = self.plot.resetZoomAction
+ toolButton = getQToolButtonFromAction(resetZoomAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ def testDockWidgets(self):
+ """Test add/remove dock widgets"""
+ dock1 = qt.QDockWidget("Test 1")
+ dock1.setWidget(qt.QLabel("Test 1"))
+
+ self.plot.addTabbedDockWidget(dock1)
+ self.qapp.processEvents()
+
+ self.plot.removeDockWidget(dock1)
+ self.qapp.processEvents()
+
+ dock2 = qt.QDockWidget("Test 2")
+ dock2.setWidget(qt.QLabel("Test 2"))
+
+ self.plot.addTabbedDockWidget(dock2)
+ self.qapp.processEvents()
+
+ self.assertNotEqual(
+ self.plot.layout().indexOf(dock2),
+ -1,
+ "dock2 not properly displayed",
+ )
+
+ def testToolAspectRatio(self):
+ self.plot.toolBar()
+ self.plot.keepDataAspectRatioButton.keepDataAspectRatio()
+ self.assertTrue(self.plot.isKeepDataAspectRatio())
+ self.plot.keepDataAspectRatioButton.dontKeepDataAspectRatio()
+ self.assertFalse(self.plot.isKeepDataAspectRatio())
+
+ def testToolYAxisOrigin(self):
+ self.plot.toolBar()
+ self.plot.yAxisInvertedButton.setYAxisUpward()
+ self.assertFalse(self.plot.getYAxis().isInverted())
+ self.plot.yAxisInvertedButton.setYAxisDownward()
+ self.assertTrue(self.plot.getYAxis().isInverted())
+
+ def testColormapAutoscaleCache(self):
+ # Test that the min/max cache is not computed twice
+
+ old = Colormap._computeAutoscaleRange
+ self._count = 0
+
+ def _computeAutoscaleRange(colormap, data):
+ self._count = self._count + 1
+ return 10, 20
+
+ Colormap._computeAutoscaleRange = _computeAutoscaleRange
+ try:
+ colormap = Colormap(name="red")
+ self.plot.setVisible(True)
+
+ # Add an image
+ data = numpy.arange(8**2).reshape(8, 8)
+ self.plot.addImage(data, legend="foo", colormap=colormap)
+ self.plot.setActiveImage("foo")
+
+ # Use the colorbar
+ self.plot.getColorBarWidget().setVisible(True)
+ self.qWait(50)
+
+ # Remove and add again the same item
+ image = self.plot.getImage("foo")
+ self.plot.removeImage("foo")
+ self.plot.addItem(image)
+ self.qWait(50)
+ finally:
+ Colormap._computeAutoscaleRange = old
+ self.assertEqual(self._count, 1)
+ del self._count
+
+ @pytest.mark.usefixtures("use_opengl")
+ def testSwitchBackend(self):
+ """Test switching an empty plot"""
+ self.plot.resetZoom()
+ xlimits = self.plot.getXAxis().getLimits()
+ ylimits = self.plot.getYAxis().getLimits()
+ isKeepAspectRatio = self.plot.isKeepDataAspectRatio()
+
+ for backend in ("gl", "mpl"):
+ with self.subTest():
+ self.plot.setBackend(backend)
+ self.plot.replot()
+ self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
+ self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
+ self.assertEqual(self.plot.isKeepDataAspectRatio(), isKeepAspectRatio)
diff --git a/src/silx/gui/plot/test/testRoiStatsWidget.py b/src/silx/gui/plot/test/testRoiStatsWidget.py
new file mode 100644
index 0000000..759ebe2
--- /dev/null
+++ b/src/silx/gui/plot/test/testRoiStatsWidget.py
@@ -0,0 +1,260 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ROIStatsWidget"""
+
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.stats.stats import Stats
+from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.gui.plot.items.roi import RectangleROI, PolygonROI
+from silx.gui.plot.StatsWidget import UpdateMode
+import numpy
+
+
+class _TestRoiStatsBase(TestCaseQt):
+ """Base class for several unittest relative to ROIStatsWidget"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ # define plot
+ self.plot = PlotWindow()
+ self.plot.addImage(numpy.arange(10000).reshape(100, 100), legend="img1")
+ self.img_item = self.plot.getImage("img1")
+ self.plot.addCurve(
+ x=numpy.linspace(0, 10, 56), y=numpy.arange(56), legend="curve1"
+ )
+ self.curve_item = self.plot.getCurve("curve1")
+ self.plot.addHistogram(
+ edges=numpy.linspace(0, 10, 56), histogram=numpy.arange(56), legend="histo1"
+ )
+ self.histogram_item = self.plot.getHistogram(legend="histo1")
+ self.plot.addScatter(
+ x=numpy.linspace(0, 10, 56),
+ y=numpy.linspace(0, 10, 56),
+ value=numpy.arange(56),
+ legend="scatter1",
+ )
+ self.scatter_item = self.plot.getScatter(legend="scatter1")
+
+ # stats widget
+ self.statsWidget = ROIStatsWidget(plot=self.plot)
+
+ # define stats
+ stats = [
+ ("sum", numpy.sum),
+ ("mean", numpy.mean),
+ ]
+ self.statsWidget.setStats(stats=stats)
+
+ # define rois
+ self.roi1D = ROI(name="range1", fromdata=0, todata=4, type_="energy")
+ self.rectangle_roi = RectangleROI()
+ self.rectangle_roi.setGeometry(origin=(0, 0), size=(20, 20))
+ self.rectangle_roi.setName("Initial ROI")
+ self.polygon_roi = PolygonROI()
+ points = numpy.array([[0, 5], [5, 0], [10, 5], [5, 10]])
+ self.polygon_roi.setPoints(points)
+
+ def statsTable(self):
+ return self.statsWidget._statsROITable
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.statsWidget.close()
+ self.statsWidget = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.plot.close()
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+
+class TestRoiStatsCouple(_TestRoiStatsBase):
+ """
+ Test different possible couple (roi, plotItem).
+ Check that:
+
+ * computation is correct if couple is valid
+ * raise an error if couple is invalid
+ """
+
+ def testROICurve(self):
+ """
+ Test that the couple (ROI, curveItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.curve_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems["sum"].text(), "253")
+ self.assertEqual(tableItems["mean"].text(), "11.0")
+
+ def testRectangleImage(self):
+ """
+ Test that the couple (RectangleROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
+ assert item is not None
+ self.plot.addImage(numpy.ones(10000).reshape(100, 100), legend="img1")
+ self.qapp.processEvents()
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems["sum"].text(), str(float(21 * 21)))
+ self.assertEqual(tableItems["mean"].text(), "1.0")
+
+ def testPolygonImage(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.polygon_roi, plotItem=self.img_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems["sum"].text(), "22750")
+ self.assertEqual(tableItems["mean"].text(), "455.0")
+
+ def testROIImage(self):
+ """
+ Test that the couple (ROI, imageItem) is raising an error
+ """
+ with self.assertRaises(TypeError):
+ self.statsWidget.addItem(roi=self.roi1D, plotItem=self.img_item)
+
+ def testRectangleCurve(self):
+ """
+ Test that the couple (rectangleROI, curveItem) is raising an error
+ """
+ with self.assertRaises(TypeError):
+ self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.curve_item)
+
+ def testROIHistogram(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.histogram_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems["sum"].text(), "253")
+ self.assertEqual(tableItems["mean"].text(), "11.0")
+
+ def testROIScatter(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.scatter_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems["sum"].text(), "253")
+ self.assertEqual(tableItems["mean"].text(), "11.0")
+
+
+class TestRoiStatsAddRemoveItem(_TestRoiStatsBase):
+ """Test adding and removing (roi, plotItem) items"""
+
+ def testAddRemoveItems(self):
+ item1 = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.scatter_item)
+ self.assertTrue(item1 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 1)
+ item2 = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.histogram_item)
+ self.assertTrue(item2 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ # try to add twice the same item
+ item3 = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.histogram_item)
+ self.assertTrue(item3 is None)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ item4 = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.curve_item)
+ self.assertTrue(item4 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 3)
+
+ self.statsWidget.removeItem(plotItem=item4._plot_item, roi=item4._roi)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ # try to remove twice the same item
+ self.statsWidget.removeItem(plotItem=item4._plot_item, roi=item4._roi)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ self.statsWidget.removeItem(plotItem=item2._plot_item, roi=item2._roi)
+ self.statsWidget.removeItem(plotItem=item1._plot_item, roi=item1._roi)
+ self.assertEqual(self.statsTable().rowCount(), 0)
+
+
+class TestRoiStatsRoiUpdate(_TestRoiStatsBase):
+ """Test that the stats will be updated if the roi is updated"""
+
+ def testChangeRoi(self):
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems["sum"].text(), "445410")
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
+
+ # update roi
+ self.rectangle_roi.setOrigin(position=(10, 10))
+ self.assertNotEqual(tableItems["sum"].text(), "445410")
+ self.assertNotEqual(tableItems["mean"].text(), "1010.0")
+
+ def testUpdateModeScenario(self):
+ """Test update according to a simple scenario"""
+ self.statsWidget._setUpdateMode(UpdateMode.AUTO)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems["sum"].text(), "445410")
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
+ self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
+ self.rectangle_roi.setOrigin(position=(10, 10))
+ self.qapp.processEvents()
+ self.assertNotEqual(tableItems["sum"].text(), "445410")
+ self.assertNotEqual(tableItems["mean"].text(), "1010.0")
+ self.statsWidget._updateAllStats(is_request=True)
+ self.assertNotEqual(tableItems["sum"].text(), "445410")
+ self.assertNotEqual(tableItems["mean"].text(), "1010.0")
+
+
+class TestRoiStatsPlotItemUpdate(_TestRoiStatsBase):
+ """Test that the stats will be updated if the plot item is updated"""
+
+ def testChangeImage(self):
+ self.statsWidget._setUpdateMode(UpdateMode.AUTO)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
+
+ # update plot
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100), legend="img1")
+ self.assertNotEqual(tableItems["mean"].text(), "1059.5")
+
+ def testUpdateModeScenario(self):
+ """Test update according to a simple scenario"""
+ self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100), legend="img1")
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
+ self.statsWidget._updateAllStats(is_request=True)
+ self.assertEqual(tableItems["mean"].text(), "1110.0")
diff --git a/src/silx/gui/plot/test/testSaveAction.py b/src/silx/gui/plot/test/testSaveAction.py
new file mode 100644
index 0000000..f8ac7ee
--- /dev/null
+++ b/src/silx/gui/plot/test/testSaveAction.py
@@ -0,0 +1,131 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test the plot's save action (consistency of output)"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/11/2017"
+
+
+import unittest
+import tempfile
+import os
+
+from silx.gui.plot.test.utils import PlotWidgetTestCase
+
+from silx.gui.plot import PlotWidget
+from silx.gui.plot.actions.io import SaveAction
+
+
+class TestSaveActionSaveCurvesAsSpec(unittest.TestCase):
+ def setUp(self):
+ self.plot = PlotWidget(backend="none")
+ self.saveAction = SaveAction(plot=self.plot)
+
+ self.tempdir = tempfile.mkdtemp()
+ self.out_fname = os.path.join(self.tempdir, "out.dat")
+
+ def tearDown(self):
+ os.unlink(self.out_fname)
+ os.rmdir(self.tempdir)
+
+ def testSaveMultipleCurvesAsSpec(self):
+ """Test that labels are properly used."""
+ self.plot.setGraphXLabel("graph x label")
+ self.plot.setGraphYLabel("graph y label")
+
+ self.plot.addCurve(
+ [0, 1], [1, 2], "curve with labels", xlabel="curve0 X", ylabel="curve0 Y"
+ )
+ self.plot.addCurve([-1, 3], [-6, 2], "curve with X label", xlabel="curve1 X")
+ self.plot.addCurve([-2, 0], [8, 12], "curve with Y label", ylabel="curve2 Y")
+ self.plot.addCurve([3, 1], [7, 6], "curve with no labels")
+
+ self.saveAction._saveCurves(
+ self.plot, self.out_fname, SaveAction.DEFAULT_ALL_CURVES_FILTERS[0]
+ ) # "All curves as SpecFile (*.dat)"
+
+ with open(self.out_fname, "rb") as f:
+ file_content = f.read()
+ if hasattr(file_content, "decode"):
+ file_content = file_content.decode()
+
+ # case with all curve labels specified
+ self.assertIn("#S 1 curve0 Y", file_content)
+ self.assertIn("#L curve0 X curve0 Y", file_content)
+
+ # graph X&Y labels are used when no curve label is specified
+ self.assertIn("#S 2 graph y label", file_content)
+ self.assertIn("#L curve1 X graph y label", file_content)
+
+ self.assertIn("#S 3 curve2 Y", file_content)
+ self.assertIn("#L graph x label curve2 Y", file_content)
+
+ self.assertIn("#S 4 graph y label", file_content)
+ self.assertIn("#L graph x label graph y label", file_content)
+
+
+class TestSaveActionExtension(PlotWidgetTestCase):
+ """Test SaveAction file filter API"""
+
+ def _dummySaveFunction(self, plot, filename, nameFilter):
+ pass
+
+ def testFileFilterAPI(self):
+ """Test addition/update of a file filter"""
+ saveAction = SaveAction(plot=self.plot, parent=self.plot)
+
+ # Add a new file filter
+ nameFilter = "Dummy file (*.dummy)"
+ saveAction.setFileFilter("all", nameFilter, self._dummySaveFunction)
+ self.assertTrue(nameFilter in saveAction.getFileFilters("all"))
+ self.assertEqual(
+ saveAction.getFileFilters("all")[nameFilter], self._dummySaveFunction
+ )
+
+ # Add a new file filter at a particular position
+ nameFilter = "Dummy file2 (*.dummy)"
+ saveAction.setFileFilter("all", nameFilter, self._dummySaveFunction, index=3)
+ self.assertTrue(nameFilter in saveAction.getFileFilters("all"))
+ filters = saveAction.getFileFilters("all")
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter), 3)
+
+ # Update an existing file filter
+ nameFilter = SaveAction.IMAGE_FILTER_EDF
+ saveAction.setFileFilter("image", nameFilter, self._dummySaveFunction)
+ self.assertEqual(
+ saveAction.getFileFilters("image")[nameFilter], self._dummySaveFunction
+ )
+
+ # Change the position of an existing file filter
+ nameFilter = "Dummy file2 (*.dummy)"
+ oldIndex = list(saveAction.getFileFilters("all")).index(nameFilter)
+ newIndex = oldIndex - 1
+ saveAction.setFileFilter(
+ "all", nameFilter, self._dummySaveFunction, index=newIndex
+ )
+ filters = saveAction.getFileFilters("all")
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter), newIndex)
diff --git a/src/silx/gui/plot/test/testScatterMaskToolsWidget.py b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py
new file mode 100644
index 0000000..5dc14e1
--- /dev/null
+++ b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py
@@ -0,0 +1,307 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for MaskToolsWidget"""
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import os.path
+
+import numpy
+
+from silx.gui import qt
+from silx.test.utils import temp_dir
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import getQToolButtonFromAction
+from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget
+from .utils import PlotWidgetTestCase
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
+ """Basic test for MaskToolsWidget"""
+
+ def _createPlot(self):
+ return PlotWindow()
+
+ def setUp(self):
+ super(TestScatterMaskToolsWidget, self).setUp()
+ self.widget = ScatterMaskToolsWidget.ScatterMaskToolsDockWidget(
+ plot=self.plot, name="TEST"
+ )
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
+
+ self.maskWidget = self.widget.widget()
+
+ def tearDown(self):
+ del self.maskWidget
+ del self.widget
+ super(TestScatterMaskToolsWidget, self).tearDown()
+
+ def testEmptyPlot(self):
+ """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
+ self.maskWidget.setMultipleMasks("single")
+ self.qapp.processEvents()
+
+ self.maskWidget.setMultipleMasks("exclusive")
+ self.qapp.processEvents()
+
+ def _drag(self):
+ """Drag from plot center to offset position"""
+ plot = self.plot.getWidgetHandle()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ pos0 = xCenter, yCenter
+ pos1 = xCenter + offset, yCenter + offset
+
+ self.mouseMove(plot, pos=(0, 0))
+ self.mouseMove(plot, pos=pos0)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.qapp.processEvents()
+
+ self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
+ self.mouseMove(plot, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(0, 0))
+
+ def _drawPolygon(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [
+ (x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ (x, y + offset),
+ ] # Close polygon
+
+ self.mouseMove(plot, pos=[0, 0])
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+
+ def _drawPencil(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [
+ (x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ ]
+
+ self.mouseMove(plot, pos=[0, 0])
+ self.mouseMove(plot, pos=star[0])
+ self.mousePress(plot, qt.Qt.LeftButton, pos=star[0])
+ for pos in star[1:]:
+ self.mouseMove(plot, pos=pos)
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=star[-1])
+
+ def testWithAScatter(self):
+ """Plot with a Scatter: test MaskToolsWidget interactions"""
+
+ # Add and remove a scatter (this should enable/disable GUI + change mask)
+ self.plot.addScatter(
+ x=numpy.arange(256),
+ y=numpy.arange(256),
+ value=numpy.random.random(256),
+ legend="test",
+ )
+ self.plot.setActiveScatter("test")
+ self.qapp.processEvents()
+
+ self.plot.remove("test", kind="scatter")
+ self.qapp.processEvents()
+
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend="test",
+ )
+ self.plot.setActiveScatter("test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ # Test draw rectangle #
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertFalse(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertTrue(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw polygon #
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertFalse(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertTrue(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw pencil #
+ toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.maskWidget.pencilSpinBox.setValue(30)
+ self.qapp.processEvents()
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertFalse(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertTrue(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test no draw tool #
+ toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.plot.clear()
+
+ def __loadSave(self, file_format):
+ self.plot.addScatter(
+ x=numpy.arange(256),
+ y=25 * (numpy.arange(256) % 10),
+ value=numpy.random.random(256),
+ legend="test",
+ )
+ self.plot.setActiveScatter("test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ # Draw a polygon mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self._drawPolygon()
+
+ ref_mask = self.maskWidget.getSelectionMask()
+ self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
+
+ with temp_dir() as tmp:
+ mask_filename = os.path.join(tmp, "mask." + file_format)
+ self.maskWidget.save(mask_filename, file_format)
+
+ self.maskWidget.resetSelectionMask()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
+
+ self.maskWidget.load(mask_filename)
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), ref_mask))
+ )
+
+ def testLoadSaveNpy(self):
+ self.__loadSave("npy")
+
+ def testLoadSaveCsv(self):
+ self.__loadSave("csv")
+
+ def testSigMaskChangedEmitted(self):
+ self.qapp.processEvents()
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.ones((1000,)),
+ legend="test",
+ )
+ self.plot.setActiveScatter("test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ self.plot.remove("test", kind="scatter")
+ self.qapp.processEvents()
+
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend="test",
+ )
+
+ l = []
+
+ def slot():
+ l.append(1)
+
+ self.maskWidget.sigMaskChanged.connect(slot)
+
+ # rectangle mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertGreater(len(l), 0)
diff --git a/src/silx/gui/plot/test/testScatterView.py b/src/silx/gui/plot/test/testScatterView.py
new file mode 100644
index 0000000..d6853b1
--- /dev/null
+++ b/src/silx/gui/plot/test/testScatterView.py
@@ -0,0 +1,122 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ScatterView"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2018"
+
+
+import numpy
+
+from silx.gui.plot.items import Axis, Scatter
+from silx.gui.plot import ScatterView
+from silx.gui.plot.test.utils import PlotWidgetTestCase
+
+
+class TestScatterView(PlotWidgetTestCase):
+ """Test of ScatterView widget"""
+
+ def _createPlot(self):
+ return ScatterView()
+
+ def test(self):
+ """Simple tests"""
+ x = numpy.arange(100)
+ y = numpy.arange(100)
+ value = numpy.arange(100)
+ self.plot.setData(x, y, value)
+ self.qapp.processEvents()
+
+ data = self.plot.getData()
+ self.assertEqual(len(data), 5)
+ self.assertTrue(numpy.all(numpy.equal(x, data[0])))
+ self.assertTrue(numpy.all(numpy.equal(y, data[1])))
+ self.assertTrue(numpy.all(numpy.equal(value, data[2])))
+ self.assertIsNone(data[3]) # xerror
+ self.assertIsNone(data[4]) # yerror
+
+ # Test access to scatter item
+ self.assertIsInstance(self.plot.getScatterItem(), Scatter)
+
+ # Test toolbar actions
+
+ action = self.plot.getScatterToolBar().getXAxisLogarithmicAction()
+ action.trigger()
+ self.qapp.processEvents()
+
+ maskAction = self.plot.getScatterToolBar().actions()[-1]
+ maskAction.trigger()
+ self.qapp.processEvents()
+
+ # Test proxy API
+
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ scale = self.plot.getXAxis().getScale()
+ self.assertEqual(scale, Axis.LOGARITHMIC)
+
+ scale = self.plot.getYAxis().getScale()
+ self.assertEqual(scale, Axis.LINEAR)
+
+ title = "Test ScatterView"
+ self.plot.setGraphTitle(title)
+ self.assertEqual(self.plot.getGraphTitle(), title)
+
+ self.qapp.processEvents()
+
+ # Reset scatter data
+
+ self.plot.setData(None, None, None)
+ self.qapp.processEvents()
+
+ data = self.plot.getData()
+ self.assertEqual(len(data), 5)
+ self.assertEqual(len(data[0]), 0) # x
+ self.assertEqual(len(data[1]), 0) # y
+ self.assertEqual(len(data[2]), 0) # value
+ self.assertIsNone(data[3]) # xerror
+ self.assertIsNone(data[4]) # yerror
+
+ def testAlpha(self):
+ """Test alpha transparency in setData"""
+ _pts = 100
+ _levels = 100
+ _fwhm = 50
+ x = numpy.random.rand(_pts) * _levels
+ y = numpy.random.rand(_pts) * _levels
+ value = numpy.random.rand(_pts) * _levels
+ x0 = x[int(_pts / 2)]
+ y0 = x[int(_pts / 2)]
+ # 2D Gaussian kernel
+ alpha = numpy.exp(
+ -4 * numpy.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / _fwhm**2
+ )
+
+ self.plot.setData(x, y, value, alpha=alpha)
+ self.qapp.processEvents()
+
+ alphaData = self.plot.getScatterItem().getAlphaData()
+ self.assertTrue(numpy.all(numpy.equal(alpha, alphaData)))
diff --git a/src/silx/gui/plot/test/testStackView.py b/src/silx/gui/plot/test/testStackView.py
new file mode 100644
index 0000000..5e0ead5
--- /dev/null
+++ b/src/silx/gui/plot/test/testStackView.py
@@ -0,0 +1,266 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for StackView"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/03/2017"
+
+
+import numpy
+
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+
+from silx.gui import qt
+from silx.gui.plot import StackView
+from silx.gui.plot.StackView import StackViewMainWindow
+
+from silx.utils.array_like import ListOfImages
+
+
+class TestStackView(TestCaseQt):
+ """Base class for tests of StackView."""
+
+ def setUp(self):
+ super(TestStackView, self).setUp()
+ self.stackview = StackView()
+ self.stackview.show()
+ self.qWaitForWindowExposed(self.stackview)
+ self.mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i / 15.0)
+ + numpy.cos(j / 4.0)
+ + 2 * numpy.sin(k / 6.0),
+ (10, 20, 30),
+ )
+
+ def tearDown(self):
+ self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.stackview.close()
+ del self.stackview
+ super(TestStackView, self).tearDown()
+
+ def testScaleColormapRangeToStack(self):
+ """Test scaleColormapRangeToStack"""
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis")
+ colormap = self.stackview.getColormap()
+
+ # Colormap autoscale to image
+ self.assertEqual(colormap.getVRange(), (None, None))
+ self.stackview.scaleColormapRangeToStack()
+
+ # Colormap range set according to stack range
+ self.assertEqual(colormap.getVRange(), (self.mystack.min(), self.mystack.max()))
+
+ def testSetStack(self):
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis")
+ my_trans_stack, params = self.stackview.getStack()
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack, my_trans_stack))
+ self.assertEqual(params["colormap"]["name"], "viridis")
+
+ def testSetStackPerspective(self):
+ self.stackview.setStack(self.mystack, perspective=1)
+ # my_orig_stack, params = self.stackview.getStack()
+ my_trans_stack, params = self.stackview.getCurrentView()
+
+ # get stack returns the transposed data, depending on the perspective
+ self.assertEqual(
+ my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]),
+ )
+ self.assertTrue(
+ numpy.array_equal(
+ numpy.transpose(self.mystack, axes=(1, 0, 2)), my_trans_stack
+ )
+ )
+
+ def testSetStackListOfImages(self):
+ loi = [self.mystack[i] for i in range(self.mystack.shape[0])]
+
+ self.stackview.setStack(loi)
+ my_orig_stack, params = self.stackview.getStack(returnNumpyArray=True)
+ my_trans_stack, params = self.stackview.getStack(returnNumpyArray=True)
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack, my_trans_stack))
+ self.assertTrue(numpy.array_equal(self.mystack, my_orig_stack))
+ self.assertIsInstance(my_trans_stack, numpy.ndarray)
+
+ self.stackview.setStack(loi, perspective=2)
+ my_orig_stack, params = self.stackview.getStack(copy=False)
+ my_trans_stack, params = self.stackview.getCurrentView(copy=False)
+ # getStack(copy=False) must return the object set in setStack
+ self.assertIs(my_orig_stack, loi)
+ # getCurrentView(copy=False) returns a ListOfImages whose .images
+ # attr is the original data
+ self.assertEqual(
+ my_trans_stack.shape,
+ (self.mystack.shape[2], self.mystack.shape[0], self.mystack.shape[1]),
+ )
+ self.assertTrue(
+ numpy.array_equal(
+ numpy.array(my_trans_stack),
+ numpy.transpose(self.mystack, axes=(2, 0, 1)),
+ )
+ )
+ self.assertIsInstance(
+ my_trans_stack, ListOfImages
+ ) # returnNumpyArray=False by default in getStack
+ self.assertIs(my_trans_stack.images, loi)
+
+ def testPerspective(self):
+ self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)))
+ self.assertEqual(
+ self.stackview._perspective, 0, "Default perspective is not 0 (dim1-dim2)."
+ )
+
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.assertEqual(
+ self.stackview._perspective,
+ 1,
+ "Plane selection combobox not updating perspective",
+ )
+
+ self.stackview.setStack(numpy.arange(6).reshape((1, 2, 3)))
+ self.assertEqual(
+ self.stackview._perspective,
+ 1,
+ "Perspective not preserved when calling setStack "
+ "without specifying the perspective parameter.",
+ )
+
+ self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)), perspective=2)
+ self.assertEqual(
+ self.stackview._perspective,
+ 2,
+ "Perspective not set in setStack(..., perspective=2).",
+ )
+
+ def testDefaultTitle(self):
+ """Test that the plot title contains the proper Z information"""
+ self.stackview.setStack(
+ numpy.arange(24).reshape((4, 3, 2)),
+ calibrations=[(0, 1), (-10, 10), (3.14, 3.14)],
+ )
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=2")
+
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=-10")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=10")
+
+ self.stackview._StackView__planeSelection.setPerspective(2)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=3.14")
+ self.stackview.setFrameNumber(1)
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=6.28")
+
+ def testCustomTitle(self):
+ """Test setting the plot title with a user defined callback"""
+ self.stackview.setStack(
+ numpy.arange(24).reshape((4, 3, 2)),
+ calibrations=[(0, 1), (-10, 10), (3.14, 3.14)],
+ )
+
+ def title_callback(frame_idx):
+ return "Cubed index title %d" % (frame_idx**3)
+
+ self.stackview.setTitleCallback(title_callback)
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Cubed index title 0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Cubed index title 8")
+
+ # perspective should not matter, only frame index
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Cubed index title 0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Cubed index title 8")
+
+ with self.assertRaises(TypeError):
+ # setTitleCallback should not accept non-callable objects like strings
+ self.stackview.setTitleCallback(
+ "Là, vous faites sirop de vingt-et-un et vous dites : "
+ "beau sirop, mi-sirop, siroté, gagne-sirop, sirop-grelot,"
+ " passe-montagne, sirop au bon goût."
+ )
+
+ def testStackFrameNumber(self):
+ self.stackview.setStack(self.mystack)
+ self.assertEqual(self.stackview.getFrameNumber(), 0)
+
+ listener = SignalListener()
+ self.stackview.sigFrameChanged.connect(listener)
+
+ self.stackview.setFrameNumber(1)
+ self.assertEqual(self.stackview.getFrameNumber(), 1)
+ self.assertEqual(listener.arguments(), [(1,)])
+
+
+class TestStackViewMainWindow(TestCaseQt):
+ """Base class for tests of StackView."""
+
+ def setUp(self):
+ super(TestStackViewMainWindow, self).setUp()
+ self.stackview = StackViewMainWindow()
+ self.stackview.show()
+ self.qWaitForWindowExposed(self.stackview)
+ self.mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i / 15.0)
+ + numpy.cos(j / 4.0)
+ + 2 * numpy.sin(k / 6.0),
+ (10, 20, 30),
+ )
+
+ def tearDown(self):
+ self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.stackview.close()
+ del self.stackview
+ super(TestStackViewMainWindow, self).tearDown()
+
+ def testSetStack(self):
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis")
+ my_trans_stack, params = self.stackview.getStack()
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack, my_trans_stack))
+ self.assertEqual(params["colormap"]["name"], "viridis")
+
+ def testSetStackPerspective(self):
+ self.stackview.setStack(self.mystack, perspective=1)
+ my_trans_stack, params = self.stackview.getCurrentView()
+ # get stack returns the transposed data, depending on the perspective
+ self.assertEqual(
+ my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]),
+ )
+ self.assertTrue(
+ numpy.array_equal(
+ numpy.transpose(self.mystack, axes=(1, 0, 2)), my_trans_stack
+ )
+ )
diff --git a/src/silx/gui/plot/test/testStats.py b/src/silx/gui/plot/test/testStats.py
new file mode 100644
index 0000000..2a2793e
--- /dev/null
+++ b/src/silx/gui/plot/test/testStats.py
@@ -0,0 +1,1143 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for CurvesROIWidget"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "07/03/2018"
+
+
+from silx.gui import qt
+from silx.gui.plot.stats import stats
+from silx.gui.plot import StatsWidget
+from silx.gui.plot.stats import statshandler
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+from silx.gui.plot import Plot1D, Plot2D
+from silx.gui.plot.items.roi import RectangleROI, PolygonROI
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.stats.stats import Stats
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.utils.testutils import ParametricTestCase
+import logging
+import numpy
+
+_logger = logging.getLogger(__name__)
+
+
+class TestStatsBase(object):
+ """Base class for stats TestCase"""
+
+ def setUp(self):
+ self.createCurveContext()
+ self.createImageContext()
+ self.createScatterContext()
+
+ def tearDown(self):
+ self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot1d.close()
+ del self.plot1d
+ self.plot2d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot2d.close()
+ del self.plot2d
+ self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.scatterPlot.close()
+ del self.scatterPlot
+
+ def createCurveContext(self):
+ self.plot1d = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot1d.addCurve(x, y, legend="curve0")
+
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve("curve0"),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=None,
+ )
+
+ def createScatterContext(self):
+ self.scatterPlot = Plot2D()
+ lgd = "scatter plot"
+ self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36])
+ self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18])
+ self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5])
+ self.scatterPlot.addScatter(
+ self.xScatterData, self.yScatterData, self.valuesScatterData, legend=lgd
+ )
+ self.scatterContext = stats._ScatterContext(
+ item=self.scatterPlot.getScatter(lgd),
+ plot=self.scatterPlot,
+ onlimits=False,
+ roi=None,
+ )
+
+ def createImageContext(self):
+ self.plot2d = Plot2D()
+ self._imgLgd = "test image"
+ self.imageData = numpy.arange(32 * 128).reshape(32, 128)
+ self.plot2d.addImage(data=self.imageData, legend=self._imgLgd, replace=False)
+ self.imageContext = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=None,
+ )
+
+ def getBasicStats(self):
+ return {
+ "min": stats.StatMin(),
+ "minCoords": stats.StatCoordMin(),
+ "max": stats.StatMax(),
+ "maxCoords": stats.StatCoordMax(),
+ "std": stats.Stat(name="std", fct=numpy.std),
+ "mean": stats.Stat(name="mean", fct=numpy.mean),
+ "com": stats.StatCOM(),
+ }
+
+
+class TestStats(TestStatsBase, TestCaseQt):
+ """
+ Test :class:`BaseClass` class and inheriting classes
+ """
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ TestStatsBase.setUp(self)
+
+ def tearDown(self):
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
+ def testBasicStatsCurve(self):
+ """Test result for simple stats on a curve"""
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(20))
+ self.assertEqual(_stats["min"].calculate(self.curveContext), 0)
+ self.assertEqual(_stats["max"].calculate(self.curveContext), 19)
+ self.assertEqual(_stats["minCoords"].calculate(self.curveContext), (0,))
+ self.assertEqual(_stats["maxCoords"].calculate(self.curveContext), (19,))
+ self.assertEqual(_stats["std"].calculate(self.curveContext), numpy.std(yData))
+ self.assertEqual(_stats["mean"].calculate(self.curveContext), numpy.mean(yData))
+ com = numpy.sum(xData * yData) / numpy.sum(yData)
+ self.assertEqual(_stats["com"].calculate(self.curveContext), com)
+
+ def testBasicStatsImage(self):
+ """Test result for simple stats on an image"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats["min"].calculate(self.imageContext), 0)
+ self.assertEqual(_stats["max"].calculate(self.imageContext), 128 * 32 - 1)
+ self.assertEqual(_stats["minCoords"].calculate(self.imageContext), (0, 0))
+ self.assertEqual(_stats["maxCoords"].calculate(self.imageContext), (127, 31))
+ self.assertEqual(
+ _stats["std"].calculate(self.imageContext), numpy.std(self.imageData)
+ )
+ self.assertEqual(
+ _stats["mean"].calculate(self.imageContext), numpy.mean(self.imageData)
+ )
+
+ yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1)
+ xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0)
+ dataXRange = range(self.imageData.shape[1])
+ dataYRange = range(self.imageData.shape[0])
+
+ ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
+
+ self.assertEqual(_stats["com"].calculate(self.imageContext), (xcom, ycom))
+
+ def testStatsImageAdv(self):
+ """Test that scale and origin are taking into account for images"""
+
+ image2Data = numpy.arange(32 * 128).reshape(32, 128)
+ self.plot2d.addImage(
+ data=image2Data,
+ legend=self._imgLgd,
+ replace=True,
+ origin=(100, 10),
+ scale=(2, 0.5),
+ )
+ image2Context = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=None,
+ )
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats["min"].calculate(image2Context), 0)
+ self.assertEqual(_stats["max"].calculate(image2Context), 128 * 32 - 1)
+ self.assertEqual(_stats["minCoords"].calculate(image2Context), (100, 10))
+ self.assertEqual(
+ _stats["maxCoords"].calculate(image2Context),
+ (127 * 2.0 + 100, 31 * 0.5 + 10),
+ )
+ self.assertEqual(
+ _stats["std"].calculate(image2Context), numpy.std(self.imageData)
+ )
+ self.assertEqual(
+ _stats["mean"].calculate(image2Context), numpy.mean(self.imageData)
+ )
+
+ yData = numpy.sum(self.imageData, axis=1)
+ xData = numpy.sum(self.imageData, axis=0)
+ dataXRange = numpy.arange(self.imageData.shape[1], dtype=numpy.float64)
+ dataYRange = numpy.arange(self.imageData.shape[0], dtype=numpy.float64)
+
+ ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
+ ycom = (ycom * 0.5) + 10
+ xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
+ xcom = (xcom * 2.0) + 100
+ self.assertTrue(
+ numpy.allclose(_stats["com"].calculate(image2Context), (xcom, ycom))
+ )
+
+ def testBasicStatsScatter(self):
+ """Test result for simple stats on a scatter"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats["min"].calculate(self.scatterContext), 5)
+ self.assertEqual(_stats["max"].calculate(self.scatterContext), 90)
+ self.assertEqual(_stats["minCoords"].calculate(self.scatterContext), (0, 2))
+ self.assertEqual(_stats["maxCoords"].calculate(self.scatterContext), (50, 69))
+ self.assertEqual(
+ _stats["std"].calculate(self.scatterContext),
+ numpy.std(self.valuesScatterData),
+ )
+ self.assertEqual(
+ _stats["mean"].calculate(self.scatterContext),
+ numpy.mean(self.valuesScatterData),
+ )
+
+ data = self.valuesScatterData.astype(numpy.float64)
+ comx = numpy.sum(self.xScatterData * data) / numpy.sum(data)
+ comy = numpy.sum(self.yScatterData * data) / numpy.sum(data)
+ self.assertEqual(_stats["com"].calculate(self.scatterContext), (comx, comy))
+
+ def testKindNotManagedByStat(self):
+ """Make sure an exception is raised if we try to execute calculate
+ of the base class"""
+ b = stats.StatBase(name="toto", compatibleKinds="curve")
+ with self.assertRaises(NotImplementedError):
+ b.calculate(self.imageContext)
+
+ def testKindNotManagedByContext(self):
+ """
+ Make sure an error is raised if we try to calculate a statistic with
+ a context not managed
+ """
+ myStat = stats.Stat(name="toto", fct=numpy.std, kinds=("curve"))
+ myStat.calculate(self.curveContext)
+ with self.assertRaises(ValueError):
+ myStat.calculate(self.scatterContext)
+ with self.assertRaises(ValueError):
+ myStat.calculate(self.imageContext)
+
+ def testOnLimits(self):
+ stat = stats.StatMin()
+
+ self.plot1d.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
+ curveContextOnLimits = stats._CurveContext(
+ item=self.plot1d.getCurve("curve0"),
+ plot=self.plot1d,
+ onlimits=True,
+ roi=None,
+ )
+ self.assertEqual(stat.calculate(curveContextOnLimits), 2)
+
+ self.plot2d.getXAxis().setLimitsConstraints(minPos=32)
+ imageContextOnLimits = stats._ImageContext(
+ item=self.plot2d.getImage("test image"),
+ plot=self.plot2d,
+ onlimits=True,
+ roi=None,
+ )
+ self.assertEqual(stat.calculate(imageContextOnLimits), 32)
+
+ self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40)
+ scatterContextOnLimits = stats._ScatterContext(
+ item=self.scatterPlot.getScatter("scatter plot"),
+ plot=self.scatterPlot,
+ onlimits=True,
+ roi=None,
+ )
+ self.assertEqual(stat.calculate(scatterContextOnLimits), 20)
+
+
+class TestStatsFormatter(TestCaseQt):
+ """Simple test to check usage of the :class:`StatsFormatter`"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot1d = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot1d.addCurve(x, y, legend="curve0")
+
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve("curve0"),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=None,
+ )
+
+ self.stat = stats.StatMin()
+
+ def tearDown(self):
+ self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot1d.close()
+ del self.plot1d
+ TestCaseQt.tearDown(self)
+
+ def testEmptyFormatter(self):
+ """Make sure a formatter with no formatter definition will return a
+ simple cast to str"""
+ emptyFormatter = statshandler.StatFormatter()
+ self.assertEqual(
+ emptyFormatter.format(self.stat.calculate(self.curveContext)), "0.000"
+ )
+
+ def testSettedFormatter(self):
+ """Make sure a formatter with no formatter definition will return a
+ simple cast to str"""
+ formatter = statshandler.StatFormatter(formatter="{0:.3f}")
+ self.assertEqual(
+ formatter.format(self.stat.calculate(self.curveContext)), "0.000"
+ )
+
+
+class TestStatsHandler(TestCaseQt):
+ """Make sure the StatHandler is correctly making the link between
+ :class:`StatBase` and :class:`StatFormatter` and checking the API is valid
+ """
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot1d = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot1d.addCurve(x, y, legend="curve0")
+ self.curveItem = self.plot1d.getCurve("curve0")
+
+ self.stat = stats.StatMin()
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot1d.close()
+ self.plot1d = None
+ TestCaseQt.tearDown(self)
+
+ def testConstructor(self):
+ """Make sure the constructor can deal will all possible arguments:
+
+ * tuple of :class:`StatBase` derivated classes
+ * tuple of tuples (:class:`StatBase`, :class:`StatFormatter`)
+ * tuple of tuples (str, pointer to function, kind)
+ """
+ handler0 = statshandler.StatsHandler((stats.StatMin(), stats.StatMax()))
+
+ res = handler0.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False)
+ self.assertTrue("min" in res)
+ self.assertEqual(res["min"], "0")
+ self.assertTrue("max" in res)
+ self.assertEqual(res["max"], "19")
+
+ handler1 = statshandler.StatsHandler(
+ (
+ (stats.StatMin(), statshandler.StatFormatter(formatter=None)),
+ (stats.StatMax(), statshandler.StatFormatter()),
+ )
+ )
+
+ res = handler1.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False)
+ self.assertTrue("min" in res)
+ self.assertEqual(res["min"], "0")
+ self.assertTrue("max" in res)
+ self.assertEqual(res["max"], "19.000")
+
+ handler2 = statshandler.StatsHandler(
+ ((stats.StatMin(), None), (stats.StatMax(), statshandler.StatFormatter()))
+ )
+
+ res = handler2.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False)
+ self.assertTrue("min" in res)
+ self.assertEqual(res["min"], "0")
+ self.assertTrue("max" in res)
+ self.assertEqual(res["max"], "19.000")
+
+ handler3 = statshandler.StatsHandler(
+ (
+ (("amin", numpy.argmin), statshandler.StatFormatter()),
+ ("amax", numpy.argmax),
+ )
+ )
+
+ res = handler3.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False)
+ self.assertTrue("amin" in res)
+ self.assertEqual(res["amin"], "0.000")
+ self.assertTrue("amax" in res)
+ self.assertEqual(res["amax"], "19")
+
+ with self.assertRaises(ValueError):
+ statshandler.StatsHandler(("name"))
+
+
+class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
+ """Basic test for StatsWidget with curves"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot = Plot1D()
+ self.plot.show()
+ x = range(20)
+ y = range(20)
+ self.plot.addCurve(x, y, legend="curve0")
+ y = range(12, 32)
+ self.plot.addCurve(x, y, legend="curve1")
+ y = range(-2, 18)
+ self.plot.addCurve(x, y, legend="curve2")
+ self.widget = StatsWidget.StatsWidget(plot=self.plot)
+ self.statsTable = self.widget._statsTable
+
+ mystats = statshandler.StatsHandler(
+ (
+ stats.StatMin(),
+ (
+ stats.StatCoordMin(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ stats.StatMax(),
+ (
+ stats.StatCoordMax(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ stats.StatDelta(),
+ ("std", numpy.std),
+ ("mean", numpy.mean),
+ stats.StatCOM(),
+ )
+ )
+
+ self.statsTable.setStats(mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.statsTable = None
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def testDisplayActiveItemsSyncOptions(self):
+ """
+ Test that the several option of the sync options are well
+ synchronized between the different object"""
+ widget = StatsWidget.StatsWidget(plot=self.plot)
+ table = StatsWidget.StatsTable(plot=self.plot)
+
+ def check_display_only_active_item(only_active):
+ # check internal value
+ self.assertIs(widget._statsTable._displayOnlyActItem, only_active)
+ # self.assertTrue(table._displayOnlyActItem is only_active)
+ # check gui display
+ self.assertEqual(widget._options.isActiveItemMode(), only_active)
+
+ for displayOnlyActiveItems in (True, False):
+ with self.subTest(displayOnlyActiveItems=displayOnlyActiveItems):
+ widget.setDisplayOnlyActiveItem(displayOnlyActiveItems)
+ # table.setDisplayOnlyActiveItem(displayOnlyActiveItems)
+ check_display_only_active_item(displayOnlyActiveItems)
+
+ check_display_only_active_item(only_active=False)
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ table.setAttribute(qt.Qt.WA_DeleteOnClose)
+ widget.close()
+ table.close()
+
+ def testInit(self):
+ """Make sure all the curves are registred on initialization"""
+ self.assertEqual(self.statsTable.rowCount(), 3)
+
+ def testRemoveCurve(self):
+ """Make sure the Curves stats take into account the curve removal from
+ plot"""
+ self.plot.removeCurve("curve2")
+ self.assertEqual(self.statsTable.rowCount(), 2)
+ for iRow in range(2):
+ self.assertTrue(
+ self.statsTable.item(iRow, 0).text() in ("curve0", "curve1")
+ )
+
+ self.plot.removeCurve("curve0")
+ self.assertEqual(self.statsTable.rowCount(), 1)
+ self.plot.removeCurve("curve1")
+ self.assertEqual(self.statsTable.rowCount(), 0)
+
+ def testAddCurve(self):
+ """Make sure the Curves stats take into account the add curve action"""
+ self.plot.addCurve(legend="curve3", x=range(10), y=range(10))
+ self.assertEqual(self.statsTable.rowCount(), 4)
+
+ def testUpdateCurveFromAddCurve(self):
+ """Make sure the stats of the cuve will be removed after updating a
+ curve"""
+ self.plot.addCurve(legend="curve0", x=range(10), y=range(10))
+ self.qapp.processEvents()
+ self.assertEqual(self.statsTable.rowCount(), 3)
+ curve = self.plot._getItem(kind="curve", legend="curve0")
+ tableItems = self.statsTable._itemToTableItems(curve)
+ self.assertEqual(tableItems["max"].text(), "9")
+
+ def testUpdateCurveFromCurveObj(self):
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(4))
+ self.qapp.processEvents()
+ self.assertEqual(self.statsTable.rowCount(), 3)
+ curve = self.plot._getItem(kind="curve", legend="curve0")
+ tableItems = self.statsTable._itemToTableItems(curve)
+ self.assertEqual(tableItems["max"].text(), "3")
+
+ def testSetAnotherPlot(self):
+ plot2 = Plot1D()
+ plot2.addCurve(x=range(26), y=range(26), legend="new curve")
+ self.statsTable.setPlot(plot2)
+ self.assertEqual(self.statsTable.rowCount(), 1)
+ self.qapp.processEvents()
+ plot2.setAttribute(qt.Qt.WA_DeleteOnClose)
+ plot2.close()
+ plot2 = None
+
+ def testUpdateMode(self):
+ """Make sure the update modes are well take into account"""
+ self.plot.setActiveCurve("curve0")
+ for display_only_active in (True, False):
+ with self.subTest(display_only_active=display_only_active):
+ self.widget.setDisplayOnlyActiveItem(display_only_active)
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(4))
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ update_stats_action = self.widget._options.getUpdateStatsAction()
+ # test from api
+ self.assertEqual(
+ self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO
+ )
+ self.widget.show()
+ # check stats change in auto mode
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(-1, 3))
+ self.qapp.processEvents()
+ tableItems = self.statsTable._itemToTableItems(
+ self.plot.getCurve("curve0")
+ )
+ curve0_min = tableItems["min"].text()
+ self.assertTrue(float(curve0_min) == -1.0)
+
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(1, 5))
+ self.qapp.processEvents()
+ tableItems = self.statsTable._itemToTableItems(
+ self.plot.getCurve("curve0")
+ )
+ curve0_min = tableItems["min"].text()
+ self.assertTrue(float(curve0_min) == 1.0)
+
+ # check stats change in manual mode only if requested
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
+ self.assertEqual(
+ self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL
+ )
+
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(2, 6))
+ self.qapp.processEvents()
+ tableItems = self.statsTable._itemToTableItems(
+ self.plot.getCurve("curve0")
+ )
+ curve0_min = tableItems["min"].text()
+ self.assertTrue(float(curve0_min) == 1.0)
+
+ update_stats_action.trigger()
+ tableItems = self.statsTable._itemToTableItems(
+ self.plot.getCurve("curve0")
+ )
+ curve0_min = tableItems["min"].text()
+ self.assertTrue(float(curve0_min) == 2.0)
+
+ def testItemHidden(self):
+ """Test if an item is hide, then the associated stats item is also
+ hide"""
+ curve0 = self.plot.getCurve("curve0")
+ curve1 = self.plot.getCurve("curve1")
+ curve2 = self.plot.getCurve("curve2")
+
+ self.plot.show()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+ self.assertFalse(self.statsTable.isRowHidden(0))
+ self.assertFalse(self.statsTable.isRowHidden(1))
+ self.assertFalse(self.statsTable.isRowHidden(2))
+
+ curve0.setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.statsTable.isRowHidden(0))
+ curve0.setVisible(True)
+ self.qapp.processEvents()
+ self.assertFalse(self.statsTable.isRowHidden(0))
+ curve1.setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.statsTable.isRowHidden(1))
+ tableItems = self.statsTable._itemToTableItems(curve2)
+ curve2_min = tableItems["min"].text()
+ self.assertTrue(float(curve2_min) == -2.0)
+
+ curve0.setVisible(False)
+ curve1.setVisible(False)
+ curve2.setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.statsTable.isRowHidden(0))
+ self.assertTrue(self.statsTable.isRowHidden(1))
+ self.assertTrue(self.statsTable.isRowHidden(2))
+
+
+class TestStatsWidgetWithImages(TestCaseQt):
+ """Basic test for StatsWidget with images"""
+
+ IMAGE_LEGEND = "test image"
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot = Plot2D()
+
+ self.plot.addImage(
+ data=numpy.arange(128 * 128).reshape(128, 128),
+ legend=self.IMAGE_LEGEND,
+ replace=False,
+ )
+
+ self.widget = StatsWidget.StatsTable(plot=self.plot)
+
+ mystats = statshandler.StatsHandler(
+ (
+ (stats.StatMin(), statshandler.StatFormatter()),
+ (
+ stats.StatCoordMin(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ (stats.StatMax(), statshandler.StatFormatter()),
+ (
+ stats.StatCoordMax(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ (stats.StatDelta(), statshandler.StatFormatter()),
+ ("std", numpy.std),
+ ("mean", numpy.mean),
+ (stats.StatCOM(), statshandler.StatFormatter(None)),
+ )
+ )
+
+ self.widget.setStats(mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def test(self):
+ image = self.plot._getItem(kind="image", legend=self.IMAGE_LEGEND)
+ tableItems = self.widget._itemToTableItems(image)
+
+ maxText = "{0:.3f}".format((128 * 128) - 1)
+ self.assertEqual(tableItems["legend"].text(), self.IMAGE_LEGEND)
+ self.assertEqual(tableItems["min"].text(), "0.000")
+ self.assertEqual(tableItems["max"].text(), maxText)
+ self.assertEqual(tableItems["delta"].text(), maxText)
+ self.assertEqual(tableItems["coords min"].text(), "0.0, 0.0")
+ self.assertEqual(tableItems["coords max"].text(), "127.0, 127.0")
+
+ def testItemHidden(self):
+ """Test if an item is hide, then the associated stats item is also
+ hide"""
+ self.widget.show()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.widget)
+ self.assertFalse(self.widget.isRowHidden(0))
+ self.plot.getImage(self.IMAGE_LEGEND).setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget.isRowHidden(0))
+
+
+class TestStatsWidgetWithScatters(TestCaseQt):
+ SCATTER_LEGEND = "scatter plot"
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.scatterPlot = Plot2D()
+ self.scatterPlot.addScatter(
+ [0, 1, 2, 20, 50, 60],
+ [2, 3, 4, 26, 69, 6],
+ [5, 6, 7, 10, 90, 20],
+ legend=self.SCATTER_LEGEND,
+ )
+ self.widget = StatsWidget.StatsTable(plot=self.scatterPlot)
+
+ mystats = statshandler.StatsHandler(
+ (
+ stats.StatMin(),
+ (
+ stats.StatCoordMin(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ stats.StatMax(),
+ (
+ stats.StatCoordMax(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ stats.StatDelta(),
+ ("std", numpy.std),
+ ("mean", numpy.mean),
+ stats.StatCOM(),
+ )
+ )
+
+ self.widget.setStats(mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.scatterPlot.close()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.scatterPlot = None
+ TestCaseQt.tearDown(self)
+
+ def testStats(self):
+ scatter = self.scatterPlot._getItem(kind="scatter", legend=self.SCATTER_LEGEND)
+ tableItems = self.widget._itemToTableItems(scatter)
+ self.assertEqual(tableItems["legend"].text(), self.SCATTER_LEGEND)
+ self.assertEqual(tableItems["min"].text(), "5")
+ self.assertEqual(tableItems["coords min"].text(), "0, 2")
+ self.assertEqual(tableItems["max"].text(), "90")
+ self.assertEqual(tableItems["coords max"].text(), "50, 69")
+ self.assertEqual(tableItems["delta"].text(), "85")
+
+
+class TestEmptyStatsWidget(TestCaseQt):
+ def test(self):
+ widget = StatsWidget.StatsWidget()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+
+class TestLineWidget(TestCaseQt):
+ """Some test for the StatsLineWidget."""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+
+ mystats = statshandler.StatsHandler(
+ ((stats.StatMin(), statshandler.StatFormatter()),)
+ )
+
+ self.plot = Plot1D()
+ self.plot.show()
+ self.x = range(20)
+ self.y0 = range(20)
+ self.plot.addCurve(self.x, self.y0, legend="curve0")
+ self.y1 = range(12, 32)
+ self.plot.addCurve(self.x, self.y1, legend="curve1")
+ self.y2 = range(-2, 18)
+ self.plot.addCurve(self.x, self.y2, legend="curve2")
+ self.widget = StatsWidget.BasicGridStatsWidget(
+ plot=self.plot, kind="curve", stats=mystats
+ )
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.widget.setPlot(None)
+ self.widget._lineStatsWidget._statQlineEdit.clear()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def testProcessing(self):
+ self.widget._lineStatsWidget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.plot.setActiveCurve(legend="curve0")
+ self.assertTrue(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "0.000"
+ )
+ self.plot.setActiveCurve(legend="curve1")
+ self.assertTrue(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "12.000"
+ )
+ self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
+ self.widget.setStatsOnVisibleData(True)
+ self.qapp.processEvents()
+ self.assertTrue(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "14.000"
+ )
+ self.plot.setActiveCurve(None)
+ self.assertIsNone(self.plot.getActiveCurve())
+ self.widget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.assertFalse(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "14.000"
+ )
+ self.widget.setKind("image")
+ self.plot.addImage(numpy.arange(100 * 100).reshape(100, 100) + 0.312)
+ self.qapp.processEvents()
+ self.assertTrue(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "0.312"
+ )
+
+ def testUpdateMode(self):
+ """Make sure the update modes are well take into account"""
+ self.plot.setActiveCurve("curve0")
+ _autoRB = self.widget._options._autoRB
+ _manualRB = self.widget._options._manualRB
+ # test from api
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ self.assertTrue(_autoRB.isChecked())
+ self.assertFalse(_manualRB.isChecked())
+
+ # check stats change in auto mode
+ curve0_min = self.widget._lineStatsWidget._statQlineEdit["min"].text()
+ new_y = numpy.array(self.y0) - 2.56
+ self.plot.addCurve(x=self.x, y=new_y, legend="curve0")
+ curve0_min2 = self.widget._lineStatsWidget._statQlineEdit["min"].text()
+ self.assertTrue(curve0_min != curve0_min2)
+
+ # check stats change in manual mode only if requested
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
+ self.assertFalse(_autoRB.isChecked())
+ self.assertTrue(_manualRB.isChecked())
+
+ new_y = numpy.array(self.y0) - 1.2
+ self.plot.addCurve(x=self.x, y=new_y, legend="curve0")
+ curve0_min3 = self.widget._lineStatsWidget._statQlineEdit["min"].text()
+ self.assertTrue(curve0_min3 == curve0_min2)
+ self.widget._options._updateRequested()
+ curve0_min3 = self.widget._lineStatsWidget._statQlineEdit["min"].text()
+ self.assertTrue(curve0_min3 != curve0_min2)
+
+ # test from gui
+ self.widget.showRadioButtons(True)
+ self.widget._options._autoRB.toggle()
+ self.assertTrue(_autoRB.isChecked())
+ self.assertFalse(_manualRB.isChecked())
+
+ self.widget._options._manualRB.toggle()
+ self.assertFalse(_autoRB.isChecked())
+ self.assertTrue(_manualRB.isChecked())
+
+
+class TestUpdateModeWidget(TestCaseQt):
+ """Test UpdateModeWidget"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.widget = StatsWidget.UpdateModeWidget(parent=None)
+
+ def tearDown(self):
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ TestCaseQt.tearDown(self)
+
+ def testSignals(self):
+ """Test the signal emission of the widget"""
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ modeChangedListener = SignalListener()
+ manualUpdateListener = SignalListener()
+ self.widget.sigUpdateModeChanged.connect(modeChangedListener)
+ self.widget.sigUpdateRequested.connect(manualUpdateListener)
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO)
+ self.assertEqual(modeChangedListener.callCount(), 0)
+ self.qapp.processEvents()
+
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL)
+ self.qapp.processEvents()
+ self.assertEqual(modeChangedListener.callCount(), 1)
+ self.assertEqual(manualUpdateListener.callCount(), 0)
+ self.widget._updatePB.click()
+ self.widget._updatePB.click()
+ self.assertEqual(manualUpdateListener.callCount(), 2)
+
+ self.widget._autoRB.setChecked(True)
+ self.assertEqual(modeChangedListener.callCount(), 2)
+ self.widget._updatePB.click()
+ self.assertEqual(manualUpdateListener.callCount(), 2)
+
+
+class TestStatsROI(TestStatsBase, TestCaseQt):
+ """
+ Test stats based on ROI
+ """
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.createRois()
+ TestStatsBase.setUp(self)
+ self.createHistogramContext()
+
+ self.roiManager = RegionOfInterestManager(self.plot2d)
+ self.roiManager.addRoi(self._2Droi_rect)
+ self.roiManager.addRoi(self._2Droi_poly)
+
+ def tearDown(self):
+ self.roiManager.clear()
+ self.roiManager = None
+ self._1Droi = None
+ self._2Droi_rect = None
+ self._2Droi_poly = None
+ self.plotHisto.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plotHisto.close()
+ self.plotHisto = None
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
+ def createRois(self):
+ self._1Droi = ROI(name="my1DRoi", fromdata=2.0, todata=5.0)
+ self._2Droi_rect = RectangleROI()
+ self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0))
+ self._2Droi_poly = PolygonROI()
+ points = numpy.array(((0, 20), (0, 0), (10, 0)))
+ self._2Droi_poly.setPoints(points=points)
+
+ def createCurveContext(self):
+ TestStatsBase.createCurveContext(self)
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve("curve0"),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._1Droi,
+ )
+
+ def createHistogramContext(self):
+ self.plotHisto = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plotHisto.addHistogram(x, y, legend="histo0")
+
+ self.histoContext = stats._HistogramContext(
+ item=self.plotHisto.getHistogram("histo0"),
+ plot=self.plotHisto,
+ onlimits=False,
+ roi=self._1Droi,
+ )
+
+ def createScatterContext(self):
+ TestStatsBase.createScatterContext(self)
+ self.scatterContext = stats._ScatterContext(
+ item=self.scatterPlot.getScatter("scatter plot"),
+ plot=self.scatterPlot,
+ onlimits=False,
+ roi=self._1Droi,
+ )
+
+ def createImageContext(self):
+ TestStatsBase.createImageContext(self)
+
+ self.imageContext = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_rect,
+ )
+
+ self.imageContext_2 = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_poly,
+ )
+
+ def testErrors(self):
+ # test if onlimits is True and give also a roi
+ with self.assertRaises(ValueError):
+ stats._CurveContext(
+ item=self.plot1d.getCurve("curve0"),
+ plot=self.plot1d,
+ onlimits=True,
+ roi=self._1Droi,
+ )
+
+ # test if is a curve context and give an invalid 2D roi
+ with self.assertRaises(TypeError):
+ stats._CurveContext(
+ item=self.plot1d.getCurve("curve0"),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._2Droi_rect,
+ )
+
+ def testBasicStatsCurve(self):
+ """Test result for simple stats on a curve"""
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(0, 10))
+ self.assertEqual(_stats["min"].calculate(self.curveContext), 2)
+ self.assertEqual(_stats["max"].calculate(self.curveContext), 5)
+ self.assertEqual(_stats["minCoords"].calculate(self.curveContext), (2,))
+ self.assertEqual(_stats["maxCoords"].calculate(self.curveContext), (5,))
+ self.assertEqual(
+ _stats["std"].calculate(self.curveContext), numpy.std(yData[2:6])
+ )
+ self.assertEqual(
+ _stats["mean"].calculate(self.curveContext), numpy.mean(yData[2:6])
+ )
+ com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6])
+ self.assertEqual(_stats["com"].calculate(self.curveContext), com)
+
+ def testBasicStatsImageRectRoi(self):
+ """Test result for simple stats on an image"""
+ self.assertEqual(self.imageContext.values.compressed().size, 121)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats["min"].calculate(self.imageContext), 10)
+ self.assertEqual(_stats["max"].calculate(self.imageContext), 1300)
+ self.assertEqual(_stats["minCoords"].calculate(self.imageContext), (10, 0))
+ self.assertEqual(_stats["maxCoords"].calculate(self.imageContext), (20.0, 10.0))
+ self.assertAlmostEqual(
+ _stats["std"].calculate(self.imageContext),
+ numpy.std(self.imageData[0:11, 10:21]),
+ )
+ self.assertAlmostEqual(
+ _stats["mean"].calculate(self.imageContext),
+ numpy.mean(self.imageData[0:11, 10:21]),
+ )
+
+ compressed_values = self.imageContext.values.compressed()
+ compressed_values = compressed_values.reshape(11, 11)
+ yData = numpy.sum(compressed_values.astype(numpy.float64), axis=1)
+ xData = numpy.sum(compressed_values.astype(numpy.float64), axis=0)
+
+ dataYRange = range(11)
+ dataXRange = range(10, 21)
+
+ ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
+ self.assertEqual(_stats["com"].calculate(self.imageContext), (xcom, ycom))
+
+ def testBasicStatsImagePolyRoi(self):
+ """Test a simple rectangle ROI"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats["min"].calculate(self.imageContext_2), 0)
+ self.assertEqual(_stats["max"].calculate(self.imageContext_2), 2432)
+ self.assertEqual(_stats["minCoords"].calculate(self.imageContext_2), (0.0, 0.0))
+ # not 0.0, 19.0 because not fully in. Should all pixel have a weight,
+ # on to manage them in stats. For now 0 if the center is not in, else 1
+ self.assertEqual(
+ _stats["maxCoords"].calculate(self.imageContext_2), (0.0, 19.0)
+ )
+
+ def testBasicStatsScatter(self):
+ self.assertEqual(self.scatterContext.values.compressed().size, 2)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats["min"].calculate(self.scatterContext), 6)
+ self.assertEqual(_stats["max"].calculate(self.scatterContext), 7)
+ self.assertEqual(_stats["minCoords"].calculate(self.scatterContext), (2, 3))
+ self.assertEqual(_stats["maxCoords"].calculate(self.scatterContext), (3, 4))
+ self.assertEqual(
+ _stats["std"].calculate(self.scatterContext), numpy.std([6, 7])
+ )
+ self.assertEqual(
+ _stats["mean"].calculate(self.scatterContext), numpy.mean([6, 7])
+ )
+
+ def testBasicHistogram(self):
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(2, 6))
+ self.assertEqual(_stats["min"].calculate(self.histoContext), 2)
+ self.assertEqual(_stats["max"].calculate(self.histoContext), 5)
+ self.assertEqual(_stats["minCoords"].calculate(self.histoContext), (2,))
+ self.assertEqual(_stats["maxCoords"].calculate(self.histoContext), (5,))
+ self.assertEqual(_stats["std"].calculate(self.histoContext), numpy.std(yData))
+ self.assertEqual(_stats["mean"].calculate(self.histoContext), numpy.mean(yData))
+ com = numpy.sum(xData * yData) / numpy.sum(yData)
+ self.assertEqual(_stats["com"].calculate(self.histoContext), com)
+
+
+class TestAdvancedROIImageContext(TestCaseQt):
+ """Test stats result on an image context with different scale and
+ origins"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.data_dims = (100, 100)
+ self.data = numpy.random.rand(*self.data_dims)
+ self.plot = Plot2D()
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def test(self):
+ """Test stats result on an image context with different scale and
+ origins"""
+ roi_origins = [(0, 0), (2, 10), (14, 20)]
+ img_origins = [(0, 0), (14, 20), (2, 10)]
+ img_scales = [1.0, 0.5, 2.0]
+ _stats = {
+ "sum": stats.Stat(name="sum", fct=numpy.sum),
+ }
+ for roi_origin in roi_origins:
+ for img_origin in img_origins:
+ for img_scale in img_scales:
+ with self.subTest(
+ roi_origin=roi_origin,
+ img_origin=img_origin,
+ img_scale=img_scale,
+ ):
+ self.plot.addImage(
+ self.data, legend="img", origin=img_origin, scale=img_scale
+ )
+ roi = RectangleROI()
+ roi.setGeometry(origin=roi_origin, size=(20, 20))
+ context = stats._ImageContext(
+ item=self.plot.getImage("img"),
+ plot=self.plot,
+ onlimits=False,
+ roi=roi,
+ )
+ x_start = int((roi_origin[0] - img_origin[0]) / img_scale)
+ x_end = int(x_start + (20 / img_scale)) + 1
+ y_start = int((roi_origin[1] - img_origin[1]) / img_scale)
+ y_end = int(y_start + (20 / img_scale)) + 1
+ x_start = max(x_start, 0)
+ x_end = min(max(x_end, 0), self.data_dims[1])
+ y_start = max(y_start, 0)
+ y_end = min(max(y_end, 0), self.data_dims[0])
+ th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end])
+ self.assertAlmostEqual(_stats["sum"].calculate(context), th_sum)
diff --git a/src/silx/gui/plot/test/testUtilsAxis.py b/src/silx/gui/plot/test/testUtilsAxis.py
new file mode 100644
index 0000000..d749845
--- /dev/null
+++ b/src/silx/gui/plot/test/testUtilsAxis.py
@@ -0,0 +1,239 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "20/11/2018"
+
+
+from silx.gui.plot import PlotWidget
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.utils.axis import SyncAxes
+
+
+class TestAxisSync(TestCaseQt):
+ """Tests AxisSync class"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot1 = PlotWidget()
+ self.plot2 = PlotWidget()
+ self.plot3 = PlotWidget()
+
+ def tearDown(self):
+ self.plot1 = None
+ self.plot2 = None
+ self.plot3 = None
+ TestCaseQt.tearDown(self)
+
+ def testMoveFirstAxis(self):
+ """Test synchronization after construction"""
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testMoveSecondAxis(self):
+ """Test synchronization after construction"""
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+
+ self.plot2.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testMoveTwoAxes(self):
+ """Test synchronization after construction"""
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+
+ self.plot1.getXAxis().setLimits(1, 50)
+ self.plot2.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testDestruction(self):
+ """Test synchronization when sync object is destroyed"""
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+ del sync
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testAxisDestruction(self):
+ """Test synchronization when an axis disappear"""
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+
+ # Destroy the plot is possible
+ import weakref
+
+ plot = weakref.ref(self.plot2)
+ self.plot2 = None
+ result = self.qWaitForDestroy(plot)
+ if not result:
+ # We can't test
+ self.skipTest("Object not destroyed")
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testStop(self):
+ """Test synchronization after calling stop"""
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+ sync.stop()
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testStopMovingStart(self):
+ """Test synchronization after calling stop, moving an axis, then start again"""
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+ sync.stop()
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.plot2.getXAxis().setLimits(1, 50)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ sync.start()
+
+ # The first axis is the reference
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testDoubleStop(self):
+ """Test double stop"""
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+ sync.stop()
+ self.assertRaises(RuntimeError, sync.stop)
+
+ def testDoubleStart(self):
+ """Test double stop"""
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+ self.assertRaises(RuntimeError, sync.start)
+
+ def testScale(self):
+ """Test scale change"""
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+ self.plot1.getXAxis().setScale(self.plot1.getXAxis().LOGARITHMIC)
+ self.assertEqual(
+ self.plot1.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC
+ )
+ self.assertEqual(
+ self.plot2.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC
+ )
+ self.assertEqual(
+ self.plot3.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC
+ )
+
+ def testDirection(self):
+ """Test direction change"""
+ _sync = SyncAxes(
+ [self.plot1.getYAxis(), self.plot2.getYAxis(), self.plot3.getYAxis()]
+ )
+ self.plot1.getYAxis().setInverted(True)
+ self.assertEqual(self.plot1.getYAxis().isInverted(), True)
+ self.assertEqual(self.plot2.getYAxis().isInverted(), True)
+ self.assertEqual(self.plot3.getYAxis().isInverted(), True)
+
+ def testSyncCenter(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False,
+ syncCenter=True,
+ )
+
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (100 - 1, 100 + 1))
+
+ def testSyncCenterAndZoom(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False,
+ syncCenter=True,
+ syncZoom=True,
+ )
+
+ # Supposing all the plots use the same size
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (0, 200))
+
+ def testAddAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis()])
+ sync.addAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testRemoveAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
+ sync.removeAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
diff --git a/src/silx/gui/plot/test/utils.py b/src/silx/gui/plot/test/utils.py
new file mode 100644
index 0000000..d48a467
--- /dev/null
+++ b/src/silx/gui/plot/test/utils.py
@@ -0,0 +1,92 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/01/2018"
+
+
+import logging
+import pytest
+
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.mark.usefixtures("test_options_class_attr")
+class PlotWidgetTestCase(TestCaseQt):
+ """Base class for tests of PlotWidget, not a TestCase in itself.
+
+ plot attribute is the PlotWidget created for the test.
+ """
+
+ __screenshot_already_taken = False
+ backend = None
+
+ def _createPlot(self):
+ return PlotWidget(backend=self.backend)
+
+ def setUp(self):
+ super(PlotWidgetTestCase, self).setUp()
+ self.plot = self._createPlot()
+ self.plot.show()
+ self.plotAlive = True
+ self.qWaitForWindowExposed(self.plot)
+ TestCaseQt.mouseClick(self, self.plot, button=qt.Qt.LeftButton, pos=(0, 0))
+
+ def __onPlotDestroyed(self):
+ self.plotAlive = False
+
+ def _waitForPlotClosed(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.destroyed.connect(self.__onPlotDestroyed)
+ self.plot.close()
+ del self.plot
+ for _ in range(100):
+ if not self.plotAlive:
+ break
+ self.qWait(10)
+ else:
+ logger.error("Plot is still alive")
+
+ def tearDown(self):
+ if not self._currentTestSucceeded():
+ # MPL is the only widget which uses the real system mouse.
+ # In case of a the windows is outside of the screen, minimzed,
+ # overlapped by a system popup, the MPL widget will not receive the
+ # mouse event.
+ # Taking a screenshot help debuging this cases in the continuous
+ # integration environement.
+ if not PlotWidgetTestCase.__screenshot_already_taken:
+ PlotWidgetTestCase.__screenshot_already_taken = True
+ self.logScreenShot()
+ self.qapp.processEvents()
+ self._waitForPlotClosed()
+ super(PlotWidgetTestCase, self).tearDown()
diff --git a/src/silx/gui/plot/tools/CurveLegendsWidget.py b/src/silx/gui/plot/tools/CurveLegendsWidget.py
new file mode 100644
index 0000000..0ebea0d
--- /dev/null
+++ b/src/silx/gui/plot/tools/CurveLegendsWidget.py
@@ -0,0 +1,245 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget to display :class:`PlotWidget` curve legends.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "20/07/2018"
+
+
+import logging
+import weakref
+
+
+from ... import qt
+from ...widgets.FlowLayout import FlowLayout as _FlowLayout
+from ..LegendSelector import LegendIcon as _LegendIcon
+from .. import items
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _LegendWidget(qt.QWidget):
+ """Widget displaying curve style and its legend
+
+ :param QWidget parent: See :class:`QWidget`
+ :param ~silx.gui.plot.items.Curve curve: Associated curve
+ """
+
+ def __init__(self, parent, curve):
+ super(_LegendWidget, self).__init__(parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(10, 0, 10, 0)
+
+ curve.sigItemChanged.connect(self._curveChanged)
+
+ icon = _LegendIcon(curve=curve)
+ layout.addWidget(icon)
+
+ label = qt.QLabel(curve.getName())
+ label.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
+ layout.addWidget(label)
+
+ self._update()
+
+ def getCurve(self):
+ """Returns curve associated to this widget
+
+ :rtype: Union[~silx.gui.plot.items.Curve,None]
+ """
+ icon = self.findChild(_LegendIcon)
+ return icon.getCurve()
+
+ def _update(self):
+ """Update widget according to current curve state."""
+ curve = self.getCurve()
+ if curve is None:
+ _logger.error("Curve no more exists")
+ self.setVisible(False)
+ return
+
+ self.setEnabled(curve.isVisible())
+
+ label = self.findChild(qt.QLabel)
+ if curve.isHighlighted():
+ label.setStyleSheet("border: 1px solid black")
+ else:
+ label.setStyleSheet("")
+
+ def _curveChanged(self, event):
+ """Handle update of curve item
+
+ :param event: Kind of change
+ """
+ if event in (
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.HIGHLIGHTED,
+ items.ItemChangedType.HIGHLIGHTED_STYLE,
+ ):
+ self._update()
+
+
+class CurveLegendsWidget(qt.QWidget):
+ """Widget displaying curves legends in a plot
+
+ :param QWidget parent: See :class:`QWidget`
+ """
+
+ sigCurveClicked = qt.Signal(object)
+ """Signal emitted when the legend of a curve is clicked
+
+ It provides the corresponding curve.
+ """
+
+ def __init__(self, parent=None):
+ super(CurveLegendsWidget, self).__init__(parent)
+ self._clicked = None
+ self._legends = {}
+ self._plotRef = None
+
+ def layout(self):
+ layout = super(CurveLegendsWidget, self).layout()
+ if layout is None:
+ # Lazy layout initialization to allow overloading
+ layout = _FlowLayout()
+ layout.setHorizontalSpacing(0)
+ self.setLayout(layout)
+ return layout
+
+ def getPlotWidget(self):
+ """Returns the associated :class:`PlotWidget`
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def setPlotWidget(self, plot):
+ """Set the associated :class:`PlotWidget`
+
+ :param ~silx.gui.plot.PlotWidget plot: Plot widget to attach
+ """
+ previousPlot = self.getPlotWidget()
+ if previousPlot is not None:
+ previousPlot.sigItemAdded.disconnect(self._itemAdded)
+ previousPlot.sigItemAboutToBeRemoved.disconnect(self._itemRemoved)
+ for legend in list(self._legends.keys()):
+ self._removeLegend(legend)
+
+ self._plotRef = None if plot is None else weakref.ref(plot)
+
+ if plot is not None:
+ plot.sigItemAdded.connect(self._itemAdded)
+ plot.sigItemAboutToBeRemoved.connect(self._itemRemoved)
+
+ for legend in plot.getAllCurves(just_legend=True):
+ self._addLegend(legend)
+
+ def curveAt(self, *args):
+ """Returns the curve object represented at the given position
+
+ Either takes a QPoint or x and y as input in widget coordinates.
+
+ :rtype: Union[~silx.gui.plot.items.Curve,None]
+ """
+ if len(args) == 1:
+ point = args[0]
+ elif len(args) == 2:
+ point = qt.QPoint(*args)
+ else:
+ raise ValueError("Unsupported arguments")
+ assert isinstance(point, qt.QPoint)
+
+ widget = self.childAt(point)
+ while widget not in (self, None):
+ if isinstance(widget, _LegendWidget):
+ return widget.getCurve()
+ widget = widget.parent()
+ return None # No widget or not in _LegendWidget
+
+ def _itemAdded(self, item):
+ """Handle item added to the plot content"""
+ if isinstance(item, items.Curve):
+ self._addLegend(item.getName())
+
+ def _itemRemoved(self, item):
+ """Handle item removed from the plot content"""
+ if isinstance(item, items.Curve):
+ self._removeLegend(item.getName())
+
+ def _addLegend(self, legend):
+ """Add a curve to the legends
+
+ :param str legend: Curve's legend
+ """
+ if legend in self._legends:
+ return # Can happen when changing curve's y axis
+
+ plot = self.getPlotWidget()
+ if plot is None:
+ return None
+
+ curve = plot.getCurve(legend)
+ if curve is None:
+ _logger.error("Curve not found: %s" % legend)
+ return
+
+ widget = _LegendWidget(parent=self, curve=curve)
+ self.layout().addWidget(widget)
+ self._legends[legend] = widget
+
+ def _removeLegend(self, legend):
+ """Remove a curve from the legends if it exists
+
+ :param str legend: The curve's legend
+ """
+ widget = self._legends.pop(legend, None)
+ if widget is None:
+ _logger.warning("Unknown legend: %s" % legend)
+ else:
+ self.layout().removeWidget(widget)
+ widget.setParent(None)
+
+ def mousePressEvent(self, event):
+ if event.button() == qt.Qt.LeftButton:
+ self._clicked = event.pos()
+
+ _CLICK_THRESHOLD = 5
+ """Threshold for clicks"""
+
+ def mouseMoveEvent(self, event):
+ if self._clicked is not None:
+ dx = abs(self._clicked.x() - event.pos().x())
+ dy = abs(self._clicked.y() - event.pos().y())
+ if dx > self._CLICK_THRESHOLD or dy > self._CLICK_THRESHOLD:
+ self._clicked = None # Click is cancelled
+
+ def mouseReleaseEvent(self, event):
+ if event.button() == qt.Qt.LeftButton and self._clicked is not None:
+ curve = self.curveAt(event.pos())
+ if curve is not None:
+ self.sigCurveClicked.emit(curve)
+
+ self._clicked = None
diff --git a/src/silx/gui/plot/tools/LimitsToolBar.py b/src/silx/gui/plot/tools/LimitsToolBar.py
new file mode 100644
index 0000000..5ed09f7
--- /dev/null
+++ b/src/silx/gui/plot/tools/LimitsToolBar.py
@@ -0,0 +1,123 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A toolbar to display and edit limits of a PlotWidget
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/10/2017"
+
+
+from ... import qt
+from ...widgets.FloatEdit import FloatEdit
+
+
+class LimitsToolBar(qt.QToolBar):
+ """QToolBar displaying and controlling the limits of a :class:`PlotWidget`.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow:
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> plot = PlotWindow() # Create a PlotWindow to add the toolbar to
+
+ Then, create the LimitsToolBar and add it to the PlotWindow.
+
+ >>> from silx.gui import qt
+ >>> from silx.gui.plot.tools import LimitsToolBar
+
+ >>> toolbar = LimitsToolBar(plot=plot) # Create the toolbar
+ >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolbar) # Add it to the plot
+ >>> plot.show() # To display the PlotWindow with the limits toolbar
+
+ :param parent: See :class:`QToolBar`.
+ :param plot: :class:`PlotWidget` instance on which to operate.
+ :param str title: See :class:`QToolBar`.
+ """
+
+ def __init__(self, parent=None, plot=None, title="Limits"):
+ super(LimitsToolBar, self).__init__(title, parent)
+ assert plot is not None
+ self._plot = plot
+ self._plot.sigPlotSignal.connect(self._plotWidgetSlot)
+
+ self._initWidgets()
+
+ @property
+ def plot(self):
+ """The :class:`PlotWidget` the toolbar is attached to."""
+ return self._plot
+
+ def _initWidgets(self):
+ """Create and init Toolbar widgets."""
+ xMin, xMax = self.plot.getXAxis().getLimits()
+ yMin, yMax = self.plot.getYAxis().getLimits()
+
+ self.addWidget(qt.QLabel("Limits: "))
+ self.addWidget(qt.QLabel(" X: "))
+ self._xMinFloatEdit = FloatEdit(self, xMin)
+ self._xMinFloatEdit.editingFinished[()].connect(self._xFloatEditChanged)
+ self.addWidget(self._xMinFloatEdit)
+
+ self._xMaxFloatEdit = FloatEdit(self, xMax)
+ self._xMaxFloatEdit.editingFinished[()].connect(self._xFloatEditChanged)
+ self.addWidget(self._xMaxFloatEdit)
+
+ self.addWidget(qt.QLabel(" Y: "))
+ self._yMinFloatEdit = FloatEdit(self, yMin)
+ self._yMinFloatEdit.editingFinished[()].connect(self._yFloatEditChanged)
+ self.addWidget(self._yMinFloatEdit)
+
+ self._yMaxFloatEdit = FloatEdit(self, yMax)
+ self._yMaxFloatEdit.editingFinished[()].connect(self._yFloatEditChanged)
+ self.addWidget(self._yMaxFloatEdit)
+
+ def _plotWidgetSlot(self, event):
+ """Listen to :class:`PlotWidget` events."""
+ if event["event"] not in ("limitsChanged",):
+ return
+
+ xMin, xMax = self.plot.getXAxis().getLimits()
+ yMin, yMax = self.plot.getYAxis().getLimits()
+
+ self._xMinFloatEdit.setValue(xMin)
+ self._xMaxFloatEdit.setValue(xMax)
+ self._yMinFloatEdit.setValue(yMin)
+ self._yMaxFloatEdit.setValue(yMax)
+
+ def _xFloatEditChanged(self):
+ """Handle X limits changed from the GUI."""
+ xMin, xMax = self._xMinFloatEdit.value(), self._xMaxFloatEdit.value()
+ if xMax < xMin:
+ xMin, xMax = xMax, xMin
+
+ self.plot.getXAxis().setLimits(xMin, xMax)
+
+ def _yFloatEditChanged(self):
+ """Handle Y limits changed from the GUI."""
+ yMin, yMax = self._yMinFloatEdit.value(), self._yMaxFloatEdit.value()
+ if yMax < yMin:
+ yMin, yMax = yMax, yMin
+
+ self.plot.getYAxis().setLimits(yMin, yMax)
diff --git a/src/silx/gui/plot/tools/PlotToolButton.py b/src/silx/gui/plot/tools/PlotToolButton.py
new file mode 100644
index 0000000..3a14f77
--- /dev/null
+++ b/src/silx/gui/plot/tools/PlotToolButton.py
@@ -0,0 +1,92 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides an abstract PlotToolButton that can be use to create
+plot tools for a toolbar.
+"""
+
+from __future__ import annotations
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "20/12/2023"
+
+
+import logging
+import weakref
+
+from silx.gui import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotToolButton(qt.QToolButton):
+ """A QToolButton connected to a :class:`~silx.gui.plot.PlotWidget`."""
+
+ def __init__(self, parent: qt.QWidget | None = None, plot=None):
+ super(PlotToolButton, self).__init__(parent)
+ self._plotRef = None
+ if plot is not None:
+ self.setPlot(plot)
+
+ def plot(self):
+ """
+ Returns the plot connected to the widget.
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def setPlot(self, plot):
+ """
+ Set the plot connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ """
+ previousPlot = self.plot()
+
+ if previousPlot is plot:
+ return
+ if previousPlot is not None:
+ self._disconnectPlot(previousPlot)
+
+ if plot is None:
+ self._plotRef = None
+ else:
+ self._plotRef = weakref.ref(plot)
+ self._connectPlot(plot)
+
+ def _connectPlot(self, plot):
+ """
+ Called when the plot is connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ pass
+
+ def _disconnectPlot(self, plot):
+ """
+ Called when the plot is disconnected from the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ pass
diff --git a/src/silx/gui/plot/tools/PositionInfo.py b/src/silx/gui/plot/tools/PositionInfo.py
new file mode 100644
index 0000000..e3b8425
--- /dev/null
+++ b/src/silx/gui/plot/tools/PositionInfo.py
@@ -0,0 +1,357 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget displaying mouse coordinates in a PlotWidget.
+
+It can be configured to provide more information.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/10/2017"
+
+
+import logging
+import numbers
+import traceback
+import weakref
+
+import numpy
+
+from ... import qt
+from .. import items
+from ...widgets.ElidedLabel import ElidedLabel
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _PositionInfoLabel(ElidedLabel):
+ """QLabel with a default size larger than what is displayed."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ def sizeHint(self):
+ hint = super().sizeHint()
+ width = self.fontMetrics().boundingRect("##############").width()
+ return qt.QSize(max(hint.width(), width), hint.height())
+
+
+# PositionInfo ################################################################
+
+
+class PositionInfo(qt.QWidget):
+ """QWidget displaying coords converted from data coords of the mouse.
+
+ Provide this widget with a list of couple:
+
+ - A name to display before the data
+ - A function that takes (x, y) as arguments and returns something that
+ gets converted to a string.
+ If the result is a float it is converted with '%.7g' format.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow and add a QToolBar where to place the
+ PositionInfo widget.
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> from silx.gui import qt
+
+ >>> plot = PlotWindow() # Create a PlotWindow to add the widget to
+ >>> toolBar = qt.QToolBar() # Create a toolbar to place the widget in
+ >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) # Add it to plot
+
+ Then, create the PositionInfo widget and add it to the toolbar.
+ The PositionInfo widget is created with a list of converters, here
+ to display polar coordinates of the mouse position.
+
+ >>> import numpy
+ >>> from silx.gui.plot.tools import PositionInfo
+
+ >>> position = PositionInfo(plot=plot, converters=[
+ ... ('Radius', lambda x, y: numpy.sqrt(x*x + y*y)),
+ ... ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))])
+ >>> toolBar.addWidget(position) # Add the widget to the toolbar
+ <...>
+ >>> plot.show() # To display the PlotWindow with the position widget
+
+ :param plot: The PlotWidget this widget is displaying data coords from.
+ :param converters:
+ List of 2-tuple: name to display and conversion function from (x, y)
+ in data coords to displayed value.
+ If None, the default, it displays X and Y.
+ :param parent: Parent widget
+ """
+
+ SNAP_THRESHOLD_DIST = 5
+
+ def __init__(self, parent=None, plot=None, converters=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._snappingMode = self.SNAPPING_DISABLED
+
+ super(PositionInfo, self).__init__(parent)
+
+ if converters is None:
+ converters = (("X", lambda x, y: x), ("Y", lambda x, y: y))
+
+ self._fields = [] # To store (QLineEdit, name, function (x, y)->v)
+
+ # Create a new layout with new widgets
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ # layout.setSpacing(0)
+
+ # Create all QLabel and store them with the corresponding converter
+ for name, func in converters:
+ layout.addWidget(qt.QLabel("<b>" + name + ":</b>"))
+
+ contentWidget = _PositionInfoLabel(self)
+ contentWidget.setText("------")
+ layout.addWidget(contentWidget)
+ self._fields.append((contentWidget, name, func))
+
+ layout.addStretch(1)
+ self.setLayout(layout)
+
+ # Connect to Plot events
+ plot.sigPlotSignal.connect(self._plotEvent)
+
+ def getPlotWidget(self):
+ """Returns the PlotWidget this widget is attached to or None.
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return self._plotRef()
+
+ def getConverters(self):
+ """Return the list of converters as 2-tuple (name, function)."""
+ return [(name, func) for _label, name, func in self._fields]
+
+ def _plotEvent(self, event):
+ """Handle events from the Plot.
+
+ :param dict event: Plot event
+ """
+ if event["event"] == "mouseMoved":
+ x, y = event["x"], event["y"]
+ xPixel, yPixel = event["xpixel"], event["ypixel"]
+ self._updateStatusBar(x, y, xPixel, yPixel)
+
+ def updateInfo(self):
+ """Update displayed information"""
+ plot = self.getPlotWidget()
+ if plot is None:
+ _logger.error(
+ "Trying to update PositionInfo " "while PlotWidget no longer exists"
+ )
+ return
+
+ widget = plot.getWidgetHandle()
+ position = widget.mapFromGlobal(qt.QCursor.pos())
+ xPixel, yPixel = position.x(), position.y()
+ dataPos = plot.pixelToData(xPixel, yPixel, check=True)
+ if dataPos is not None: # Inside plot area
+ x, y = dataPos
+ self._updateStatusBar(x, y, xPixel, yPixel)
+
+ def _updateStatusBar(self, x, y, xPixel, yPixel):
+ """Update information from the status bar using the definitions.
+
+ :param float x: Position-x in data
+ :param float y: Position-y in data
+ :param float xPixel: Position-x in pixels
+ :param float yPixel: Position-y in pixels
+ """
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ styleSheet = "" # Default style
+ xData, yData = x, y
+
+ snappingMode = self.getSnappingMode()
+
+ # Snapping when crosshair either not requested or active
+ if snappingMode & (self.SNAPPING_CURVE | self.SNAPPING_SCATTER) and (
+ not (snappingMode & self.SNAPPING_CROSSHAIR) or plot.getGraphCursor()
+ ):
+ styleSheet = "color: rgb(255, 0, 0);" # Style far from item
+
+ if snappingMode & self.SNAPPING_ACTIVE_ONLY:
+ selectedItems = []
+
+ if snappingMode & self.SNAPPING_CURVE:
+ activeCurve = plot.getActiveCurve()
+ if activeCurve:
+ selectedItems.append(activeCurve)
+
+ if snappingMode & self.SNAPPING_SCATTER:
+ activeScatter = plot.getActiveScatter()
+ if activeScatter:
+ selectedItems.append(activeScatter)
+
+ else:
+ kinds = []
+ if snappingMode & self.SNAPPING_CURVE:
+ kinds.append(items.Curve)
+ kinds.append(items.Histogram)
+ if snappingMode & self.SNAPPING_SCATTER:
+ kinds.append(items.Scatter)
+ selectedItems = [
+ item
+ for item in plot.getItems()
+ if isinstance(item, tuple(kinds)) and item.isVisible()
+ ]
+
+ # Compute distance threshold
+ window = plot.window()
+ windowHandle = window.windowHandle()
+ if windowHandle is not None:
+ ratio = windowHandle.devicePixelRatio()
+ else:
+ ratio = qt.QGuiApplication.primaryScreen().devicePixelRatio()
+
+ # Baseline squared distance threshold
+ sqDistInPixels = (self.SNAP_THRESHOLD_DIST * ratio) ** 2
+
+ for item in selectedItems:
+ if snappingMode & self.SNAPPING_SYMBOLS_ONLY and (
+ not isinstance(item, items.SymbolMixIn) or not item.getSymbol()
+ ):
+ # Only handled if item symbols are visible
+ continue
+
+ if isinstance(item, items.Histogram):
+ result = item.pick(xPixel, yPixel)
+ if result is not None: # Histogram picked
+ index = result.getIndices()[0]
+ edges = item.getBinEdgesData(copy=False)
+
+ # Snap to bin center and value
+ xData = 0.5 * (edges[index] + edges[index + 1])
+ yData = item.getValueData(copy=False)[index]
+
+ # Update label style sheet
+ styleSheet = ""
+ break
+
+ else: # Curve, Scatter
+ result = item.pick(xPixel, yPixel)
+ if result is None:
+ continue
+ indices = result.getIndices(copy=False)
+ if indices is None:
+ continue
+
+ if isinstance(item, items.YAxisMixIn):
+ axis = item.getYAxis()
+ else:
+ axis = "left"
+
+ xArray = item.getXData(copy=False)[indices]
+ yArray = item.getYData(copy=False)[indices]
+ pixelPositions = plot.dataToPixel(xArray, yArray, axis=axis)
+ if pixelPositions is None:
+ continue
+ sqDistances = (pixelPositions[0] - xPixel) ** 2 + (
+ pixelPositions[1] - yPixel
+ ) ** 2
+ if not numpy.any(numpy.isfinite(sqDistances)):
+ continue
+ closestIndex = numpy.nanargmin(sqDistances)
+ closestSqDistInPixels = sqDistances[closestIndex]
+
+ if closestSqDistInPixels <= sqDistInPixels:
+ # Update label style sheet
+ styleSheet = ""
+
+ # if close enough, snap to data point coord
+ xData, yData = xArray[closestIndex], yArray[closestIndex]
+ sqDistInPixels = closestSqDistInPixels
+
+ for label, name, func in self._fields:
+ label.setStyleSheet(styleSheet)
+
+ try:
+ value = func(xData, yData)
+ text = self.valueToString(value)
+ label.setText(text)
+ except:
+ label.setText("Error")
+ _logger.error(
+ "Error while converting coordinates (%f, %f)"
+ "with converter '%s'" % (xPixel, yPixel, name)
+ )
+ _logger.error(traceback.format_exc())
+
+ def valueToString(self, value):
+ if isinstance(value, (tuple, list)):
+ value = [self.valueToString(v) for v in value]
+ return ", ".join(value)
+ elif isinstance(value, numbers.Real):
+ # Use this for floats and int
+ return "%.7g" % value
+ else:
+ # Fallback for other types
+ return str(value)
+
+ # Snapping mode
+
+ SNAPPING_DISABLED = 0
+ """No snapping occurs"""
+
+ SNAPPING_CROSSHAIR = 1 << 0
+ """Snapping only enabled when crosshair cursor is enabled"""
+
+ SNAPPING_ACTIVE_ONLY = 1 << 1
+ """Snapping only enabled for active item"""
+
+ SNAPPING_SYMBOLS_ONLY = 1 << 2
+ """Snapping only when symbols are visible"""
+
+ SNAPPING_CURVE = 1 << 3
+ """Snapping on curves"""
+
+ SNAPPING_SCATTER = 1 << 4
+ """Snapping on scatter"""
+
+ def setSnappingMode(self, mode):
+ """Set the snapping mode.
+
+ The mode is a mask.
+
+ :param int mode: The mode to use
+ """
+ if mode != self._snappingMode:
+ self._snappingMode = mode
+ self.updateInfo()
+
+ def getSnappingMode(self):
+ """Returns the snapping mode as a mask
+
+ :rtype: int
+ """
+ return self._snappingMode
diff --git a/src/silx/gui/plot/tools/RadarView.py b/src/silx/gui/plot/tools/RadarView.py
new file mode 100644
index 0000000..8ddb98b
--- /dev/null
+++ b/src/silx/gui/plot/tools/RadarView.py
@@ -0,0 +1,359 @@
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying an overview of a 2D plot.
+
+This shows the available range of the data, and the current location of the
+plot view.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/02/2021"
+
+import logging
+import weakref
+from ... import qt
+from ...utils import LockReentrant
+
+_logger = logging.getLogger(__name__)
+
+
+class _DraggableRectItem(qt.QGraphicsRectItem):
+ """RectItem which signals its change through visibleRectDragged."""
+
+ def __init__(self, *args, **kwargs):
+ super(_DraggableRectItem, self).__init__(*args, **kwargs)
+
+ self._previousCursor = None
+ self.setFlag(qt.QGraphicsItem.ItemIsMovable)
+ self.setFlag(qt.QGraphicsItem.ItemSendsGeometryChanges)
+ self.setAcceptHoverEvents(True)
+ self._ignoreChange = False
+ self._constraint = 0, 0, 0, 0
+
+ def setConstraintRect(self, left, top, width, height):
+ """Set the constraint rectangle for dragging.
+
+ The coordinates are in the _DraggableRectItem coordinate system.
+
+ This constraint only applies to modification through interaction
+ (i.e., this constraint is not applied to change through API).
+
+ If the _DraggableRectItem is smaller than the constraint rectangle,
+ the _DraggableRectItem remains within the constraint rectangle.
+ If the _DraggableRectItem is wider than the constraint rectangle,
+ the constraint rectangle remains within the _DraggableRectItem.
+ """
+ self._constraint = left, left + width, top, top + height
+
+ def setPos(self, *args, **kwargs):
+ """Overridden to ignore changes from API in itemChange."""
+ self._ignoreChange = True
+ super(_DraggableRectItem, self).setPos(*args, **kwargs)
+ self._ignoreChange = False
+
+ def moveBy(self, *args, **kwargs):
+ """Overridden to ignore changes from API in itemChange."""
+ self._ignoreChange = True
+ super(_DraggableRectItem, self).moveBy(*args, **kwargs)
+ self._ignoreChange = False
+
+ def itemChange(self, change, value):
+ """Callback called before applying changes to the item."""
+ if change == qt.QGraphicsItem.ItemPositionChange and not self._ignoreChange:
+ # Makes sure that the visible area is in the data
+ # or that data is in the visible area if area is too wide
+ x, y = value.x(), value.y()
+ xMin, xMax, yMin, yMax = self._constraint
+
+ if self.rect().width() <= (xMax - xMin):
+ if x < xMin:
+ value.setX(xMin)
+ elif x > xMax - self.rect().width():
+ value.setX(xMax - self.rect().width())
+ else:
+ if x > xMin:
+ value.setX(xMin)
+ elif x < xMax - self.rect().width():
+ value.setX(xMax - self.rect().width())
+
+ if self.rect().height() <= (yMax - yMin):
+ if y < yMin:
+ value.setY(yMin)
+ elif y > yMax - self.rect().height():
+ value.setY(yMax - self.rect().height())
+ else:
+ if y > yMin:
+ value.setY(yMin)
+ elif y < yMax - self.rect().height():
+ value.setY(yMax - self.rect().height())
+
+ if self.pos() != value:
+ # Notify change through signal
+ views = self.scene().views()
+ assert len(views) == 1
+ views[0].visibleRectDragged.emit(
+ value.x() + self.rect().left(),
+ value.y() + self.rect().top(),
+ self.rect().width(),
+ self.rect().height(),
+ )
+
+ return value
+
+ return super(_DraggableRectItem, self).itemChange(change, value)
+
+ def hoverEnterEvent(self, event):
+ """Called when the mouse enters the rectangle area"""
+ self._previousCursor = self.cursor()
+ self.setCursor(qt.Qt.OpenHandCursor)
+
+ def hoverLeaveEvent(self, event):
+ """Called when the mouse leaves the rectangle area"""
+ if self._previousCursor is not None:
+ self.setCursor(self._previousCursor)
+ self._previousCursor = None
+
+
+class RadarView(qt.QGraphicsView):
+ """Widget presenting a synthetic view of a 2D area and
+ the current visible area.
+
+ Coordinates are as in QGraphicsView:
+ x goes from left to right and y goes from top to bottom.
+ This widget preserves the aspect ratio of the areas.
+
+ The 2D area and the visible area can be set with :meth:`setDataRect`
+ and :meth:`setVisibleRect`.
+ When the visible area has been dragged by the user, its new position
+ is signaled by the *visibleRectDragged* signal.
+
+ It is possible to invert the direction of the axes by using the
+ :meth:`scale` method of QGraphicsView.
+ """
+
+ visibleRectDragged = qt.Signal(float, float, float, float)
+ """Signals that the visible rectangle has been dragged.
+
+ It provides: left, top, width, height in data coordinates.
+ """
+
+ _DATA_PEN = qt.QPen(qt.QColor("white"))
+ _DATA_BRUSH = qt.QBrush(qt.QColor("light gray"))
+ _ACTIVEDATA_PEN = qt.QPen(qt.QColor("black"))
+ _ACTIVEDATA_BRUSH = qt.QBrush(qt.QColor("transparent"))
+ _ACTIVEDATA_PEN.setWidth(2)
+ _ACTIVEDATA_PEN.setCosmetic(True)
+ _VISIBLE_PEN = qt.QPen(qt.QColor("blue"))
+ _VISIBLE_PEN.setWidth(2)
+ _VISIBLE_PEN.setCosmetic(True)
+ _VISIBLE_BRUSH = qt.QBrush(qt.QColor(0, 0, 0, 0))
+ _TOOLTIP = "Radar View:\nRed contour: Visible area\nGray area: The image"
+
+ _PIXMAP_SIZE = 256
+
+ def __init__(self, parent=None):
+ self.__plotRef = None
+ self._scene = qt.QGraphicsScene()
+ self._dataRect = self._scene.addRect(
+ 0, 0, 1, 1, self._DATA_PEN, self._DATA_BRUSH
+ )
+ self._imageRect = self._scene.addRect(
+ 0, 0, 1, 1, self._ACTIVEDATA_PEN, self._ACTIVEDATA_BRUSH
+ )
+ self._imageRect.setVisible(False)
+ self._scatterRect = self._scene.addRect(
+ 0, 0, 1, 1, self._ACTIVEDATA_PEN, self._ACTIVEDATA_BRUSH
+ )
+ self._scatterRect.setVisible(False)
+ self._curveRect = self._scene.addRect(
+ 0, 0, 1, 1, self._ACTIVEDATA_PEN, self._ACTIVEDATA_BRUSH
+ )
+ self._curveRect.setVisible(False)
+
+ self._visibleRect = _DraggableRectItem(0, 0, 1, 1)
+ self._visibleRect.setPen(self._VISIBLE_PEN)
+ self._visibleRect.setBrush(self._VISIBLE_BRUSH)
+ self._scene.addItem(self._visibleRect)
+
+ super(RadarView, self).__init__(self._scene, parent)
+ self.setHorizontalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
+ self.setVerticalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
+ self.setFocusPolicy(qt.Qt.NoFocus)
+ self.setStyleSheet("border: 0px")
+ self.setToolTip(self._TOOLTIP)
+
+ self.__reentrant = LockReentrant()
+ self.visibleRectDragged.connect(self._viewRectDragged)
+
+ self.__timer = qt.QTimer(self)
+ self.__timer.timeout.connect(self._updateDataContent)
+
+ def sizeHint(self):
+ # """Overridden to avoid sizeHint to depend on content size."""
+ return self.minimumSizeHint()
+
+ def wheelEvent(self, event):
+ # """Overridden to disable vertical scrolling with wheel."""
+ event.ignore()
+
+ def resizeEvent(self, event):
+ # """Overridden to fit current content to new size."""
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+ super(RadarView, self).resizeEvent(event)
+
+ def setDataRect(self, left, top, width, height):
+ """Set the bounds of the data rectangular area.
+
+ This sets the coordinate system.
+ """
+ self._dataRect.setRect(left, top, width, height)
+ self._visibleRect.setConstraintRect(left, top, width, height)
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+
+ def setVisibleRect(self, left, top, width, height):
+ """Set the visible rectangular area.
+
+ The coordinates are relative to the data rect.
+ """
+ self.__visibleRect = left, top, width, height
+ self._visibleRect.setRect(0, 0, width, height)
+ self._visibleRect.setPos(left, top)
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+
+ def __setVisibleRectFromPlot(self, plot):
+ """Update radar view visible area.
+
+ Takes care of y coordinate conversion.
+ """
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ self.setVisibleRect(xMin, yMin, xMax - xMin, yMax - yMin)
+
+ def getPlotWidget(self):
+ """Returns the connected plot
+
+ :rtype: Union[None,PlotWidget]
+ """
+ if self.__plotRef is None:
+ return None
+ plot = self.__plotRef()
+ if plot is None:
+ self.__plotRef = None
+ return plot
+
+ def setPlotWidget(self, plot):
+ """Set the PlotWidget this radar view connects to.
+
+ As result `setDataRect` and `setVisibleRect` will be called
+ automatically.
+
+ :param Union[None,PlotWidget] plot:
+ """
+ previousPlot = self.getPlotWidget()
+ if previousPlot is not None: # Disconnect previous plot
+ plot.getXAxis().sigLimitsChanged.disconnect(self._xLimitChanged)
+ plot.getYAxis().sigLimitsChanged.disconnect(self._yLimitChanged)
+ plot.getYAxis().sigInvertedChanged.disconnect(self._updateYAxisInverted)
+
+ # Reset plot and timer
+ # FIXME: It would be good to clean up the display here
+ self.__plotRef = None
+ self.__timer.stop()
+
+ if plot is not None: # Connect new plot
+ self.__plotRef = weakref.ref(plot)
+ plot.getXAxis().sigLimitsChanged.connect(self._xLimitChanged)
+ plot.getYAxis().sigLimitsChanged.connect(self._yLimitChanged)
+ plot.getYAxis().sigInvertedChanged.connect(self._updateYAxisInverted)
+ self.__setVisibleRectFromPlot(plot)
+ self._updateYAxisInverted()
+ self.__timer.start(500)
+
+ def _xLimitChanged(self, vmin, vmax):
+ plot = self.getPlotWidget()
+ self.__setVisibleRectFromPlot(plot)
+
+ def _yLimitChanged(self, vmin, vmax):
+ plot = self.getPlotWidget()
+ self.__setVisibleRectFromPlot(plot)
+
+ def _updateYAxisInverted(self, inverted=None):
+ """Sync radar view axis orientation."""
+ plot = self.getPlotWidget()
+ if inverted is None:
+ # Do not perform this when called from plot signal
+ inverted = plot.getYAxis().isInverted()
+ # Use scale to invert radarView
+ # RadarView default Y direction is from top to bottom
+ # As opposed to Plot. So invert RadarView when Plot is NOT inverted.
+ self.resetTransform()
+ if not inverted:
+ self.scale(1.0, -1.0)
+ self.update()
+
+ def _viewRectDragged(self, left, top, width, height):
+ """Slot for radar view visible rectangle changes."""
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ if self.__reentrant.locked():
+ return
+
+ with self.__reentrant:
+ plot.setLimits(left, left + width, top, top + height)
+
+ def _updateDataContent(self):
+ """Update the content to the current data content"""
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+ ranges = plot.getDataRange()
+ xmin, xmax = ranges.x if ranges.x is not None else (0, 0)
+ ymin, ymax = ranges.y if ranges.y is not None else (0, 0)
+ self.setDataRect(xmin, ymin, xmax - xmin, ymax - ymin)
+
+ self.__updateItem(self._imageRect, plot.getActiveImage())
+ self.__updateItem(self._scatterRect, plot.getActiveScatter())
+ self.__updateItem(self._curveRect, plot.getActiveCurve())
+
+ def __updateItem(self, rect, item):
+ """Sync rect with item bounds
+
+ :param QGraphicsRectItem rect:
+ :param Item item:
+ """
+ if item is None:
+ rect.setVisible(False)
+ return
+ ranges = item._getBounds()
+ if ranges is None:
+ rect.setVisible(False)
+ return
+ xmin, xmax, ymin, ymax = ranges
+ width = xmax - xmin
+ height = ymax - ymin
+ rect.setRect(xmin, ymin, width, height)
+ rect.setVisible(True)
diff --git a/src/silx/gui/plot/tools/RulerToolButton.py b/src/silx/gui/plot/tools/RulerToolButton.py
new file mode 100644
index 0000000..55cc02f
--- /dev/null
+++ b/src/silx/gui/plot/tools/RulerToolButton.py
@@ -0,0 +1,183 @@
+# /*##########################################################################
+#
+# Copyright (c) 20023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+PlotToolButton to measure a distance in a plot
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "30/10/2023"
+
+
+import logging
+import numpy
+import weakref
+import typing
+
+from silx.gui import icons
+
+from .PlotToolButton import PlotToolButton
+
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.items.roi import LineROI
+from silx.gui.plot import items
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _RulerROI(LineROI):
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self._formatFunction: typing.Optional[
+ typing.Callable[
+ [numpy.ndarray, numpy.ndarray], str
+ ]
+ ] = None
+ self.setColor("#001122") # Only there to trig updateStyle
+
+ def registerFormatFunction(
+ self,
+ fct: typing.Callable[
+ [numpy.ndarray, numpy.ndarray], str
+ ],
+ ):
+ """Register a function for the formatting of the label"""
+ self._formatFunction = fct
+
+ def _updatedStyle(self, event, style: items.CurveStyle):
+ style = items.CurveStyle(
+ color="red",
+ gapcolor="white",
+ linestyle=(0, (5, 5)),
+ linewidth=style.getLineWidth())
+ LineROI._updatedStyle(self, event, style)
+ self._handleLabel.setColor("black")
+ self._handleLabel.setBackgroundColor("#FFFFFF60")
+ self._handleLabel.setZValue(1000)
+
+ def setEndPoints(self, startPoint: numpy.ndarray, endPoint: numpy.ndarray):
+ super().setEndPoints(startPoint=startPoint, endPoint=endPoint)
+ if self._formatFunction is not None:
+ ruler_text = self._formatFunction(
+ startPoint=startPoint, endPoint=endPoint
+ )
+ self._updateText(ruler_text)
+
+
+class RulerToolButton(PlotToolButton):
+ """
+ Button to active measurement between two point of the plot
+
+ An instance of `RulerToolButton` can be added to a plot toolbar like:
+ .. code-block:: python
+
+ plot = Plot2D()
+
+ rulerButton = RulerToolButton(parent=plot, plot=plot)
+ plot.toolBar().addWidget(rulerButton)
+ """
+
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ ):
+ super().__init__(parent=parent, plot=plot)
+ self.setCheckable(True)
+ self._roiManager = None
+ self.__lastRoiCreated = None
+ self.setIcon(icons.getQIcon("ruler"))
+ self.toggled.connect(self._callback)
+ self._connectPlot(plot)
+
+ def setPlot(self, plot):
+ return super().setPlot(plot)
+
+ @property
+ def _lastRoiCreated(self):
+ if self.__lastRoiCreated is None:
+ return None
+ return self.__lastRoiCreated()
+
+ def _callback(self, *args, **kwargs):
+ if not self._roiManager:
+ return
+ if self._lastRoiCreated is not None:
+ self._lastRoiCreated.setVisible(self.isChecked())
+ if self.isChecked():
+ self._roiManager.start(_RulerROI, self)
+ self.__interactiveModeStarted(self._roiManager)
+ else:
+ source = self._roiManager.getInteractionSource()
+ if source is self:
+ self._roiManager.stop()
+
+ def __interactiveModeStarted(self, roiManager):
+ roiManager.sigInteractiveModeFinished.connect(self.__interactiveModeFinished)
+
+ def __interactiveModeFinished(self):
+ roiManager = self._roiManager
+ if roiManager is not None:
+ roiManager.sigInteractiveModeFinished.disconnect(
+ self.__interactiveModeFinished
+ )
+ self.setChecked(False)
+
+ def _connectPlot(self, plot):
+ """
+ Called when the plot is connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ if plot is None:
+ return
+ self._roiManager = RegionOfInterestManager(plot)
+ self._roiManager.sigRoiAdded.connect(self._registerCurrentROI)
+
+ def _disconnectPlot(self, plot):
+ if plot and self._lastRoiCreated is not None:
+ self._roiManager.removeRoi(self._lastRoiCreated)
+ self.__lastRoiCreated = None
+ return super()._disconnectPlot(plot)
+
+ def _registerCurrentROI(self, currentRoi):
+ if self._lastRoiCreated is None:
+ self.__lastRoiCreated = weakref.ref(currentRoi)
+ self._lastRoiCreated.registerFormatFunction(self.buildDistanceText)
+ elif currentRoi is not self._lastRoiCreated and self._roiManager is not None:
+ self._roiManager.removeRoi(self._lastRoiCreated)
+ currentRoi.registerFormatFunction(self.buildDistanceText)
+ self.__lastRoiCreated = weakref.ref(currentRoi)
+
+ def buildDistanceText(self, startPoint: numpy.ndarray, endPoint: numpy.ndarray) -> str:
+ """
+ Define the text to be displayed by the ruler.
+
+ It can be redefine to modify precision or handle other parameters
+ (handling pixel size to display metric distance, display distance
+ on each distance - for non-square pixels...)
+ """
+ distance = numpy.linalg.norm(endPoint - startPoint)
+ return f"{distance: .1f}px"
diff --git a/src/silx/gui/plot/tools/__init__.py b/src/silx/gui/plot/tools/__init__.py
new file mode 100644
index 0000000..5b6b74c
--- /dev/null
+++ b/src/silx/gui/plot/tools/__init__.py
@@ -0,0 +1,49 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of widgets working with :class:`PlotWidget`.
+
+It provides some QToolBar and QWidget:
+
+- :class:`InteractiveModeToolBar`
+- :class:`OutputToolBar`
+- :class:`ImageToolBar`
+- :class:`CurveToolBar`
+- :class:`LimitsToolBar`
+- :class:`PositionInfo`
+
+It also provides a :mod:`~silx.gui.plot.tools.roi` module to handle
+interactive region of interest on a :class:`~silx.gui.plot.PlotWidget`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/03/2018"
+
+
+from .toolbars import InteractiveModeToolBar # noqa
+from .toolbars import OutputToolBar # noqa
+from .toolbars import ImageToolBar, CurveToolBar, ScatterToolBar # noqa
+
+from .LimitsToolBar import LimitsToolBar # noqa
+from .PositionInfo import PositionInfo # noqa
diff --git a/src/silx/gui/plot/tools/compare/__init__.py b/src/silx/gui/plot/tools/compare/__init__.py
new file mode 100644
index 0000000..7f23852
--- /dev/null
+++ b/src/silx/gui/plot/tools/compare/__init__.py
@@ -0,0 +1,29 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tools related to the compare image plot.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "09/06/2023"
diff --git a/src/silx/gui/plot/tools/compare/core.py b/src/silx/gui/plot/tools/compare/core.py
new file mode 100644
index 0000000..90dbb79
--- /dev/null
+++ b/src/silx/gui/plot/tools/compare/core.py
@@ -0,0 +1,198 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides main objects shared by the compare image plot.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "09/06/2023"
+
+
+import numpy
+import enum
+import contextlib
+from typing import NamedTuple
+
+from silx.gui.plot.items.image import ImageBase
+from silx.gui.plot.items.core import ItemChangedType, ColormapMixIn
+
+from silx.opencl import ocl
+
+if ocl is not None:
+ try:
+ from silx.opencl import sift
+ except ImportError:
+ # sift module is not available (e.g., in official Debian packages)
+ sift = None
+else: # No OpenCL device or no pyopencl
+ sift = None
+
+
+@enum.unique
+class VisualizationMode(enum.Enum):
+ """Enum for each visualization mode available."""
+
+ ONLY_A = "a"
+ ONLY_B = "b"
+ VERTICAL_LINE = "vline"
+ HORIZONTAL_LINE = "hline"
+ COMPOSITE_RED_BLUE_GRAY = "rbgchannel"
+ COMPOSITE_RED_BLUE_GRAY_NEG = "rbgnegchannel"
+ COMPOSITE_A_MINUS_B = "aminusb"
+
+
+@enum.unique
+class AlignmentMode(enum.Enum):
+ """Enum for each alignment mode available."""
+
+ ORIGIN = "origin"
+ CENTER = "center"
+ STRETCH = "stretch"
+ AUTO = "auto"
+
+
+class AffineTransformation(NamedTuple):
+ """Description of a 2D affine transformation: translation, scale and
+ rotation.
+ """
+
+ tx: float
+ ty: float
+ sx: float
+ sy: float
+ rot: float
+
+
+class _CompareImageItem(ImageBase, ColormapMixIn):
+ """Description of a virtual item of images to compare, in order to share
+ the data through the silx components.
+ """
+
+ def __init__(self):
+ ImageBase.__init__(self)
+ ColormapMixIn.__init__(self)
+ self.__image1 = None
+ self.__image2 = None
+ self.__vizualisationMode = VisualizationMode.ONLY_A
+
+ def getImageData1(self):
+ return self.__image1
+
+ def getImageData2(self):
+ return self.__image2
+
+ def setImageData1(self, image1):
+ if self.__image1 is image1:
+ return
+ self.__image1 = image1
+ self._updated(ItemChangedType.DATA)
+
+ def setImageData2(self, image2):
+ if self.__image2 is image2:
+ return
+ self.__image2 = image2
+ self._updated(ItemChangedType.DATA)
+
+ def getVizualisationMode(self) -> VisualizationMode:
+ return self.__vizualisationMode
+
+ @contextlib.contextmanager
+ def _updateColormapRange(self, previousMode, mode):
+ """COMPOSITE_A_MINUS_B don't have the same data range than others.
+
+ If the colormap is using a fixed range, it is updated in order to set
+ a similar range with the new data.
+ """
+ normalize_colormap = (
+ previousMode == VisualizationMode.COMPOSITE_A_MINUS_B
+ or mode == VisualizationMode.COMPOSITE_A_MINUS_B
+ )
+ if normalize_colormap:
+ data = self._getConcatenatedData(copy=False)
+ if data is None or data.size == 0:
+ normalize_colormap = False
+ else:
+ std1 = numpy.nanstd(data)
+ mean1 = numpy.nanmean(data)
+ yield
+
+ def transfer(v, std1, mean1, std2, mean2):
+ """Transfer a value from a data range to another using statistics"""
+ if v is None:
+ return None
+ rv = (v - mean1) / std1
+ return rv * std2 + mean2
+
+ if normalize_colormap:
+ data = self._getConcatenatedData(copy=False)
+ if data is not None and data.size != 0:
+ std2 = numpy.nanstd(data)
+ mean2 = numpy.nanmean(data)
+ c = self.getColormap()
+ if c is not None:
+ vmin, vmax = c.getVRange()
+ vmin = transfer(vmin, std1, mean1, std2, mean2)
+ vmax = transfer(vmax, std1, mean1, std2, mean2)
+ c.setVRange(vmin, vmax)
+
+ def setVizualisationMode(self, mode: VisualizationMode):
+ if self.__vizualisationMode == mode:
+ return None
+ with self._updateColormapRange(self.__vizualisationMode, mode):
+ self.__vizualisationMode = mode
+ self._updated(ItemChangedType.DATA)
+
+ def _getConcatenatedData(self, copy=True):
+ if self.__image1 is None and self.__image2 is None:
+ return None
+ if self.__image1 is None:
+ return numpy.array(self.__image2, copy=copy)
+ if self.__image2 is None:
+ return numpy.array(self.__image1, copy=copy)
+
+ if self.__vizualisationMode == VisualizationMode.COMPOSITE_A_MINUS_B:
+ # In this case the histogram have to be special
+ if self.__image1.shape == self.__image2.shape:
+ return self.__image1.astype(numpy.float32) - self.__image2.astype(
+ numpy.float32
+ )
+ else:
+ d1 = self.__image1[numpy.isfinite(self.__image1)]
+ d2 = self.__image2[numpy.isfinite(self.__image2)]
+ return numpy.concatenate((d1, d2))
+
+ def _updated(self, event=None, checkVisibility=True):
+ # Synchronizes colormapped data if changed
+ if event in (ItemChangedType.DATA, ItemChangedType.MASK):
+ data = self._getConcatenatedData(copy=False)
+ return self._setColormappedData(data, copy=False)
+ super()._updated(event=event, checkVisibility=checkVisibility)
+
+ def getColormappedData(self, copy=True):
+ """
+ Reimplementation of the `ColormapMixIn.getColormappedData` method.
+
+ This is used to provide a consistent auto scale on the compared images.
+ """
+ return self._getConcatenatedData(copy=copy)
diff --git a/src/silx/gui/plot/tools/compare/profile.py b/src/silx/gui/plot/tools/compare/profile.py
new file mode 100644
index 0000000..afe0eba
--- /dev/null
+++ b/src/silx/gui/plot/tools/compare/profile.py
@@ -0,0 +1,173 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This provides profile ROIs.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "09/06/2023"
+
+
+import numpy
+
+from silx.gui.plot.tools.profile import rois
+from silx.gui.plot.tools.profile import core
+from .core import _CompareImageItem
+
+
+COLOR_A = "C0"
+COLOR_B = "C8"
+
+
+class ProfileImageLineROI(rois.ProfileImageLineROI):
+ """ROI for a compare image profile between 2 points.
+
+ The X profile of this ROI is the projection into one of the x/y axes,
+ using its scale and its orientation.
+ """
+
+ def computeProfile(self, item):
+ if not isinstance(item, _CompareImageItem):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+ roiInfo = self._getRoiInfo()
+
+ def createProfile2(currentData):
+ coords, profile, _area, profileName, xLabel = core.createProfile(
+ roiInfo=roiInfo,
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=lineWidth,
+ method=method,
+ )
+ return coords, profile, profileName, xLabel
+
+ currentData1 = item.getImageData1()
+ currentData2 = item.getImageData2()
+
+ yLabel = "%s" % str(method).capitalize()
+ coords, profile1, title, xLabel = createProfile2(currentData1)
+ title = title + "; width = %d" % lineWidth
+ _coords, profile2, _title, _xLabel = createProfile2(currentData2)
+
+ profile1.shape = -1
+ profile2.shape = -1
+
+ title = title.format(xlabel="width", ylabel="height")
+ xLabel = xLabel.format(xlabel="width", ylabel="height")
+ yLabel = yLabel.format(xlabel="width", ylabel="height")
+
+ data = core.CurvesProfileData(
+ coords=coords,
+ profiles=[
+ core.CurveProfileDesc(profile1, color=COLOR_A, name="profileA"),
+ core.CurveProfileDesc(profile2, color=COLOR_B, name="profileB"),
+ ],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
+
+
+class ProfileImageDirectedLineROI(rois.ProfileImageDirectedLineROI):
+ """ROI for a compare image profile between 2 points.
+
+ The X profile of the line is displayed projected into the line itself,
+ using its scale and its orientation. It's the distance from the origin.
+ """
+
+ def computeProfile(self, item):
+ if not isinstance(item, _CompareImageItem):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ from silx.image.bilinear import BilinearImage
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+
+ roiInfo = self._getRoiInfo()
+ roiStart, roiEnd, _lineProjectionMode = roiInfo
+
+ startPt = (
+ (roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0],
+ )
+ endPt = ((roiEnd[1] - origin[1]) / scale[1], (roiEnd[0] - origin[0]) / scale[0])
+
+ if numpy.array_equal(startPt, endPt):
+ return None
+
+ def computeProfile(data):
+ bilinear = BilinearImage(data)
+ profile = bilinear.profile_line(
+ (startPt[0] - 0.5, startPt[1] - 0.5),
+ (endPt[0] - 0.5, endPt[1] - 0.5),
+ lineWidth,
+ method=method,
+ )
+ return profile
+
+ currentData1 = item.getImageData1()
+ currentData2 = item.getImageData2()
+ profile1 = computeProfile(currentData1)
+ profile2 = computeProfile(currentData2)
+
+ # Compute the line size
+ lineSize = numpy.sqrt(
+ (roiEnd[1] - roiStart[1]) ** 2 + (roiEnd[0] - roiStart[0]) ** 2
+ )
+ coords = numpy.linspace(
+ 0, lineSize, len(profile1), endpoint=True, dtype=numpy.float32
+ )
+
+ title = rois._lineProfileTitle(*roiStart, *roiEnd)
+ title = title + "; width = %d" % lineWidth
+ xLabel = "√({xlabel}²+{ylabel}²)"
+ yLabel = str(method).capitalize()
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ xLabel = rois._relabelAxes(plot, xLabel)
+ title = rois._relabelAxes(plot, title)
+
+ data = core.CurvesProfileData(
+ coords=coords,
+ profiles=[
+ core.CurveProfileDesc(profile1, color=COLOR_A, name="profileA"),
+ core.CurveProfileDesc(profile2, color=COLOR_B, name="profileB"),
+ ],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
diff --git a/src/silx/gui/plot/tools/compare/statusbar.py b/src/silx/gui/plot/tools/compare/statusbar.py
new file mode 100644
index 0000000..5e43a37
--- /dev/null
+++ b/src/silx/gui/plot/tools/compare/statusbar.py
@@ -0,0 +1,218 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tool bar helper.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "09/06/2023"
+
+
+import logging
+import weakref
+import numpy
+
+from silx.gui import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CompareImagesStatusBar(qt.QStatusBar):
+ """StatusBar containing specific information contained in a
+ :class:`CompareImages` widget
+
+ Use :meth:`setCompareWidget` to connect this toolbar to a specific
+ :class:`CompareImages` widget.
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ """
+
+ def __init__(self, parent=None):
+ qt.QStatusBar.__init__(self, parent)
+ self.setSizeGripEnabled(False)
+ self.layout().setSpacing(0)
+ self.__compareWidget = None
+ self._label1 = qt.QLabel(self)
+ self._label1.setFrameShape(qt.QFrame.WinPanel)
+ self._label1.setFrameShadow(qt.QFrame.Sunken)
+ self._label2 = qt.QLabel(self)
+ self._label2.setFrameShape(qt.QFrame.WinPanel)
+ self._label2.setFrameShadow(qt.QFrame.Sunken)
+ self._transform = qt.QLabel(self)
+ self._transform.setFrameShape(qt.QFrame.WinPanel)
+ self._transform.setFrameShadow(qt.QFrame.Sunken)
+ self.addWidget(self._label1)
+ self.addWidget(self._label2)
+ self.addWidget(self._transform)
+ self._pos = None
+ self._updateStatusBar()
+
+ def setCompareWidget(self, widget):
+ """
+ Connect this tool bar to a specific :class:`CompareImages` widget.
+
+ :param Union[None,CompareImages] widget: The widget to connect with.
+ """
+ compareWidget = self.getCompareWidget()
+ if compareWidget is not None:
+ compareWidget.getPlot().sigPlotSignal.disconnect(self.__plotSignalReceived)
+ compareWidget.sigConfigurationChanged.disconnect(self.__dataChanged)
+ compareWidget = widget
+ if compareWidget is None:
+ self.__compareWidget = None
+ else:
+ self.__compareWidget = weakref.ref(compareWidget)
+ if compareWidget is not None:
+ compareWidget.getPlot().sigPlotSignal.connect(self.__plotSignalReceived)
+ compareWidget.sigConfigurationChanged.connect(self.__dataChanged)
+
+ def getCompareWidget(self):
+ """Returns the connected widget.
+
+ :rtype: CompareImages
+ """
+ if self.__compareWidget is None:
+ return None
+ else:
+ return self.__compareWidget()
+
+ def __plotSignalReceived(self, event):
+ """Called when old style signals at emmited from the plot."""
+ if event["event"] == "mouseMoved":
+ x, y = event["x"], event["y"]
+ self.__mouseMoved(x, y)
+
+ def __mouseMoved(self, x, y):
+ """Called when mouse move over the plot."""
+ self._pos = x, y
+ self._updateStatusBar()
+
+ def __dataChanged(self):
+ """Called when internal data from the connected widget changes."""
+ self._updateStatusBar()
+
+ def _formatData(self, data):
+ """Format pixel of an image.
+
+ It supports intensity, RGB, and RGBA.
+
+ :param Union[int,float,numpy.ndarray,str]: Value of a pixel
+ :rtype: str
+ """
+ if data is None:
+ return "No data"
+ if isinstance(data, (int, numpy.integer)):
+ return "%d" % data
+ if isinstance(data, (float, numpy.floating)):
+ return "%f" % data
+ if isinstance(data, numpy.ndarray):
+ # RGBA value
+ if data.shape == (3,):
+ return "R:%d G:%d B:%d" % (data[0], data[1], data[2])
+ elif data.shape == (4,):
+ return "R:%d G:%d B:%d A:%d" % (data[0], data[1], data[2], data[3])
+ _logger.debug("Unsupported data format %s. Cast it to string.", type(data))
+ return str(data)
+
+ def _updateStatusBar(self):
+ """Update the content of the status bar"""
+ widget = self.getCompareWidget()
+ if widget is None:
+ self._label1.setText("ImageA: NA")
+ self._label2.setText("ImageB: NA")
+ self._transform.setVisible(False)
+ else:
+ transform = widget.getTransformation()
+ self._transform.setVisible(transform is not None)
+ if transform is not None:
+ has_notable_translation = not numpy.isclose(
+ transform.tx, 0.0, atol=0.01
+ ) or not numpy.isclose(transform.ty, 0.0, atol=0.01)
+ has_notable_scale = not numpy.isclose(
+ transform.sx, 1.0, atol=0.01
+ ) or not numpy.isclose(transform.sy, 1.0, atol=0.01)
+ has_notable_rotation = not numpy.isclose(transform.rot, 0.0, atol=0.01)
+
+ strings = []
+ if has_notable_translation:
+ strings.append("Translation")
+ if has_notable_scale:
+ strings.append("Scale")
+ if has_notable_rotation:
+ strings.append("Rotation")
+ if strings == []:
+ has_translation = not numpy.isclose(
+ transform.tx, 0.0
+ ) or not numpy.isclose(transform.ty, 0.0)
+ has_scale = not numpy.isclose(
+ transform.sx, 1.0
+ ) or not numpy.isclose(transform.sy, 1.0)
+ has_rotation = not numpy.isclose(transform.rot, 0.0)
+ if has_translation or has_scale or has_rotation:
+ text = "No big changes"
+ else:
+ text = "No changes"
+ else:
+ text = "+".join(strings)
+ self._transform.setText("Align: " + text)
+
+ strings = []
+ if not numpy.isclose(transform.ty, 0.0):
+ strings.append("Translation x: %0.3fpx" % transform.tx)
+ if not numpy.isclose(transform.ty, 0.0):
+ strings.append("Translation y: %0.3fpx" % transform.ty)
+ if not numpy.isclose(transform.sx, 1.0):
+ strings.append("Scale x: %0.3f" % transform.sx)
+ if not numpy.isclose(transform.sy, 1.0):
+ strings.append("Scale y: %0.3f" % transform.sy)
+ if not numpy.isclose(transform.rot, 0.0):
+ strings.append(
+ "Rotation: %0.3fdeg" % (transform.rot * 180 / numpy.pi)
+ )
+ if strings == []:
+ text = "No transformation"
+ else:
+ text = "\n".join(strings)
+ self._transform.setToolTip(text)
+
+ if self._pos is None:
+ self._label1.setText("ImageA: NA")
+ self._label2.setText("ImageB: NA")
+ else:
+ data1, data2 = widget.getRawPixelData(self._pos[0], self._pos[1])
+ if isinstance(data1, str):
+ self._label1.setToolTip(data1)
+ text1 = "NA"
+ else:
+ self._label1.setToolTip("")
+ text1 = self._formatData(data1)
+ if isinstance(data2, str):
+ self._label2.setToolTip(data2)
+ text2 = "NA"
+ else:
+ self._label2.setToolTip("")
+ text2 = self._formatData(data2)
+ self._label1.setText("ImageA: %s" % text1)
+ self._label2.setText("ImageB: %s" % text2)
diff --git a/src/silx/gui/plot/tools/compare/toolbar.py b/src/silx/gui/plot/tools/compare/toolbar.py
new file mode 100644
index 0000000..a7f56ec
--- /dev/null
+++ b/src/silx/gui/plot/tools/compare/toolbar.py
@@ -0,0 +1,390 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tool bar helper.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import logging
+import weakref
+from typing import List, Optional
+
+from silx.gui import qt
+from silx.gui import icons
+from .core import AlignmentMode
+from .core import VisualizationMode
+from .core import sift
+
+
+_logger = logging.getLogger(__name__)
+
+
+class AlignmentModeToolButton(qt.QToolButton):
+ """ToolButton to select a AlignmentMode"""
+
+ sigSelected = qt.Signal(AlignmentMode)
+
+ def __init__(self, parent=None):
+ super(AlignmentModeToolButton, self).__init__(parent=parent)
+
+ menu = qt.QMenu(self)
+ self.setMenu(menu)
+
+ self.__group = qt.QActionGroup(self)
+ self.__group.setExclusive(True)
+ self.__group.triggered.connect(self.__selectionChanged)
+
+ icon = icons.getQIcon("compare-align-origin")
+ action = qt.QAction(icon, "Align images on their upper-left pixel", self)
+ action.setProperty("enum", AlignmentMode.ORIGIN)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__originAlignAction = action
+ menu.addAction(action)
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-align-center")
+ action = qt.QAction(icon, "Center images", self)
+ action.setProperty("enum", AlignmentMode.CENTER)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__centerAlignAction = action
+ menu.addAction(action)
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-align-stretch")
+ action = qt.QAction(icon, "Stretch the second image on the first one", self)
+ action.setProperty("enum", AlignmentMode.STRETCH)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__stretchAlignAction = action
+ menu.addAction(action)
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-align-auto")
+ action = qt.QAction(icon, "Auto-alignment of the second image", self)
+ action.setProperty("enum", AlignmentMode.AUTO)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__autoAlignAction = action
+ menu.addAction(action)
+ if sift is None:
+ action.setEnabled(False)
+ action.setToolTip("Sift module is not available")
+ self.__group.addAction(action)
+
+ def getActionFromMode(self, mode: AlignmentMode) -> Optional[qt.QAction]:
+ """Returns an action from it's mode"""
+ for action in self.__group.actions():
+ actionMode = action.property("enum")
+ if mode == actionMode:
+ return action
+ return None
+
+ def setVisibleModes(self, modes: List[AlignmentMode]):
+ """Make visible only a set of modes.
+
+ The order does not matter.
+ """
+ modes = set(modes)
+ for action in self.__group.actions():
+ mode = action.property("enum")
+ action.setVisible(mode in modes)
+
+ def __selectionChanged(self, selectedAction: qt.QAction):
+ """Called when user requesting changes of the alignment mode."""
+ self.__updateMenu()
+ mode = self.getSelected()
+ self.sigSelected.emit(mode)
+
+ def __updateMenu(self):
+ """Update the state of the action containing alignment menu."""
+ selectedAction = self.__group.checkedAction()
+ if selectedAction is not None:
+ self.setText(selectedAction.text())
+ self.setIcon(selectedAction.icon())
+ self.setToolTip(selectedAction.toolTip())
+ else:
+ self.setText("")
+ self.setIcon(qt.QIcon())
+ self.setToolTip("")
+
+ def getSelected(self) -> AlignmentMode:
+ action = self.__group.checkedAction()
+ if action is None:
+ return None
+ return action.property("enum")
+
+ def setSelected(self, mode: AlignmentMode):
+ action = self.getActionFromMode(mode)
+ old = self.__group.blockSignals(True)
+ if action is not None:
+ # Check this action
+ action.setChecked(True)
+ else:
+ action = self.__group.checkedAction()
+ if action is not None:
+ # Uncheck this action
+ action.setChecked(False)
+ self.__updateMenu()
+ self.__group.blockSignals(old)
+
+
+class VisualizationModeToolButton(qt.QToolButton):
+ """ToolButton to select a VisualisationMode"""
+
+ sigSelected = qt.Signal(VisualizationMode)
+
+ def __init__(self, parent=None):
+ super(VisualizationModeToolButton, self).__init__(parent=parent)
+
+ menu = qt.QMenu(self)
+ self.setMenu(menu)
+
+ self.__group = qt.QActionGroup(self)
+ self.__group.setExclusive(True)
+ self.__group.triggered.connect(self.__selectionChanged)
+
+ icon = icons.getQIcon("compare-mode-a")
+ action = qt.QAction(icon, "Display the first image only", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_A))
+ action.setProperty("enum", VisualizationMode.ONLY_A)
+ menu.addAction(action)
+ self.__aModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-b")
+ action = qt.QAction(icon, "Display the second image only", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_B))
+ action.setProperty("enum", VisualizationMode.ONLY_B)
+ menu.addAction(action)
+ self.__bModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-vline")
+ action = qt.QAction(icon, "Vertical compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_V))
+ action.setProperty("enum", VisualizationMode.VERTICAL_LINE)
+ menu.addAction(action)
+ self.__vlineModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-hline")
+ action = qt.QAction(icon, "Horizontal compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_H))
+ action.setProperty("enum", VisualizationMode.HORIZONTAL_LINE)
+ menu.addAction(action)
+ self.__hlineModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-rb-channel")
+ action = qt.QAction(icon, "Blue/red compare mode (additive mode)", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_C))
+ action.setProperty("enum", VisualizationMode.COMPOSITE_RED_BLUE_GRAY)
+ menu.addAction(action)
+ self.__brChannelModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-rbneg-channel")
+ action = qt.QAction(icon, "Yellow/cyan compare mode (subtractive mode)", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_Y))
+ action.setProperty("enum", VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG)
+ menu.addAction(action)
+ self.__ycChannelModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-a-minus-b")
+ action = qt.QAction(icon, "Raw A minus B compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
+ action.setProperty("enum", VisualizationMode.COMPOSITE_A_MINUS_B)
+ menu.addAction(action)
+ self.__ycChannelModeAction = action
+ self.__group.addAction(action)
+
+ def getActionFromMode(self, mode: VisualizationMode) -> Optional[qt.QAction]:
+ """Returns an action from it's mode"""
+ for action in self.__group.actions():
+ actionMode = action.property("enum")
+ if mode == actionMode:
+ return action
+ return None
+
+ def setVisibleModes(self, modes: List[VisualizationMode]):
+ """Make visible only a set of modes.
+
+ The order does not matter.
+ """
+ modes = set(modes)
+ for action in self.__group.actions():
+ mode = action.property("enum")
+ action.setVisible(mode in modes)
+
+ def __selectionChanged(self, selectedAction: qt.QAction):
+ """Called when user requesting changes of the visualization mode."""
+ self.__updateMenu()
+ mode = self.getSelected()
+ self.sigSelected.emit(mode)
+
+ def __updateMenu(self):
+ """Update the state of the action containing visualization menu."""
+ selectedAction = self.__group.checkedAction()
+ if selectedAction is not None:
+ self.setText(selectedAction.text())
+ self.setIcon(selectedAction.icon())
+ self.setToolTip(selectedAction.toolTip())
+ else:
+ self.setText("")
+ self.setIcon(qt.QIcon())
+ self.setToolTip("")
+
+ def getSelected(self) -> VisualizationMode:
+ action = self.__group.checkedAction()
+ if action is None:
+ return None
+ return action.property("enum")
+
+ def setSelected(self, mode: VisualizationMode):
+ action = self.getActionFromMode(mode)
+ old = self.__group.blockSignals(True)
+ if action is not None:
+ # Check this action
+ action.setChecked(True)
+ else:
+ action = self.__group.checkedAction()
+ if action is not None:
+ # Uncheck this action
+ action.setChecked(False)
+ self.__updateMenu()
+ self.__group.blockSignals(old)
+
+
+class CompareImagesToolBar(qt.QToolBar):
+ """ToolBar containing specific tools to custom the configuration of a
+ :class:`CompareImages` widget
+
+ Use :meth:`setCompareWidget` to connect this toolbar to a specific
+ :class:`CompareImages` widget.
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ """
+
+ def __init__(self, parent=None):
+ qt.QToolBar.__init__(self, parent)
+ self.setWindowTitle("Compare images")
+
+ self.__compareWidget = None
+
+ self.__visualizationToolButton = VisualizationModeToolButton(self)
+ self.__visualizationToolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ self.__visualizationToolButton.sigSelected.connect(self.__visualizationChanged)
+ self.addWidget(self.__visualizationToolButton)
+
+ self.__alignmentToolButton = AlignmentModeToolButton(self)
+ self.__alignmentToolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ self.__alignmentToolButton.sigSelected.connect(self.__alignmentChanged)
+ self.addWidget(self.__alignmentToolButton)
+
+ icon = icons.getQIcon("compare-keypoints")
+ action = qt.QAction(icon, "Display/hide alignment keypoints", self)
+ action.setCheckable(True)
+ action.triggered.connect(self.__keypointVisibilityChanged)
+ self.addAction(action)
+ self.__displayKeypoints = action
+
+ def __visualizationChanged(self, mode: VisualizationMode):
+ widget = self.getCompareWidget()
+ if widget is not None:
+ widget.setVisualizationMode(mode)
+
+ def __alignmentChanged(self, mode: AlignmentMode):
+ widget = self.getCompareWidget()
+ if widget is not None:
+ widget.setAlignmentMode(mode)
+
+ def setCompareWidget(self, widget):
+ """
+ Connect this tool bar to a specific :class:`CompareImages` widget.
+
+ :param Union[None,CompareImages] widget: The widget to connect with.
+ """
+ compareWidget = self.getCompareWidget()
+ if compareWidget is not None:
+ compareWidget.sigConfigurationChanged.disconnect(
+ self.__updateSelectedActions
+ )
+ compareWidget = widget
+ self.setEnabled(compareWidget is not None)
+ if compareWidget is None:
+ self.__compareWidget = None
+ else:
+ self.__compareWidget = weakref.ref(compareWidget)
+ if compareWidget is not None:
+ widget.sigConfigurationChanged.connect(self.__updateSelectedActions)
+ self.__updateSelectedActions()
+
+ def getCompareWidget(self):
+ """Returns the connected widget.
+
+ :rtype: CompareImages
+ """
+ if self.__compareWidget is None:
+ return None
+ else:
+ return self.__compareWidget()
+
+ def __updateSelectedActions(self):
+ """
+ Update the state of this tool bar according to the state of the
+ connected :class:`CompareImages` widget.
+ """
+ widget = self.getCompareWidget()
+ if widget is None:
+ return
+ self.__visualizationToolButton.setSelected(widget.getVisualizationMode())
+ self.__alignmentToolButton.setSelected(widget.getAlignmentMode())
+ self.__displayKeypoints.setChecked(widget.getKeypointsVisible())
+
+ def __keypointVisibilityChanged(self):
+ """Called when action managing keypoints visibility changes"""
+ widget = self.getCompareWidget()
+ if widget is not None:
+ keypointsVisible = self.__displayKeypoints.isChecked()
+ widget.setKeypointsVisible(keypointsVisible)
diff --git a/src/silx/gui/plot/tools/menus.py b/src/silx/gui/plot/tools/menus.py
new file mode 100644
index 0000000..c748b6e
--- /dev/null
+++ b/src/silx/gui/plot/tools/menus.py
@@ -0,0 +1,93 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides :class:`PlotWidget`-related QMenu.
+
+The following QMenu is available:
+
+- :class:`ZoomEnabledAxesMenu`
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "12/06/2023"
+
+
+import weakref
+from typing import Optional
+
+from silx.gui import qt
+
+from ..PlotWidget import PlotWidget
+
+
+class ZoomEnabledAxesMenu(qt.QMenu):
+ """Menu to toggle axes for zoom interaction"""
+
+ def __init__(self, plot: PlotWidget, parent: Optional[qt.QWidget] = None):
+ super().__init__(parent)
+ self.setTitle("Zoom axes")
+
+ assert isinstance(plot, PlotWidget)
+ self.__plotRef = weakref.ref(plot)
+
+ self.addSection("Enabled axes")
+ self.__xAxisAction = qt.QAction("X axis", parent=self)
+ self.__yAxisAction = qt.QAction("Y left axis", parent=self)
+ self.__y2AxisAction = qt.QAction("Y right axis", parent=self)
+
+ for action in (self.__xAxisAction, self.__yAxisAction, self.__y2AxisAction):
+ action.setCheckable(True)
+ action.setChecked(True)
+ action.triggered.connect(self._axesActionTriggered)
+ self.addAction(action)
+
+ # Listen to interaction configuration change
+ plot.interaction().sigChanged.connect(self._interactionChanged)
+ # Init the state
+ self._interactionChanged()
+
+ def getPlotWidget(self) -> Optional[PlotWidget]:
+ return self.__plotRef()
+
+ def _axesActionTriggered(self, checked=False):
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ plot.interaction().setZoomEnabledAxes(
+ self.__xAxisAction.isChecked(),
+ self.__yAxisAction.isChecked(),
+ self.__y2AxisAction.isChecked(),
+ )
+
+ def _interactionChanged(self):
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ enabledAxes = plot.interaction().getZoomEnabledAxes()
+ self.__xAxisAction.setChecked(enabledAxes.xaxis)
+ self.__yAxisAction.setChecked(enabledAxes.yaxis)
+ self.__y2AxisAction.setChecked(enabledAxes.y2axis)
diff --git a/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
new file mode 100644
index 0000000..271adb8
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
@@ -0,0 +1,44 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module profile tools for scatter plots.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+from . import toolbar
+
+
+class ScatterProfileToolBar(toolbar.ProfileToolBar):
+ """QToolBar providing scatter plot profiling tools
+
+ :param parent: See :class:`QToolBar`.
+ :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate.
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(ScatterProfileToolBar, self).__init__(parent, plot)
+ self.setScheme("scatter")
diff --git a/src/silx/gui/plot/tools/profile/__init__.py b/src/silx/gui/plot/tools/profile/__init__.py
new file mode 100644
index 0000000..a72b5d2
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/__init__.py
@@ -0,0 +1,37 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tools to get profiles on plot data.
+
+It provides:
+
+- :class:`ScatterProfileToolBar`: a QToolBar to handle profile on scatter data
+
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "07/06/2018"
+
+
+from .ScatterProfileToolBar import ScatterProfileToolBar # noqa
diff --git a/src/silx/gui/plot/tools/profile/core.py b/src/silx/gui/plot/tools/profile/core.py
new file mode 100644
index 0000000..194f459
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/core.py
@@ -0,0 +1,567 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module define core objects for profile tools.
+"""
+
+from __future__ import annotations
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno", "V. Valls"]
+__license__ = "MIT"
+__date__ = "17/04/2020"
+
+import typing
+import numpy
+import weakref
+
+from silx.image.bilinear import BilinearImage
+from silx.gui import qt
+from silx.gui import colors
+import silx.gui.plot.items
+
+
+class CurveProfileData(typing.NamedTuple):
+ coords: numpy.ndarray
+ profile: numpy.ndarray
+ title: str
+ xLabel: str
+ yLabel: str
+
+
+class RgbaProfileData(typing.NamedTuple):
+ coords: numpy.ndarray
+ profile: numpy.ndarray
+ profile_r: numpy.ndarray
+ profile_g: numpy.ndarray
+ profile_b: numpy.ndarray
+ profile_a: numpy.ndarray
+ title: str
+ xLabel: str
+ yLabel: str
+
+
+class ImageProfileData(typing.NamedTuple):
+ coords: numpy.ndarray
+ profile: numpy.ndarray
+ title: str
+ xLabel: str
+ yLabel: str
+ colormap: colors.Colormap
+
+
+class CurveProfileDesc(typing.NamedTuple):
+ profile: numpy.ndarray
+ name: typing.Optional[str] = None
+ color: typing.Optional[str] = None
+
+
+class CurvesProfileData(typing.NamedTuple):
+ coords: numpy.ndarray
+ profiles: typing.List[CurveProfileDesc]
+ title: str
+ xLabel: str
+ yLabel: str
+
+
+class ProfileRoiMixIn:
+ """Base mix-in for ROI which can be used to select a profile.
+
+ This mix-in have to be applied to a :class:`~silx.gui.plot.items.roi.RegionOfInterest`
+ in order to be usable by a :class:`~silx.gui.plot.tools.profile.manager.ProfileManager`.
+ """
+
+ ITEM_KIND = None
+ """Define the plot item which can be used with this profile ROI"""
+
+ sigProfilePropertyChanged = qt.Signal()
+ """Emitted when a property of this profile have changed"""
+
+ sigPlotItemChanged = qt.Signal()
+ """Emitted when the plot item linked to this profile have changed"""
+
+ def __init__(self, parent=None):
+ self.__profileWindow = None
+ self.__profileManager = None
+ self.__plotItem = None
+ self.setName("Profile")
+ self.setEditable(True)
+ self.setSelectable(True)
+
+ def invalidateProfile(self):
+ """Must be called by the implementation when the profile have to be
+ recomputed."""
+ profileManager = self.getProfileManager()
+ if profileManager is not None:
+ profileManager.requestUpdateProfile(self)
+
+ def invalidateProperties(self):
+ """Must be called when a property of the profile have changed."""
+ self.sigProfilePropertyChanged.emit()
+
+ def _setPlotItem(self, plotItem):
+ """Specify the plot item to use with this profile
+
+ :param `~silx.gui.plot.items.Item` plotItem: A plot item
+ """
+ previousPlotItem = self.getPlotItem()
+ if previousPlotItem is plotItem:
+ return
+ self.__plotItem = weakref.ref(plotItem)
+ self.sigPlotItemChanged.emit()
+
+ def getPlotItem(self):
+ """Returns the plot item used by this profile
+
+ :rtype: `~silx.gui.plot.items.Item`
+ """
+ if self.__plotItem is None:
+ return None
+ plotItem = self.__plotItem()
+ if plotItem is None:
+ self.__plotItem = None
+ return plotItem
+
+ def _setProfileManager(self, profileManager):
+ self.__profileManager = profileManager
+
+ def getProfileManager(self):
+ """
+ Returns the profile manager connected to this ROI.
+
+ :rtype: ~silx.gui.plot.tools.profile.manager.ProfileManager
+ """
+ return self.__profileManager
+
+ def getProfileWindow(self):
+ """
+ Returns the windows associated to this ROI, else None.
+
+ :rtype: ProfileWindow
+ """
+ return self.__profileWindow
+
+ def setProfileWindow(self, profileWindow):
+ """
+ Associate a window to this ROI. Can be None.
+
+ :param ProfileWindow profileWindow: A main window
+ to display the profile.
+ """
+ if profileWindow is self.__profileWindow:
+ return
+ if self.__profileWindow is not None:
+ self.__profileWindow.sigClose.disconnect(self.__profileWindowAboutToClose)
+ self.__profileWindow.setRoiProfile(None)
+ self.__profileWindow = profileWindow
+ if self.__profileWindow is not None:
+ self.__profileWindow.sigClose.connect(self.__profileWindowAboutToClose)
+ self.__profileWindow.setRoiProfile(self)
+
+ def __profileWindowAboutToClose(self):
+ profileManager = self.getProfileManager()
+ roiManager = profileManager.getRoiManager()
+ try:
+ roiManager.removeRoi(self)
+ except ValueError:
+ pass
+
+ def computeProfile(
+ self, item: silx.gui.plot.items.Item
+ ) -> typing.Union[
+ CurveProfileData, ImageProfileData, RgbaProfileData, CurvesProfileData
+ ]:
+ """
+ Compute the profile which will be displayed.
+
+ This method is not called from the main Qt thread, but from a thread
+ pool.
+
+ :param item: A plot item
+ """
+ raise NotImplementedError()
+
+
+def _alignedFullProfile(data, origin, scale, position, roiWidth, axis, method):
+ """Get a profile along one axis on a stack of images
+
+ :param numpy.ndarray data: 3D volume (stack of 2D images)
+ The first dimension is the image index.
+ :param origin: Origin of image in plot (ox, oy)
+ :param scale: Scale of image in plot (sx, sy)
+ :param float position: Position of profile line in plot coords
+ on the axis orthogonal to the profile direction.
+ :param int roiWidth: Width of the profile in image pixels.
+ :param int axis: 0 for horizontal profile, 1 for vertical.
+ :param str method: method to compute the profile. Can be 'mean' or 'sum' or
+ 'none'
+ :return: profile image + effective ROI area corners in plot coords
+ """
+ assert axis in (0, 1)
+ assert len(data.shape) == 3
+ assert method in ("mean", "sum", "none")
+
+ # Convert from plot to image coords
+ imgPos = int((position - origin[1 - axis]) / scale[1 - axis])
+
+ if axis == 1: # Vertical profile
+ # Transpose image to always do a horizontal profile
+ data = numpy.transpose(data, (0, 2, 1))
+
+ nimages, height, width = data.shape
+
+ roiWidth = min(height, roiWidth) # Clip roi width to image size
+
+ # Get [start, end[ coords of the roi in the data
+ start = int(int(imgPos) + 0.5 - roiWidth / 2.0)
+ start = min(max(0, start), height - roiWidth)
+ end = start + roiWidth
+
+ if method == "none":
+ profile = None
+ else:
+ if start < height and end > 0:
+ if method == "mean":
+ fct = numpy.mean
+ elif method == "sum":
+ fct = numpy.sum
+ else:
+ raise ValueError("method not managed")
+ profile = fct(data[:, max(0, start) : min(end, height), :], axis=1).astype(
+ numpy.float32
+ )
+ else:
+ profile = numpy.zeros((nimages, width), dtype=numpy.float32)
+
+ # Compute effective ROI in plot coords
+ profileBounds = (
+ numpy.array((0, width, width, 0), dtype=numpy.float32) * scale[axis]
+ + origin[axis]
+ )
+ roiBounds = (
+ numpy.array((start, start, end, end), dtype=numpy.float32) * scale[1 - axis]
+ + origin[1 - axis]
+ )
+
+ if axis == 0: # Horizontal profile
+ area = profileBounds, roiBounds
+ else: # vertical profile
+ area = roiBounds, profileBounds
+
+ return profile, area
+
+
+def _alignedPartialProfile(data, rowRange, colRange, axis, method):
+ """Mean of a rectangular region (ROI) of a stack of images
+ along a given axis.
+
+ Returned values and all parameters are in image coordinates.
+
+ :param numpy.ndarray data: 3D volume (stack of 2D images)
+ The first dimension is the image index.
+ :param rowRange: [min, max[ of ROI rows (upper bound excluded).
+ :type rowRange: 2-tuple of int (min, max) with min < max
+ :param colRange: [min, max[ of ROI columns (upper bound excluded).
+ :type colRange: 2-tuple of int (min, max) with min < max
+ :param int axis: The axis along which to take the profile of the ROI.
+ 0: Sum rows along columns.
+ 1: Sum columns along rows.
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ :return: Profile image along the ROI as the mean of the intersection
+ of the ROI and the image.
+ """
+ assert axis in (0, 1)
+ assert len(data.shape) == 3
+ assert rowRange[0] < rowRange[1]
+ assert colRange[0] < colRange[1]
+ assert method in ("mean", "sum")
+
+ nimages, height, width = data.shape
+
+ # Range aligned with the integration direction
+ profileRange = colRange if axis == 0 else rowRange
+
+ profileLength = abs(profileRange[1] - profileRange[0])
+
+ # Subset of the image to use as intersection of ROI and image
+ rowStart = min(max(0, rowRange[0]), height)
+ rowEnd = min(max(0, rowRange[1]), height)
+ colStart = min(max(0, colRange[0]), width)
+ colEnd = min(max(0, colRange[1]), width)
+
+ if method == "mean":
+ _fct = numpy.mean
+ elif method == "sum":
+ _fct = numpy.sum
+ else:
+ raise ValueError("method not managed")
+
+ imgProfile = _fct(
+ data[:, rowStart:rowEnd, colStart:colEnd], axis=axis + 1, dtype=numpy.float32
+ )
+
+ # Profile including out of bound area
+ profile = numpy.zeros((nimages, profileLength), dtype=numpy.float32)
+
+ # Place imgProfile in full profile
+ offset = -min(0, profileRange[0])
+ profile[:, offset : offset + imgProfile.shape[1]] = imgProfile
+
+ return profile
+
+
+def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
+ """Create the profile line for the the given image.
+
+ :param roiInfo: information about the ROI: start point, end point and
+ type ("X", "Y", "D")
+ :param numpy.ndarray currentData: the 2D image or the 3D stack of images
+ on which we compute the profile.
+ :param origin: (ox, oy) the offset from origin
+ :type origin: 2-tuple of float
+ :param scale: (sx, sy) the scale to use
+ :type scale: 2-tuple of float
+ :param int lineWidth: width of the profile line
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ or 'none': to compute everything except the profile
+ :return: `coords, profile, area, profileName, xLabel`, where:
+ - coords is the X coordinate to use to display the profile
+ - profile is a 2D array of the profiles of the stack of images.
+ For a single image, the profile is a curve, so this parameter
+ has a shape *(1, len(curve))*
+ - area is a tuple of two 1D arrays with 4 values each. They represent
+ the effective ROI area corners in plot coords.
+ - profileName is a string describing the ROI, meant to be used as
+ title of the profile plot
+ - xLabel the label for X in the profile window
+
+ :rtype: tuple(ndarray,ndarray,(ndarray,ndarray),str)
+ """
+ if currentData is None or roiInfo is None or lineWidth is None:
+ raise ValueError("createProfile called with invalide arguments")
+
+ # force 3D data (stack of images)
+ if len(currentData.shape) == 2:
+ currentData3D = currentData.reshape((1,) + currentData.shape)
+ elif len(currentData.shape) == 3:
+ currentData3D = currentData
+
+ roiWidth = max(1, lineWidth)
+ roiStart, roiEnd, lineProjectionMode = roiInfo
+
+ if lineProjectionMode == "X": # Horizontal profile on the whole image
+ profile, area = _alignedFullProfile(
+ currentData3D, origin, scale, roiStart[1], roiWidth, axis=0, method=method
+ )
+
+ if method == "none":
+ coords = None
+ else:
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[0] + origin[0]
+
+ yMin, yMax = min(area[1]), max(area[1]) - 1
+ if roiWidth <= 1:
+ profileName = "{ylabel} = %g" % yMin
+ else:
+ profileName = "{ylabel} = [%g, %g]" % (yMin, yMax)
+ xLabel = "{xlabel}"
+
+ elif lineProjectionMode == "Y": # Vertical profile on the whole image
+ profile, area = _alignedFullProfile(
+ currentData3D, origin, scale, roiStart[0], roiWidth, axis=1, method=method
+ )
+
+ if method == "none":
+ coords = None
+ else:
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[1] + origin[1]
+
+ xMin, xMax = min(area[0]), max(area[0]) - 1
+ if roiWidth <= 1:
+ profileName = "{xlabel} = %g" % xMin
+ else:
+ profileName = "{xlabel} = [%g, %g]" % (xMin, xMax)
+ xLabel = "{ylabel}"
+
+ else: # Free line profile
+ # Convert start and end points in image coords as (row, col)
+ startPt = (
+ (roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0],
+ )
+ endPt = ((roiEnd[1] - origin[1]) / scale[1], (roiEnd[0] - origin[0]) / scale[0])
+
+ if int(startPt[0]) == int(endPt[0]) or int(startPt[1]) == int(endPt[1]):
+ # Profile is aligned with one of the axes
+
+ # Convert to int
+ startPt = int(startPt[0]), int(startPt[1])
+ endPt = int(endPt[0]), int(endPt[1])
+
+ # Ensure startPt <= endPt
+ if startPt[0] > endPt[0] or startPt[1] > endPt[1]:
+ startPt, endPt = endPt, startPt
+
+ if startPt[0] == endPt[0]: # Row aligned
+ rowRange = (
+ int(startPt[0] + 0.5 - 0.5 * roiWidth),
+ int(startPt[0] + 0.5 + 0.5 * roiWidth),
+ )
+ colRange = startPt[1], endPt[1] + 1
+ if method == "none":
+ profile = None
+ else:
+ profile = _alignedPartialProfile(
+ currentData3D, rowRange, colRange, axis=0, method=method
+ )
+
+ else: # Column aligned
+ rowRange = startPt[0], endPt[0] + 1
+ colRange = (
+ int(startPt[1] + 0.5 - 0.5 * roiWidth),
+ int(startPt[1] + 0.5 + 0.5 * roiWidth),
+ )
+ if method == "none":
+ profile = None
+ else:
+ profile = _alignedPartialProfile(
+ currentData3D, rowRange, colRange, axis=1, method=method
+ )
+ # Convert ranges to plot coords to draw ROI area
+ area = (
+ numpy.array(
+ (colRange[0], colRange[1], colRange[1], colRange[0]),
+ dtype=numpy.float32,
+ )
+ * scale[0]
+ + origin[0],
+ numpy.array(
+ (rowRange[0], rowRange[0], rowRange[1], rowRange[1]),
+ dtype=numpy.float32,
+ )
+ * scale[1]
+ + origin[1],
+ )
+
+ else: # General case: use bilinear interpolation
+ # Ensure startPt <= endPt
+ if startPt[1] > endPt[1] or (
+ startPt[1] == endPt[1] and startPt[0] > endPt[0]
+ ):
+ startPt, endPt = endPt, startPt
+
+ if method == "none":
+ profile = None
+ else:
+ profile = []
+ for slice_idx in range(currentData3D.shape[0]):
+ bilinear = BilinearImage(currentData3D[slice_idx, :, :])
+
+ profile.append(
+ bilinear.profile_line(
+ (startPt[0] - 0.5, startPt[1] - 0.5),
+ (endPt[0] - 0.5, endPt[1] - 0.5),
+ roiWidth,
+ method=method,
+ )
+ )
+ profile = numpy.array(profile)
+
+ # Extend ROI with half a pixel on each end, and
+ # Convert back to plot coords (x, y)
+ length = numpy.sqrt(
+ (endPt[0] - startPt[0]) ** 2 + (endPt[1] - startPt[1]) ** 2
+ )
+ dRow = (endPt[0] - startPt[0]) / length
+ dCol = (endPt[1] - startPt[1]) / length
+
+ # Extend ROI with half a pixel on each end
+ roiStartPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol
+ roiEndPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol
+
+ # Rotate deltas by 90 degrees to apply line width
+ dRow, dCol = dCol, -dRow
+
+ area = (
+ numpy.array(
+ (
+ roiStartPt[1] - 0.5 * roiWidth * dCol,
+ roiStartPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] - 0.5 * roiWidth * dCol,
+ ),
+ dtype=numpy.float32,
+ )
+ * scale[0]
+ + origin[0],
+ numpy.array(
+ (
+ roiStartPt[0] - 0.5 * roiWidth * dRow,
+ roiStartPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] - 0.5 * roiWidth * dRow,
+ ),
+ dtype=numpy.float32,
+ )
+ * scale[1]
+ + origin[1],
+ )
+
+ # Convert start and end points back to plot coords
+ y0 = startPt[0] * scale[1] + origin[1]
+ x0 = startPt[1] * scale[0] + origin[0]
+ y1 = endPt[0] * scale[1] + origin[1]
+ x1 = endPt[1] * scale[0] + origin[0]
+
+ if startPt[1] == endPt[1]:
+ profileName = "{xlabel} = %g; {ylabel} = [%g, %g]" % (x0, y0, y1)
+ if method == "none":
+ coords = None
+ else:
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[1] + y0
+ xLabel = "{ylabel}"
+
+ elif startPt[0] == endPt[0]:
+ profileName = "{ylabel} = %g; {xlabel} = [%g, %g]" % (y0, x0, x1)
+ if method == "none":
+ coords = None
+ else:
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[0] + x0
+ xLabel = "{xlabel}"
+
+ else:
+ m = (y1 - y0) / (x1 - x0)
+ b = y0 - m * x0
+ profileName = "{ylabel} = %g * {xlabel} %+g" % (m, b)
+ if method == "none":
+ coords = None
+ else:
+ coords = numpy.linspace(
+ x0, x1, len(profile[0]), endpoint=True, dtype=numpy.float32
+ )
+ xLabel = "{xlabel}"
+
+ return coords, profile, area, profileName, xLabel
diff --git a/src/silx/gui/plot/tools/profile/editors.py b/src/silx/gui/plot/tools/profile/editors.py
new file mode 100644
index 0000000..d53f775
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/editors.py
@@ -0,0 +1,309 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides editors which are used to custom profile ROI properties.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+import logging
+
+from silx.gui import qt
+
+from silx.gui.utils import blockSignals
+from silx.gui.plot.PlotToolButtons import ProfileOptionToolButton
+from silx.gui.plot.PlotToolButtons import ProfileToolButton
+from . import rois
+from . import core
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _NoProfileRoiEditor(qt.QWidget):
+ sigDataCommited = qt.Signal()
+
+ def setEditorData(self, roi):
+ pass
+
+ def setRoiData(self, roi):
+ pass
+
+
+class _DefaultImageProfileRoiEditor(qt.QWidget):
+ sigDataCommited = qt.Signal()
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent=parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self._initLayout(layout)
+
+ def _initLayout(self, layout):
+ self._lineWidth = qt.QSpinBox(self)
+ self._lineWidth.setRange(1, 1000)
+ self._lineWidth.setValue(1)
+ self._lineWidth.valueChanged[int].connect(self._widgetChanged)
+
+ self._methodsButton = ProfileOptionToolButton(parent=self, plot=None)
+ self._methodsButton.sigMethodChanged.connect(self._widgetChanged)
+
+ label = qt.QLabel("W:")
+ label.setToolTip("Line width in pixels")
+ layout.addWidget(label)
+ layout.addWidget(self._lineWidth)
+ layout.addWidget(self._methodsButton)
+
+ def _widgetChanged(self, value=None):
+ self.commitData()
+
+ def commitData(self):
+ self.sigDataCommited.emit()
+
+ def setEditorData(self, roi):
+ with blockSignals(self._lineWidth):
+ self._lineWidth.setValue(roi.getProfileLineWidth())
+ with blockSignals(self._methodsButton):
+ method = roi.getProfileMethod()
+ self._methodsButton.setMethod(method)
+
+ def setRoiData(self, roi):
+ lineWidth = self._lineWidth.value()
+ roi.setProfileLineWidth(lineWidth)
+ method = self._methodsButton.getMethod()
+ roi.setProfileMethod(method)
+
+
+class _DefaultImageStackProfileRoiEditor(_DefaultImageProfileRoiEditor):
+ def _initLayout(self, layout):
+ super(_DefaultImageStackProfileRoiEditor, self)._initLayout(layout)
+ self._profileDim = ProfileToolButton(parent=self, plot=None)
+ self._profileDim.sigDimensionChanged.connect(self._widgetChanged)
+ layout.addWidget(self._profileDim)
+
+ def setEditorData(self, roi):
+ super(_DefaultImageStackProfileRoiEditor, self).setEditorData(roi)
+ with blockSignals(self._profileDim):
+ kind = roi.getProfileType()
+ dim = {"1D": 1, "2D": 2}[kind]
+ self._profileDim.setDimension(dim)
+
+ def setRoiData(self, roi):
+ super(_DefaultImageStackProfileRoiEditor, self).setRoiData(roi)
+ dim = self._profileDim.getDimension()
+ kind = {1: "1D", 2: "2D"}[dim]
+ roi.setProfileType(kind)
+
+
+class _DefaultScatterProfileRoiEditor(qt.QWidget):
+ sigDataCommited = qt.Signal()
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent=parent)
+
+ self._nPoints = qt.QSpinBox(self)
+ self._nPoints.setRange(1, 9999)
+ self._nPoints.setValue(1024)
+ self._nPoints.valueChanged[int].connect(self.__widgetChanged)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ label = qt.QLabel("Samples:")
+ label.setToolTip("Number of sample points of the profile")
+ layout.addWidget(label)
+ layout.addWidget(self._nPoints)
+
+ def __widgetChanged(self, value=None):
+ self.commitData()
+
+ def commitData(self):
+ self.sigDataCommited.emit()
+
+ def setEditorData(self, roi):
+ with blockSignals(self._nPoints):
+ self._nPoints.setValue(roi.getNPoints())
+
+ def setRoiData(self, roi):
+ nPoints = self._nPoints.value()
+ roi.setNPoints(nPoints)
+
+
+class ProfileRoiEditorAction(qt.QWidgetAction):
+ """
+ Action displaying GUI to edit the selected ROI.
+
+ :param qt.QWidget parent: Parent widget
+ """
+
+ def __init__(self, parent=None):
+ super(ProfileRoiEditorAction, self).__init__(parent)
+ self.__roiManager = None
+ self.__roi = None
+ self.__inhibiteReentance = None
+
+ def createWidget(self, parent):
+ """Inherit the method to create a new editor"""
+ widget = qt.QWidget(parent)
+ layout = qt.QHBoxLayout(widget)
+ if isinstance(parent, qt.QMenu):
+ margins = layout.contentsMargins()
+ layout.setContentsMargins(margins.left(), 0, margins.right(), 0)
+ else:
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ editorClass = self.getEditorClass(self.__roi)
+ editor = editorClass(parent)
+ editor.setEditorData(self.__roi)
+ self.__setEditor(widget, editor)
+ return widget
+
+ def deleteWidget(self, widget):
+ """Inherit the method to delete an editor"""
+ self.__setEditor(widget, None)
+ return qt.QWidgetAction.deleteWidget(self, widget)
+
+ def _getEditor(self, widget):
+ """Returns the editor contained in the widget holder"""
+ layout = widget.layout()
+ if layout.count() == 0:
+ return None
+ return layout.itemAt(0).widget()
+
+ def setRoiManager(self, roiManager):
+ """
+ Connect this action to a ROI manager.
+
+ :param RegionOfInterestManager roiManager: A ROI manager
+ """
+ if self.__roiManager is roiManager:
+ return
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged)
+ self.__roiManager = roiManager
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged)
+ self.__currentRoiChanged(roiManager.getCurrentRoi())
+
+ def __currentRoiChanged(self, roi):
+ """Handle changes of the selected ROI"""
+ if roi is not None and not isinstance(roi, core.ProfileRoiMixIn):
+ return
+ self.setProfileRoi(roi)
+
+ def setProfileRoi(self, roi):
+ """Set a profile ROI to edit.
+
+ :param ProfileRoiMixIn roi: A profile ROI
+ """
+ if self.__roi is roi:
+ return
+ if self.__roi is not None:
+ self.__roi.sigProfilePropertyChanged.disconnect(self.__roiPropertyChanged)
+ self.__roi = roi
+ if self.__roi is not None:
+ self.__roi.sigProfilePropertyChanged.connect(self.__roiPropertyChanged)
+ self._updateWidgets()
+
+ def getRoiProfile(self):
+ """Returns the edited profile ROI.
+
+ :rtype: ProfileRoiMixIn
+ """
+ return self.__roi
+
+ def __roiPropertyChanged(self):
+ """Handle changes on the property defining the ROI."""
+ self._updateWidgetValues()
+
+ def __setEditor(self, widget, editor):
+ """Set the editor to display.
+
+ :param qt.QWidget editor: The editor to display
+ """
+ previousEditor = self._getEditor(widget)
+ if previousEditor is editor:
+ return
+ layout = widget.layout()
+ if previousEditor is not None:
+ try:
+ previousEditor.sigDataCommited.disconnect(self._editorDataCommited)
+ except (RuntimeError, TypeError):
+ pass
+ layout.removeWidget(previousEditor)
+ previousEditor.deleteLater()
+ if editor is not None:
+ editor.sigDataCommited.connect(self._editorDataCommited)
+ layout.addWidget(editor)
+
+ def getEditorClass(self, roi):
+ """Returns the editor class to use according to the ROI."""
+ if roi is None:
+ editorClass = _NoProfileRoiEditor
+ elif isinstance(
+ roi,
+ (rois._DefaultImageStackProfileRoiMixIn, rois.ProfileImageStackCrossROI),
+ ):
+ # Must be done before the default image ROI
+ # Cause ImageStack ROIs inherit from Image ROIs
+ editorClass = _DefaultImageStackProfileRoiEditor
+ elif isinstance(
+ roi, (rois._DefaultImageProfileRoiMixIn, rois.ProfileImageCrossROI)
+ ):
+ editorClass = _DefaultImageProfileRoiEditor
+ elif isinstance(
+ roi, (rois._DefaultScatterProfileRoiMixIn, rois.ProfileScatterCrossROI)
+ ):
+ editorClass = _DefaultScatterProfileRoiEditor
+ else:
+ # Unsupported
+ editorClass = _NoProfileRoiEditor
+ return editorClass
+
+ def _updateWidgets(self):
+ """Update the kind of editor to display, according to the selected
+ profile ROI."""
+ parent = self.parent()
+ editorClass = self.getEditorClass(self.__roi)
+ for widget in self.createdWidgets():
+ editor = editorClass(parent)
+ editor.setEditorData(self.__roi)
+ self.__setEditor(widget, editor)
+
+ def _updateWidgetValues(self):
+ """Update the content of the displayed editor, according to the
+ selected profile ROI."""
+ for widget in self.createdWidgets():
+ editor = self._getEditor(widget)
+ if self.__inhibiteReentance is editor:
+ continue
+ editor.setEditorData(self.__roi)
+
+ def _editorDataCommited(self):
+ """Handle changes from the editor."""
+ editor = self.sender()
+ if self.__roi is not None:
+ self.__inhibiteReentance = editor
+ editor.setRoiData(self.__roi)
+ self.__inhibiteReentance = None
diff --git a/src/silx/gui/plot/tools/profile/manager.py b/src/silx/gui/plot/tools/profile/manager.py
new file mode 100644
index 0000000..6f4ba35
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/manager.py
@@ -0,0 +1,1109 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a manager to compute and display profiles.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+import logging
+import weakref
+
+from silx.gui import qt
+from silx.gui import colors
+from silx.gui import utils
+
+from silx.utils.weakref import WeakMethodProxy
+from silx.gui import icons
+from silx.gui.plot import PlotWidget
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.tools.roi import CreateRoiModeAction
+from silx.gui.plot import items
+from silx.gui.qt import silxGlobalThreadPool
+from silx.gui.qt import inspect
+from . import rois
+from . import core
+from . import editors
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _RunnableComputeProfile(qt.QRunnable):
+ """Runner to process profiles
+
+ :param qt.QThreadPool threadPool: The thread which will be used to
+ execute this runner. It is used to update the used signals
+ :param ~silx.gui.plot.items.Item item: Item in which the profile is
+ computed
+ :param ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn roi: ROI
+ defining the profile shape and other characteristics
+ """
+
+ class _Signals(qt.QObject):
+ """Signal holder"""
+
+ resultReady = qt.Signal(object, object)
+ runnerFinished = qt.Signal(object)
+
+ def __init__(self, threadPool, item, roi):
+ """Constructor"""
+ super(_RunnableComputeProfile, self).__init__()
+ self._signals = self._Signals()
+ self._signals.moveToThread(threadPool.thread())
+ self._item = item
+ self._roi = roi
+ self._cancelled = False
+
+ def _lazyCancel(self):
+ """Cancel the runner if it is not yet started.
+
+ The threadpool will still execute the runner, but this will process
+ nothing.
+
+ This is only used with Qt<5.9 where QThreadPool.tryTake is not available.
+ """
+ self._cancelled = True
+
+ def autoDelete(self):
+ return False
+
+ def getRoi(self):
+ """Returns the ROI in which the runner will compute a profile.
+
+ :rtype: ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn
+ """
+ return self._roi
+
+ @property
+ def resultReady(self):
+ """Signal emitted when the result of the computation is available.
+
+ This signal provides 2 values: The ROI, and the computation result.
+ """
+ return self._signals.resultReady
+
+ @property
+ def runnerFinished(self):
+ """Signal emitted when runner have finished.
+
+ This signal provides a single value: the runner itself.
+ """
+ return self._signals.runnerFinished
+
+ def run(self):
+ """Process the profile computation."""
+ if not self._cancelled:
+ try:
+ profileData = self._roi.computeProfile(self._item)
+ except Exception:
+ _logger.error("Error while computing profile", exc_info=True)
+ else:
+ self.resultReady.emit(self._roi, profileData)
+ self.runnerFinished.emit(self)
+
+
+class ProfileWindow(qt.QMainWindow):
+ """
+ Display a computed profile.
+
+ The content can be described using :meth:`setRoiProfile` if the source of
+ the profile is a profile ROI, and :meth:`setProfile` for the data content.
+ """
+
+ sigClose = qt.Signal()
+ """Emitted by :meth:`closeEvent` (e.g. when the window is closed
+ through the window manager's close icon)."""
+
+ def __init__(self, parent=None, backend=None):
+ qt.QMainWindow.__init__(self, parent=parent, flags=qt.Qt.Dialog)
+
+ self.setWindowTitle("Profile window")
+ self._plot1D = None
+ self._plot2D = None
+ self._backend = backend
+ self._data = None
+
+ widget = qt.QWidget()
+ self._layout = qt.QStackedLayout(widget)
+ self._layout.setContentsMargins(0, 0, 0, 0)
+ self.setCentralWidget(widget)
+
+ def prepareWidget(self, roi):
+ """Called before the show to prepare the window to use with
+ a specific ROI."""
+ if isinstance(roi, rois._DefaultImageStackProfileRoiMixIn):
+ profileType = roi.getProfileType()
+ else:
+ profileType = "1D"
+ if profileType == "1D":
+ self.getPlot1D()
+ elif profileType == "2D":
+ self.getPlot2D()
+
+ def createPlot1D(self, parent, backend):
+ """Inherit this function to create your own plot to render 1D
+ profiles. The default value is a `Plot1D`.
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot.
+ See :class:`PlotWidget` for the list of supported backend.
+ :rtype: PlotWidget
+ """
+ # import here to avoid circular import
+ from ...PlotWindow import Plot1D
+
+ plot = Plot1D(parent=parent, backend=backend)
+ plot.setDataMargins(yMinMargin=0.1, yMaxMargin=0.1)
+ plot.setGraphYLabel("Profile")
+ plot.setGraphXLabel("")
+ positionInfo = plot.getPositionInfoWidget()
+ positionInfo.setSnappingMode(positionInfo.SNAPPING_CURVE)
+ return plot
+
+ def createPlot2D(self, parent, backend):
+ """Inherit this function to create your own plot to render 2D
+ profiles. The default value is a `Plot2D`.
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot.
+ See :class:`PlotWidget` for the list of supported backend.
+ :rtype: PlotWidget
+ """
+ # import here to avoid circular import
+ from ...PlotWindow import Plot2D
+
+ return Plot2D(parent=parent, backend=backend)
+
+ def getPlot1D(self, init=True):
+ """Return the current plot used to display curves and create it if it
+ does not yet exists and `init` is True. Else returns None."""
+ if not init:
+ return self._plot1D
+ if self._plot1D is None:
+ self._plot1D = self.createPlot1D(self, self._backend)
+ self._layout.addWidget(self._plot1D)
+ return self._plot1D
+
+ def _showPlot1D(self):
+ plot = self.getPlot1D()
+ self._layout.setCurrentWidget(plot)
+
+ def getPlot2D(self, init=True):
+ """Return the current plot used to display image and create it if it
+ does not yet exists and `init` is True. Else returns None."""
+ if not init:
+ return self._plot2D
+ if self._plot2D is None:
+ self._plot2D = self.createPlot2D(parent=self, backend=self._backend)
+ self._layout.addWidget(self._plot2D)
+ return self._plot2D
+
+ def _showPlot2D(self):
+ plot = self.getPlot2D()
+ self._layout.setCurrentWidget(plot)
+
+ def getCurrentPlotWidget(self):
+ return self._layout.currentWidget()
+
+ def closeEvent(self, qCloseEvent):
+ self.sigClose.emit()
+ qCloseEvent.accept()
+
+ def setRoiProfile(self, roi):
+ """Set the profile ROI which it the source of the following data
+ to display.
+
+ :param ProfileRoiMixIn roi: The profile ROI data source
+ """
+ if roi is None:
+ return
+ self.__color = colors.rgba(roi.getColor())
+
+ def _setImageProfile(self, data: core.ImageProfileData):
+ """
+ Setup the window to display a new profile data which is represented
+ by an image.
+
+ :param data: Computed data profile
+ """
+ plot = self.getPlot2D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+
+ coords = data.coords
+ colormap = data.colormap
+ profileScale = (coords[-1] - coords[0]) / data.profile.shape[1], 1
+ plot.addImage(
+ data.profile,
+ legend="profile",
+ colormap=colormap,
+ origin=(coords[0], 0),
+ scale=profileScale,
+ )
+ plot.getYAxis().setLabel("Frame index (depth)")
+
+ self._showPlot2D()
+
+ def _setCurveProfile(self, data: core.CurveProfileData):
+ """
+ Setup the window to display a new profile data which is represented
+ by a curve.
+
+ :param data: Computed data profile
+ """
+ plot = self.getPlot1D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+ plot.getYAxis().setLabel(data.yLabel)
+
+ plot.addCurve(data.coords, data.profile, legend="level", color=self.__color)
+
+ self._showPlot1D()
+
+ def _setRgbaProfile(self, data: core.RgbaProfileData):
+ """
+ Setup the window to display a new profile data which is represented
+ by a curve.
+
+ :param data: Computed data profile
+ """
+ plot = self.getPlot1D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+ plot.getYAxis().setLabel(data.yLabel)
+
+ self._showPlot1D()
+
+ plot.addCurve(data.coords, data.profile, legend="level", color="black")
+ plot.addCurve(data.coords, data.profile_r, legend="red", color="red")
+ plot.addCurve(data.coords, data.profile_g, legend="green", color="green")
+ plot.addCurve(data.coords, data.profile_b, legend="blue", color="blue")
+ if data.profile_a is not None:
+ plot.addCurve(data.coords, data.profile_a, legend="alpha", color="gray")
+
+ def _setCurvesProfile(self, data: core.CurvesProfileData):
+ """
+ Setup the window to display a new profile data which is represented
+ by multiple curves.
+
+ :param data: Computed data profile
+ """
+ plot = self.getPlot1D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+ plot.getYAxis().setLabel(data.yLabel)
+
+ self._showPlot1D()
+
+ for i, desc in enumerate(data.profiles):
+ name = desc.name if desc.name is not None else f"profile{i}"
+ plot.addCurve(data.coords, desc.profile, legend=name, color=desc.color)
+
+ def clear(self):
+ """Clear the window profile"""
+ plot = self.getPlot1D(init=False)
+ if plot is not None:
+ plot.clear()
+ plot = self.getPlot2D(init=False)
+ if plot is not None:
+ plot.clear()
+
+ def getProfile(self):
+ """Returns the profile data which is displayed"""
+ return self.__data
+
+ def setProfile(self, data):
+ """
+ Setup the window to display a new profile data.
+
+ This method dispatch the result to a specific method according to the
+ data type.
+
+ :param data: Computed data profile
+ """
+ self.__data = data
+ if data is None:
+ self.clear()
+ elif isinstance(data, core.ImageProfileData):
+ self._setImageProfile(data)
+ elif isinstance(data, core.RgbaProfileData):
+ self._setRgbaProfile(data)
+ elif isinstance(data, core.CurveProfileData):
+ self._setCurveProfile(data)
+ elif isinstance(data, core.CurvesProfileData):
+ self._setCurvesProfile(data)
+ else:
+ raise TypeError("Unsupported type %s" % type(data))
+
+
+class _ClearAction(qt.QAction):
+ """Action to clear the profile manager
+
+ The action is only enabled if something can be cleaned up.
+ """
+
+ def __init__(self, parent, profileManager):
+ super(_ClearAction, self).__init__(parent)
+ self.__profileManager = weakref.ref(profileManager)
+ icon = icons.getQIcon("profile-clear")
+ self.setIcon(icon)
+ self.setText("Clear profile")
+ self.setToolTip("Clear the profiles")
+ self.setCheckable(False)
+ self.setEnabled(False)
+ self.triggered.connect(profileManager.clearProfile)
+ plot = profileManager.getPlotWidget()
+ roiManager = profileManager.getRoiManager()
+ plot.sigInteractiveModeChanged.connect(self.__modeUpdated)
+ roiManager.sigRoiChanged.connect(self.__roiListUpdated)
+
+ def getProfileManager(self):
+ return self.__profileManager()
+
+ def __roiListUpdated(self):
+ self.__update()
+
+ def __modeUpdated(self, source):
+ self.__update()
+
+ def __update(self):
+ profileManager = self.getProfileManager()
+ if profileManager is None:
+ return
+ roiManager = profileManager.getRoiManager()
+ if roiManager is None:
+ return
+ enabled = roiManager.isStarted() or len(roiManager.getRois()) > 0
+ self.setEnabled(enabled)
+
+
+class _StoreLastParamBehavior(qt.QObject):
+ """This object allow to store and restore the properties of the ROI
+ profiles"""
+
+ def __init__(self, parent):
+ assert isinstance(parent, ProfileManager)
+ super(_StoreLastParamBehavior, self).__init__(parent=parent)
+ self.__properties = {}
+ self.__profileRoi = None
+ self.__filter = utils.LockReentrant()
+
+ def _roi(self):
+ """Return the spied ROI"""
+ if self.__profileRoi is None:
+ return None
+ roi = self.__profileRoi()
+ if roi is None:
+ self.__profileRoi = None
+ return roi
+
+ def setProfileRoi(self, roi):
+ """Set a profile ROI to spy.
+
+ :param ProfileRoiMixIn roi: A profile ROI
+ """
+ previousRoi = self._roi()
+ if previousRoi is roi:
+ return
+ if previousRoi is not None:
+ previousRoi.sigProfilePropertyChanged.disconnect(
+ self._profilePropertyChanged
+ )
+ self.__profileRoi = None if roi is None else weakref.ref(roi)
+ if roi is not None:
+ roi.sigProfilePropertyChanged.connect(self._profilePropertyChanged)
+
+ def _profilePropertyChanged(self):
+ """Handle changes on the properties defining the profile ROI."""
+ if self.__filter.locked():
+ return
+ roi = self.sender()
+ self.storeProperties(roi)
+
+ def storeProperties(self, roi):
+ if isinstance(
+ roi,
+ (rois._DefaultImageStackProfileRoiMixIn, rois.ProfileImageStackCrossROI),
+ ):
+ self.__properties["method"] = roi.getProfileMethod()
+ self.__properties["line-width"] = roi.getProfileLineWidth()
+ self.__properties["type"] = roi.getProfileType()
+ elif isinstance(
+ roi, (rois._DefaultImageProfileRoiMixIn, rois.ProfileImageCrossROI)
+ ):
+ self.__properties["method"] = roi.getProfileMethod()
+ self.__properties["line-width"] = roi.getProfileLineWidth()
+ elif isinstance(
+ roi, (rois._DefaultScatterProfileRoiMixIn, rois.ProfileScatterCrossROI)
+ ):
+ self.__properties["npoints"] = roi.getNPoints()
+
+ def restoreProperties(self, roi):
+ with self.__filter:
+ if isinstance(
+ roi,
+ (
+ rois._DefaultImageStackProfileRoiMixIn,
+ rois.ProfileImageStackCrossROI,
+ ),
+ ):
+ value = self.__properties.get("method", None)
+ if value is not None:
+ roi.setProfileMethod(value)
+ value = self.__properties.get("line-width", None)
+ if value is not None:
+ roi.setProfileLineWidth(value)
+ value = self.__properties.get("type", None)
+ if value is not None:
+ roi.setProfileType(value)
+ elif isinstance(
+ roi, (rois._DefaultImageProfileRoiMixIn, rois.ProfileImageCrossROI)
+ ):
+ value = self.__properties.get("method", None)
+ if value is not None:
+ roi.setProfileMethod(value)
+ value = self.__properties.get("line-width", None)
+ if value is not None:
+ roi.setProfileLineWidth(value)
+ elif isinstance(
+ roi, (rois._DefaultScatterProfileRoiMixIn, rois.ProfileScatterCrossROI)
+ ):
+ value = self.__properties.get("npoints", None)
+ if value is not None:
+ roi.setNPoints(value)
+
+
+class ProfileManager(qt.QObject):
+ """Base class for profile management tools
+
+ :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate.
+ :param plot: :class:`~silx.gui.plot.tools.roi.RegionOfInterestManager`
+ on which to operate.
+ """
+
+ def __init__(self, parent=None, plot=None, roiManager=None):
+ super(ProfileManager, self).__init__(parent)
+
+ assert isinstance(plot, PlotWidget)
+ self._plotRef = weakref.ref(plot, WeakMethodProxy(self.__plotDestroyed))
+
+ # Set-up interaction manager
+ if roiManager is None:
+ roiManager = RegionOfInterestManager(plot)
+
+ self._roiManagerRef = weakref.ref(roiManager)
+ self._rois = []
+ self._pendingRunners = []
+ """List of ROIs which have to be updated"""
+
+ self.__reentrantResults = {}
+ """Store reentrant result to avoid to skip some of them
+ cause the implementation uses a QEventLoop."""
+
+ self._profileWindowClass = ProfileWindow
+ """Class used to display the profile results"""
+
+ self._computedProfiles = 0
+ """Statistics for tests"""
+
+ self.__itemTypes = []
+ """Kind of items to use"""
+
+ self.__tracking = False
+ """Is the plot active items are tracked"""
+
+ self.__useColorFromCursor = True
+ """If true, force the ROI color with the colormap marker color"""
+
+ self._item = None
+ """The selected item"""
+
+ self.__singleProfileAtATime = True
+ """When it's true, only a single profile is displayed at a time."""
+
+ self._previousWindowGeometry = []
+
+ self._storeProperties = _StoreLastParamBehavior(self)
+ """If defined the profile properties of the last ROI are reused to the
+ new created ones"""
+
+ # Listen to plot limits changed
+ plot.getXAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile)
+ plot.getYAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile)
+
+ roiManager.sigInteractiveModeFinished.connect(self.__interactionFinished)
+ roiManager.sigInteractiveRoiCreated.connect(self.__roiCreated)
+ roiManager.sigRoiAdded.connect(self.__roiAdded)
+ roiManager.sigRoiAboutToBeRemoved.connect(self.__roiRemoved)
+
+ def setSingleProfile(self, enable):
+ """
+ Enable or disable the single profile mode.
+
+ In single mode, the manager enforce a single ROI at the same
+ time. A new one will remove the previous one.
+
+ If this mode is not enabled, many ROIs can be created, and many
+ profile windows will be displayed.
+ """
+ self.__singleProfileAtATime = enable
+
+ def isSingleProfile(self):
+ """
+ Returns true if the manager is in a single profile mode.
+
+ :rtype: bool
+ """
+ return self.__singleProfileAtATime
+
+ def __interactionFinished(self):
+ """Handle end of interactive mode"""
+ pass
+
+ def __roiAdded(self, roi):
+ """Handle new ROI"""
+ # Filter out non profile ROIs
+ if not isinstance(roi, core.ProfileRoiMixIn):
+ return
+ self.__addProfile(roi)
+
+ def __roiRemoved(self, roi):
+ """Handle removed ROI"""
+ # Filter out non profile ROIs
+ if not isinstance(roi, core.ProfileRoiMixIn):
+ return
+ self.__removeProfile(roi)
+
+ def createProfileAction(self, profileRoiClass, parent=None):
+ """Create an action from a class of ProfileRoi
+
+ :param core.ProfileRoiMixIn profileRoiClass: A class of a profile ROI
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: qt.QAction
+ """
+ if not issubclass(profileRoiClass, core.ProfileRoiMixIn):
+ raise TypeError("Type %s not expected" % type(profileRoiClass))
+ roiManager = self.getRoiManager()
+ action = CreateRoiModeAction(parent, roiManager, profileRoiClass)
+ if hasattr(profileRoiClass, "ICON"):
+ action.setIcon(icons.getQIcon(profileRoiClass.ICON))
+ if hasattr(profileRoiClass, "NAME"):
+
+ def articulify(word):
+ """Add an an/a article in the front of the word"""
+ first = word[1] if word[0] == "h" else word[0]
+ if first in "aeiou":
+ return "an " + word
+ return "a " + word
+
+ action.setText("Define %s" % articulify(profileRoiClass.NAME))
+ action.setToolTip("Enables %s selection mode" % profileRoiClass.NAME)
+ action.setSingleShot(True)
+ return action
+
+ def createClearAction(self, parent):
+ """Create an action to clean up the plot from the profile ROIs.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: qt.QAction
+ """
+ action = _ClearAction(parent, self)
+ return action
+
+ def createImageActions(self, parent):
+ """Create actions designed for image items. This actions created
+ new ROIs.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileImageHorizontalLineROI,
+ rois.ProfileImageVerticalLineROI,
+ rois.ProfileImageLineROI,
+ rois.ProfileImageDirectedLineROI,
+ rois.ProfileImageCrossROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createScatterActions(self, parent):
+ """Create actions designed for scatter items. This actions created
+ new ROIs.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileScatterHorizontalLineROI,
+ rois.ProfileScatterVerticalLineROI,
+ rois.ProfileScatterLineROI,
+ rois.ProfileScatterCrossROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createScatterSliceActions(self, parent):
+ """Create actions designed for regular scatter items. This actions
+ created new ROIs.
+
+ This ROIs was designed to use the input data without interpolation,
+ like you could do with an image.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileScatterHorizontalSliceROI,
+ rois.ProfileScatterVerticalSliceROI,
+ rois.ProfileScatterCrossSliceROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createImageStackActions(self, parent):
+ """Create actions designed for stack image items. This actions
+ created new ROIs.
+
+ This ROIs was designed to create both profile on the displayed image
+ and profile on the full stack (2D result).
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileImageStackHorizontalLineROI,
+ rois.ProfileImageStackVerticalLineROI,
+ rois.ProfileImageStackLineROI,
+ rois.ProfileImageStackCrossROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createEditorAction(self, parent):
+ """Create an action containing GUI to edit the selected profile ROI.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: qt.QAction
+ """
+ action = editors.ProfileRoiEditorAction(parent)
+ action.setRoiManager(self.getRoiManager())
+ return action
+
+ def setItemType(self, image=False, scatter=False):
+ """Set the item type to use and select the active one.
+
+ :param bool image: Image item are allowed
+ :param bool scatter: Scatter item are allowed
+ """
+ self.__itemTypes = []
+ plot = self.getPlotWidget()
+ item = None
+ if image:
+ self.__itemTypes.append("image")
+ item = plot.getActiveImage()
+ if scatter:
+ self.__itemTypes.append("scatter")
+ if item is None:
+ item = plot.getActiveScatter()
+ self.setPlotItem(item)
+
+ def setProfileWindowClass(self, profileWindowClass):
+ """Set the class which will be instantiated to display profile result."""
+ self._profileWindowClass = profileWindowClass
+
+ def setActiveItemTracking(self, tracking):
+ """Enable/disable the tracking of the active item of the plot.
+
+ :param bool tracking: Tracking mode
+ """
+ if self.__tracking == tracking:
+ return
+ plot = self.getPlotWidget()
+ if self.__tracking:
+ plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
+ self.__tracking = tracking
+ if self.__tracking:
+ plot.sigActiveImageChanged.connect(self.__activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self.__activeScatterChanged)
+
+ def setDefaultColorFromCursorColor(self, enabled):
+ """Enabled/disable the use of the colormap cursor color to display the
+ ROIs.
+
+ If set, the manager will update the color of the profile ROIs using the
+ current colormap cursor color from the selected item.
+ """
+ self.__useColorFromCursor = enabled
+
+ def __activeImageChanged(self, previous, legend):
+ """Handle plot item selection"""
+ if "image" in self.__itemTypes:
+ plot = self.getPlotWidget()
+ item = plot.getImage(legend)
+ self.setPlotItem(item)
+
+ def __activeScatterChanged(self, previous, legend):
+ """Handle plot item selection"""
+ if "scatter" in self.__itemTypes:
+ plot = self.getPlotWidget()
+ item = plot.getScatter(legend)
+ self.setPlotItem(item)
+
+ def __roiCreated(self, roi):
+ """Handle ROI creation"""
+ # Filter out non profile ROIs
+ if isinstance(roi, core.ProfileRoiMixIn):
+ if self._storeProperties is not None:
+ # Initialize the properties with the previous ones
+ self._storeProperties.restoreProperties(roi)
+
+ def __addProfile(self, profileRoi):
+ """Add a new ROI to the manager."""
+ if profileRoi.getFocusProxy() is None:
+ if self._storeProperties is not None:
+ # Follow changes on properties
+ self._storeProperties.setProfileRoi(profileRoi)
+ if self.__singleProfileAtATime:
+ # FIXME: It would be good to reuse the windows to avoid blinking
+ self.clearProfile()
+
+ profileRoi._setProfileManager(self)
+ self._updateRoiColor(profileRoi)
+ self._rois.append(profileRoi)
+ self.requestUpdateProfile(profileRoi)
+
+ def __removeProfile(self, profileRoi):
+ """Remove a ROI from the manager."""
+ window = self._disconnectProfileWindow(profileRoi)
+ if window is not None:
+ geometry = window.geometry()
+ if not geometry.isEmpty():
+ self._previousWindowGeometry.append(geometry)
+ self.clearProfileWindow(window)
+ if profileRoi in self._rois:
+ self._rois.remove(profileRoi)
+
+ def _disconnectProfileWindow(self, profileRoi):
+ """Handle profile window close."""
+ window = profileRoi.getProfileWindow()
+ profileRoi.setProfileWindow(None)
+ return window
+
+ def clearProfile(self):
+ """Clear the associated ROI profile"""
+ roiManager = self.getRoiManager()
+ for roi in list(self._rois):
+ if roi.getFocusProxy() is not None:
+ # Skip sub ROIs, it will be removed by their parents
+ continue
+ roiManager.removeRoi(roi)
+
+ if not roiManager.isDrawing():
+ # Clean the selected mode
+ roiManager.stop()
+
+ def hasPendingOperations(self):
+ """Returns true if a thread is still computing or displaying a profile.
+
+ :rtype: bool
+ """
+ return len(self.__reentrantResults) > 0 or len(self._pendingRunners) > 0
+
+ def requestUpdateAllProfile(self):
+ """Request to update the profile of all the managed ROIs."""
+ for roi in self._rois:
+ self.requestUpdateProfile(roi)
+
+ def requestUpdateProfile(self, profileRoi):
+ """Request to update a specific profile ROI.
+
+ :param ~core.ProfileRoiMixIn profileRoi:
+ """
+ if profileRoi.computeProfile is None:
+ return
+ threadPool = silxGlobalThreadPool()
+
+ # Clean up deprecated runners
+ for runner in list(self._pendingRunners):
+ if not inspect.isValid(runner):
+ self._pendingRunners.remove(runner)
+ continue
+ if runner.getRoi() is profileRoi:
+ if hasattr(threadPool, "tryTake"):
+ if threadPool.tryTake(runner):
+ self._pendingRunners.remove(runner)
+ else: # Support Qt<5.9
+ runner._lazyCancel()
+
+ item = self.getPlotItem()
+ if item is None or not isinstance(item, profileRoi.ITEM_KIND):
+ # This item is not compatible with this profile
+ profileRoi._setPlotItem(None)
+ profileWindow = profileRoi.getProfileWindow()
+ if profileWindow is not None:
+ profileWindow.setProfile(None)
+ return
+
+ profileRoi._setPlotItem(item)
+ runner = _RunnableComputeProfile(threadPool, item, profileRoi)
+ runner.runnerFinished.connect(self.__cleanUpRunner)
+ runner.resultReady.connect(self.__displayResult)
+ self._pendingRunners.append(runner)
+ threadPool.start(runner)
+
+ def __cleanUpRunner(self, runner):
+ """Remove a thread pool runner from the list of hold tasks.
+
+ Called at the termination of the runner.
+ """
+ if runner in self._pendingRunners:
+ self._pendingRunners.remove(runner)
+
+ def __displayResult(self, roi, profileData):
+ """Display the result of a ROI.
+
+ :param ~core.ProfileRoiMixIn profileRoi: A managed ROI
+ :param ~core.CurveProfileData profileData: Computed data profile
+ """
+ if roi in self.__reentrantResults:
+ # Store the data to process it in the main loop
+ # And not a sub loop created by initProfileWindow
+ # This also remove the duplicated requested
+ self.__reentrantResults[roi] = profileData
+ return
+
+ self.__reentrantResults[roi] = profileData
+ self._computedProfiles = self._computedProfiles + 1
+ window = roi.getProfileWindow()
+ if window is None:
+ plot = self.getPlotWidget()
+ window = self.createProfileWindow(plot, roi)
+ # roi.profileWindow have to be set before initializing the window
+ # Cause the initialization is using QEventLoop
+ roi.setProfileWindow(window)
+ self.initProfileWindow(window, roi)
+ window.show()
+
+ lastData = self.__reentrantResults.pop(roi)
+ window.setProfile(lastData)
+
+ def __plotDestroyed(self, ref):
+ """Handle finalization of PlotWidget
+
+ :param ref: weakref to the plot
+ """
+ self._plotRef = None
+ self._roiManagerRef = None
+ self._pendingRunners = []
+
+ def setPlotItem(self, item):
+ """Set the plot item focused by the profile manager.
+
+ :param ~silx.gui.plot.items.Item item: A plot item
+ """
+ previous = self.getPlotItem()
+ if previous is item:
+ return
+ if item is None:
+ self._item = None
+ else:
+ item.sigItemChanged.connect(self.__itemChanged)
+ self._item = weakref.ref(item)
+ self._updateRoiColors()
+ self.requestUpdateAllProfile()
+
+ def getDefaultColor(self, item):
+ """Returns the default ROI color to use according to the given item.
+
+ :param ~silx.gui.plot.items.item.Item item: AN item
+ :rtype: qt.QColor
+ """
+ color = "pink"
+ if isinstance(item, items.ColormapMixIn):
+ colormap = item.getColormap()
+ name = colormap.getName()
+ if name is not None:
+ color = colors.cursorColorForColormap(name)
+ color = colors.asQColor(color)
+ return color
+
+ def _updateRoiColors(self):
+ """Update ROI color according to the item selection"""
+ if not self.__useColorFromCursor:
+ return
+ item = self.getPlotItem()
+ color = self.getDefaultColor(item)
+ for roi in self._rois:
+ roi.setColor(color)
+
+ def _updateRoiColor(self, roi):
+ """Update a specific ROI according to the current selected item.
+
+ :param RegionOfInterest roi: The ROI to update
+ """
+ if not self.__useColorFromCursor:
+ return
+ item = self.getPlotItem()
+ color = self.getDefaultColor(item)
+ roi.setColor(color)
+
+ def __itemChanged(self, changeType):
+ """Handle item changes."""
+ if changeType in (
+ items.ItemChangedType.DATA,
+ items.ItemChangedType.MASK,
+ items.ItemChangedType.POSITION,
+ items.ItemChangedType.SCALE,
+ ):
+ self.requestUpdateAllProfile()
+ elif changeType == (items.ItemChangedType.COLORMAP):
+ self._updateRoiColors()
+
+ def getPlotItem(self):
+ """Returns the item focused by the profile manager.
+
+ :rtype: ~silx.gui.plot.items.Item
+ """
+ if self._item is None:
+ return None
+ item = self._item()
+ if item is None:
+ self._item = None
+ return item
+
+ def getPlotWidget(self):
+ """The plot associated to the profile manager.
+
+ :rtype: ~silx.gui.plot.PlotWidget
+ """
+ if self._plotRef is None:
+ return None
+ plot = self._plotRef()
+ if plot is None:
+ self._plotRef = None
+ return plot
+
+ def getCurrentRoi(self):
+ """Returns the currently selected ROI, else None.
+
+ :rtype: core.ProfileRoiMixIn
+ """
+ roiManager = self.getRoiManager()
+ if roiManager is None:
+ return None
+ roi = roiManager.getCurrentRoi()
+ if not isinstance(roi, core.ProfileRoiMixIn):
+ return None
+ return roi
+
+ def getRoiManager(self):
+ """Returns the used ROI manager
+
+ :rtype: RegionOfInterestManager
+ """
+ return self._roiManagerRef()
+
+ def createProfileWindow(self, plot, roi):
+ """Create a new profile window.
+
+ :param ~core.ProfileRoiMixIn roi: The plot containing the raw data
+ :param ~core.ProfileRoiMixIn roi: A managed ROI
+ :rtype: ~ProfileWindow
+ """
+ return self._profileWindowClass(plot)
+
+ def initProfileWindow(self, profileWindow, roi):
+ """This function is called just after the profile window creation in
+ order to initialize the window location.
+
+ :param ~ProfileWindow profileWindow:
+ The profile window to initialize.
+ """
+ # Enforce the use of one of the widgets
+ # To have the correct window size
+ profileWindow.prepareWidget(roi)
+ profileWindow.adjustSize()
+
+ # Trick to avoid blinking while retrieving the right window size
+ # Display the window, hide it and wait for some event loops
+ profileWindow.show()
+ profileWindow.hide()
+ eventLoop = qt.QEventLoop(self)
+ for _ in range(10):
+ if not eventLoop.processEvents():
+ break
+
+ profileWindow.show()
+ if len(self._previousWindowGeometry) > 0:
+ geometry = self._previousWindowGeometry.pop()
+ profileWindow.setGeometry(geometry)
+ return
+
+ window = self.getPlotWidget().window()
+ winGeom = window.frameGeometry()
+ if qt.BINDING == "PyQt5":
+ qapp = qt.QApplication.instance()
+ desktop = qapp.desktop()
+ screenGeom = desktop.availableGeometry(window)
+ else: # Qt6 (and also Qt>=5.14)
+ screenGeom = window.screen().availableGeometry()
+ spaceOnLeftSide = winGeom.left()
+ spaceOnRightSide = screenGeom.width() - winGeom.right()
+
+ profileGeom = profileWindow.frameGeometry()
+ profileWidth = profileGeom.width()
+
+ # Align vertically to the center of the window
+ top = winGeom.top() + (winGeom.height() - profileGeom.height()) // 2
+
+ margin = 5
+ if profileWidth < spaceOnRightSide:
+ # Place profile on the right
+ left = winGeom.right() + margin
+ elif profileWidth < spaceOnLeftSide:
+ # Place profile on the left
+ left = max(0, winGeom.left() - profileWidth - margin)
+ else:
+ # Move it as much as possible where there is more space
+ if spaceOnLeftSide > spaceOnRightSide:
+ left = 0
+ else:
+ left = screenGeom.width() - profileGeom.width()
+ profileWindow.move(left, top)
+
+ def clearProfileWindow(self, profileWindow):
+ """Called when a profile window is not anymore needed.
+
+ By default the window will be closed. But it can be
+ inherited to change this behavior.
+ """
+ profileWindow.deleteLater()
diff --git a/src/silx/gui/plot/tools/profile/rois.py b/src/silx/gui/plot/tools/profile/rois.py
new file mode 100644
index 0000000..23f086a
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/rois.py
@@ -0,0 +1,1176 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module define ROIs for profile tools.
+
+.. inheritance-diagram::
+ silx.gui.plot.tools.profile.rois
+ :top-classes: silx.gui.plot.tools.profile.core.ProfileRoiMixIn, silx.gui.plot.items.roi.RegionOfInterest
+ :parts: 1
+ :private-bases:
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "01/12/2020"
+
+import numpy
+import weakref
+from concurrent.futures import CancelledError
+
+from silx.gui import colors
+
+from silx.gui.plot import items
+from silx.gui.plot.items import roi as roi_items
+from . import core
+from silx.gui import utils
+from .....utils.proxy import docstring
+
+
+def _relabelAxes(plot, text):
+ """Relabel {xlabel} and {ylabel} from this text using the corresponding
+ plot axis label. If the axis label is empty, label it with "X" and "Y".
+
+ :rtype: str
+ """
+ xLabel = plot.getXAxis().getLabel()
+ if not xLabel:
+ xLabel = "X"
+ yLabel = plot.getYAxis().getLabel()
+ if not yLabel:
+ yLabel = "Y"
+ return text.format(xlabel=xLabel, ylabel=yLabel)
+
+
+def _lineProfileTitle(x0, y0, x1, y1):
+ """Compute corresponding plot title
+
+ This can be overridden to change title behavior.
+
+ :param float x0: Profile start point X coord
+ :param float y0: Profile start point Y coord
+ :param float x1: Profile end point X coord
+ :param float y1: Profile end point Y coord
+ :return: Title to use
+ :rtype: str
+ """
+ if x0 == x1:
+ title = "{xlabel} = %g; {ylabel} = [%g, %g]" % (x0, y0, y1)
+ elif y0 == y1:
+ title = "{ylabel} = %g; {xlabel} = [%g, %g]" % (y0, x0, x1)
+ else:
+ m = (y1 - y0) / (x1 - x0)
+ b = y0 - m * x0
+ title = "{ylabel} = %g * {xlabel} %+g" % (m, b)
+
+ return title
+
+
+class _ImageProfileArea(items.Shape):
+ """This shape displays the location of pixels used to compute the
+ profile."""
+
+ def __init__(self, parentRoi):
+ items.Shape.__init__(self, "polygon")
+ color = colors.rgba(parentRoi.getColor())
+ self.setColor(color)
+ self.setFill(True)
+ self.setOverlay(True)
+ self.setPoints([[0, 0], [0, 0]]) # Else it segfault
+
+ self.__parentRoi = weakref.ref(parentRoi)
+ parentRoi.sigItemChanged.connect(self._updateAreaProperty)
+ parentRoi.sigRegionChanged.connect(self._updateArea)
+ parentRoi.sigProfilePropertyChanged.connect(self._updateArea)
+ parentRoi.sigPlotItemChanged.connect(self._updateArea)
+
+ def getParentRoi(self):
+ if self.__parentRoi is None:
+ return None
+ parentRoi = self.__parentRoi()
+ if parentRoi is None:
+ self.__parentRoi = None
+ return parentRoi
+
+ def _updateAreaProperty(self, event=None, checkVisibility=True):
+ parentRoi = self.sender()
+ if event == items.ItemChangedType.COLOR:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+ elif event == items.ItemChangedType.VISIBLE:
+ if self.getPlotItem() is not None:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+
+ def _updateArea(self):
+ roi = self.getParentRoi()
+ item = roi.getPlotItem()
+ if item is None:
+ self.setVisible(False)
+ return
+ polygon = self._computePolygon(item)
+ self.setVisible(True)
+ polygon = numpy.array(polygon).T
+ self.setLineStyle("--")
+ self.setPoints(polygon, copy=False)
+
+ def _computePolygon(self, item):
+ if not isinstance(item, items.ImageBase):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ currentData = item.getValueData(copy=False)
+
+ roi = self.getParentRoi()
+ origin = item.getOrigin()
+ scale = item.getScale()
+ _coords, _profile, area, _profileName, _xLabel = core.createProfile(
+ roiInfo=roi._getRoiInfo(),
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=roi.getProfileLineWidth(),
+ method="none",
+ )
+ return area
+
+
+class _SliceProfileArea(items.Shape):
+ """This shape displays the location a profile in a scatter.
+
+ Each point used to compute the slice are linked together.
+ """
+
+ def __init__(self, parentRoi):
+ items.Shape.__init__(self, "polygon")
+ color = colors.rgba(parentRoi.getColor())
+ self.setColor(color)
+ self.setFill(True)
+ self.setOverlay(True)
+ self.setPoints([[0, 0], [0, 0]]) # Else it segfault
+
+ self.__parentRoi = weakref.ref(parentRoi)
+ parentRoi.sigItemChanged.connect(self._updateAreaProperty)
+ parentRoi.sigRegionChanged.connect(self._updateArea)
+ parentRoi.sigProfilePropertyChanged.connect(self._updateArea)
+ parentRoi.sigPlotItemChanged.connect(self._updateArea)
+
+ def getParentRoi(self):
+ if self.__parentRoi is None:
+ return None
+ parentRoi = self.__parentRoi()
+ if parentRoi is None:
+ self.__parentRoi = None
+ return parentRoi
+
+ def _updateAreaProperty(self, event=None, checkVisibility=True):
+ parentRoi = self.sender()
+ if event == items.ItemChangedType.COLOR:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+ elif event == items.ItemChangedType.VISIBLE:
+ if self.getPlotItem() is not None:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+
+ def _updateArea(self):
+ roi = self.getParentRoi()
+ item = roi.getPlotItem()
+ if item is None:
+ self.setVisible(False)
+ return
+ polylines = self._computePolylines(roi, item)
+ if polylines is None:
+ self.setVisible(False)
+ return
+ self.setVisible(True)
+ self.setLineStyle("--")
+ self.setPoints(polylines, copy=False)
+
+ def _computePolylines(self, roi, item):
+ slicing = roi._getSlice(item)
+ if slicing is None:
+ return None
+ xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False)
+ xx, yy = xx[slicing], yy[slicing]
+ polylines = numpy.array((xx, yy)).T
+ if len(polylines) == 0:
+ return None
+ return polylines
+
+
+class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
+ """Provide common behavior for silx default image profile ROI."""
+
+ ITEM_KIND = items.ImageBase
+
+ def __init__(self, parent=None):
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.__method = "mean"
+ self.__width = 1
+ self.sigRegionChanged.connect(self.__regionChanged)
+ self.sigPlotItemChanged.connect(self.__updateArea)
+ self.__area = _ImageProfileArea(self)
+ self.addItem(self.__area)
+
+ def __regionChanged(self):
+ self.invalidateProfile()
+ self.__updateArea()
+
+ def setProfileMethod(self, method):
+ """
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ """
+ if self.__method == method:
+ return
+ self.__method = method
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def getProfileMethod(self):
+ return self.__method
+
+ def setProfileLineWidth(self, width):
+ if self.__width == width:
+ return
+ self.__width = width
+ self.__updateArea()
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def getProfileLineWidth(self):
+ return self.__width
+
+ def __updateArea(self):
+ plotItem = self.getPlotItem()
+ if plotItem is None:
+ self.setLineStyle("-")
+ else:
+ self.setLineStyle("--")
+
+ def _getRoiInfo(self):
+ """Wrapper to allow to reuse the previous Profile code.
+
+ It would be good to remove it at one point.
+ """
+ if isinstance(self, roi_items.HorizontalLineROI):
+ lineProjectionMode = "X"
+ y = self.getPosition()
+ roiStart = (0, y)
+ roiEnd = (1, y)
+ elif isinstance(self, roi_items.VerticalLineROI):
+ lineProjectionMode = "Y"
+ x = self.getPosition()
+ roiStart = (x, 0)
+ roiEnd = (x, 1)
+ elif isinstance(self, roi_items.LineROI):
+ lineProjectionMode = "D"
+ roiStart, roiEnd = self.getEndPoints()
+ else:
+ assert False
+
+ return roiStart, roiEnd, lineProjectionMode
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.ImageBase):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+ roiInfo = self._getRoiInfo()
+
+ def createProfile2(currentData):
+ coords, profile, _area, profileName, xLabel = core.createProfile(
+ roiInfo=roiInfo,
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=lineWidth,
+ method=method,
+ )
+ return coords, profile, profileName, xLabel
+
+ currentData = item.getValueData(copy=False)
+
+ yLabel = "%s" % str(method).capitalize()
+ coords, profile, title, xLabel = createProfile2(currentData)
+ title = title + "; width = %d" % lineWidth
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ title = _relabelAxes(plot, title)
+ xLabel = _relabelAxes(plot, xLabel)
+
+ if isinstance(item, items.ImageRgba):
+ rgba = item.getData(copy=False)
+ _coords, r, _profileName, _xLabel = createProfile2(rgba[..., 0])
+ _coords, g, _profileName, _xLabel = createProfile2(rgba[..., 1])
+ _coords, b, _profileName, _xLabel = createProfile2(rgba[..., 2])
+ if rgba.shape[-1] == 4:
+ _coords, a, _profileName, _xLabel = createProfile2(rgba[..., 3])
+ else:
+ a = [None]
+ data = core.RgbaProfileData(
+ coords=coords,
+ profile=profile[0],
+ profile_r=r[0],
+ profile_g=g[0],
+ profile_b=b[0],
+ profile_a=a[0],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ else:
+ data = core.CurveProfileData(
+ coords=coords,
+ profile=profile[0],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
+
+
+class ProfileImageHorizontalLineROI(
+ roi_items.HorizontalLineROI, _DefaultImageProfileRoiMixIn
+):
+ """ROI for an horizontal profile at a location of an image"""
+
+ ICON = "shape-horizontal"
+ NAME = "horizontal line profile"
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageVerticalLineROI(
+ roi_items.VerticalLineROI, _DefaultImageProfileRoiMixIn
+):
+ """ROI for a vertical profile at a location of an image"""
+
+ ICON = "shape-vertical"
+ NAME = "vertical line profile"
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageLineROI(roi_items.LineROI, _DefaultImageProfileRoiMixIn):
+ """ROI for an image profile between 2 points.
+
+ The X profile of this ROI is the projecting into one of the x/y axes,
+ using its scale and its orientation.
+ """
+
+ ICON = "shape-diagonal"
+ NAME = "line profile"
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageDirectedLineROI(roi_items.LineROI, _DefaultImageProfileRoiMixIn):
+ """ROI for an image profile between 2 points.
+
+ The X profile of the line is displayed projected into the line itself,
+ using its scale and its orientation. It's the distance from the origin.
+ """
+
+ ICON = "shape-diagonal-directed"
+ NAME = "directed line profile"
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+ self._handleStart.setSymbol("o")
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.ImageBase):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ from silx.image.bilinear import BilinearImage
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+ currentData = item.getValueData(copy=False)
+
+ roiInfo = self._getRoiInfo()
+ roiStart, roiEnd, _lineProjectionMode = roiInfo
+
+ startPt = (
+ (roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0],
+ )
+ endPt = ((roiEnd[1] - origin[1]) / scale[1], (roiEnd[0] - origin[0]) / scale[0])
+
+ if numpy.array_equal(startPt, endPt):
+ return None
+
+ bilinear = BilinearImage(currentData)
+ profile = bilinear.profile_line(
+ (startPt[0] - 0.5, startPt[1] - 0.5),
+ (endPt[0] - 0.5, endPt[1] - 0.5),
+ lineWidth,
+ method=method,
+ )
+
+ # Compute the line size
+ lineSize = numpy.sqrt(
+ (roiEnd[1] - roiStart[1]) ** 2 + (roiEnd[0] - roiStart[0]) ** 2
+ )
+ coords = numpy.linspace(
+ 0, lineSize, len(profile), endpoint=True, dtype=numpy.float32
+ )
+
+ title = _lineProfileTitle(*roiStart, *roiEnd)
+ title = title + "; width = %d" % lineWidth
+ xLabel = "√({xlabel}²+{ylabel}²)"
+ yLabel = str(method).capitalize()
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ xLabel = _relabelAxes(plot, xLabel)
+ title = _relabelAxes(plot, title)
+
+ data = core.CurveProfileData(
+ coords=coords,
+ profile=profile,
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
+
+
+class _ProfileCrossROI(roi_items.HandleBasedROI, core.ProfileRoiMixIn):
+
+ """ROI to manage a cross of profiles
+
+ It is managed using 2 sub ROIs for vertical and horizontal.
+ """
+
+ _kind = "Cross"
+ """Label for this kind of ROI"""
+
+ _plotShape = "point"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ roi_items.HandleBasedROI.__init__(self, parent=parent)
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.sigRegionChanged.connect(self.__regionChanged)
+ self.sigAboutToBeRemoved.connect(self.__aboutToBeRemoved)
+ self.__position = 0, 0
+ self.__vline = None
+ self.__hline = None
+ self.__handle = self.addHandle()
+ self.__handleLabel = self.addLabelHandle()
+ self.__handleLabel.setText(self.getName())
+ self.__inhibitReentance = utils.LockReentrant()
+ self.computeProfile = None
+ self.sigItemChanged.connect(self.__updateLineProperty)
+
+ # Make sure the marker is over the ROIs
+ self.__handle.setZValue(1)
+ # Create the vline and the hline
+ self._createSubRois()
+
+ @docstring(roi_items.HandleBasedROI)
+ def contains(self, position):
+ roiPos = self.getPosition()
+ return position[0] == roiPos[0] or position[1] == roiPos[1]
+
+ def setFirstShapePoints(self, points):
+ pos = points[0]
+ self.setPosition(pos)
+
+ def getPosition(self):
+ """Returns the position of this ROI
+
+ :rtype: numpy.ndarray
+ """
+ return self.__position
+
+ def setPosition(self, pos):
+ """Set the position of this ROI
+
+ :param numpy.ndarray pos: 2d-coordinate of this point
+ """
+ self.__position = pos
+ with utils.blockSignals(self.__handle):
+ self.__handle.setPosition(*pos)
+ with utils.blockSignals(self.__handleLabel):
+ self.__handleLabel.setPosition(*pos)
+ self.sigRegionChanged.emit()
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self.__handle:
+ self.setPosition(current)
+
+ def __updateLineProperty(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.NAME:
+ self.__handleLabel.setText(self.getName())
+ elif event in [items.ItemChangedType.COLOR, items.ItemChangedType.VISIBLE]:
+ lines = []
+ if self.__vline:
+ lines.append(self.__vline)
+ if self.__hline:
+ lines.append(self.__hline)
+ self._updateItemProperty(event, self, lines)
+
+ def _createLines(self, parent):
+ """Inherit this function to return 2 ROI objects for respectivly
+ the horizontal, and the vertical lines."""
+ raise NotImplementedError()
+
+ def _setProfileManager(self, profileManager):
+ core.ProfileRoiMixIn._setProfileManager(self, profileManager)
+ # Connecting the vline and the hline
+ roiManager = profileManager.getRoiManager()
+ roiManager.addRoi(self.__vline)
+ roiManager.addRoi(self.__hline)
+
+ def _createSubRois(self):
+ hline, vline = self._createLines(parent=None)
+ for i, line in enumerate([vline, hline]):
+ line.setPosition(self.__position[i])
+ line.setEditable(True)
+ line.setSelectable(True)
+ line.setFocusProxy(self)
+ line.setName("")
+ self.__vline = vline
+ self.__hline = hline
+ vline.sigAboutToBeRemoved.connect(self.__vlineRemoved)
+ vline.sigRegionChanged.connect(self.__vlineRegionChanged)
+ hline.sigAboutToBeRemoved.connect(self.__hlineRemoved)
+ hline.sigRegionChanged.connect(self.__hlineRegionChanged)
+
+ def _getLines(self):
+ return self.__hline, self.__vline
+
+ def __regionChanged(self):
+ if self.__inhibitReentance.locked():
+ return
+ x, y = self.getPosition()
+ hline, vline = self._getLines()
+ if hline is None:
+ return
+ with self.__inhibitReentance:
+ hline.setPosition(y)
+ vline.setPosition(x)
+
+ def __vlineRegionChanged(self):
+ if self.__inhibitReentance.locked():
+ return
+ pos = self.getPosition()
+ vline = self.__vline
+ pos = vline.getPosition(), pos[1]
+ with self.__inhibitReentance:
+ self.setPosition(pos)
+
+ def __hlineRegionChanged(self):
+ if self.__inhibitReentance.locked():
+ return
+ pos = self.getPosition()
+ hline = self.__hline
+ pos = pos[0], hline.getPosition()
+ with self.__inhibitReentance:
+ self.setPosition(pos)
+
+ def __aboutToBeRemoved(self):
+ vline = self.__vline
+ hline = self.__hline
+ # Avoid side remove signals
+ if hline is not None:
+ hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved)
+ hline.sigRegionChanged.disconnect(self.__hlineRegionChanged)
+ if vline is not None:
+ vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved)
+ vline.sigRegionChanged.disconnect(self.__vlineRegionChanged)
+ # Clean up the child
+ profileManager = self.getProfileManager()
+ roiManager = profileManager.getRoiManager()
+ if hline is not None:
+ roiManager.removeRoi(hline)
+ self.__hline = None
+ if vline is not None:
+ roiManager.removeRoi(vline)
+ self.__vline = None
+
+ def __hlineRemoved(self):
+ self.__lineRemoved(isHline=True)
+
+ def __vlineRemoved(self):
+ self.__lineRemoved(isHline=False)
+
+ def __lineRemoved(self, isHline):
+ """If any of the lines is removed: disconnect this objects, and let the
+ other one persist"""
+ hline, vline = self._getLines()
+
+ hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved)
+ vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved)
+ hline.sigRegionChanged.disconnect(self.__hlineRegionChanged)
+ vline.sigRegionChanged.disconnect(self.__vlineRegionChanged)
+
+ self.__hline = None
+ self.__vline = None
+ profileManager = self.getProfileManager()
+ roiManager = profileManager.getRoiManager()
+ if isHline:
+ self.__releaseLine(vline)
+ else:
+ self.__releaseLine(hline)
+ roiManager.removeRoi(self)
+
+ def __releaseLine(self, line):
+ """Release the line in order to make it independent"""
+ line.setFocusProxy(None)
+ line.setName(self.getName())
+ line.setEditable(self.isEditable())
+ line.setSelectable(self.isSelectable())
+
+
+class ProfileImageCrossROI(_ProfileCrossROI):
+ """ROI to manage a cross of profiles
+
+ It is managed using 2 sub ROIs for vertical and horizontal.
+ """
+
+ ICON = "shape-cross"
+ NAME = "cross profile"
+ ITEM_KIND = items.ImageBase
+
+ def _createLines(self, parent):
+ vline = ProfileImageVerticalLineROI(parent=parent)
+ hline = ProfileImageHorizontalLineROI(parent=parent)
+ return hline, vline
+
+ def setProfileMethod(self, method):
+ """
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ """
+ hline, vline = self._getLines()
+ hline.setProfileMethod(method)
+ vline.setProfileMethod(method)
+ self.invalidateProperties()
+
+ def getProfileMethod(self):
+ hline, _vline = self._getLines()
+ return hline.getProfileMethod()
+
+ def setProfileLineWidth(self, width):
+ hline, vline = self._getLines()
+ hline.setProfileLineWidth(width)
+ vline.setProfileLineWidth(width)
+ self.invalidateProperties()
+
+ def getProfileLineWidth(self):
+ hline, _vline = self._getLines()
+ return hline.getProfileLineWidth()
+
+
+class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
+ """Provide common behavior for silx default scatter profile ROI."""
+
+ ITEM_KIND = items.Scatter
+
+ def __init__(self, parent=None):
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.__nPoints = 1024
+ self.sigRegionChanged.connect(self.__regionChanged)
+
+ def __regionChanged(self):
+ self.invalidateProfile()
+
+ # Number of points
+
+ def getNPoints(self):
+ """Returns the number of points of the profiles
+
+ :rtype: int
+ """
+ return self.__nPoints
+
+ def setNPoints(self, npoints):
+ """Set the number of points of the profiles
+
+ :param int npoints:
+ """
+ npoints = int(npoints)
+ if npoints < 1:
+ raise ValueError("Unsupported number of points: %d" % npoints)
+ elif npoints != self.__nPoints:
+ self.__nPoints = npoints
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def _computeProfile(self, scatter, x0, y0, x1, y1):
+ """Compute corresponding profile
+
+ :param float x0: Profile start point X coord
+ :param float y0: Profile start point Y coord
+ :param float x1: Profile end point X coord
+ :param float y1: Profile end point Y coord
+ :return: (points, values) profile data or None
+ """
+ future = scatter._getInterpolatorFuture()
+ try:
+ interpolator = future.result()
+ except CancelledError:
+ return None
+ if interpolator is None:
+ return None # Cannot init an interpolator
+
+ nPoints = self.getNPoints()
+ x = numpy.linspace(x0, x1, nPoints, endpoint=True)
+ y = numpy.linspace(y0, y1, nPoints, endpoint=True)
+
+ values = interpolator(x, y)
+ if not numpy.any(numpy.isfinite(values)):
+ return None # Profile outside convex hull
+
+ points = numpy.transpose((x, y))
+ return points, values
+
+ def computeProfile(self, item):
+ """Update profile according to current ROI"""
+ if not isinstance(item, items.Scatter):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ # Get end points
+ if isinstance(self, roi_items.LineROI):
+ points = self.getEndPoints()
+ x0, y0 = points[0]
+ x1, y1 = points[1]
+ elif isinstance(self, (roi_items.VerticalLineROI, roi_items.HorizontalLineROI)):
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+
+ if isinstance(self, roi_items.HorizontalLineROI):
+ x0, x1 = plot.getXAxis().getLimits()
+ y0 = y1 = self.getPosition()
+
+ elif isinstance(self, roi_items.VerticalLineROI):
+ x0 = x1 = self.getPosition()
+ y0, y1 = plot.getYAxis().getLimits()
+ else:
+ raise RuntimeError("Unsupported ROI for profile: {}".format(self.__class__))
+
+ if x1 < x0 or (x1 == x0 and y1 < y0):
+ # Invert points
+ x0, y0, x1, y1 = x1, y1, x0, y0
+
+ profile = self._computeProfile(item, x0, y0, x1, y1)
+ if profile is None:
+ return None
+
+ title = _lineProfileTitle(x0, y0, x1, y1)
+ points = profile[0]
+ values = profile[1]
+
+ if numpy.abs(points[-1, 0] - points[0, 0]) > numpy.abs(
+ points[-1, 1] - points[0, 1]
+ ):
+ xProfile = points[:, 0]
+ xLabel = "{xlabel}"
+ else:
+ xProfile = points[:, 1]
+ xLabel = "{ylabel}"
+
+ # Use the axis names from the original
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ title = _relabelAxes(plot, title)
+ xLabel = _relabelAxes(plot, xLabel)
+
+ data = core.CurveProfileData(
+ coords=xProfile,
+ profile=values,
+ title=title,
+ xLabel=xLabel,
+ yLabel="Profile",
+ )
+ return data
+
+
+class ProfileScatterHorizontalLineROI(
+ roi_items.HorizontalLineROI, _DefaultScatterProfileRoiMixIn
+):
+ """ROI for an horizontal profile at a location of a scatter"""
+
+ ICON = "shape-horizontal"
+ NAME = "horizontal line profile"
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterVerticalLineROI(
+ roi_items.VerticalLineROI, _DefaultScatterProfileRoiMixIn
+):
+ """ROI for an horizontal profile at a location of a scatter"""
+
+ ICON = "shape-vertical"
+ NAME = "vertical line profile"
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterLineROI(roi_items.LineROI, _DefaultScatterProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of a scatter"""
+
+ ICON = "shape-diagonal"
+ NAME = "line profile"
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterCrossROI(_ProfileCrossROI):
+ """ROI to manage a cross of profiles for scatters."""
+
+ ICON = "shape-cross"
+ NAME = "cross profile"
+ ITEM_KIND = items.Scatter
+
+ def _createLines(self, parent):
+ vline = ProfileScatterVerticalLineROI(parent=parent)
+ hline = ProfileScatterHorizontalLineROI(parent=parent)
+ return hline, vline
+
+ def getNPoints(self):
+ """Returns the number of points of the profiles
+
+ :rtype: int
+ """
+ hline, _vline = self._getLines()
+ return hline.getNPoints()
+
+ def setNPoints(self, npoints):
+ """Set the number of points of the profiles
+
+ :param int npoints:
+ """
+ hline, vline = self._getLines()
+ hline.setNPoints(npoints)
+ vline.setNPoints(npoints)
+ self.invalidateProperties()
+
+
+class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn):
+ """Default ROI to allow to slice in the scatter data."""
+
+ ITEM_KIND = items.Scatter
+
+ def __init__(self, parent=None):
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.__area = _SliceProfileArea(self)
+ self.addItem(self.__area)
+ self.sigRegionChanged.connect(self._regionChanged)
+ self.sigPlotItemChanged.connect(self._updateArea)
+
+ def _regionChanged(self):
+ self.invalidateProfile()
+ self._updateArea()
+
+ def _updateArea(self):
+ plotItem = self.getPlotItem()
+ if plotItem is None:
+ self.setLineStyle("-")
+ else:
+ self.setLineStyle("--")
+
+ def _getSlice(self, item):
+ position = self.getPosition()
+ bounds = item.getCurrentVisualizationParameter(
+ items.Scatter.VisualizationParameter.GRID_BOUNDS
+ )
+ if isinstance(self, roi_items.HorizontalLineROI):
+ axis = 1
+ elif isinstance(self, roi_items.VerticalLineROI):
+ axis = 0
+ else:
+ assert False
+ if bounds is None or position < bounds[0][axis] or position > bounds[1][axis]:
+ # ROI outside of the scatter bound
+ return None
+
+ major_order = item.getCurrentVisualizationParameter(
+ items.Scatter.VisualizationParameter.GRID_MAJOR_ORDER
+ )
+ assert major_order == "row"
+ max_grid_yy, max_grid_xx = item.getCurrentVisualizationParameter(
+ items.Scatter.VisualizationParameter.GRID_SHAPE
+ )
+
+ xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False)
+ if isinstance(self, roi_items.HorizontalLineROI):
+ axis = yy
+ max_grid_first = max_grid_yy
+ max_grid_second = max_grid_xx
+ major_axis = major_order == "column"
+ elif isinstance(self, roi_items.VerticalLineROI):
+ axis = xx
+ max_grid_first = max_grid_xx
+ max_grid_second = max_grid_yy
+ major_axis = major_order == "row"
+ else:
+ assert False
+
+ def argnearest(array, value):
+ array = numpy.abs(array - value)
+ return numpy.argmin(array)
+
+ if major_axis:
+ # slice in the middle of the scatter
+ actual_size_grid_second = len(axis) // max_grid_first
+ start = actual_size_grid_second // 2 * max_grid_first
+ vslice = axis[start : start + max_grid_first]
+ if len(vslice) == 0:
+ return None
+ index = argnearest(vslice, position)
+ slicing = slice(index, None, max_grid_first)
+ else:
+ # slice in the middle of the scatter
+ actual_size_grid_second = len(axis) // max_grid_first
+ vslice = axis[actual_size_grid_second // 2 :: max_grid_second]
+ if len(vslice) == 0:
+ return None
+ index = argnearest(vslice, position)
+ start = index * max_grid_second
+ slicing = slice(start, start + max_grid_second)
+
+ return slicing
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.Scatter):
+ raise TypeError("Unsupported %s item" % type(item))
+
+ slicing = self._getSlice(item)
+ if slicing is None:
+ # ROI out of bounds
+ return None
+
+ _xx, _yy, values, _xx_error, _yy_error = item.getData(copy=False)
+ profile = values[slicing]
+
+ if isinstance(self, roi_items.HorizontalLineROI):
+ title = "Horizontal slice"
+ xLabel = "{xlabel} index"
+ elif isinstance(self, roi_items.VerticalLineROI):
+ title = "Vertical slice"
+ xLabel = "{ylabel} index"
+ else:
+ assert False
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ xLabel = _relabelAxes(plot, xLabel)
+
+ data = core.CurveProfileData(
+ coords=numpy.arange(len(profile)),
+ profile=profile,
+ title=title,
+ xLabel=xLabel,
+ yLabel="Profile",
+ )
+ return data
+
+
+class ProfileScatterHorizontalSliceROI(
+ roi_items.HorizontalLineROI, _DefaultScatterProfileSliceRoiMixIn
+):
+ """ROI for an horizontal profile at a location of a scatter
+ using data slicing.
+ """
+
+ ICON = "slice-horizontal"
+ NAME = "horizontal data slice profile"
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterVerticalSliceROI(
+ roi_items.VerticalLineROI, _DefaultScatterProfileSliceRoiMixIn
+):
+ """ROI for a vertical profile at a location of a scatter
+ using data slicing.
+ """
+
+ ICON = "slice-vertical"
+ NAME = "vertical data slice profile"
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterCrossSliceROI(_ProfileCrossROI):
+ """ROI to manage a cross of slicing profiles on scatters."""
+
+ ICON = "slice-cross"
+ NAME = "cross data slice profile"
+ ITEM_KIND = items.Scatter
+
+ def _createLines(self, parent):
+ vline = ProfileScatterVerticalSliceROI(parent=parent)
+ hline = ProfileScatterHorizontalSliceROI(parent=parent)
+ return hline, vline
+
+
+class _DefaultImageStackProfileRoiMixIn(_DefaultImageProfileRoiMixIn):
+ ITEM_KIND = items.ImageStack
+
+ def __init__(self, parent=None):
+ super(_DefaultImageStackProfileRoiMixIn, self).__init__(parent=parent)
+ self.__profileType = "1D"
+ """Kind of profile"""
+
+ def getProfileType(self):
+ return self.__profileType
+
+ def setProfileType(self, kind):
+ assert kind in ["1D", "2D"]
+ if self.__profileType == kind:
+ return
+ self.__profileType = kind
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.ImageStack):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ kind = self.getProfileType()
+ if kind == "1D":
+ result = _DefaultImageProfileRoiMixIn.computeProfile(self, item)
+ # z = item.getStackPosition()
+ return result
+
+ assert kind == "2D"
+
+ currentData = numpy.array(item.getStackData(copy=False))
+ origin = item.getOrigin()
+ scale = item.getScale()
+ colormap = item.getColormap()
+ method = self.getProfileMethod()
+ roiInfo = self._getRoiInfo()
+
+ def createProfile2(currentData):
+ coords, profile, _area, profileName, xLabel = core.createProfile(
+ roiInfo=roiInfo,
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=self.getProfileLineWidth(),
+ method=method,
+ )
+ return coords, profile, profileName, xLabel
+
+ coords, profile, profileName, xLabel = createProfile2(currentData)
+
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+
+ data = core.ImageProfileData(
+ coords=coords,
+ profile=profile,
+ title=_relabelAxes(plot, profileName),
+ xLabel=_relabelAxes(plot, xLabel),
+ yLabel="Profile",
+ colormap=colormap,
+ )
+ return data
+
+
+class ProfileImageStackHorizontalLineROI(
+ roi_items.HorizontalLineROI, _DefaultImageStackProfileRoiMixIn
+):
+ """ROI for an horizontal profile at a location of a stack of images"""
+
+ ICON = "shape-horizontal"
+ NAME = "horizontal line profile"
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageStackVerticalLineROI(
+ roi_items.VerticalLineROI, _DefaultImageStackProfileRoiMixIn
+):
+ """ROI for an vertical profile at a location of a stack of images"""
+
+ ICON = "shape-vertical"
+ NAME = "vertical line profile"
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageStackLineROI(roi_items.LineROI, _DefaultImageStackProfileRoiMixIn):
+ """ROI for an vertical profile at a location of a stack of images"""
+
+ ICON = "shape-diagonal"
+ NAME = "line profile"
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageStackCrossROI(ProfileImageCrossROI):
+ """ROI for an vertical profile at a location of a stack of images"""
+
+ ICON = "shape-cross"
+ NAME = "cross profile"
+ ITEM_KIND = items.ImageStack
+
+ def _createLines(self, parent):
+ vline = ProfileImageStackVerticalLineROI(parent=parent)
+ hline = ProfileImageStackHorizontalLineROI(parent=parent)
+ return hline, vline
+
+ def getProfileType(self):
+ hline, _vline = self._getLines()
+ return hline.getProfileType()
+
+ def setProfileType(self, kind):
+ hline, vline = self._getLines()
+ hline.setProfileType(kind)
+ vline.setProfileType(kind)
+ self.invalidateProperties()
diff --git a/src/silx/gui/plot/tools/profile/toolbar.py b/src/silx/gui/plot/tools/profile/toolbar.py
new file mode 100644
index 0000000..d073717
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/toolbar.py
@@ -0,0 +1,172 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tool bar helper.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import logging
+import weakref
+
+from silx.gui import qt
+from silx.gui.widgets.MultiModeAction import MultiModeAction
+from . import manager
+from .. import roi as roi_mdl
+from silx.gui.plot import items
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ProfileToolBar(qt.QToolBar):
+ """Tool bar to provide profile for a plot.
+
+ It is an helper class. For a dedicated application it would be better to
+ use an own tool bar in order in order have more flexibility.
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(ProfileToolBar, self).__init__(parent=parent)
+ self.__scheme = None
+ self.__manager = None
+ self.__plot = weakref.ref(plot)
+ self.__multiAction = None
+
+ def getPlotWidget(self):
+ """The :class:`~silx.gui.plot.PlotWidget` associated to the toolbar.
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ if self.__plot is None:
+ return None
+ plot = self.__plot()
+ if self.__plot is None:
+ self.__plot = None
+ return plot
+
+ def setScheme(self, scheme):
+ """Initialize the tool bar using a configuration scheme.
+
+ It have to be done once and only once.
+
+ :param str scheme: One of "scatter", "image", "imagestack"
+ """
+ assert self.__scheme is None
+ self.__scheme = scheme
+
+ plot = self.getPlotWidget()
+ self.__manager = manager.ProfileManager(self, plot)
+
+ if scheme == "image":
+ self.__manager.setItemType(image=True)
+ self.__manager.setActiveItemTracking(True)
+
+ multiAction = MultiModeAction(self)
+ self.addAction(multiAction)
+ for action in self.__manager.createImageActions(self):
+ multiAction.addAction(action)
+ self.__multiAction = multiAction
+
+ cleanAction = self.__manager.createClearAction(self)
+ self.addAction(cleanAction)
+ editorAction = self.__manager.createEditorAction(self)
+ self.addAction(editorAction)
+
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ self._activeImageChanged()
+
+ elif scheme == "scatter":
+ self.__manager.setItemType(scatter=True)
+ self.__manager.setActiveItemTracking(True)
+
+ multiAction = MultiModeAction(self)
+ self.addAction(multiAction)
+ for action in self.__manager.createScatterActions(self):
+ multiAction.addAction(action)
+ for action in self.__manager.createScatterSliceActions(self):
+ multiAction.addAction(action)
+ self.__multiAction = multiAction
+
+ cleanAction = self.__manager.createClearAction(self)
+ self.addAction(cleanAction)
+ editorAction = self.__manager.createEditorAction(self)
+ self.addAction(editorAction)
+
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ self._activeScatterChanged()
+
+ elif scheme == "imagestack":
+ self.__manager.setItemType(image=True)
+ self.__manager.setActiveItemTracking(True)
+
+ multiAction = MultiModeAction(self)
+ self.addAction(multiAction)
+ for action in self.__manager.createImageStackActions(self):
+ multiAction.addAction(action)
+ self.__multiAction = multiAction
+
+ cleanAction = self.__manager.createClearAction(self)
+ self.addAction(cleanAction)
+ editorAction = self.__manager.createEditorAction(self)
+ self.addAction(editorAction)
+
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ self._activeImageChanged()
+
+ else:
+ raise ValueError("Toolbar scheme %s unsupported" % scheme)
+
+ def _setRoiActionEnabled(self, itemKind, enabled):
+ for action in self.__multiAction.getMenu().actions():
+ if not isinstance(action, roi_mdl.CreateRoiModeAction):
+ continue
+ roiClass = action.getRoiClass()
+ if issubclass(itemKind, roiClass.ITEM_KIND):
+ action.setEnabled(enabled)
+
+ def _activeImageChanged(self, previous=None, legend=None):
+ """Handle active image change to toggle actions"""
+ if legend is None:
+ self._setRoiActionEnabled(items.ImageStack, False)
+ self._setRoiActionEnabled(items.ImageBase, False)
+ else:
+ plot = self.getPlotWidget()
+ image = plot.getActiveImage()
+ # Disable for empty image
+ enabled = image.getData(copy=False).size > 0
+ self._setRoiActionEnabled(type(image), enabled)
+
+ def _activeScatterChanged(self, previous=None, legend=None):
+ """Handle active scatter change to toggle actions"""
+ if legend is None:
+ self._setRoiActionEnabled(items.Scatter, False)
+ else:
+ plot = self.getPlotWidget()
+ scatter = plot.getActiveScatter()
+ # Disable for empty image
+ enabled = scatter.getValueData(copy=False).size > 0
+ self._setRoiActionEnabled(type(scatter), enabled)
diff --git a/src/silx/gui/plot/tools/roi.py b/src/silx/gui/plot/tools/roi.py
new file mode 100644
index 0000000..21b9409
--- /dev/null
+++ b/src/silx/gui/plot/tools/roi.py
@@ -0,0 +1,1515 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides ROI interaction for :class:`~silx.gui.plot.PlotWidget`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import enum
+import logging
+import time
+import weakref
+import functools
+from typing import Optional
+
+import numpy
+
+from ... import qt, icons
+from ...utils import blockSignals
+from ...utils import LockReentrant
+from .. import PlotWidget
+from ..items import roi as roi_items
+from ..items import ItemChangedType
+from ..items.roi import RegionOfInterest
+
+from ...colors import rgba
+
+
+logger = logging.getLogger(__name__)
+
+
+class CreateRoiModeAction(qt.QAction):
+ """
+ This action is a plot mode which allows to create new ROIs using a ROI
+ manager.
+
+ A ROI is created using a specific `roiClass`. `initRoi` and `finalizeRoi`
+ can be inherited to custom the ROI initialization.
+
+ :param class roiClass: The ROI class which will be created by this action.
+ :param qt.QObject parent: The action parent
+ :param RegionOfInterestManager roiManager: The ROI manager
+ """
+
+ def __init__(self, parent, roiManager, roiClass):
+ assert roiManager is not None
+ assert roiClass is not None
+ qt.QAction.__init__(self, parent=parent)
+ self._roiManager = weakref.ref(roiManager)
+ self._roiClass = roiClass
+ self._singleShot = False
+ self._initAction()
+ self.triggered[bool].connect(self._actionTriggered)
+
+ def _initAction(self):
+ """Default initialization of the action"""
+ roiClass = self._roiClass
+
+ name = None
+ iconName = None
+ if hasattr(roiClass, "NAME"):
+ name = roiClass.NAME
+ if hasattr(roiClass, "ICON"):
+ iconName = roiClass.ICON
+
+ if iconName is None:
+ iconName = "add-shape-unknown"
+ if name is None:
+ name = roiClass.__name__
+ text = "Add %s" % name
+ self.setIcon(icons.getQIcon(iconName))
+ self.setText(text)
+ self.setCheckable(True)
+ self.setToolTip(text)
+
+ def getRoiClass(self):
+ """Return the ROI class used by this action to create ROIs"""
+ return self._roiClass
+
+ def getRoiManager(self):
+ return self._roiManager()
+
+ def setSingleShot(self, singleShot):
+ """Set it to True to deactivate the action after the first creation
+ of a ROI.
+
+ :param bool singleShot: New single short state
+ """
+ self._singleShot = singleShot
+
+ def getSingleShot(self):
+ """If True, after the first creation of a ROI with this mode,
+ the mode is deactivated.
+
+ :rtype: bool
+ """
+ return self._singleShot
+
+ def _actionTriggered(self, checked):
+ """Handle mode actions being checked by the user
+
+ :param bool checked:
+ :param str kind: Corresponding shape kind
+ """
+ roiManager = self.getRoiManager()
+ if roiManager is None:
+ return
+
+ if checked:
+ roiManager.start(self._roiClass, self)
+ self.__interactiveModeStarted(roiManager)
+ else:
+ source = roiManager.getInteractionSource()
+ if source is self:
+ roiManager.stop()
+
+ def __interactiveModeStarted(self, roiManager):
+ roiManager.sigInteractiveRoiCreated.connect(self.initRoi)
+ roiManager.sigInteractiveRoiFinalized.connect(self.__finalizeRoi)
+ roiManager.sigInteractiveModeFinished.connect(self.__interactiveModeFinished)
+
+ def __interactiveModeFinished(self):
+ roiManager = self.getRoiManager()
+ if roiManager is not None:
+ roiManager.sigInteractiveRoiCreated.disconnect(self.initRoi)
+ roiManager.sigInteractiveRoiFinalized.disconnect(self.__finalizeRoi)
+ roiManager.sigInteractiveModeFinished.disconnect(
+ self.__interactiveModeFinished
+ )
+ self.setChecked(False)
+
+ def initRoi(self, roi):
+ """Inherit it to custom the new ROI at it's creation during the
+ interaction."""
+ pass
+
+ def __finalizeRoi(self, roi):
+ self.finalizeRoi(roi)
+ if self._singleShot:
+ roiManager = self.getRoiManager()
+ if roiManager is not None:
+ roiManager.stop()
+
+ def finalizeRoi(self, roi):
+ """Inherit it to custom the new ROI after it's creation when the
+ interaction is finalized."""
+ pass
+
+
+class RoiModeSelector(qt.QWidget):
+ def __init__(self, parent=None):
+ super(RoiModeSelector, self).__init__(parent=parent)
+ self.__roi = None
+ self.__reentrant = LockReentrant()
+
+ layout = qt.QHBoxLayout(self)
+ if isinstance(parent, qt.QMenu):
+ margins = layout.contentsMargins()
+ layout.setContentsMargins(margins.left(), 0, margins.right(), 0)
+ else:
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ self._label = qt.QLabel(self)
+ self._label.setText("Mode:")
+ self._label.setToolTip("Select a specific interaction to edit the ROI")
+ self._combo = qt.QComboBox(self)
+ self._combo.currentIndexChanged.connect(self._modeSelected)
+ layout.addWidget(self._label)
+ layout.addWidget(self._combo)
+ self._updateAvailableModes()
+
+ def getRoi(self):
+ """Returns the edited ROI.
+
+ :rtype: roi_items.RegionOfInterest
+ """
+ return self.__roi
+
+ def setRoi(self, roi):
+ """Returns the edited ROI.
+
+ :rtype: roi_items.RegionOfInterest
+ """
+ if self.__roi is roi:
+ return
+ if not isinstance(roi, roi_items.InteractionModeMixIn):
+ self.__roi = None
+ self._updateAvailableModes()
+ return
+
+ if self.__roi is not None:
+ self.__roi.sigInteractionModeChanged.disconnect(self._modeChanged)
+ self.__roi = roi
+ if self.__roi is not None:
+ self.__roi.sigInteractionModeChanged.connect(self._modeChanged)
+ self._updateAvailableModes()
+
+ def isEmpty(self):
+ return not self._label.isVisibleTo(self)
+
+ def _updateAvailableModes(self):
+ roi = self.getRoi()
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ modes = roi.availableInteractionModes()
+ else:
+ modes = []
+ if len(modes) <= 1:
+ self._label.setVisible(False)
+ self._combo.setVisible(False)
+ else:
+ self._label.setVisible(True)
+ self._combo.setVisible(True)
+ with blockSignals(self._combo):
+ self._combo.clear()
+ for im, m in enumerate(modes):
+ self._combo.addItem(m.label, m)
+ self._combo.setItemData(im, m.description, qt.Qt.ToolTipRole)
+ mode = roi.getInteractionMode()
+ self._modeChanged(mode)
+ index = modes.index(mode)
+ self._combo.setCurrentIndex(index)
+
+ def _modeChanged(self, mode):
+ """Triggered when the ROI interaction mode was changed externally"""
+ if self.__reentrant.locked():
+ # This event was initialised by the widget
+ return
+ roi = self.__roi
+ modes = roi.availableInteractionModes()
+ index = modes.index(mode)
+ with blockSignals(self._combo):
+ self._combo.setCurrentIndex(index)
+
+ def _modeSelected(self):
+ """Triggered when the ROI interaction mode was selected in the widget"""
+ index = self._combo.currentIndex()
+ if index == -1:
+ return
+ roi = self.getRoi()
+ if roi is not None:
+ mode = self._combo.itemData(index, qt.Qt.UserRole)
+ with self.__reentrant:
+ roi.setInteractionMode(mode)
+
+
+class RoiModeSelectorAction(qt.QWidgetAction):
+ """Display the selected mode of a ROI and allow to change it"""
+
+ def __init__(self, parent=None):
+ super(RoiModeSelectorAction, self).__init__(parent)
+ self.__roiManager = None
+
+ def createWidget(self, parent):
+ """Inherit the method to create a new widget"""
+ widget = RoiModeSelector(parent)
+ manager = self.__roiManager
+ if manager is not None:
+ roi = manager.getCurrentRoi()
+ widget.setRoi(roi)
+ self.setVisible(not widget.isEmpty())
+ return widget
+
+ def deleteWidget(self, widget):
+ """Inherit the method to delete a widget"""
+ widget.setRoi(None)
+ return qt.QWidgetAction.deleteWidget(self, widget)
+
+ def setRoiManager(self, roiManager):
+ """
+ Connect this action to a ROI manager.
+
+ :param RegionOfInterestManager roiManager: A ROI manager
+ """
+ if self.__roiManager is roiManager:
+ return
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged)
+ self.__roiManager = roiManager
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged)
+ self.__currentRoiChanged(roiManager.getCurrentRoi())
+
+ def __currentRoiChanged(self, roi):
+ """Handle changes of the selected ROI"""
+ self.setRoi(roi)
+
+ def setRoi(self, roi):
+ """Set a profile ROI to edit.
+
+ :param ProfileRoiMixIn roi: A profile ROI
+ """
+ widget = None
+ for widget in self.createdWidgets():
+ widget.setRoi(roi)
+ if widget is not None:
+ self.setVisible(not widget.isEmpty())
+
+
+class RegionOfInterestManager(qt.QObject):
+ """Class handling ROI interaction on a PlotWidget.
+
+ It supports the multiple ROIs: points, rectangles, polygons,
+ lines, horizontal and vertical lines.
+
+ See ``plotInteractiveImageROI.py`` sample code (:ref:`sample-code`).
+
+ :param silx.gui.plot.PlotWidget parent:
+ The plot widget in which to control the ROIs.
+ """
+
+ sigRoiAdded = qt.Signal(roi_items.RegionOfInterest)
+ """Signal emitted when a new ROI has been added.
+
+ It provides the newly add :class:`RegionOfInterest` object.
+ """
+
+ sigRoiAboutToBeRemoved = qt.Signal(roi_items.RegionOfInterest)
+ """Signal emitted just before a ROI is removed.
+
+ It provides the :class:`RegionOfInterest` object that is about to be removed.
+ """
+
+ sigRoiChanged = qt.Signal()
+ """Signal emitted whenever the ROIs have changed."""
+
+ sigCurrentRoiChanged = qt.Signal(object)
+ """Signal emitted whenever a ROI is selected."""
+
+ sigInteractiveModeStarted = qt.Signal(object)
+ """Signal emitted when switching to ROI drawing interactive mode.
+
+ It provides the class of the ROI which will be created by the interactive
+ mode.
+ """
+
+ sigInteractiveRoiCreated = qt.Signal(object)
+ """Signal emitted when a ROI is created during the interaction.
+ The interaction is still incomplete and can be aborted.
+
+ It provides the ROI object which was just been created.
+ """
+
+ sigInteractiveRoiFinalized = qt.Signal(object)
+ """Signal emitted when a ROI creation is complet.
+
+ It provides the ROI object which was just been created.
+ """
+
+ sigInteractiveModeFinished = qt.Signal()
+ """Signal emitted when leaving interactive ROI drawing mode.
+ """
+
+ ROI_CLASSES = (
+ roi_items.PointROI,
+ roi_items.CrossROI,
+ roi_items.RectangleROI,
+ roi_items.CircleROI,
+ roi_items.EllipseROI,
+ roi_items.PolygonROI,
+ roi_items.LineROI,
+ roi_items.HorizontalLineROI,
+ roi_items.VerticalLineROI,
+ roi_items.ArcROI,
+ roi_items.HorizontalRangeROI,
+ roi_items.BandROI,
+ )
+
+ def __init__(self, parent):
+ assert isinstance(parent, PlotWidget)
+ super(RegionOfInterestManager, self).__init__(parent)
+ self._rois = [] # List of ROIs
+ self._drawnROI = None # New ROI being currently drawn
+
+ self._roiClass = None
+ self._source = None
+ self._lastHoveredMarkerLabel = None
+ self._color = rgba("red")
+
+ self._label = "__RegionOfInterestManager__%d" % id(self)
+
+ self._currentRoi = None
+ """Hold currently selected ROI"""
+
+ self._eventLoop = None
+
+ self._modeActions = {}
+
+ parent.sigPlotSignal.connect(self._plotSignals)
+
+ parent.sigInteractiveModeChanged.connect(self._plotInteractiveModeChanged)
+
+ parent.sigItemRemoved.connect(self._itemRemoved)
+
+ parent._sigDefaultContextMenu.connect(self._feedContextMenu)
+
+ @classmethod
+ def getSupportedRoiClasses(cls):
+ """Returns the default available ROI classes
+
+ :rtype: List[class]
+ """
+ return tuple(cls.ROI_CLASSES)
+
+ # Associated QActions
+
+ def getInteractionModeAction(self, roiClass):
+ """Returns the QAction corresponding to a kind of ROI
+
+ The QAction allows to enable the corresponding drawing
+ interactive mode.
+
+ :param class roiClass: The ROI class which will be created by this action.
+ :rtype: QAction
+ :raise ValueError: If kind is not supported
+ """
+ if not issubclass(roiClass, roi_items.RegionOfInterest):
+ raise ValueError("Unsupported ROI class %s" % roiClass)
+
+ action = self._modeActions.get(roiClass, None)
+ if action is None: # Lazy-loading
+ action = CreateRoiModeAction(self, self, roiClass)
+ self._modeActions[roiClass] = action
+ return action
+
+ # PlotWidget eventFilter and listeners
+
+ def _plotInteractiveModeChanged(self, source):
+ """Handle change of interactive mode in the plot"""
+ if source is not self:
+ self.__roiInteractiveModeEnded()
+
+ def _getRoiFromItem(self, item):
+ """Returns the ROI which own this item, else None
+ if this manager do not have knowledge of this ROI."""
+ for roi in self._rois:
+ if isinstance(roi, roi_items.RegionOfInterest):
+ for child in roi.getItems():
+ if child is item:
+ return roi
+ return None
+
+ def _itemRemoved(self, item):
+ """Called after an item was removed from the plot."""
+ if not hasattr(item, "_roiGroup"):
+ # Early break to avoid to use _getRoiFromItem
+ # And to avoid reentrant signal when the ROI remove the item itself
+ return
+ roi = self._getRoiFromItem(item)
+ if roi is not None:
+ self.removeRoi(roi)
+
+ # Handle ROI interaction
+
+ def _handleInteraction(self, event):
+ """Handle mouse interaction for ROI addition"""
+ roiClass = self.getCurrentInteractionModeRoiClass()
+ if roiClass is None:
+ return # Should not happen
+
+ kind = roiClass.getFirstInteractionShape()
+ if kind == "point":
+ if event["event"] == "mouseClicked" and event["button"] == "left":
+ points = numpy.array([(event["x"], event["y"])], dtype=numpy.float64)
+ # Not an interactive creation
+ roi = self._createInteractiveRoi(roiClass, points=points)
+ roi.creationFinalized()
+ self.sigInteractiveRoiFinalized.emit(roi)
+ else: # other shapes
+ if (
+ event["event"] in ("drawingProgress", "drawingFinished")
+ and event["parameters"]["label"] == self._label
+ ):
+ points = numpy.array(
+ (event["xdata"], event["ydata"]), dtype=numpy.float64
+ ).T
+
+ if self._drawnROI is None: # Create new ROI
+ # NOTE: Set something before createRoi, so isDrawing is True
+ self._drawnROI = object()
+ self._drawnROI = self._createInteractiveRoi(roiClass, points=points)
+ else:
+ self._drawnROI.setFirstShapePoints(points)
+
+ if event["event"] == "drawingFinished":
+ if kind == "polygon" and len(points) > 1:
+ self._drawnROI.setFirstShapePoints(points[:-1])
+ roi = self._drawnROI
+ self._drawnROI = None # Stop drawing
+ roi.creationFinalized()
+ self.sigInteractiveRoiFinalized.emit(roi)
+
+ # RegionOfInterest selection
+
+ def __getRoiFromMarker(self, marker):
+ """Returns a ROI from a marker, else None"""
+ # This should be speed up
+ for roi in self._rois:
+ if isinstance(roi, roi_items.HandleBasedROI):
+ for m in roi.getHandles():
+ if m is marker:
+ return roi
+ else:
+ for m in roi.getItems():
+ if m is marker:
+ return roi
+ return None
+
+ def setCurrentRoi(self, roi: Optional[RegionOfInterest]):
+ """Set the currently selected ROI, and emit a signal.
+
+ :param Union[RegionOfInterest,None] roi: The ROI to select
+ """
+ if self._currentRoi is roi:
+ return
+ if roi is not None:
+ # Note: Fixed range to avoid infinite loops
+ for _ in range(10):
+ target = roi.getFocusProxy()
+ if target is None:
+ break
+ roi = target
+ else:
+ raise RuntimeError("Max selection proxy depth (10) reached.")
+
+ if self._currentRoi is not None:
+ self._currentRoi.setHighlighted(False)
+ self._currentRoi = roi
+ if self._currentRoi is not None:
+ self._currentRoi.setHighlighted(True)
+ self.sigCurrentRoiChanged.emit(roi)
+
+ def getCurrentRoi(self) -> Optional[RegionOfInterest]:
+ """Returns the currently selected ROI, else None."""
+ return self._currentRoi
+
+ def _plotSignals(self, event):
+ """Handle mouse interaction for ROI addition"""
+ clicked = False
+ roi = None
+ if event["event"] in ("markerClicked", "markerMoving"):
+ plot = self.parent()
+ legend = event["label"]
+ marker = plot._getMarker(legend=legend)
+ roi = self.__getRoiFromMarker(marker)
+ elif event["event"] == "mouseClicked" and event["button"] == "left":
+ # Marker click is only for dnd
+ # This also can click on a marker
+ clicked = True
+ plot = self.parent()
+ marker = plot._getMarkerAt(event["xpixel"], event["ypixel"])
+ roi = self.__getRoiFromMarker(marker)
+ elif event["event"] == "hover":
+ self._lastHoveredMarkerLabel = event["label"]
+ else:
+ return
+
+ if roi not in self._rois:
+ # The ROI is not own by this manager
+ return
+
+ if roi is not None:
+ currentRoi = self.getCurrentRoi()
+ if currentRoi is roi:
+ if clicked:
+ self.__updateMode(roi)
+ elif roi.isSelectable():
+ self.setCurrentRoi(roi)
+ else:
+ self.setCurrentRoi(None)
+
+ def __updateMode(self, roi: RegionOfInterest):
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ available = roi.availableInteractionModes()
+ mode = roi.getInteractionMode()
+ imode = available.index(mode)
+ mode = available[(imode + 1) % len(available)]
+ roi.setInteractionMode(mode)
+
+ def _feedContextMenu(self, menu: qt.QMenu):
+ """Called when the default plot context menu is about to be displayed"""
+ roi = self.getCurrentRoi()
+ if roi is not None:
+ if roi.isEditable():
+ if self._isMouseHoverRoi(roi):
+ roiMenu = self._createMenuForRoi(menu, roi)
+ menu.addMenu(roiMenu)
+
+ def _isMouseHoverRoi(self, roi: RegionOfInterest) -> bool:
+ """Check that the mouse hovers this roi"""
+ plot = self.parent()
+
+ if self._lastHoveredMarkerLabel is not None:
+ marker = plot._getMarker(self._lastHoveredMarkerLabel)
+ if marker is not None:
+ r = self.__getRoiFromMarker(marker)
+ if roi is r:
+ return True
+
+ # Filter by data position
+ # FIXME: It would be better to use GUI coords for it
+ pos = plot.getWidgetHandle().mapFromGlobal(qt.QCursor.pos())
+ data = plot.pixelToData(pos.x(), pos.y())
+ return roi.contains(data)
+
+ def _createMenuForRoi(self, parent: qt.QWidget, roi: RegionOfInterest) -> qt.QMenu:
+ """Create a QMenu for the given RegionOfInterest"""
+ roiMenu = qt.QMenu(parent)
+ roiMenu.setTitle(roi.getName())
+
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ interactionMenu = roi.createMenuForInteractionMode(roiMenu)
+ roiMenu.addMenu(interactionMenu)
+
+ removeAction = qt.QAction(roiMenu)
+ removeAction.setText("Remove")
+ callback = functools.partial(self.removeRoi, roi)
+ removeAction.triggered.connect(callback)
+ roiMenu.addAction(removeAction)
+
+ roi.populateContextMenu(roiMenu)
+
+ return roiMenu
+
+ # RegionOfInterest API
+
+ def getRois(self):
+ """Returns the list of ROIs.
+
+ It returns an empty tuple if there is currently no ROI.
+
+ :return: Tuple of arrays of objects describing the ROIs
+ :rtype: List[RegionOfInterest]
+ """
+ return tuple(self._rois)
+
+ def clear(self):
+ """Reset current ROIs
+
+ :return: True if ROIs were reset.
+ :rtype: bool
+ """
+ if self.getRois(): # Something to reset
+ for roi in self._rois:
+ roi.sigRegionChanged.disconnect(self._regionOfInterestChanged)
+ roi.setParent(None)
+ self._rois = []
+ self._roisUpdated()
+ return True
+
+ else:
+ return False
+
+ def _regionOfInterestChanged(self, event=None):
+ """Handle ROI object changed"""
+ self.sigRoiChanged.emit()
+
+ def _createInteractiveRoi(self, roiClass, points, label=None, index=None):
+ """Create a new ROI with interactive creation.
+
+ :param class roiClass: The class of the ROI to create
+ :param numpy.ndarray points: The first shape used to create the ROI
+ :param str label: The label to display along with the ROI.
+ :param int index: The position where to insert the ROI.
+ By default it is appended to the end of the list.
+ :return: The created ROI object
+ :rtype: roi_items.RegionOfInterest
+ :raise RuntimeError: When ROI cannot be added because the maximum
+ number of ROIs has been reached.
+ """
+ roi = roiClass(parent=None)
+ if label is not None:
+ roi.setName(str(label))
+ roi.creationStarted()
+ roi.setFirstShapePoints(points)
+
+ self.addRoi(roi, index)
+ if roi.isSelectable():
+ self.setCurrentRoi(roi)
+ self.sigInteractiveRoiCreated.emit(roi)
+ return roi
+
+ def containsRoi(self, roi):
+ """Returns true if the ROI is part of this manager.
+
+ :param roi_items.RegionOfInterest roi: The ROI to add
+ :rtype: bool
+ """
+ return roi in self._rois
+
+ def addRoi(self, roi, index=None, useManagerColor=True):
+ """Add the ROI to the list of ROIs.
+
+ :param roi_items.RegionOfInterest roi: The ROI to add
+ :param int index: The position where to insert the ROI,
+ By default it is appended to the end of the list of ROIs
+ :param bool useManagerColor:
+ Whether to set the ROI color to the default one of the manager or not.
+ (Default: True).
+ :raise RuntimeError: When ROI cannot be added because the maximum
+ number of ROIs has been reached.
+ """
+ plot = self.parent()
+ if plot is None:
+ raise RuntimeError("Cannot add ROI: PlotWidget no more available")
+
+ roi.setParent(self)
+
+ if useManagerColor:
+ roi.setColor(self.getColor())
+
+ roi.sigRegionChanged.connect(self._regionOfInterestChanged)
+ roi.sigItemChanged.connect(self._regionOfInterestChanged)
+
+ if index is None:
+ self._rois.append(roi)
+ else:
+ self._rois.insert(index, roi)
+ self.sigRoiAdded.emit(roi)
+ self._roisUpdated()
+
+ def removeRoi(self, roi):
+ """Remove a ROI from the list of ROIs.
+
+ :param roi_items.RegionOfInterest roi: The ROI to remove
+ :raise ValueError: When ROI does not belong to this object
+ """
+ if not (
+ isinstance(roi, roi_items.RegionOfInterest)
+ and roi.parent() is self
+ and roi in self._rois
+ ):
+ raise ValueError("RegionOfInterest does not belong to this instance")
+
+ roi.sigAboutToBeRemoved.emit()
+ self.sigRoiAboutToBeRemoved.emit(roi)
+
+ if roi is self._currentRoi:
+ self.setCurrentRoi(None)
+
+ mustRestart = False
+ if roi is self._drawnROI:
+ self._drawnROI = None
+ mustRestart = True
+ self._rois.remove(roi)
+ roi.sigRegionChanged.disconnect(self._regionOfInterestChanged)
+ roi.sigItemChanged.disconnect(self._regionOfInterestChanged)
+ roi.setParent(None)
+ self._roisUpdated()
+
+ if mustRestart:
+ self._restart()
+
+ def _roisUpdated(self):
+ """Handle update of the ROI list"""
+ self.sigRoiChanged.emit()
+
+ # RegionOfInterest parameters
+
+ def getColor(self):
+ """Return the default color of created ROIs
+
+ :rtype: QColor
+ """
+ return qt.QColor.fromRgbF(*self._color)
+
+ def setColor(self, color):
+ """Set the default color to use when creating ROIs.
+
+ Existing ROIs are not affected.
+
+ :param color: The color to use for displaying ROIs as
+ either a color name, a QColor, a list of uint8 or float in [0, 1].
+ """
+ self._color = rgba(color)
+
+ # Control ROI
+
+ def getCurrentInteractionModeRoiClass(self):
+ """Returns the current ROI class used by the interactive drawing mode.
+
+ Returns None if the ROI manager is not in an interactive mode.
+
+ :rtype: Union[class,None]
+ """
+ return self._roiClass
+
+ def getInteractionSource(self):
+ """Returns the object which have requested the ROI creation.
+
+ Returns None if the ROI manager is not in an interactive mode.
+
+ :rtype: Union[object,None]
+ """
+ return self._source
+
+ def isStarted(self):
+ """Returns True if an interactive ROI drawing mode is active.
+
+ :rtype: bool
+ """
+ return self._roiClass is not None
+
+ def isDrawing(self):
+ """Returns True if an interactive ROI is drawing.
+
+ :rtype: bool
+ """
+ return self._drawnROI is not None
+
+ def start(self, roiClass, source=None):
+ """Start an interactive ROI drawing mode.
+
+ :param class roiClass: The ROI class to create. It have to inherite from
+ `roi_items.RegionOfInterest`.
+ :param object source: SOurce of the ROI interaction.
+ :return: True if interactive ROI drawing was started, False otherwise
+ :rtype: bool
+ :raise ValueError: If roiClass is not supported
+ """
+ self.stop()
+
+ if not issubclass(roiClass, roi_items.RegionOfInterest):
+ raise ValueError("Unsupported ROI class %s" % roiClass)
+
+ plot = self.parent()
+ if plot is None:
+ return False
+
+ self._roiClass = roiClass
+ self._source = source
+
+ self._restart()
+
+ plot.sigPlotSignal.connect(self._handleInteraction)
+
+ self.sigInteractiveModeStarted.emit(roiClass)
+
+ return True
+
+ def _restart(self):
+ """Restart the plot interaction without changing the
+ source or the ROI class.
+ """
+ roiClass = self._roiClass
+ plot = self.parent()
+ firstInteractionShapeKind = roiClass.getFirstInteractionShape()
+
+ if firstInteractionShapeKind == "point":
+ plot.setInteractiveMode(mode="select", source=self)
+ else:
+ if roiClass.showFirstInteractionShape():
+ color = rgba(self.getColor())
+ else:
+ color = None
+ plot.setInteractiveMode(
+ mode="draw",
+ source=self,
+ shape=firstInteractionShapeKind,
+ color=color,
+ label=self._label,
+ )
+
+ def __roiInteractiveModeEnded(self):
+ """Handle end of ROI draw interactive mode"""
+ if self.isStarted():
+ self._roiClass = None
+ self._source = None
+
+ if self._drawnROI is not None:
+ # Cancel ROI create
+ roi = self._drawnROI
+ self._drawnROI = None
+ self.removeRoi(roi)
+
+ plot = self.parent()
+ if plot is not None:
+ plot.sigPlotSignal.disconnect(self._handleInteraction)
+
+ self.sigInteractiveModeFinished.emit()
+
+ def stop(self):
+ """Stop interactive ROI drawing mode.
+
+ :return: True if an interactive ROI drawing mode was actually stopped
+ :rtype: bool
+ """
+ if not self.isStarted():
+ return False
+
+ plot = self.parent()
+ if plot is not None:
+ # This leads to call __roiInteractiveModeEnded through
+ # interactive mode changed signal
+ plot.resetInteractiveMode()
+ else: # Fallback
+ self.__roiInteractiveModeEnded()
+
+ return True
+
+ def exec(self, roiClass):
+ """Block until :meth:`quit` is called.
+
+ :param class kind: The class of the ROI which have to be created.
+ See `silx.gui.plot.items.roi`.
+ :return: The list of ROIs
+ :rtype: tuple
+ """
+ self.start(roiClass)
+
+ plot = self.parent()
+ plot.show()
+ plot.raise_()
+
+ self._eventLoop = qt.QEventLoop()
+ self._eventLoop.exec()
+ self._eventLoop = None
+
+ self.stop()
+
+ rois = self.getRois()
+ self.clear()
+ return rois
+
+ def exec_(self, roiClass): # Qt5-like compatibility
+ return self.exec(roiClass)
+
+ def quit(self):
+ """Stop a blocking :meth:`exec` and call :meth:`stop`"""
+ if self._eventLoop is not None:
+ self._eventLoop.quit()
+ self._eventLoop = None
+ self.stop()
+
+
+class InteractiveRegionOfInterestManager(RegionOfInterestManager):
+ """RegionOfInterestManager with features for use from interpreter.
+
+ It is meant to be used through the :meth:`exec`.
+ It provides some messages to display in a status bar and
+ different modes to end blocking calls to :meth:`exec`.
+
+ :param parent: See QObject
+ """
+
+ sigMessageChanged = qt.Signal(str)
+ """Signal emitted when a new message should be displayed to the user
+
+ It provides the message as a str.
+ """
+
+ def __init__(self, parent):
+ super(InteractiveRegionOfInterestManager, self).__init__(parent)
+ self._maxROI = None
+ self.__timeoutEndTime = None
+ self.__message = ""
+ self.__validationMode = self.ValidationMode.ENTER
+ self.__execClass = None
+
+ self.sigRoiAdded.connect(self.__added)
+ self.sigRoiAboutToBeRemoved.connect(self.__aboutToBeRemoved)
+ self.sigInteractiveModeStarted.connect(self.__started)
+ self.sigInteractiveModeFinished.connect(self.__finished)
+
+ # Max ROI
+
+ def getMaxRois(self):
+ """Returns the maximum number of ROIs or None if no limit.
+
+ :rtype: Union[int,None]
+ """
+ return self._maxROI
+
+ def setMaxRois(self, max_):
+ """Set the maximum number of ROIs.
+
+ :param Union[int,None] max_: The max limit or None for no limit.
+ :raise ValueError: If there is more ROIs than max value
+ """
+ if max_ is not None:
+ max_ = int(max_)
+ if max_ <= 0:
+ raise ValueError("Max limit must be strictly positive")
+
+ if len(self.getRois()) > max_:
+ raise ValueError("Cannot set max limit: Already too many ROIs")
+
+ self._maxROI = max_
+
+ def isMaxRois(self):
+ """Returns True if the maximum number of ROIs is reached.
+
+ :rtype: bool
+ """
+ max_ = self.getMaxRois()
+ return max_ is not None and len(self.getRois()) >= max_
+
+ # Validation mode
+
+ @enum.unique
+ class ValidationMode(enum.Enum):
+ """Mode of validation to leave blocking :meth:`exec`"""
+
+ AUTO = "auto"
+ """Automatically ends the interactive mode once
+ the user terminates the last ROI shape."""
+
+ ENTER = "enter"
+ """Ends the interactive mode when the *Enter* key is pressed."""
+
+ AUTO_ENTER = "auto_enter"
+ """Ends the interactive mode when reaching max ROIs or
+ when the *Enter* key is pressed.
+ """
+
+ NONE = "none"
+ """Do not provide the user a way to end the interactive mode.
+
+ The end of :meth:`exec` is done through :meth:`quit` or timeout.
+ """
+
+ def getValidationMode(self):
+ """Returns the interactive mode validation in use.
+
+ :rtype: ValidationMode
+ """
+ return self.__validationMode
+
+ def setValidationMode(self, mode):
+ """Set the way to perform interactive mode validation.
+
+ See :class:`ValidationMode` enumeration for the supported
+ validation modes.
+
+ :param ValidationMode mode: The interactive mode validation to use.
+ """
+ assert isinstance(mode, self.ValidationMode)
+ if mode != self.__validationMode:
+ self.__validationMode = mode
+
+ if self.isExec():
+ if self.isMaxRois() and self.getValidationMode() in (
+ self.ValidationMode.AUTO,
+ self.ValidationMode.AUTO_ENTER,
+ ):
+ self.quit()
+
+ self.__updateMessage()
+
+ def eventFilter(self, obj, event):
+ if event.type() == qt.QEvent.Hide:
+ self.quit()
+
+ if event.type() == qt.QEvent.KeyPress:
+ key = event.key()
+ if key in (
+ qt.Qt.Key_Return,
+ qt.Qt.Key_Enter,
+ ) and self.getValidationMode() in (
+ self.ValidationMode.ENTER,
+ self.ValidationMode.AUTO_ENTER,
+ ):
+ # Stop on return key pressed
+ self.quit()
+ return True # Stop further handling of this keys
+
+ if key in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or (
+ key == qt.Qt.Key_Z and event.modifiers() & qt.Qt.ControlModifier
+ ):
+ rois = self.getRois()
+ if rois: # Something to undo
+ self.removeRoi(rois[-1])
+ # Stop further handling of keys if something was undone
+ return True
+
+ return super(InteractiveRegionOfInterestManager, self).eventFilter(obj, event)
+
+ # Message API
+
+ def getMessage(self):
+ """Returns the current status message.
+
+ This message is meant to be displayed in a status bar.
+
+ :rtype: str
+ """
+ if self.__timeoutEndTime is None:
+ return self.__message
+ else:
+ remaining = self.__timeoutEndTime - time.time()
+ return self.__message + (" - %d seconds remaining" % max(1, int(remaining)))
+
+ # Listen to ROI updates
+
+ def __added(self, *args, **kwargs):
+ """Handle new ROI added"""
+ max_ = self.getMaxRois()
+ if max_ is not None:
+ # When reaching max number of ROIs, redo last one
+ while len(self.getRois()) > max_:
+ self.removeRoi(self.getRois()[-2])
+
+ self.__updateMessage()
+ if self.isMaxRois() and self.getValidationMode() in (
+ self.ValidationMode.AUTO,
+ self.ValidationMode.AUTO_ENTER,
+ ):
+ self.quit()
+
+ def __aboutToBeRemoved(self, *args, **kwargs):
+ """Handle removal of a ROI"""
+ # RegionOfInterest not removed yet
+ self.__updateMessage(nbrois=len(self.getRois()) - 1)
+
+ def __started(self, roiKind):
+ """Handle interactive mode started"""
+ self.__updateMessage()
+
+ def __finished(self):
+ """Handle interactive mode finished"""
+ self.__updateMessage()
+
+ def __updateMessage(self, nbrois=None):
+ """Update message"""
+ if not self.isExec():
+ message = "Done"
+
+ elif not self.isStarted():
+ message = "Use %s ROI edition mode" % self.__execClass
+
+ else:
+ if nbrois is None:
+ nbrois = len(self.getRois())
+
+ name = self.__execClass._getShortName()
+
+ max_ = self.getMaxRois()
+ if max_ is None:
+ message = "Select %ss (%d selected)" % (name, nbrois)
+
+ elif max_ <= 1:
+ message = "Select a %s" % name
+ else:
+ message = "Select %d/%d %ss" % (nbrois, max_, name)
+
+ if (
+ self.getValidationMode() == self.ValidationMode.ENTER
+ and self.isMaxRois()
+ ):
+ message += " - Press Enter to confirm"
+
+ if message != self.__message:
+ self.__message = message
+ # Use getMessage to add timeout message
+ self.sigMessageChanged.emit(self.getMessage())
+
+ # Handle blocking call
+
+ def __timeoutUpdate(self):
+ """Handle update of timeout"""
+ if (
+ self.__timeoutEndTime is not None
+ and (self.__timeoutEndTime - time.time()) > 0
+ ):
+ self.sigMessageChanged.emit(self.getMessage())
+ else: # Stop interactive mode and message timer
+ timer = self.sender()
+ if timer is not None:
+ timer.stop()
+ self.__timeoutEndTime = None
+ self.quit()
+
+ def isExec(self):
+ """Returns True if :meth:`exec` is currently running.
+
+ :rtype: bool"""
+ return self.__execClass is not None
+
+ def exec(self, roiClass, timeout=0):
+ """Block until ROI selection is done or timeout is elapsed.
+
+ :meth:`quit` also ends this blocking call.
+
+ :param class roiClass: The class of the ROI which have to be created.
+ See `silx.gui.plot.items.roi`.
+ :param int timeout: Maximum duration in seconds to block.
+ Default: No timeout
+ :return: The list of ROIs
+ :rtype: List[RegionOfInterest]
+ """
+ plot = self.parent()
+ if plot is None:
+ return
+
+ self.__execClass = roiClass
+
+ plot.installEventFilter(self)
+
+ if timeout > 0:
+ self.__timeoutEndTime = time.time() + timeout
+ timer = qt.QTimer(self)
+ timer.timeout.connect(self.__timeoutUpdate)
+ timer.start(1000)
+
+ rois = super(InteractiveRegionOfInterestManager, self).exec(roiClass)
+
+ timer.stop()
+ self.__timeoutEndTime = None
+
+ else:
+ rois = super(InteractiveRegionOfInterestManager, self).exec(roiClass)
+
+ plot.removeEventFilter(self)
+
+ self.__execClass = None
+ self.__updateMessage()
+
+ return rois
+
+ def exec_(self, roiClass, timeout=0): # Qt5-like compatibility
+ return self.exec(roiClass, timeout)
+
+
+class _DeleteRegionOfInterestToolButton(qt.QToolButton):
+ """Tool button deleting a ROI object
+
+ :param parent: See QWidget
+ :param RegionOfInterest roi: The ROI to delete
+ """
+
+ def __init__(self, parent, roi):
+ super(_DeleteRegionOfInterestToolButton, self).__init__(parent)
+ self.setIcon(icons.getQIcon("remove"))
+ self.setToolTip("Remove this ROI")
+ self.__roiRef = roi if roi is None else weakref.ref(roi)
+ self.clicked.connect(self.__clicked)
+
+ def __clicked(self, checked):
+ """Handle button clicked"""
+ roi = None if self.__roiRef is None else self.__roiRef()
+ if roi is not None:
+ manager = roi.parent()
+ if manager is not None:
+ manager.removeRoi(roi)
+ self.__roiRef = None
+
+
+class RegionOfInterestTableWidget(qt.QTableWidget):
+ """Widget displaying the ROIs of a :class:`RegionOfInterestManager`"""
+
+ # Columns indices of the different displayed information
+ (
+ _LABEL_VISIBLE_COL,
+ _EDITABLE_COL,
+ _KIND_COL,
+ _COORDINATES_COL,
+ _DELETE_COL,
+ ) = range(5)
+
+ def __init__(self, parent=None):
+ super(RegionOfInterestTableWidget, self).__init__(parent)
+ self._roiManagerRef = None
+
+ headers = ["Label", "Edit", "Kind", "Coordinates", ""]
+ self.setColumnCount(len(headers))
+ self.setHorizontalHeaderLabels(headers)
+
+ horizontalHeader = self.horizontalHeader()
+ horizontalHeader.setDefaultAlignment(qt.Qt.AlignLeft)
+
+ horizontalHeader.setSectionResizeMode(0, qt.QHeaderView.Interactive)
+ horizontalHeader.setSectionResizeMode(1, qt.QHeaderView.ResizeToContents)
+ horizontalHeader.setSectionResizeMode(2, qt.QHeaderView.ResizeToContents)
+ horizontalHeader.setSectionResizeMode(3, qt.QHeaderView.Stretch)
+ horizontalHeader.setSectionResizeMode(4, qt.QHeaderView.ResizeToContents)
+
+ verticalHeader = self.verticalHeader()
+ verticalHeader.setVisible(False)
+
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+ self.setFocusPolicy(qt.Qt.NoFocus)
+
+ self.itemChanged.connect(self.__itemChanged)
+
+ def __itemChanged(self, item):
+ """Handle QTableWidget item updates"""
+ column = item.column()
+ roi = item.data(qt.Qt.UserRole)
+ if roi is None:
+ return
+
+ if column == 0:
+ # First collect information from item, then update ROI
+ # Otherwise, this causes issues
+ checked = item.checkState() == qt.Qt.Checked
+ text = item.text()
+ roi.setVisible(checked)
+ roi.setName(text)
+ elif column == 1:
+ roi.setEditable(item.checkState() == qt.Qt.Checked)
+ elif column in (2, 3, 4):
+ pass # TODO
+ else:
+ logger.error("Unhandled column %d", column)
+
+ def setRegionOfInterestManager(self, manager):
+ """Set the :class:`RegionOfInterestManager` object to sync with
+
+ :param RegionOfInterestManager manager:
+ """
+ assert manager is None or isinstance(manager, RegionOfInterestManager)
+
+ previousManager = self.getRegionOfInterestManager()
+
+ if previousManager is not None:
+ previousManager.sigRoiAdded.disconnect(self.__roiAdded)
+ previousManager.sigRoiAboutToBeRemoved.disconnect(
+ self.__roiAboutToBeRemoved
+ )
+ for roi in previousManager.getRois():
+ self.__disconnectRoi(roi)
+
+ self.setRowCount(0)
+
+ self._roiManagerRef = weakref.ref(manager)
+
+ self._sync()
+
+ if manager is not None:
+ for roi in manager.getRois():
+ self.__connectRoi(roi)
+ manager.sigRoiAdded.connect(self.__roiAdded)
+ manager.sigRoiAboutToBeRemoved.connect(self.__roiAboutToBeRemoved)
+
+ def _getReadableRoiDescription(self, roi):
+ """Returns modelisation of a ROI as a readable sequence of values.
+
+ :rtype: str
+ """
+ text = str(roi)
+ try:
+ # Extract the params from syntax "CLASSNAME(PARAMS)"
+ elements = text.split("(", 1)
+ if len(elements) != 2:
+ return text
+ result = elements[1]
+ result = result.strip()
+ if not result.endswith(")"):
+ return text
+ result = result[0:-1]
+ # Capitalize each words
+ result = result.title()
+ return result
+ except Exception:
+ logger.debug("Backtrace", exc_info=True)
+ return text
+
+ def __connectRoi(self, roi: RegionOfInterest):
+ """Start listening ROI signals"""
+ roi.sigItemChanged.connect(self.__roiItemChanged)
+ roi.sigRegionChanged.connect(self.__roiRegionChanged)
+
+ def __disconnectRoi(self, roi: RegionOfInterest):
+ """Stop listening ROI signals"""
+ roi.sigItemChanged.disconnect(self.__roiItemChanged)
+ roi.sigRegionChanged.disconnect(self.__roiRegionChanged)
+
+ def __getRoiRow(self, roi: RegionOfInterest) -> int:
+ """Returns row index of given region of interest
+
+ :raises ValueError: If region of interest is not in the list
+ """
+ manager = self.getRegionOfInterestManager()
+ if manager is None:
+ return
+ return manager.getRois().index(roi)
+
+ def __roiAdded(self, roi: RegionOfInterest):
+ """Handle new ROI added to the manager"""
+ self.__connectRoi(roi)
+ self._sync()
+
+ def __roiAboutToBeRemoved(self, roi: RegionOfInterest):
+ """Handle removing a ROI from the manager"""
+ self.__disconnectRoi(roi)
+ self.removeRow(self.__getRoiRow(roi))
+
+ def __roiItemChanged(self, event: ItemChangedType):
+ """Handle ROI sigItemChanged events"""
+ roi = self.sender()
+ if roi is None:
+ return
+
+ try:
+ row = self.__getRoiRow(roi)
+ except ValueError:
+ return
+
+ if event == ItemChangedType.VISIBLE:
+ item = self.item(row, self._LABEL_VISIBLE_COL)
+ item.setCheckState(qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked)
+ return
+
+ if event == ItemChangedType.NAME:
+ item = self.item(row, self._LABEL_VISIBLE_COL)
+ item.setText(roi.getName())
+ return
+
+ if event == ItemChangedType.EDITABLE:
+ item = self.item(row, self._EDITABLE_COL)
+ item.setCheckState(qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
+ return
+
+ def __roiRegionChanged(self):
+ """Handle change of ROI coordinates"""
+ roi = self.sender()
+ if roi is None:
+ return
+
+ item = self.item(self.__getRoiRow(roi), self._COORDINATES_COL)
+ if item is None:
+ return
+
+ text = self._getReadableRoiDescription(roi)
+ item.setText(text)
+
+ def _sync(self):
+ """Update widget content according to ROI manger"""
+ manager = self.getRegionOfInterestManager()
+
+ if manager is None:
+ self.setRowCount(0)
+ return
+
+ rois = manager.getRois()
+
+ self.setRowCount(len(rois))
+ for index, roi in enumerate(rois):
+ baseFlags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled
+
+ # Label and visible
+ item = qt.QTableWidgetItem()
+ item.setFlags(baseFlags | qt.Qt.ItemIsEditable | qt.Qt.ItemIsUserCheckable)
+ item.setData(qt.Qt.UserRole, roi)
+ item.setText(roi.getName())
+ item.setCheckState(qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked)
+ self.setItem(index, self._LABEL_VISIBLE_COL, item)
+
+ # Editable
+ item = qt.QTableWidgetItem()
+ item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable)
+ item.setData(qt.Qt.UserRole, roi)
+ item.setCheckState(qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
+ self.setItem(index, self._EDITABLE_COL, item)
+ item.setTextAlignment(qt.Qt.AlignCenter)
+ item.setText(None)
+
+ # Kind
+ label = roi._getShortName()
+ if label is None:
+ # Default value if kind is not overrided
+ label = roi.__class__.__name__
+ item = qt.QTableWidgetItem(label.capitalize())
+ item.setFlags(baseFlags)
+ self.setItem(index, self._KIND_COL, item)
+
+ # Coordinates
+ item = qt.QTableWidgetItem()
+ item.setFlags(baseFlags)
+ text = self._getReadableRoiDescription(roi)
+ item.setText(text)
+ self.setItem(index, self._COORDINATES_COL, item)
+
+ # Delete
+ widget = qt.QWidget(self)
+ delBtn = _DeleteRegionOfInterestToolButton(widget, roi)
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(2, 2, 2, 2)
+ layout.setSpacing(0)
+ widget.setLayout(layout)
+ layout.addStretch(1)
+ layout.addWidget(delBtn)
+ layout.addStretch(1)
+ self.setCellWidget(index, self._DELETE_COL, widget)
+
+ def getRegionOfInterestManager(self):
+ """Returns the :class:`RegionOfInterestManager` this widget supervise.
+
+ It returns None if not sync with an :class:`RegionOfInterestManager`.
+
+ :rtype: RegionOfInterestManager
+ """
+ return None if self._roiManagerRef is None else self._roiManagerRef()
diff --git a/src/silx/gui/plot/tools/test/__init__.py b/src/silx/gui/plot/tools/test/__init__.py
new file mode 100644
index 0000000..2e682d7
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/__init__.py
@@ -0,0 +1,23 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
new file mode 100644
index 0000000..9f1a184
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
@@ -0,0 +1,117 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/08/2018"
+
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.tools import CurveLegendsWidget
+
+
+class TestCurveLegendsWidget(TestCaseQt, ParametricTestCase):
+ """Tests for CurveLegendsWidget class"""
+
+ def setUp(self):
+ super(TestCurveLegendsWidget, self).setUp()
+ self.plot = PlotWindow()
+
+ self.legends = CurveLegendsWidget.CurveLegendsWidget()
+ self.legends.setPlotWidget(self.plot)
+
+ dock = qt.QDockWidget()
+ dock.setWindowTitle("Curve Legends")
+ dock.setWidget(self.legends)
+ self.plot.addTabbedDockWidget(dock)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ del self.legends
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestCurveLegendsWidget, self).tearDown()
+
+ def _assertNbLegends(self, count):
+ """Check the number of legends in the CurveLegendsWidget"""
+ children = self.legends.findChildren(CurveLegendsWidget._LegendWidget)
+ self.assertEqual(len(children), count)
+
+ def testAddRemoveCurves(self):
+ """Test CurveLegendsWidget while adding/removing curves"""
+ self.plot.addCurve((0, 1), (1, 2), legend="a")
+ self._assertNbLegends(1)
+ self.plot.addCurve((0, 1), (2, 3), legend="b")
+ self._assertNbLegends(2)
+
+ # Detached/attach
+ self.legends.setPlotWidget(None)
+ self._assertNbLegends(0)
+
+ self.legends.setPlotWidget(self.plot)
+ self._assertNbLegends(2)
+
+ self.plot.clear()
+ self._assertNbLegends(0)
+
+ def testUpdateCurves(self):
+ """Test CurveLegendsWidget while updating curves"""
+ self.plot.addCurve((0, 1), (1, 2), legend="a")
+ self._assertNbLegends(1)
+ self.plot.addCurve((0, 1), (2, 3), legend="b")
+ self._assertNbLegends(2)
+
+ # Activate curve
+ self.plot.setActiveCurve("a")
+ self.qapp.processEvents()
+ self.plot.setActiveCurve("b")
+ self.qapp.processEvents()
+
+ # Change curve style
+ curve = self.plot.getCurve("a")
+ curve.setLineWidth(2)
+ for linestyle in (
+ ":",
+ "",
+ "--",
+ "-",
+ (0.0, (5.0, 5.0)),
+ (5.0, (10.0, 2.0, 2.0, 5.0)),
+ ):
+ with self.subTest(linestyle=linestyle):
+ curve.setLineStyle(linestyle)
+ self.qapp.processEvents()
+ self.qWait(1000)
+
+ for symbol in ("o", "d", "", "s"):
+ with self.subTest(symbol=symbol):
+ curve.setSymbol(symbol)
+ self.qapp.processEvents()
+ self.qWait(1000)
diff --git a/src/silx/gui/plot/tools/test/testProfile.py b/src/silx/gui/plot/tools/test/testProfile.py
new file mode 100644
index 0000000..61b95a6
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testProfile.py
@@ -0,0 +1,604 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import contextlib
+import numpy
+import logging
+
+from silx.gui import qt
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.plot import PlotWindow, Plot1D, Plot2D, Profile
+from silx.gui.plot.StackView import StackView
+from silx.gui.plot.tools.profile import rois
+from silx.gui.plot.tools.profile import editors
+from silx.gui.plot.items import roi as roi_items
+from silx.gui.plot.tools.profile import manager
+from silx.gui import plot as silx_plot
+
+_logger = logging.getLogger(__name__)
+
+
+class TestRois(TestCaseQt):
+ def test_init(self):
+ """Check that the constructor is not called twice"""
+ roi = rois.ProfileImageVerticalLineROI()
+ if qt.BINDING == "PyQt5":
+ # the profile ROI + the shape
+ self.assertEqual(roi.receivers(roi.sigRegionChanged), 2)
+
+
+class TestInteractions(TestCaseQt):
+ @contextlib.contextmanager
+ def defaultPlot(self):
+ try:
+ widget = silx_plot.PlotWidget()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ @contextlib.contextmanager
+ def imagePlot(self):
+ try:
+ widget = silx_plot.Plot2D()
+ image = numpy.arange(10 * 10).reshape(10, -1)
+ widget.addImage(image)
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ @contextlib.contextmanager
+ def scatterPlot(self):
+ try:
+ widget = silx_plot.ScatterView()
+
+ nbX, nbY = 7, 5
+ yy = numpy.atleast_2d(numpy.ones(nbY)).T
+ xx = numpy.atleast_2d(numpy.ones(nbX))
+ positionX = numpy.linspace(10, 50, nbX) * yy
+ positionX = positionX.reshape(nbX * nbY)
+ positionY = numpy.atleast_2d(numpy.linspace(20, 60, nbY)).T * xx
+ positionY = positionY.reshape(nbX * nbY)
+ values = numpy.arange(nbX * nbY)
+
+ widget.setData(positionX, positionY, values)
+ widget.resetZoom()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget.getPlotWidget()
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ @contextlib.contextmanager
+ def stackPlot(self):
+ try:
+ widget = silx_plot.StackView()
+ image = numpy.arange(10 * 10).reshape(10, -1)
+ cube = numpy.array([image, image, image])
+ widget.setStack(cube)
+ widget.resetZoom()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget.getPlotWidget()
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ def waitPendingOperations(self, proflie):
+ for _ in range(10):
+ if not proflie.hasPendingOperations():
+ return
+ self.qWait(100)
+ _logger.error("The profile manager still have pending operations")
+
+ def genericRoiTest(self, plot, roiClass):
+ profileManager = manager.ProfileManager(plot, plot)
+ profileManager.setItemType(image=True, scatter=True)
+
+ try:
+ action = profileManager.createProfileAction(roiClass, plot)
+ action.triggered[bool].emit(True)
+ widget = plot.getWidgetHandle()
+
+ # Do the mouse interaction
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ self.mouseMove(widget, pos=pos1)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
+
+ if issubclass(roiClass, roi_items.LineROI):
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+ self.mouseMove(widget, pos=pos2)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos2)
+
+ self.waitPendingOperations(profileManager)
+
+ # Test that something was computed
+ if issubclass(roiClass, rois._ProfileCrossROI):
+ self.assertEqual(profileManager._computedProfiles, 2)
+ elif issubclass(roiClass, roi_items.LineROI):
+ self.assertGreaterEqual(profileManager._computedProfiles, 1)
+ else:
+ self.assertEqual(profileManager._computedProfiles, 1)
+
+ # Test the created ROIs
+ profileRois = profileManager.getRoiManager().getRois()
+ if issubclass(roiClass, rois._ProfileCrossROI):
+ self.assertEqual(len(profileRois), 3)
+ else:
+ self.assertEqual(len(profileRois), 1)
+ # The first one should be the expected one
+ roi = profileRois[0]
+
+ # Test that something was displayed
+ if issubclass(roiClass, rois._ProfileCrossROI):
+ profiles = roi._getLines()
+ window = profiles[0].getProfileWindow()
+ self.assertIsNotNone(window)
+ window = profiles[1].getProfileWindow()
+ self.assertIsNotNone(window)
+ else:
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ finally:
+ profileManager.clearProfile()
+
+ def testImageActions(self):
+ roiClasses = [
+ rois.ProfileImageHorizontalLineROI,
+ rois.ProfileImageVerticalLineROI,
+ rois.ProfileImageLineROI,
+ rois.ProfileImageCrossROI,
+ ]
+ with self.imagePlot() as plot:
+ for roiClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ self.genericRoiTest(plot, roiClass)
+
+ def testScatterActions(self):
+ roiClasses = [
+ rois.ProfileScatterHorizontalLineROI,
+ rois.ProfileScatterVerticalLineROI,
+ rois.ProfileScatterLineROI,
+ rois.ProfileScatterCrossROI,
+ rois.ProfileScatterHorizontalSliceROI,
+ rois.ProfileScatterVerticalSliceROI,
+ rois.ProfileScatterCrossSliceROI,
+ ]
+ with self.scatterPlot() as plot:
+ for roiClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ self.genericRoiTest(plot, roiClass)
+
+ def testStackActions(self):
+ roiClasses = [
+ rois.ProfileImageStackHorizontalLineROI,
+ rois.ProfileImageStackVerticalLineROI,
+ rois.ProfileImageStackLineROI,
+ rois.ProfileImageStackCrossROI,
+ ]
+ with self.stackPlot() as plot:
+ for roiClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ self.genericRoiTest(plot, roiClass)
+
+ def genericEditorTest(self, plot, roi, editor):
+ if isinstance(editor, editors._NoProfileRoiEditor):
+ pass
+ elif isinstance(editor, editors._DefaultImageStackProfileRoiEditor):
+ # GUI to ROI
+ editor._lineWidth.setValue(2)
+ self.assertEqual(roi.getProfileLineWidth(), 2)
+ editor._methodsButton.setMethod("sum")
+ self.assertEqual(roi.getProfileMethod(), "sum")
+ editor._profileDim.setDimension(1)
+ self.assertEqual(roi.getProfileType(), "1D")
+ # ROI to GUI
+ roi.setProfileLineWidth(3)
+ self.assertEqual(editor._lineWidth.value(), 3)
+ roi.setProfileMethod("mean")
+ self.assertEqual(editor._methodsButton.getMethod(), "mean")
+ roi.setProfileType("2D")
+ self.assertEqual(editor._profileDim.getDimension(), 2)
+ elif isinstance(editor, editors._DefaultImageProfileRoiEditor):
+ # GUI to ROI
+ editor._lineWidth.setValue(2)
+ self.assertEqual(roi.getProfileLineWidth(), 2)
+ editor._methodsButton.setMethod("sum")
+ self.assertEqual(roi.getProfileMethod(), "sum")
+ # ROI to GUI
+ roi.setProfileLineWidth(3)
+ self.assertEqual(editor._lineWidth.value(), 3)
+ roi.setProfileMethod("mean")
+ self.assertEqual(editor._methodsButton.getMethod(), "mean")
+ elif isinstance(editor, editors._DefaultScatterProfileRoiEditor):
+ # GUI to ROI
+ editor._nPoints.setValue(100)
+ self.assertEqual(roi.getNPoints(), 100)
+ # ROI to GUI
+ roi.setNPoints(200)
+ self.assertEqual(editor._nPoints.value(), 200)
+ else:
+ assert False
+
+ def testEditors(self):
+ roiClasses = [
+ (rois.ProfileImageHorizontalLineROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileImageVerticalLineROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileImageLineROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileImageCrossROI, editors._DefaultImageProfileRoiEditor),
+ (
+ rois.ProfileScatterHorizontalLineROI,
+ editors._DefaultScatterProfileRoiEditor,
+ ),
+ (
+ rois.ProfileScatterVerticalLineROI,
+ editors._DefaultScatterProfileRoiEditor,
+ ),
+ (rois.ProfileScatterLineROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterCrossROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterHorizontalSliceROI, editors._NoProfileRoiEditor),
+ (rois.ProfileScatterVerticalSliceROI, editors._NoProfileRoiEditor),
+ (rois.ProfileScatterCrossSliceROI, editors._NoProfileRoiEditor),
+ (
+ rois.ProfileImageStackHorizontalLineROI,
+ editors._DefaultImageStackProfileRoiEditor,
+ ),
+ (
+ rois.ProfileImageStackVerticalLineROI,
+ editors._DefaultImageStackProfileRoiEditor,
+ ),
+ (rois.ProfileImageStackLineROI, editors._DefaultImageStackProfileRoiEditor),
+ (
+ rois.ProfileImageStackCrossROI,
+ editors._DefaultImageStackProfileRoiEditor,
+ ),
+ ]
+ with self.defaultPlot() as plot:
+ profileManager = manager.ProfileManager(plot, plot)
+ editorAction = profileManager.createEditorAction(parent=plot)
+ for roiClass, editorClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ roi = roiClass()
+ roi._setProfileManager(profileManager)
+ try:
+ # Force widget creation
+ menu = qt.QMenu(plot)
+ menu.addAction(editorAction)
+ widgets = editorAction.createdWidgets()
+ self.assertGreater(len(widgets), 0)
+
+ editorAction.setProfileRoi(roi)
+ editorWidget = editorAction._getEditor(widgets[0])
+ self.assertIsInstance(editorWidget, editorClass)
+ self.genericEditorTest(plot, roi, editorWidget)
+ finally:
+ editorAction.setProfileRoi(None)
+ menu.deleteLater()
+ menu = None
+ self.qapp.processEvents()
+
+
+class TestProfileToolBar(TestCaseQt, ParametricTestCase):
+ """Tests for ProfileToolBar widget."""
+
+ def setUp(self):
+ super(TestProfileToolBar, self).setUp()
+ self.plot = PlotWindow()
+ self.toolBar = Profile.ProfileToolBar(plot=self.plot)
+ self.plot.addToolBar(self.toolBar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ profileManager = self.toolBar.getProfileManager()
+ profileManager.clearProfile()
+ profileManager = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.toolBar
+
+ super(TestProfileToolBar, self).tearDown()
+
+ def testAlignedProfile(self):
+ """Test horizontal and vertical profile, without and with image"""
+ # Use Plot backend widget to submit mouse events
+ widget = self.plot.getWidgetHandle()
+ for method in ("sum", "mean"):
+ with self.subTest(method=method):
+ # 2 positions to use for mouse events
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+
+ for action in (self.toolBar.hLineAction, self.toolBar.vLineAction):
+ with self.subTest(mode=action.text()):
+ # Trigger tool button for mode
+ action.trigger()
+ # Without image
+ self.mouseMove(widget, pos=pos1)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
+
+ # with image
+ self.plot.addImage(numpy.arange(100 * 100).reshape(100, -1))
+ self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
+ self.mouseMove(widget, pos=pos2)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
+
+ self.mouseMove(widget)
+ self.mouseClick(widget, qt.Qt.LeftButton)
+
+ manager = self.toolBar.getProfileManager()
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ def testDiagonalProfile(self):
+ """Test diagonal profile, without and with image"""
+ # Use Plot backend widget to submit mouse events
+ widget = self.plot.getWidgetHandle()
+
+ self.plot.addImage(numpy.arange(100 * 100).reshape(100, -1))
+
+ for method in ("sum", "mean"):
+ with self.subTest(method=method):
+ # 2 positions to use for mouse events
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+
+ # Trigger tool button for diagonal profile mode
+ self.toolBar.lineAction.trigger()
+
+ # draw profile line
+ widget.setFocus(qt.Qt.OtherFocusReason)
+ self.mouseMove(widget, pos=pos1)
+ self.qWait(100)
+ self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
+ self.qWait(100)
+ self.mouseMove(widget, pos=pos2)
+ self.qWait(100)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
+ self.qWait(100)
+
+ manager = self.toolBar.getProfileManager()
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ roi = manager.getCurrentRoi()
+ self.assertIsNotNone(roi)
+ roi.setProfileLineWidth(3)
+ roi.setProfileMethod(method)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ curveItem = (
+ roi.getProfileWindow().getCurrentPlotWidget().getAllCurves()[0]
+ )
+ if method == "sum":
+ self.assertTrue(curveItem.getData()[1].max() > 10000)
+ elif method == "mean":
+ self.assertTrue(curveItem.getData()[1].max() < 10000)
+
+ # Remove the ROI so the profile window is also removed
+ roiManager = manager.getRoiManager()
+ roiManager.removeRoi(roi)
+ self.qWait(100)
+
+
+class TestProfile3DToolBar(TestCaseQt):
+ """Tests for Profile3DToolBar widget."""
+
+ def setUp(self):
+ super(TestProfile3DToolBar, self).setUp()
+ self.plot = StackView()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.plot.setStack(
+ numpy.array(
+ [
+ [[0, 1, 2], [3, 4, 5]],
+ [[6, 7, 8], [9, 10, 11]],
+ [[12, 13, 14], [15, 16, 17]],
+ ]
+ )
+ )
+
+ def tearDown(self):
+ profileManager = self.plot.getProfileToolbar().getProfileManager()
+ profileManager.clearProfile()
+ profileManager = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+
+ super(TestProfile3DToolBar, self).tearDown()
+
+ def testMethodProfile2D(self):
+ """Test that the profile can have a different method if we want to
+ compute then in 1D or in 2D"""
+
+ toolBar = self.plot.getProfileToolbar()
+
+ toolBar.vLineAction.trigger()
+ plot2D = self.plot.getPlotWidget().getWidgetHandle()
+ pos1 = plot2D.width() * 0.5, plot2D.height() * 0.5
+ self.mouseClick(plot2D, qt.Qt.LeftButton, pos=pos1)
+
+ manager = toolBar.getProfileManager()
+ roi = manager.getCurrentRoi()
+ roi.setProfileMethod("mean")
+ roi.setProfileType("2D")
+ roi.setProfileLineWidth(3)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ # check 2D 'mean' profile
+ profilePlot = roi.getProfileWindow().getCurrentPlotWidget()
+ data = profilePlot.getAllImages()[0].getData()
+ expected = numpy.array([[1, 4], [7, 10], [13, 16]])
+ numpy.testing.assert_almost_equal(data, expected)
+
+ def testMethodSumLine(self):
+ """Simple interaction test to make sure the sum is correctly computed"""
+ toolBar = self.plot.getProfileToolbar()
+
+ toolBar.lineAction.trigger()
+ plot2D = self.plot.getPlotWidget().getWidgetHandle()
+ pos1 = plot2D.width() * 0.5, plot2D.height() * 0.2
+ pos2 = plot2D.width() * 0.5, plot2D.height() * 0.8
+
+ self.mouseMove(plot2D, pos=pos1)
+ self.mousePress(plot2D, qt.Qt.LeftButton, pos=pos1)
+ self.mouseMove(plot2D, pos=pos2)
+ self.mouseRelease(plot2D, qt.Qt.LeftButton, pos=pos2)
+
+ manager = toolBar.getProfileManager()
+ roi = manager.getCurrentRoi()
+ roi.setProfileMethod("sum")
+ roi.setProfileType("2D")
+ roi.setProfileLineWidth(3)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ # check 2D 'sum' profile
+ profilePlot = roi.getProfileWindow().getCurrentPlotWidget()
+ data = profilePlot.getAllImages()[0].getData()
+ expected = numpy.array([[3, 12], [21, 30], [39, 48]])
+ numpy.testing.assert_almost_equal(data, expected)
+
+
+class TestGetProfilePlot(TestCaseQt):
+ def setUp(self):
+ self.plot = None
+ super(TestGetProfilePlot, self).setUp()
+
+ def tearDown(self):
+ if self.plot is not None:
+ manager = self.plot.getProfileToolbar().getProfileManager()
+ manager.clearProfile()
+ manager = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+
+ super(TestGetProfilePlot, self).tearDown()
+
+ def testProfile1D(self):
+ self.plot = Plot2D()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ self.plot.addImage([[0, 1], [2, 3]])
+
+ toolBar = self.plot.getProfileToolbar()
+
+ manager = toolBar.getProfileManager()
+ roiManager = manager.getRoiManager()
+
+ roi = rois.ProfileImageHorizontalLineROI()
+ roi.setPosition(0.5)
+ roiManager.addRoi(roi)
+ roiManager.setCurrentRoi(roi)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ profileWindow = roi.getProfileWindow()
+ self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
+ self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D)
+
+ def testProfile2D(self):
+ """Test that the profile plot associated to a stack view is either a
+ Plot1D or a plot 2D instance."""
+ self.plot = StackView()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.plot.setStack(numpy.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]))
+
+ toolBar = self.plot.getProfileToolbar()
+
+ manager = toolBar.getProfileManager()
+ roiManager = manager.getRoiManager()
+
+ roi = rois.ProfileImageStackHorizontalLineROI()
+ roi.setPosition(0.5)
+ roi.setProfileType("2D")
+ roiManager.addRoi(roi)
+ roiManager.setCurrentRoi(roi)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ profileWindow = roi.getProfileWindow()
+ self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
+ self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot2D)
+
+ roi.setProfileType("1D")
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ profileWindow = roi.getProfileWindow()
+ self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
+ self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D)
diff --git a/src/silx/gui/plot/tools/test/testRoiCore.py b/src/silx/gui/plot/tools/test/testRoiCore.py
new file mode 100644
index 0000000..e7f6d8a
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testRoiCore.py
@@ -0,0 +1,535 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import numpy.testing
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+from silx.gui.plot import PlotWindow
+import silx.gui.plot.items.roi as roi_items
+from silx.gui.plot.tools import roi
+
+
+class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
+ """Tests for RegionOfInterestManager class"""
+
+ def setUp(self):
+ super(TestRegionOfInterestManager, self).setUp()
+ self.plot = PlotWindow()
+
+ self.roiTableWidget = roi.RegionOfInterestTableWidget()
+ dock = qt.QDockWidget()
+ dock.setWidget(self.roiTableWidget)
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ del self.roiTableWidget
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestRegionOfInterestManager, self).tearDown()
+
+ def test(self):
+ """Test ROI of different shapes"""
+ tests = ( # shape, points=[list of (x, y), list of (x, y)]
+ (roi_items.PointROI, numpy.array(([(10.0, 15.0)], [(20.0, 25.0)]))),
+ (
+ roi_items.RectangleROI,
+ numpy.array((((1.0, 10.0), (11.0, 20.0)), ((2.0, 3.0), (12.0, 13.0)))),
+ ),
+ (
+ roi_items.PolygonROI,
+ numpy.array(
+ (
+ ((0.0, 1.0), (0.0, 10.0), (10.0, 0.0)),
+ ((5.0, 6.0), (5.0, 16.0), (15.0, 6.0)),
+ )
+ ),
+ ),
+ (
+ roi_items.LineROI,
+ numpy.array(
+ (((10.0, 20.0), (10.0, 30.0)), ((30.0, 40.0), (30.0, 50.0)))
+ ),
+ ),
+ (
+ roi_items.HorizontalLineROI,
+ numpy.array(
+ (((10.0, 20.0), (10.0, 30.0)), ((30.0, 40.0), (30.0, 50.0)))
+ ),
+ ),
+ (
+ roi_items.VerticalLineROI,
+ numpy.array(
+ (((10.0, 20.0), (10.0, 30.0)), ((30.0, 40.0), (30.0, 50.0)))
+ ),
+ ),
+ (
+ roi_items.HorizontalLineROI,
+ numpy.array(
+ (((10.0, 20.0), (10.0, 30.0)), ((30.0, 40.0), (30.0, 50.0)))
+ ),
+ ),
+ )
+
+ for roiClass, points in tests:
+ with self.subTest(roiClass=roiClass):
+ manager = roi.RegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ manager.start(roiClass)
+
+ self.assertEqual(manager.getRois(), ())
+
+ finishListener = SignalListener()
+ manager.sigInteractiveModeFinished.connect(finishListener)
+
+ changedListener = SignalListener()
+ manager.sigRoiChanged.connect(changedListener)
+
+ # Add a point
+ r = roiClass()
+ r.setFirstShapePoints(points[0])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ self.assertTrue(len(manager.getRois()), 1)
+ self.assertEqual(changedListener.callCount(), 1)
+
+ # Remove it
+ manager.removeRoi(manager.getRois()[0])
+ self.assertEqual(manager.getRois(), ())
+ self.assertEqual(changedListener.callCount(), 2)
+
+ # Add two point
+ r = roiClass()
+ r.setFirstShapePoints(points[0])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ r = roiClass()
+ r.setFirstShapePoints(points[1])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ self.assertTrue(len(manager.getRois()), 2)
+ self.assertEqual(changedListener.callCount(), 4)
+
+ # Reset it
+ result = manager.clear()
+ self.assertTrue(result)
+ self.assertEqual(manager.getRois(), ())
+ self.assertEqual(changedListener.callCount(), 5)
+
+ changedListener.clear()
+
+ # Add two point
+ r = roiClass()
+ r.setFirstShapePoints(points[0])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ r = roiClass()
+ r.setFirstShapePoints(points[1])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ self.assertTrue(len(manager.getRois()), 2)
+ self.assertEqual(changedListener.callCount(), 2)
+
+ # stop
+ result = manager.stop()
+ self.assertTrue(result)
+ self.assertTrue(len(manager.getRois()), 1)
+ self.qapp.processEvents()
+ self.assertEqual(finishListener.callCount(), 1)
+
+ manager.clear()
+
+ def testRoiDisplay(self):
+ rois = []
+
+ # Line
+ item = roi_items.LineROI()
+ startPoint = numpy.array([1, 2])
+ endPoint = numpy.array([3, 4])
+ item.setEndPoints(startPoint, endPoint)
+ rois.append(item)
+ # Horizontal line
+ item = roi_items.HorizontalLineROI()
+ item.setPosition(15)
+ rois.append(item)
+ # Vertical line
+ item = roi_items.VerticalLineROI()
+ item.setPosition(15)
+ rois.append(item)
+ # Point
+ item = roi_items.PointROI()
+ point = numpy.array([1, 2])
+ item.setPosition(point)
+ rois.append(item)
+ # Rectangle
+ item = roi_items.RectangleROI()
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item.setGeometry(origin=origin, size=size)
+ rois.append(item)
+ # Polygon
+ item = roi_items.PolygonROI()
+ points = numpy.array([[10, 10], [12, 10], [50, 1]])
+ item.setPoints(points)
+ rois.append(item)
+ # Degenerated polygon: No points
+ item = roi_items.PolygonROI()
+ points = numpy.empty((0, 2))
+ item.setPoints(points)
+ rois.append(item)
+ # Degenerated polygon: A single point
+ item = roi_items.PolygonROI()
+ points = numpy.array([[5, 10]])
+ item.setPoints(points)
+ rois.append(item)
+ # Degenerated arc: it's a point
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Degenerated arc: it's a line
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Special arc: it's a donut
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Arc
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = (
+ 1,
+ 100,
+ numpy.pi * 0.5,
+ numpy.pi,
+ )
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Horizontal Range
+ item = roi_items.HorizontalRangeROI()
+ item.setRange(-1, 3)
+ rois.append(item)
+
+ manager = roi.RegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ for item in rois:
+ with self.subTest(roi=str(item)):
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ item.setEditable(True)
+ self.qapp.processEvents()
+ item.setEditable(False)
+ self.qapp.processEvents()
+ manager.removeRoi(item)
+ self.qapp.processEvents()
+
+ def testSelectionProxy(self):
+ item1 = roi_items.PointROI()
+ item1.setSelectable(True)
+ item2 = roi_items.PointROI()
+ item2.setSelectable(True)
+ item1.setFocusProxy(item2)
+ manager = roi.RegionOfInterestManager(self.plot)
+ manager.setCurrentRoi(item1)
+ self.assertIs(manager.getCurrentRoi(), item2)
+
+ def testRemovedSelection(self):
+ item1 = roi_items.PointROI()
+ item1.setSelectable(True)
+ manager = roi.RegionOfInterestManager(self.plot)
+ manager.addRoi(item1)
+ manager.setCurrentRoi(item1)
+ manager.removeRoi(item1)
+ self.assertIs(manager.getCurrentRoi(), None)
+
+ def testInitROIWithParent(self):
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.PointROI(manager)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ manager.removeRoi(item)
+ self.qapp.processEvents()
+
+ def testMaxROI(self):
+ """Test Max ROI"""
+ origin1 = numpy.array([1.0, 10.0])
+ size1 = numpy.array([10.0, 10.0])
+ origin2 = numpy.array([2.0, 3.0])
+ size2 = numpy.array([10.0, 10.0])
+
+ manager = roi.InteractiveRegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ self.assertEqual(manager.getRois(), ())
+
+ changedListener = SignalListener()
+ manager.sigRoiChanged.connect(changedListener)
+
+ # Add two point
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin1, size=size1)
+ manager.addRoi(item)
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin2, size=size2)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ self.assertEqual(changedListener.callCount(), 2)
+ self.assertEqual(len(manager.getRois()), 2)
+
+ # Try to set max ROI to 1 while there is 2 ROIs
+ with self.assertRaises(ValueError):
+ manager.setMaxRois(1)
+
+ manager.clear()
+ self.assertEqual(len(manager.getRois()), 0)
+ self.assertEqual(changedListener.callCount(), 3)
+
+ # Set max limit to 1
+ manager.setMaxRois(1)
+
+ # Add a point
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin1, size=size1)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ self.assertEqual(changedListener.callCount(), 4)
+
+ # Add a 2nd point while max ROI is 1
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin1, size=size1)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ self.assertEqual(changedListener.callCount(), 6)
+ self.assertEqual(len(manager.getRois()), 1)
+
+ def testChangeInteractionMode(self):
+ """Test change of interaction mode"""
+ manager = roi.RegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ manager.start(roi_items.PointROI)
+
+ interactiveModeToolBar = self.plot.getInteractiveModeToolBar()
+ panAction = interactiveModeToolBar.getPanModeAction()
+
+ for roiClass in manager.getSupportedRoiClasses():
+ with self.subTest(roiClass=roiClass):
+ # Change to pan mode
+ panAction.trigger()
+
+ # Change to interactive ROI mode
+ action = manager.getInteractionModeAction(roiClass)
+ action.trigger()
+
+ self.assertEqual(roiClass, manager.getCurrentInteractionModeRoiClass())
+
+ manager.clear()
+
+ def testLineInteraction(self):
+ """This test make sure that a ROI based on handles can be edited with
+ the mouse."""
+ xlimit = self.plot.getXAxis().getLimits()
+ ylimit = self.plot.getYAxis().getLimits()
+ points = numpy.array([xlimit, ylimit]).T
+ center = numpy.mean(points, axis=0)
+
+ # Create the line
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.LineROI()
+ item.setEndPoints(points[0], points[1])
+ item.setEditable(True)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+
+ # Drag the center
+ widget = self.plot.getWidgetHandle()
+ mx, my = self.plot.dataToPixel(*center)
+ self.mouseMove(widget, pos=(mx, my))
+ self.mousePress(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.mouseMove(widget, pos=(mx, my + 25))
+ self.mouseMove(widget, pos=(mx, my + 50))
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my + 50))
+
+ result = numpy.array(item.getEndPoints())
+ # x location is still the same
+ numpy.testing.assert_allclose(points[:, 0], result[:, 0], atol=0.5)
+ # size is still the same
+ numpy.testing.assert_allclose(
+ points[1] - points[0], result[1] - result[0], atol=0.5
+ )
+ # But Y is not the same
+ self.assertNotEqual(points[0, 1], result[0, 1])
+ self.assertNotEqual(points[1, 1], result[1, 1])
+ item = None
+ manager.clear()
+ self.qapp.processEvents()
+
+ def testPlotWhenCleared(self):
+ """PlotWidget.clear should clean up the available ROIs"""
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.LineROI()
+ item.setEndPoints((0, 0), (1, 1))
+ item.setEditable(True)
+ manager.addRoi(item)
+ self.qWait()
+ try:
+ # Make sure the test setup is fine
+ self.assertNotEqual(len(manager.getRois()), 0)
+ self.assertNotEqual(len(self.plot.getItems()), 0)
+
+ # Call clear and test the expected state
+ self.plot.clear()
+ self.assertEqual(len(manager.getRois()), 0)
+ self.assertEqual(len(self.plot.getItems()), 0)
+ finally:
+ # Clean up
+ manager.clear()
+
+ def testPlotWhenRoiRemoved(self):
+ """Make sure there is no remaining items in the plot when a ROI is removed"""
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.LineROI()
+ item.setEndPoints((0, 0), (1, 1))
+ item.setEditable(True)
+ manager.addRoi(item)
+ self.qWait()
+ try:
+ # Make sure the test setup is fine
+ self.assertNotEqual(len(manager.getRois()), 0)
+ self.assertNotEqual(len(self.plot.getItems()), 0)
+
+ # Call clear and test the expected state
+ manager.removeRoi(item)
+ self.assertEqual(len(manager.getRois()), 0)
+ self.assertEqual(len(self.plot.getItems()), 0)
+ finally:
+ # Clean up
+ manager.clear()
+
+ def testArcRoiSwitchMode(self):
+ """Make sure we can switch mode by clicking on the ROI"""
+ xlimit = self.plot.getXAxis().getLimits()
+ ylimit = self.plot.getYAxis().getLimits()
+ points = numpy.array([xlimit, ylimit]).T
+ center = numpy.mean(points, axis=0)
+ size = numpy.abs(points[1] - points[0])
+
+ # Create the line
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.ArcROI()
+ item.setGeometry(center, size[1] / 10, size[1] / 2, 0, 3)
+ item.setEditable(True)
+ item.setSelectable(True)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+
+ # Initial state
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
+ self.qWait(500)
+
+ # Click on the center
+ widget = self.plot.getWidgetHandle()
+ mx, my = self.plot.dataToPixel(*center)
+
+ # Select the ROI
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
+
+ # Change the mode
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.PolarMode)
+
+ manager.clear()
+ self.qapp.processEvents()
+
+ def testBandRoiSwitchMode(self):
+ """Make sure we can switch mode by clicking on the ROI"""
+ xlimit = self.plot.getXAxis().getLimits()
+ ylimit = self.plot.getYAxis().getLimits()
+ xcenter = 0.5 * (xlimit[0] + xlimit[1])
+ ycenter = 0.5 * (ylimit[0] + ylimit[1])
+
+ # Create the line
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.BandROI()
+ item.setGeometry(
+ (xlimit[0], ycenter),
+ (xlimit[1], ycenter),
+ 20,
+ )
+ item.setEditable(True)
+ item.setSelectable(True)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+
+ # Initial state
+ assert item.getInteractionMode() is roi_items.BandROI.BoundedMode
+ self.qWait(500)
+
+ # Click on the center
+ widget = self.plot.getWidgetHandle()
+ mx, my = self.plot.dataToPixel(xcenter, ycenter)
+
+ # Select the ROI
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ assert item.getInteractionMode() is roi_items.BandROI.BoundedMode
+
+ # Change the mode
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ assert item.getInteractionMode() is roi_items.BandROI.UnboundedMode
+
+ # Set available modes that exclude the current one
+ item.setAvailableInteractionModes([roi_items.BandROI.BoundedMode])
+ assert item.getInteractionMode() is roi_items.BandROI.BoundedMode
+
+ # Clicking does not change the mode since there is only one
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ assert item.getInteractionMode() is roi_items.BandROI.BoundedMode
+
+ manager.clear()
+ self.qapp.processEvents()
diff --git a/src/silx/gui/plot/tools/test/testRoiItems.py b/src/silx/gui/plot/tools/test/testRoiItems.py
new file mode 100644
index 0000000..9bd9690
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testRoiItems.py
@@ -0,0 +1,313 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import pytest
+import numpy.testing
+
+import silx.gui.plot.items.roi as roi_items
+
+
+def testLine_geometry(qapp):
+ item = roi_items.LineROI()
+ startPoint = numpy.array([1, 2])
+ endPoint = numpy.array([3, 4])
+ item.setEndPoints(startPoint, endPoint)
+ numpy.testing.assert_allclose(item.getEndPoints()[0], startPoint)
+ numpy.testing.assert_allclose(item.getEndPoints()[1], endPoint)
+
+
+def testHLine_geometry(qapp):
+ item = roi_items.HorizontalLineROI()
+ item.setPosition(15)
+ assert item.getPosition() == 15
+
+
+def testVLine_geometry(qapp):
+ item = roi_items.VerticalLineROI()
+ item.setPosition(15)
+ assert item.getPosition() == 15
+
+
+def testPoint_geometry(qapp):
+ point = numpy.array([1, 2])
+ item = roi_items.PointROI()
+ item.setPosition(point)
+ numpy.testing.assert_allclose(item.getPosition(), point)
+
+
+def testRectangle_originGeometry(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ center = numpy.array([5, 10])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ numpy.testing.assert_allclose(item.getOrigin(), origin)
+ numpy.testing.assert_allclose(item.getSize(), size)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+
+
+def testRectangle_centerGeometry(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ center = numpy.array([5, 10])
+ item = roi_items.RectangleROI()
+ item.setGeometry(center=center, size=size)
+ numpy.testing.assert_allclose(item.getOrigin(), origin)
+ numpy.testing.assert_allclose(item.getSize(), size)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+
+
+def testRectangle_setCenterGeometry(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ newCenter = numpy.array([0, 0])
+ item.setCenter(newCenter)
+ expectedOrigin = numpy.array([-5, -10])
+ numpy.testing.assert_allclose(item.getOrigin(), expectedOrigin)
+ numpy.testing.assert_allclose(item.getCenter(), newCenter)
+ numpy.testing.assert_allclose(item.getSize(), size)
+
+
+def testRectangle_setOriginGeometry(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ newOrigin = numpy.array([10, 10])
+ item.setOrigin(newOrigin)
+ expectedCenter = numpy.array([15, 20])
+ numpy.testing.assert_allclose(item.getOrigin(), newOrigin)
+ numpy.testing.assert_allclose(item.getCenter(), expectedCenter)
+ numpy.testing.assert_allclose(item.getSize(), size)
+
+
+def testCircle_geometry(qapp):
+ center = numpy.array([0, 0])
+ radius = 10.0
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ numpy.testing.assert_allclose(item.getRadius(), radius)
+
+
+def testCircle_setCenter(qapp):
+ center = numpy.array([0, 0])
+ radius = 10.0
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ newCenter = numpy.array([-10, 0])
+ item.setCenter(newCenter)
+ numpy.testing.assert_allclose(item.getCenter(), newCenter)
+ numpy.testing.assert_allclose(item.getRadius(), radius)
+
+
+def testCircle_setRadius(qapp):
+ center = numpy.array([0, 0])
+ radius = 10.0
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ newRadius = 5.1
+ item.setRadius(newRadius)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ numpy.testing.assert_allclose(item.getRadius(), newRadius)
+
+
+def testCircle_contains(qapp):
+ center = numpy.array([2, -1])
+ radius = 1.0
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ assert item.contains([1, -1])
+ assert not item.contains([0, 0])
+ assert item.contains([2, 0])
+ assert not item.contains([3.01, -1])
+
+
+def testEllipse_contains(qapp):
+ center = numpy.array([-2, 0])
+ item = roi_items.EllipseROI()
+ item.setCenter(center)
+ item.setOrientation(numpy.pi / 4.0)
+ item.setMajorRadius(2)
+ item.setMinorRadius(1)
+ print(item.getMinorRadius(), item.getMajorRadius())
+ assert not item.contains([0, 0])
+ assert item.contains([-1, 1])
+ assert item.contains([-3, 0])
+ assert item.contains([-2, 0])
+ assert item.contains([-2, 1])
+ assert not item.contains([-4, 1])
+
+
+def testRectangle_isIn(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ assert item.contains(position=(0, 0))
+ assert item.contains(position=(2, 14))
+ assert not item.contains(position=(14, 12))
+
+
+def testPolygon_emptyGeometry(qapp):
+ points = numpy.empty((0, 2))
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ numpy.testing.assert_allclose(item.getPoints(), points)
+
+
+def testPolygon_geometry(qapp):
+ points = numpy.array([[10, 10], [12, 10], [50, 1]])
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ numpy.testing.assert_allclose(item.getPoints(), points)
+
+
+def testPolygon_isIn(qapp):
+ points = numpy.array([[0, 0], [0, 10], [5, 10]])
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ assert item.contains((0, 0))
+ assert not item.contains((6, 2))
+ assert not item.contains((-2, 5))
+ assert not item.contains((2, -1))
+ assert not item.contains((8, 1))
+ assert item.contains((1, 8))
+
+
+def testArc_getToSetGeometry(qapp):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]]))
+ item.setGeometry(*item.getGeometry())
+
+
+def testArc_degenerated_point(qapp):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+
+
+def testArc_degenerated_line(qapp):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+
+
+def testArc_special_circle(qapp):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ assert item.getInnerRadius() == pytest.approx(innerRadius)
+ assert item.getOuterRadius() == pytest.approx(outerRadius)
+ assert item.getStartAngle() == pytest.approx(item.getEndAngle() - numpy.pi * 2.0)
+ assert item.isClosed()
+
+
+def testArc_special_donut(qapp):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ assert item.getInnerRadius() == pytest.approx(innerRadius)
+ assert item.getOuterRadius() == pytest.approx(outerRadius)
+ assert item.getStartAngle() == pytest.approx(item.getEndAngle() - numpy.pi * 2.0)
+ assert item.isClosed()
+
+
+def testArc_clockwiseGeometry(qapp):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ assert item.getInnerRadius() == pytest.approx(innerRadius)
+ assert item.getOuterRadius() == pytest.approx(outerRadius)
+ assert item.getStartAngle() == pytest.approx(startAngle)
+ assert item.getEndAngle() == pytest.approx(endAngle)
+ assert not item.isClosed()
+
+
+def testArc_anticlockwiseGeometry(qapp):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = (
+ 1,
+ 100,
+ numpy.pi * 0.5,
+ -numpy.pi * 0.5,
+ )
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ assert item.getInnerRadius() == pytest.approx(innerRadius)
+ assert item.getOuterRadius() == pytest.approx(outerRadius)
+ assert item.getStartAngle() == pytest.approx(startAngle)
+ assert item.getEndAngle() == pytest.approx(endAngle)
+ assert not item.isClosed()
+
+
+def testArc_position(qapp):
+ """Test validity of getPosition"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ assert item.getPosition(roi_items.ArcROI.Role.START) == pytest.approx((10.0, 70.5))
+ assert item.getPosition(roi_items.ArcROI.Role.STOP) == pytest.approx((-40.5, 20.0))
+ assert item.getPosition(roi_items.ArcROI.Role.MIDDLE) == pytest.approx(
+ (-25.71, 55.71), abs=0.1
+ )
+ assert item.getPosition(roi_items.ArcROI.Role.CENTER) == pytest.approx(
+ (10.0, 20), abs=0.1
+ )
+
+
+def testHRange_geometry(qapp):
+ item = roi_items.HorizontalRangeROI()
+ vmin = 1
+ vmax = 3
+ item.setRange(vmin, vmax)
+ assert item.getMin() == pytest.approx(vmin)
+ assert item.getMax() == pytest.approx(vmax)
+ assert item.getCenter() == pytest.approx(2)
+
+
+def testBand_getToSetGeometry(qapp):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.BandROI()
+ item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]]))
+ item.setGeometry(*item.getGeometry())
diff --git a/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
new file mode 100644
index 0000000..29c9ad0
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
@@ -0,0 +1,185 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import numpy
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.tools.profile import manager
+from silx.gui.plot.tools.profile import core
+from silx.gui.plot.tools.profile import rois
+
+
+class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase):
+ """Tests for ScatterProfileToolBar class"""
+
+ def setUp(self):
+ super(TestScatterProfileToolBar, self).setUp()
+ self.plot = PlotWindow()
+
+ self.manager = manager.ProfileManager(plot=self.plot)
+ self.manager.setItemType(scatter=True)
+ self.manager.setActiveItemTracking(True)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ del self.manager
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestScatterProfileToolBar, self).tearDown()
+
+ def testHorizontalProfile(self):
+ """Test ScatterProfileToolBar horizontal profile"""
+
+ roiManager = self.manager.getRoiManager()
+
+ # Add a scatter plot
+ self.plot.addScatter(
+ x=(0.0, 1.0, 1.0, 0.0), y=(0.0, 0.0, 1.0, 1.0), value=(0.0, 1.0, 2.0, 3.0)
+ )
+ self.plot.resetZoom(dataMargins=(0.1, 0.1, 0.1, 0.1))
+ self.qapp.processEvents()
+
+ # Set a ROI profile
+ roi = rois.ProfileScatterHorizontalLineROI()
+ roi.setPosition(0.5)
+ roi.setNPoints(8)
+ roiManager.addRoi(roi)
+
+ # Wait for async interpolator init
+ for _ in range(20):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+ self.qapp.processEvents()
+
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ data = window.getProfile()
+ self.assertIsInstance(data, core.CurveProfileData)
+ self.assertEqual(len(data.coords), 8)
+
+ # Check that profile has same limits than Plot
+ xLimits = self.plot.getXAxis().getLimits()
+ self.assertEqual(data.coords[0], xLimits[0])
+ self.assertEqual(data.coords[-1], xLimits[1])
+
+ # Clear the profile
+ self.manager.clearProfile()
+ self.qapp.processEvents()
+ self.assertIsNone(roi.getProfileWindow())
+
+ def testVerticalProfile(self):
+ """Test ScatterProfileToolBar vertical profile"""
+
+ roiManager = self.manager.getRoiManager()
+
+ # Add a scatter plot
+ self.plot.addScatter(
+ x=(0.0, 1.0, 1.0, 0.0), y=(0.0, 0.0, 1.0, 1.0), value=(0.0, 1.0, 2.0, 3.0)
+ )
+ self.plot.resetZoom(dataMargins=(0.1, 0.1, 0.1, 0.1))
+ self.qapp.processEvents()
+
+ # Set a ROI profile
+ roi = rois.ProfileScatterVerticalLineROI()
+ roi.setPosition(0.5)
+ roi.setNPoints(8)
+ roiManager.addRoi(roi)
+
+ # Wait for async interpolator init
+ for _ in range(10):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ data = window.getProfile()
+ self.assertIsInstance(data, core.CurveProfileData)
+ self.assertEqual(len(data.coords), 8)
+
+ # Check that profile has same limits than Plot
+ yLimits = self.plot.getYAxis().getLimits()
+ self.assertEqual(data.coords[0], yLimits[0])
+ self.assertEqual(data.coords[-1], yLimits[1])
+
+ # Check that profile limits are updated when changing limits
+ self.plot.getYAxis().setLimits(yLimits[0] + 1, yLimits[1] + 10)
+
+ # Wait for async interpolator init
+ for _ in range(10):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+
+ yLimits = self.plot.getYAxis().getLimits()
+ data = window.getProfile()
+ self.assertEqual(data.coords[0], yLimits[0])
+ self.assertEqual(data.coords[-1], yLimits[1])
+
+ # Clear the profile
+ self.manager.clearProfile()
+ self.qapp.processEvents()
+ self.assertIsNone(roi.getProfileWindow())
+
+ def testLineProfile(self):
+ """Test ScatterProfileToolBar line profile"""
+
+ roiManager = self.manager.getRoiManager()
+
+ # Add a scatter plot
+ self.plot.addScatter(
+ x=(0.0, 1.0, 1.0, 0.0), y=(0.0, 0.0, 1.0, 1.0), value=(0.0, 1.0, 2.0, 3.0)
+ )
+ self.plot.resetZoom(dataMargins=(0.1, 0.1, 0.1, 0.1))
+ self.qapp.processEvents()
+
+ # Set a ROI profile
+ roi = rois.ProfileScatterLineROI()
+ roi.setEndPoints(numpy.array([0.0, 0.0]), numpy.array([1.0, 1.0]))
+ roi.setNPoints(8)
+ roiManager.addRoi(roi)
+
+ # Wait for async interpolator init
+ for _ in range(10):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ data = window.getProfile()
+ self.assertIsInstance(data, core.CurveProfileData)
+ self.assertEqual(len(data.coords), 8)
diff --git a/src/silx/gui/plot/tools/test/testTools.py b/src/silx/gui/plot/tools/test/testTools.py
new file mode 100644
index 0000000..1212ead
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testTools.py
@@ -0,0 +1,135 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for silx.gui.plot.tools package"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/03/2018"
+
+
+import functools
+import numpy
+
+from silx.utils.testutils import LoggingValidator
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot import tools
+from silx.gui.plot.test.utils import PlotWidgetTestCase
+
+
+class TestPositionInfo(PlotWidgetTestCase):
+ """Tests for PositionInfo widget."""
+
+ def _createPlot(self):
+ return PlotWindow()
+
+ def setUp(self):
+ super(TestPositionInfo, self).setUp()
+ self.mouseMove(self.plot, pos=(0, 0))
+ self.qapp.processEvents()
+ self.qWait(100)
+
+ def tearDown(self):
+ super(TestPositionInfo, self).tearDown()
+
+ def _test(self, positionWidget, converterNames, **kwargs):
+ """General test of PositionInfo.
+
+ - Add it to a toolbar and
+ - Move mouse around the center of the PlotWindow.
+ """
+ toolBar = qt.QToolBar()
+ self.plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar)
+
+ toolBar.addWidget(positionWidget)
+
+ converters = positionWidget.getConverters()
+ self.assertEqual(len(converters), len(converterNames))
+ for index, name in enumerate(converterNames):
+ self.assertEqual(converters[index][0], name)
+
+ self.qapp.processEvents()
+ with LoggingValidator(tools.__name__, **kwargs):
+ # Move mouse to center
+ center = self.plot.size() / 2
+ self.mouseMove(self.plot, pos=(center.width(), center.height()))
+ # Move out
+ self.mouseMove(self.plot, pos=(1, 1))
+
+ def testDefaultConverters(self):
+ """Test PositionInfo with default converters"""
+ positionWidget = tools.PositionInfo(plot=self.plot)
+ self._test(positionWidget, ("X", "Y"))
+
+ def testCustomConverters(self):
+ """Test PositionInfo with custom converters"""
+ converters = [
+ ("Coords", lambda x, y: (int(x), int(y))),
+ ("Radius", lambda x, y: numpy.sqrt(x * x + y * y)),
+ ("Angle", lambda x, y: numpy.degrees(numpy.arctan2(y, x))),
+ ]
+ positionWidget = tools.PositionInfo(plot=self.plot, converters=converters)
+ self._test(positionWidget, ("Coords", "Radius", "Angle"))
+
+ def testFailingConverters(self):
+ """Test PositionInfo with failing custom converters"""
+
+ def raiseException(x, y):
+ raise RuntimeError()
+
+ positionWidget = tools.PositionInfo(
+ plot=self.plot, converters=[("Exception", raiseException)]
+ )
+ self._test(positionWidget, ["Exception"], error=2)
+
+ def testUpdate(self):
+ """Test :meth:`PositionInfo.updateInfo`"""
+ calls = []
+
+ def update(calls, x, y): # Get number of calls
+ calls.append((x, y))
+ return len(calls)
+
+ positionWidget = tools.PositionInfo(
+ plot=self.plot,
+ converters=[("Call count", functools.partial(update, calls))],
+ )
+
+ positionWidget.updateInfo()
+ self.assertEqual(len(calls), 1)
+
+
+class TestPlotToolsToolbars(PlotWidgetTestCase):
+ """Tests toolbars from silx.gui.plot.tools"""
+
+ def test(self):
+ """ "Add all toolbars"""
+ for tbClass in (
+ tools.InteractiveModeToolBar,
+ tools.ImageToolBar,
+ tools.CurveToolBar,
+ tools.OutputToolBar,
+ ):
+ tb = tbClass(parent=self.plot, plot=self.plot)
+ self.plot.addToolBar(tb)
diff --git a/src/silx/gui/plot/tools/toolbars.py b/src/silx/gui/plot/tools/toolbars.py
new file mode 100644
index 0000000..7f38f1c
--- /dev/null
+++ b/src/silx/gui/plot/tools/toolbars.py
@@ -0,0 +1,357 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides toolbars that work with :class:`PlotWidget`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/03/2018"
+
+
+from ... import qt
+from .. import actions
+from ..PlotWidget import PlotWidget
+from .. import PlotToolButtons
+
+
+class InteractiveModeToolBar(qt.QToolBar):
+ """Toolbar with interactive mode actions
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title="Plot Interaction"):
+ super(InteractiveModeToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._zoomModeAction = actions.mode.ZoomModeAction(parent=self, plot=plot)
+ self.addAction(self._zoomModeAction)
+
+ self._panModeAction = actions.mode.PanModeAction(parent=self, plot=plot)
+ self.addAction(self._panModeAction)
+
+ def getZoomModeAction(self):
+ """Returns the zoom mode QAction.
+
+ :rtype: PlotAction
+ """
+ return self._zoomModeAction
+
+ def getPanModeAction(self):
+ """Returns the pan mode QAction
+
+ :rtype: PlotAction
+ """
+ return self._panModeAction
+
+
+class OutputToolBar(qt.QToolBar):
+ """Toolbar providing icons to copy, save and print a PlotWidget
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title="Plot Output"):
+ super(OutputToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._copyAction = actions.io.CopyAction(parent=self, plot=plot)
+ self.addAction(self._copyAction)
+
+ self._saveAction = actions.io.SaveAction(parent=self, plot=plot)
+ self.addAction(self._saveAction)
+
+ self._printAction = actions.io.PrintAction(parent=self, plot=plot)
+ self.addAction(self._printAction)
+
+ def getCopyAction(self):
+ """Returns the QAction performing copy to clipboard of the PlotWidget
+
+ :rtype: PlotAction
+ """
+ return self._copyAction
+
+ def getSaveAction(self):
+ """Returns the QAction performing save to file of the PlotWidget
+
+ :rtype: PlotAction
+ """
+ return self._saveAction
+
+ def getPrintAction(self):
+ """Returns the QAction performing printing of the PlotWidget
+
+ :rtype: PlotAction
+ """
+ return self._printAction
+
+
+class ImageToolBar(qt.QToolBar):
+ """Toolbar providing PlotAction suited when displaying images
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title="Image"):
+ super(ImageToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._resetZoomAction = actions.control.ResetZoomAction(parent=self, plot=plot)
+ self.addAction(self._resetZoomAction)
+
+ self._colormapAction = actions.control.ColormapAction(parent=self, plot=plot)
+ self.addAction(self._colormapAction)
+
+ self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
+ parent=self, plot=plot
+ )
+ self.addWidget(self._keepDataAspectRatioButton)
+
+ self._yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton(
+ parent=self, plot=plot
+ )
+ self.addWidget(self._yAxisInvertedButton)
+
+ def getResetZoomAction(self):
+ """Returns the QAction to reset the zoom.
+
+ :rtype: PlotAction
+ """
+ return self._resetZoomAction
+
+ def getColormapAction(self):
+ """Returns the QAction to control the colormap.
+
+ :rtype: PlotAction
+ """
+ return self._colormapAction
+
+ def getKeepDataAspectRatioButton(self):
+ """Returns the QToolButton controlling data aspect ratio.
+
+ :rtype: QToolButton
+ """
+ return self._keepDataAspectRatioButton
+
+ def getYAxisInvertedButton(self):
+ """Returns the QToolButton controlling Y axis orientation.
+
+ :rtype: QToolButton
+ """
+ return self._yAxisInvertedButton
+
+
+class CurveToolBar(qt.QToolBar):
+ """Toolbar providing PlotAction suited when displaying curves
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title="Image"):
+ super(CurveToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._resetZoomAction = actions.control.ResetZoomAction(parent=self, plot=plot)
+ self.addAction(self._resetZoomAction)
+
+ self._xAxisAutoScaleAction = actions.control.XAxisAutoScaleAction(
+ parent=self, plot=plot
+ )
+ self.addAction(self._xAxisAutoScaleAction)
+
+ self._yAxisAutoScaleAction = actions.control.YAxisAutoScaleAction(
+ parent=self, plot=plot
+ )
+ self.addAction(self._yAxisAutoScaleAction)
+
+ self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction(
+ parent=self, plot=plot
+ )
+ self.addAction(self._xAxisLogarithmicAction)
+
+ self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction(
+ parent=self, plot=plot
+ )
+ self.addAction(self._yAxisLogarithmicAction)
+
+ self._gridAction = actions.control.GridAction(parent=self, plot=plot)
+ self.addAction(self._gridAction)
+
+ self._curveStyleAction = actions.control.CurveStyleAction(
+ parent=self, plot=plot
+ )
+ self.addAction(self._curveStyleAction)
+
+ def getResetZoomAction(self):
+ """Returns the QAction to reset the zoom.
+
+ :rtype: PlotAction
+ """
+ return self._resetZoomAction
+
+ def getXAxisAutoScaleAction(self):
+ """Returns the QAction to toggle X axis autoscale.
+
+ :rtype: PlotAction
+ """
+ return self._xAxisAutoScaleAction
+
+ def getYAxisAutoScaleAction(self):
+ """Returns the QAction to toggle Y axis autoscale.
+
+ :rtype: PlotAction
+ """
+ return self._yAxisAutoScaleAction
+
+ def getXAxisLogarithmicAction(self):
+ """Returns the QAction to toggle X axis log/linear scale.
+
+ :rtype: PlotAction
+ """
+ return self._xAxisLogarithmicAction
+
+ def getYAxisLogarithmicAction(self):
+ """Returns the QAction to toggle Y axis log/linear scale.
+
+ :rtype: PlotAction
+ """
+ return self._yAxisLogarithmicAction
+
+ def getGridAction(self):
+ """Returns the action to toggle the plot grid.
+
+ :rtype: PlotAction
+ """
+ return self._gridAction
+
+ def getCurveStyleAction(self):
+ """Returns the QAction to change the style of all curves.
+
+ :rtype: PlotAction
+ """
+ return self._curveStyleAction
+
+
+class ScatterToolBar(qt.QToolBar):
+ """Toolbar providing PlotAction suited when displaying scatter plot
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title="Scatter Tools"):
+ super(ScatterToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._resetZoomAction = actions.control.ResetZoomAction(parent=self, plot=plot)
+ self.addAction(self._resetZoomAction)
+
+ self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction(
+ parent=self, plot=plot
+ )
+ self.addAction(self._xAxisLogarithmicAction)
+
+ self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction(
+ parent=self, plot=plot
+ )
+ self.addAction(self._yAxisLogarithmicAction)
+
+ self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
+ parent=self, plot=plot
+ )
+ self.addWidget(self._keepDataAspectRatioButton)
+
+ self._gridAction = actions.control.GridAction(parent=self, plot=plot)
+ self.addAction(self._gridAction)
+
+ self._colormapAction = actions.control.ColormapAction(parent=self, plot=plot)
+ self.addAction(self._colormapAction)
+
+ self._visualizationToolButton = PlotToolButtons.ScatterVisualizationToolButton(
+ parent=self, plot=plot
+ )
+ self.addWidget(self._visualizationToolButton)
+
+ def getResetZoomAction(self):
+ """Returns the QAction to reset the zoom.
+
+ :rtype: PlotAction
+ """
+ return self._resetZoomAction
+
+ def getXAxisLogarithmicAction(self):
+ """Returns the QAction to toggle X axis log/linear scale.
+
+ :rtype: PlotAction
+ """
+ return self._xAxisLogarithmicAction
+
+ def getYAxisLogarithmicAction(self):
+ """Returns the QAction to toggle Y axis log/linear scale.
+
+ :rtype: PlotAction
+ """
+ return self._yAxisLogarithmicAction
+
+ def getGridAction(self):
+ """Returns the action to toggle the plot grid.
+
+ :rtype: PlotAction
+ """
+ return self._gridAction
+
+ def getColormapAction(self):
+ """Returns the QAction to control the colormap.
+
+ :rtype: PlotAction
+ """
+ return self._colormapAction
+
+ def getKeepDataAspectRatioButton(self):
+ """Returns the QToolButton controlling data aspect ratio.
+
+ :rtype: QToolButton
+ """
+ return self._keepDataAspectRatioButton
+
+ def getScatterVisualizationToolButton(self):
+ """Returns the QToolButton controlling the visualization mode.
+
+ :rtype: ScatterVisualizationToolButton
+ """
+ return self._visualizationToolButton
diff --git a/src/silx/gui/plot/utils/__init__.py b/src/silx/gui/plot/utils/__init__.py
new file mode 100644
index 0000000..61e45b4
--- /dev/null
+++ b/src/silx/gui/plot/utils/__init__.py
@@ -0,0 +1,29 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Utils module for plot.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/06/2017"
diff --git a/src/silx/gui/plot/utils/axis.py b/src/silx/gui/plot/utils/axis.py
new file mode 100644
index 0000000..4c6bcef
--- /dev/null
+++ b/src/silx/gui/plot/utils/axis.py
@@ -0,0 +1,400 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains utils class for axes management.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "20/11/2018"
+
+import functools
+import logging
+from contextlib import contextmanager
+import weakref
+import silx.utils.weakref as silxWeakref
+from silx.gui.plot.items.axis import Axis, XAxis, YAxis
+from ...qt.inspect import isValid as _isQObjectValid
+
+
+_logger = logging.getLogger(__name__)
+
+
+class SyncAxes(object):
+ """Synchronize a set of plot axes together.
+
+ It is created with the expected axes and starts to synchronize them.
+
+ It can be customized to synchronize limits, scale, and direction of axes
+ together. By default everything is synchronized.
+
+ The API :meth:`start` and :meth:`stop` can be used to enable/disable the
+ synchronization while this object is still alive.
+
+ If this object is destroyed the synchronization stop.
+
+ .. versionadded:: 0.6
+ """
+
+ def __init__(
+ self,
+ axes,
+ syncLimits=True,
+ syncScale=True,
+ syncDirection=True,
+ syncCenter=False,
+ syncZoom=False,
+ filterHiddenPlots=False,
+ ):
+ """
+ Constructor
+
+ :param list(Axis) axes: A list of axes to synchronize together
+ :param bool syncLimits: Synchronize axes limits
+ :param bool syncScale: Synchronize axes scale
+ :param bool syncDirection: Synchronize axes direction
+ :param bool syncCenter: Synchronize the center of the axes in the center
+ of the plots
+ :param bool syncZoom: Synchronize the zoom of the plot
+ :param bool filterHiddenPlots: True to avoid updating hidden plots.
+ Default: False.
+ """
+ object.__init__(self)
+
+ def implies(x, y):
+ return bool(y**x)
+
+ assert implies(syncZoom, not syncLimits)
+ assert implies(syncCenter, not syncLimits)
+ assert implies(syncLimits, not syncCenter)
+ assert implies(syncLimits, not syncZoom)
+
+ self.__filterHiddenPlots = filterHiddenPlots
+ self.__locked = False
+ self.__axisRefs = []
+ self.__syncLimits = syncLimits
+ self.__syncScale = syncScale
+ self.__syncDirection = syncDirection
+ self.__syncCenter = syncCenter
+ self.__syncZoom = syncZoom
+ self.__callbacks = None
+ self.__lastMainAxis = None
+
+ for axis in axes:
+ self.addAxis(axis)
+
+ self.start()
+
+ def start(self):
+ """Start synchronizing axes together.
+
+ The first axis is used as the reference for the first synchronization.
+ After that, any changes to any axes will be used to synchronize other
+ axes.
+ """
+ if self.isSynchronizing():
+ raise RuntimeError("Axes already synchronized")
+ self.__callbacks = {}
+
+ axes = self.__getAxes()
+
+ # register callback for further sync
+ for axis in axes:
+ self.__connectAxes(axis)
+ self.synchronize()
+
+ def isSynchronizing(self):
+ """Returns true if events are connected to the axes to synchronize them
+ all together
+
+ :rtype: bool
+ """
+ return self.__callbacks is not None
+
+ def __connectAxes(self, axis):
+ refAxis = weakref.ref(axis)
+ callbacks = []
+ if self.__syncLimits:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncCenter and self.__syncZoom:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterAndZoomChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncZoom:
+ raise NotImplementedError()
+ elif self.__syncCenter:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ if self.__syncScale:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigScaleChanged
+ sig.connect(callback)
+ callbacks.append(("sigScaleChanged", callback))
+ if self.__syncDirection:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigInvertedChanged
+ sig.connect(callback)
+ callbacks.append(("sigInvertedChanged", callback))
+
+ if self.__filterHiddenPlots:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisVisibilityChanged)
+ callback = functools.partial(callback, refAxis)
+ plot = axis._getPlot()
+ plot.sigVisibilityChanged.connect(callback)
+ callbacks.append(("sigVisibilityChanged", callback))
+
+ self.__callbacks[refAxis] = callbacks
+
+ def __disconnectAxes(self, axis):
+ if axis is not None and _isQObjectValid(axis):
+ ref = weakref.ref(axis)
+ callbacks = self.__callbacks.pop(ref)
+ for sigName, callback in callbacks:
+ if sigName == "sigVisibilityChanged":
+ obj = axis._getPlot()
+ else:
+ obj = axis
+ if obj is not None:
+ sig = getattr(obj, sigName)
+ sig.disconnect(callback)
+
+ def addAxis(self, axis):
+ """Add a new axes to synchronize.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to synchronize
+ """
+ self.__axisRefs.append(weakref.ref(axis))
+ if self.isSynchronizing():
+ self.__connectAxes(axis)
+ # This could be done faster as only this axis have to be fixed
+ self.synchronize()
+
+ def removeAxis(self, axis):
+ """Remove an axis from the synchronized axes.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to remove
+ """
+ ref = weakref.ref(axis)
+ self.__axisRefs.remove(ref)
+ if self.isSynchronizing():
+ self.__disconnectAxes(axis)
+
+ def synchronize(self, mainAxis=None):
+ """Synchronize programatically all the axes.
+
+ :param ~silx.gui.plot.items.Axis mainAxis:
+ The axis to take as reference (Default: the first axis).
+ """
+ # sync the current state
+ axes = self.__getAxes()
+ if len(axes) == 0:
+ return
+
+ if mainAxis is None:
+ mainAxis = axes[0]
+
+ refMainAxis = weakref.ref(mainAxis)
+ if self.__syncLimits:
+ self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter and self.__syncZoom:
+ self.__axisCenterAndZoomChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter:
+ self.__axisCenterChanged(refMainAxis, *mainAxis.getLimits())
+ if self.__syncScale:
+ self.__axisScaleChanged(refMainAxis, mainAxis.getScale())
+ if self.__syncDirection:
+ self.__axisInvertedChanged(refMainAxis, mainAxis.isInverted())
+
+ def stop(self):
+ """Stop the synchronization of the axes"""
+ if not self.isSynchronizing():
+ raise RuntimeError("Axes not synchronized")
+ for ref in list(self.__callbacks.keys()):
+ axis = ref()
+ self.__disconnectAxes(axis)
+ self.__callbacks = None
+
+ def __del__(self):
+ """Destructor"""
+ # clean up references
+ if self.__callbacks is not None:
+ self.stop()
+
+ def __getAxes(self):
+ """Returns list of existing axes.
+
+ :rtype: List[Axis]
+ """
+ axes = [ref() for ref in self.__axisRefs]
+ return [axis for axis in axes if axis is not None]
+
+ @contextmanager
+ def __inhibitSignals(self):
+ self.__locked = True
+ yield
+ self.__locked = False
+
+ def __axesToUpdate(self, changedAxis):
+ for axis in self.__getAxes():
+ if axis is changedAxis:
+ continue
+ if self.__filterHiddenPlots:
+ plot = axis._getPlot()
+ if not plot.isVisible():
+ continue
+ yield axis
+
+ def __axisVisibilityChanged(self, changedAxis, isVisible):
+ if not isVisible:
+ return
+ if self.__locked:
+ return
+ changedAxis = changedAxis()
+ if self.__lastMainAxis is None:
+ self.__lastMainAxis = self.__axisRefs[0]
+ mainAxis = self.__lastMainAxis
+ mainAxis = mainAxis()
+ self.synchronize(mainAxis=mainAxis)
+ # force back the main axis
+ self.__lastMainAxis = weakref.ref(mainAxis)
+
+ def __getAxesCenter(self, axis, vmin, vmax):
+ """Returns the value displayed in the center of this axis range.
+
+ :rtype: float
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ center = (vmin + vmax) * 0.5
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ return center
+
+ def __getRangeInPixel(self, axis):
+ """Returns the size of the axis in pixel"""
+ bounds = axis._getPlot().getPlotBoundsInPixels()
+ # bounds: left, top, width, height
+ if isinstance(axis, XAxis):
+ return bounds[2]
+ elif isinstance(axis, YAxis):
+ return bounds[3]
+ else:
+ assert False
+
+ def __getLimitsFromCenter(self, axis, pos, pixelSize=None):
+ """Returns the limits to apply to this axis to move the `pos` into the
+ center of this axis.
+
+ :param Axis axis:
+ :param float pos: Position in the center of the computed limits
+ :param Union[None,float] pixelSize: Pixel size to apply to compute the
+ limits. If `None` the current pixel size is applyed.
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ if pixelSize is None:
+ # Use the current pixel size of the axis
+ limits = axis.getLimits()
+ valueRange = limits[0] - limits[1]
+ a = pos - valueRange * 0.5
+ b = pos + valueRange * 0.5
+ else:
+ pixelRange = self.__getRangeInPixel(axis)
+ a = pos - pixelRange * 0.5 * pixelSize
+ b = pos + pixelRange * 0.5 * pixelSize
+
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ if a > b:
+ return b, a
+ return a, b
+
+ def __axisLimitsChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterAndZoomChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ pixelRange = self.__getRangeInPixel(changedAxis)
+ if pixelRange == 0:
+ return
+ pixelSize = (vmax - vmin) / pixelRange
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center, pixelSize)
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center)
+ axis.setLimits(vmin, vmax)
+
+ def __axisScaleChanged(self, changedAxis, scale):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setScale(scale)
+
+ def __axisInvertedChanged(self, changedAxis, isInverted):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setInverted(isInverted)
diff --git a/src/silx/gui/plot/utils/intersections.py b/src/silx/gui/plot/utils/intersections.py
new file mode 100644
index 0000000..faf6641
--- /dev/null
+++ b/src/silx/gui/plot/utils/intersections.py
@@ -0,0 +1,106 @@
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains utils class for axes management.
+"""
+
+__authors__ = [
+ "H. Payno",
+]
+__license__ = "MIT"
+__date__ = "18/05/2020"
+
+
+import numpy
+
+
+def lines_intersection(line1_pt1, line1_pt2, line2_pt1, line2_pt2):
+ """
+ line segment intersection using vectors (Computer Graphics by F.S. Hill)
+
+ :param tuple line1_pt1:
+ :param tuple line1_pt2:
+ :param tuple line2_pt1:
+ :param tuple line2_pt2:
+ :return: Union[None,numpy.array]
+ """
+ dir_line1 = line1_pt2[0] - line1_pt1[0], line1_pt2[1] - line1_pt1[1]
+ dir_line2 = line2_pt2[0] - line2_pt1[0], line2_pt2[1] - line2_pt1[1]
+ dp = line1_pt1 - line2_pt1
+
+ def perp(a):
+ b = numpy.empty_like(a)
+ b[0] = -a[1]
+ b[1] = a[0]
+ return b
+
+ dap = perp(dir_line1)
+ denom = numpy.dot(dap, dir_line2)
+ num = numpy.dot(dap, dp)
+ if denom == 0:
+ return None
+ return (
+ (num / denom.astype(float)) * dir_line2[0] + line2_pt1[0],
+ (num / denom.astype(float)) * dir_line2[1] + line2_pt1[1],
+ )
+
+
+def segments_intersection(seg1_start_pt, seg1_end_pt, seg2_start_pt, seg2_end_pt):
+ """
+ Compute intersection between two segments
+
+ :param seg1_start_pt:
+ :param seg1_end_pt:
+ :param seg2_start_pt:
+ :param seg2_end_pt:
+ :return: numpy.array if an intersection exists, else None
+ :rtype: Union[None,numpy.array]
+ """
+ intersection = lines_intersection(
+ line1_pt1=seg1_start_pt,
+ line1_pt2=seg1_end_pt,
+ line2_pt1=seg2_start_pt,
+ line2_pt2=seg2_end_pt,
+ )
+ if intersection is not None:
+ max_x_seg1 = max(seg1_start_pt[0], seg1_end_pt[0])
+ max_x_seg2 = max(seg2_start_pt[0], seg2_end_pt[0])
+ max_y_seg1 = max(seg1_start_pt[1], seg1_end_pt[1])
+ max_y_seg2 = max(seg2_start_pt[1], seg2_end_pt[1])
+
+ min_x_seg1 = min(seg1_start_pt[0], seg1_end_pt[0])
+ min_x_seg2 = min(seg2_start_pt[0], seg2_end_pt[0])
+ min_y_seg1 = min(seg1_start_pt[1], seg1_end_pt[1])
+ min_y_seg2 = min(seg2_start_pt[1], seg2_end_pt[1])
+
+ min_tmp_x = max(min_x_seg1, min_x_seg2)
+ max_tmp_x = min(max_x_seg1, max_x_seg2)
+ min_tmp_y = max(min_y_seg1, min_y_seg2)
+ max_tmp_y = min(max_y_seg1, max_y_seg2)
+ if (
+ min_tmp_x <= intersection[0] <= max_tmp_x
+ and min_tmp_y <= intersection[1] <= max_tmp_y
+ ):
+ return intersection
+ else:
+ return None