diff options
Diffstat (limited to 'src/silx/gui/plot')
135 files changed, 63409 insertions, 0 deletions
diff --git a/src/silx/gui/plot/AlphaSlider.py b/src/silx/gui/plot/AlphaSlider.py new file mode 100644 index 0000000..da55b1e --- /dev/null +++ b/src/silx/gui/plot/AlphaSlider.py @@ -0,0 +1,300 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module defines slider widgets interacting with the transparency +of an image on a :class:`PlotWidget` + +Classes: +-------- + +- :class:`BaseAlphaSlider` (abstract class) +- :class:`NamedImageAlphaSlider` +- :class:`ActiveImageAlphaSlider` + +Example: +-------- + +This widget can, for instance, be added to a plot toolbar. + +.. code-block:: python + + import numpy + from silx.gui import qt + from silx.gui.plot import PlotWidget + from silx.gui.plot.ImageAlphaSlider import NamedImageAlphaSlider + + app = qt.QApplication([]) + pw = PlotWidget() + + img0 = numpy.arange(200*150).reshape((200, 150)) + pw.addImage(img0, legend="my background", z=0, origin=(50, 50)) + + x, y = numpy.meshgrid(numpy.linspace(-10, 10, 200), + numpy.linspace(-10, 5, 150), + indexing="ij") + img1 = numpy.asarray(numpy.sin(x * y) / (x * y), + dtype='float32') + + pw.addImage(img1, legend="my data", z=1, + replace=False) + + alpha_slider = NamedImageAlphaSlider(parent=pw, + plot=pw, + legend="my data") + alpha_slider.setOrientation(qt.Qt.Horizontal) + + toolbar = qt.QToolBar("plot", pw) + toolbar.addWidget(alpha_slider) + pw.addToolBar(toolbar) + + pw.show() + app.exec() + +""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "24/03/2017" + +import logging + +from silx.gui import qt + +_logger = logging.getLogger(__name__) + + +class BaseAlphaSlider(qt.QSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of a plot primitive (image, scatter or curve). + + Internally, the slider stores its state as an integer between + 0 and 255. This is the value emitted by the :attr:`valueChanged` + signal. + + The method :meth:`getAlpha` returns the corresponding opacity/alpha + as a float between 0. and 1. (with a step of :math:`\frac{1}{255}`). + + You must subclass this class and implement :meth:`getItem`. + """ + sigAlphaChanged = qt.Signal(float) + """Emits the alpha value when the slider's value changes, + as a float between 0. and 1.""" + + def __init__(self, parent=None, plot=None): + """ + + :param parent: Parent QWidget + :param plot: Parent plot widget + """ + assert plot is not None + super(BaseAlphaSlider, self).__init__(parent) + + self.plot = plot + + self.setRange(0, 255) + + # if already connected to an item, use its alpha as initial value + if self.getItem() is None: + self.setValue(255) + self.setEnabled(False) + else: + alpha = self.getItem().getAlpha() + self.setValue(round(255*alpha)) + + self.valueChanged.connect(self._valueChanged) + + def getItem(self): + """You must implement this class to define which item + to work on. It must return an item that inherits + :class:`silx.gui.plot.items.core.AlphaMixIn`. + + :return: Item on which to operate, or None + :rtype: :class:`silx.plot.items.Item` + """ + raise NotImplementedError( + "BaseAlphaSlider must be subclassed to " + + "implement getItem()") + + def getAlpha(self): + """Get the opacity, as a float between 0. and 1. + + :return: Alpha value in [0., 1.] + :rtype: float + """ + return self.value() / 255. + + def _valueChanged(self, value): + self._updateItem() + self.sigAlphaChanged.emit(value / 255.) + + def _updateItem(self): + """Update the item's alpha channel. + """ + item = self.getItem() + if item is not None: + item.setAlpha(self.getAlpha()) + + +class ActiveImageAlphaSlider(BaseAlphaSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of the **active image**. + + :param parent: Parent QWidget + :param plot: Plot on which to operate + + See documentation of :class:`BaseAlphaSlider` + """ + def __init__(self, parent=None, plot=None): + """ + + :param parent: Parent QWidget + :param plot: Plot widget on which to operate + """ + super(ActiveImageAlphaSlider, self).__init__(parent, plot) + plot.sigActiveImageChanged.connect(self._activeImageChanged) + + def getItem(self): + return self.plot.getActiveImage() + + def _activeImageChanged(self, previous, new): + """Activate or deactivate slider depending on presence of a new + active image. + Apply transparency value to new active image. + + :param previous: Legend of previous active image, or None + :param new: Legend of new active image, or None + """ + if new is not None and not self.isEnabled(): + self.setEnabled(True) + elif new is None and self.isEnabled(): + self.setEnabled(False) + + self._updateItem() + + +class NamedItemAlphaSlider(BaseAlphaSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of an item (defined by its kind and legend). + + :param parent: Parent QWidget + :param plot: Plot on which to operate + :param str kind: Kind of item whose transparency is to be + controlled: "scatter", "image" or "curve". + :param str legend: Legend of item whose transparency is to be + controlled. + """ + def __init__(self, parent=None, plot=None, + kind=None, legend=None): + self._item_legend = legend + self._item_kind = kind + + super(NamedItemAlphaSlider, self).__init__(parent, plot) + + self._updateState() + plot.sigContentChanged.connect(self._onContentChanged) + + def _onContentChanged(self, action, kind, legend): + if legend == self._item_legend and kind == self._item_kind: + if action == "add": + self.setEnabled(True) + elif action == "remove": + self.setEnabled(False) + + def _updateState(self): + """Enable or disable widget based on item's availability.""" + if self.getItem() is not None: + self.setEnabled(True) + else: + self.setEnabled(False) + + def getItem(self): + """Return plot item currently associated to this widget (can be + a curve, an image, a scatter...) + + :rtype: subclass of :class:`silx.gui.plot.items.Item`""" + if self._item_legend is None or self._item_kind is None: + return None + return self.plot._getItem(kind=self._item_kind, + legend=self._item_legend) + + def setLegend(self, legend): + """Associate a different item (of the same kind) to the slider. + + :param legend: New legend of item whose transparency is to be + controlled. + """ + self._item_legend = legend + self._updateState() + + def getLegend(self): + """Return legend of the item currently controlled by this slider. + + :return: Image legend associated to the slider + """ + return self._item_kind + + def setItemKind(self, legend): + """Associate a different item (of the same kind) to the slider. + + :param legend: New legend of item whose transparency is to be + controlled. + """ + self._item_legend = legend + self._updateState() + + def getItemKind(self): + """Return kind of the item currently controlled by this slider. + + :return: Item kind ("image", "scatter"...) + :rtype: str on None + """ + return self._item_kind + + +class NamedImageAlphaSlider(NamedItemAlphaSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of an image (defined by its legend). + + :param parent: Parent QWidget + :param plot: Plot on which to operate + :param str legend: Legend of image whose transparency is to be + controlled. + """ + def __init__(self, parent=None, plot=None, legend=None): + NamedItemAlphaSlider.__init__(self, parent, plot, + kind="image", legend=legend) + + +class NamedScatterAlphaSlider(NamedItemAlphaSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of a scatter (defined by its legend). + + :param parent: Parent QWidget + :param plot: Plot on which to operate + :param str legend: Legend of scatter whose transparency is to be + controlled. + """ + def __init__(self, parent=None, plot=None, legend=None): + NamedItemAlphaSlider.__init__(self, parent, plot, + kind="scatter", legend=legend) diff --git a/src/silx/gui/plot/ColorBar.py b/src/silx/gui/plot/ColorBar.py new file mode 100644 index 0000000..8cafc06 --- /dev/null +++ b/src/silx/gui/plot/ColorBar.py @@ -0,0 +1,883 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Module containing several widgets associated to a colormap. +""" + +__authors__ = ["H. Payno", "T. Vincent"] +__license__ = "MIT" +__date__ = "24/04/2018" + + +import logging +import weakref +import numpy + +from ._utils import ticklayout +from .. import qt +from ..qt import inspect as qt_inspect +from silx.gui import colors +from silx.math.colormap import LogarithmicNormalization + + +_logger = logging.getLogger(__name__) + + +class ColorBarWidget(qt.QWidget): + """Colorbar widget displaying a colormap + + It uses a description of colormap as dict compatible with :class:`Plot`. + + .. image:: img/linearColorbar.png + :width: 80px + :align: center + + To run the following sample code, a QApplication must be initialized. + + >>> from silx.gui.plot import Plot2D + >>> from silx.gui.plot.ColorBar import ColorBarWidget + + >>> plot = Plot2D() # Create a plot widget + >>> plot.show() + + >>> colorbar = ColorBarWidget(plot=plot, legend='Colormap') # Associate the colorbar with it + >>> colorbar.show() + + Initializer parameters: + + :param parent: See :class:`QWidget` + :param plot: PlotWidget the colorbar is attached to (optional) + :param str legend: the label to set to the colorbar + """ + sigVisibleChanged = qt.Signal(bool) + """Emitted when the property `visible` have changed.""" + + def __init__(self, parent=None, plot=None, legend=None): + self._isConnected = False + self._plotRef = None + self._colormap = None + self._data = None + + super(ColorBarWidget, self).__init__(parent) + + self.__buildGUI() + self.setLegend(legend) + self.setPlot(plot) + + def __buildGUI(self): + self.setLayout(qt.QHBoxLayout()) + + # create color scale widget + self._colorScale = ColorScaleBar(parent=self, + colormap=None) + self.layout().addWidget(self._colorScale) + + # legend (is the right group) + self.legend = _VerticalLegend('', self) + self.layout().addWidget(self.legend) + + self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize) + + def getPlot(self): + """Returns the :class:`Plot` associated to this widget or None""" + return None if self._plotRef is None else self._plotRef() + + def setPlot(self, plot): + """Associate a plot to the ColorBar + + :param plot: the plot to associate with the colorbar. + If None will remove any connection with a previous plot. + """ + self._disconnectPlot() + self._plotRef = None if plot is None else weakref.ref(plot) + self._connectPlot() + + def _disconnectPlot(self): + """Disconnect from Plot signals""" + if self._isConnected: + self._isConnected = False + plot = self.getPlot() + if plot is not None and qt_inspect.isValid(plot): + plot.sigActiveImageChanged.disconnect( + self._activeImageChanged) + plot.sigActiveScatterChanged.disconnect( + self._activeScatterChanged) + plot.sigPlotSignal.disconnect(self._defaultColormapChanged) + + def _connectPlot(self): + """Connect to Plot signals""" + plot = self.getPlot() + if plot is not None and not self._isConnected: + activeImageLegend = plot.getActiveImage(just_legend=True) + activeScatterLegend = plot._getActiveItem( + kind='scatter', just_legend=True) + if activeImageLegend is None and activeScatterLegend is None: + # Show plot default colormap + self._syncWithDefaultColormap() + elif activeImageLegend is not None: # Show active image colormap + self._activeImageChanged(None, activeImageLegend) + elif activeScatterLegend is not None: # Show active scatter colormap + self._activeScatterChanged(None, activeScatterLegend) + + plot.sigActiveImageChanged.connect(self._activeImageChanged) + plot.sigActiveScatterChanged.connect(self._activeScatterChanged) + plot.sigPlotSignal.connect(self._defaultColormapChanged) + self._isConnected = True + + def setVisible(self, isVisible): + qt.QWidget.setVisible(self, isVisible) + self.sigVisibleChanged.emit(isVisible) + + def showEvent(self, event): + self._connectPlot() + + def hideEvent(self, event): + self._disconnectPlot() + + def getColormap(self): + """Returns the colormap displayed in the colorbar. + + :rtype: ~silx.gui.colors.Colormap + """ + return self.getColorScaleBar().getColormap() + + def setColormap(self, colormap, data=None): + """Set the colormap to be displayed. + + :param ~silx.gui.colors.Colormap colormap: + The colormap to apply on the ColorBarWidget + :param Union[numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data: + The data to display or item, needed if the colormap require an autoscale + """ + self._data = data + self.getColorScaleBar().setColormap(colormap=colormap, + data=data) + if self._colormap is not None: + self._colormap.sigChanged.disconnect(self._colormapHasChanged) + self._colormap = colormap + if self._colormap is not None: + self._colormap.sigChanged.connect(self._colormapHasChanged) + + def _colormapHasChanged(self): + """handler of the Colormap.sigChanged signal + """ + assert self._colormap is not None + self.setColormap(colormap=self._colormap, + data=self._data) + + def setLegend(self, legend): + """Set the legend displayed along the colorbar + + :param str legend: The label + """ + if legend is None or legend == "": + self.legend.hide() + self.legend.setText("") + else: + assert type(legend) is str + self.legend.show() + self.legend.setText(legend) + + def getLegend(self): + """ + Returns the legend displayed along the colorbar + + :return: return the legend displayed along the colorbar + :rtype: str + """ + return self.legend.text() + + def _activeScatterChanged(self, previous, legend): + """Handle plot active scatter changed""" + plot = self.getPlot() + + # Do not handle active scatter while there is an image + if plot.getActiveImage() is not None: + return + + if legend is None: # No active scatter, display no colormap + self.setColormap(colormap=None) + return + + # Sync with active scatter + scatter = plot._getActiveItem(kind='scatter') + + self.setColormap(colormap=scatter.getColormap(), + data=scatter) + + def _activeImageChanged(self, previous, legend): + """Handle plot active image changed""" + plot = self.getPlot() + + if legend is None: # No active image, try with active scatter + activeScatterLegend = plot._getActiveItem( + kind='scatter', just_legend=True) + # No more active image, use active scatter if any + self._activeScatterChanged(None, activeScatterLegend) + else: + # Sync with active image + image = plot.getActiveImage() + + # RGB(A) image, display default colormap + array = image.getData(copy=False) + if array.ndim != 2: + self.setColormap(colormap=None) + return + + # data image, sync with image colormap + # do we need the copy here : used in the case we are changing + # vmin and vmax but should have already be done by the plot + self.setColormap(colormap=image.getColormap(), data=image) + + def _defaultColormapChanged(self, event): + """Handle plot default colormap changed""" + if event['event'] == 'defaultColormapChanged': + plot = self.getPlot() + if (plot is not None and + plot.getActiveImage() is None and + plot._getActiveItem(kind='scatter') is None): + # No active item, take default colormap update into account + self._syncWithDefaultColormap() + + def _syncWithDefaultColormap(self): + """Update colorbar according to plot default colormap""" + self.setColormap(self.getPlot().getDefaultColormap()) + + def getColorScaleBar(self): + """ + + :return: return the :class:`ColorScaleBar` used to display ColorScale + and ticks""" + return self._colorScale + + +class _VerticalLegend(qt.QLabel): + """Display vertically the given text + """ + def __init__(self, text, parent=None): + """ + + :param text: the legend + :param parent: the Qt parent if any + """ + qt.QLabel.__init__(self, text, parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + + def paintEvent(self, event): + painter = qt.QPainter(self) + painter.setFont(self.font()) + + painter.translate(0, self.rect().height()) + painter.rotate(270) + newRect = qt.QRect(0, 0, self.rect().height(), self.rect().width()) + + painter.drawText(newRect, qt.Qt.AlignHCenter, self.text()) + + fm = qt.QFontMetrics(self.font()) + preferedHeight = fm.width(self.text()) + preferedWidth = fm.height() + self.setFixedWidth(preferedWidth) + self.setMinimumHeight(preferedHeight) + + +class ColorScaleBar(qt.QWidget): + """This class is making the composition of a :class:`_ColorScale` and a + :class:`_TickBar`. + + It is the simplest widget displaying ticks and colormap gradient. + + .. image:: img/colorScaleBar.png + :width: 150px + :align: center + + To run the following sample code, a QApplication must be initialized. + + >>> colormap = Colormap(name='gray', + ... norm='log', + ... vmin=1, + ... vmax=100000, + ... ) + >>> colorscale = ColorScaleBar(parent=None, + ... colormap=colormap ) + >>> colorscale.show() + + Initializer parameters : + + :param colormap: the colormap to be displayed + :param parent: the Qt parent if any + :param displayTicksValues: display the ticks value or only the '-' + """ + + _TEXT_MARGIN = 5 + """The tick bar need a margin to display all labels at the correct place. + So the ColorScale should have the same margin in order for both to fit""" + + def __init__(self, parent=None, colormap=None, data=None, + displayTicksValues=True): + super(ColorScaleBar, self).__init__(parent) + + self.minVal = None + """Value set to the _minLabel""" + self.maxVal = None + """Value set to the _maxLabel""" + + self.setLayout(qt.QGridLayout()) + + # create the left side group (ColorScale) + self.colorScale = _ColorScale(colormap=colormap, + data=data, + parent=self, + margin=ColorScaleBar._TEXT_MARGIN) + if colormap: + vmin, vmax = colormap.getColormapRange(data) + normalizer = colormap._getNormalizer() + else: + vmin, vmax = colors.DEFAULT_MIN_LIN, colors.DEFAULT_MAX_LIN + normalizer = None + + self.tickbar = _TickBar(vmin=vmin, + vmax=vmax, + normalizer=normalizer, + parent=self, + displayValues=displayTicksValues, + margin=ColorScaleBar._TEXT_MARGIN) + + self.layout().addWidget(self.tickbar, 1, 0, 1, 1, qt.Qt.AlignRight) + self.layout().addWidget(self.colorScale, 1, 1, qt.Qt.AlignLeft) + + self.layout().setContentsMargins(0, 0, 0, 0) + self.layout().setSpacing(0) + + # max label + self._maxLabel = qt.QLabel(str(1.0), parent=self) + self._maxLabel.setToolTip(str(0.0)) + self.layout().addWidget(self._maxLabel, 0, 0, 1, 2, qt.Qt.AlignRight) + + # min label + self._minLabel = qt.QLabel(str(0.0), parent=self) + self._minLabel.setToolTip(str(0.0)) + self.layout().addWidget(self._minLabel, 2, 0, 1, 2, qt.Qt.AlignRight) + + self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize) + self.layout().setColumnStretch(0, 1) + self.layout().setRowStretch(1, 1) + + def getTickBar(self): + """ + + :return: the instanciation of the :class:`_TickBar` + """ + return self.tickbar + + def getColorScale(self): + """ + + :return: the instanciation of the :class:`_ColorScale` + """ + return self.colorScale + + def getColormap(self): + """ + + :returns: the colormap. + :rtype: :class:`.Colormap` + """ + return self.colorScale.getColormap() + + def setColormap(self, colormap, data=None): + """Set the new colormap to be displayed + + :param Colormap colormap: the colormap to set + :param Union[numpy.ndarray,~silx.gui.plot.items.Item] data: + The data or item to display, needed if the colormap requires an autoscale + """ + self.colorScale.setColormap(colormap, data) + + if colormap is not None: + vmin, vmax = colormap.getColormapRange(data) + normalizer = colormap._getNormalizer() + else: + vmin, vmax = None, None + normalizer = None + + self.tickbar.update(vmin=vmin, + vmax=vmax, + normalizer=normalizer) + self._setMinMaxLabels(vmin, vmax) + + def setMinMaxVisible(self, val=True): + """Change visibility of the min label and the max label + + :param val: if True, set the labels visible, otherwise set it not visible + """ + self._minLabel.setVisible(val) + self._maxLabel.setVisible(val) + + def _updateMinMax(self): + """Update the min and max label if we are in the case of the + configuration 'minMaxValueOnly'""" + if self.minVal is None: + text, tooltip = '', '' + else: + if self.minVal == 0 or 0 <= numpy.log10(abs(self.minVal)) < 7: + text = '%.7g' % self.minVal + else: + text = '%.2e' % self.minVal + tooltip = repr(self.minVal) + + self._minLabel.setText(text) + self._minLabel.setToolTip(tooltip) + + if self.maxVal is None: + text, tooltip = '', '' + else: + if self.maxVal == 0 or 0 <= numpy.log10(abs(self.maxVal)) < 7: + text = '%.7g' % self.maxVal + else: + text = '%.2e' % self.maxVal + tooltip = repr(self.maxVal) + + self._maxLabel.setText(text) + self._maxLabel.setToolTip(tooltip) + + def _setMinMaxLabels(self, minVal, maxVal): + """Change the value of the min and max labels to be displayed. + + :param minVal: the minimal value of the TickBar (not str) + :param maxVal: the maximal value of the TickBar (not str) + """ + # bad hack to try to display has much information as possible + self.minVal = minVal + self.maxVal = maxVal + self._updateMinMax() + + def resizeEvent(self, event): + qt.QWidget.resizeEvent(self, event) + self._updateMinMax() + + +class _ColorScale(qt.QWidget): + """Widget displaying the colormap colorScale. + + Show matching value between the gradient color (from the colormap) at mouse + position and value. + + .. image:: img/colorScale.png + :width: 20px + :align: center + + + To run the following sample code, a QApplication must be initialized. + + >>> colormap = Colormap(name='viridis', + ... norm='log', + ... vmin=1, + ... vmax=100000, + ... ) + >>> colorscale = ColorScale(parent=None, + ... colormap=colormap) + >>> colorscale.show() + + Initializer parameters : + + :param colormap: the colormap to be displayed + :param parent: the Qt parent if any + :param int margin: the top and left margin to apply. + :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data: + The data or item to use for getting the range for autoscale colormap. + + .. warning:: Value drawing will be + done at the center of ticks. So if no margin is done your values + drawing might not be fully done for extrems values. + """ + + _NB_CONTROL_POINTS = 256 + + def __init__(self, colormap, parent=None, margin=5, data=None): + qt.QWidget.__init__(self, parent) + self._colormap = None + self.margin = margin + self.setColormap(colormap, data) + + self.setLayout(qt.QVBoxLayout()) + self.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Expanding) + # needed to get the mouse event without waiting for button click + self.setMouseTracking(True) + self.setMargin(margin) + self.setContentsMargins(0, 0, 0, 0) + + self.setMinimumHeight(self._NB_CONTROL_POINTS // 2 + 2 * self.margin) + self.setFixedWidth(25) + + def setColormap(self, colormap, data=None): + """Set the new colormap to be displayed + + :param dict colormap: the colormap to set + :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data: + Optional data for which to compute colormap range. + """ + self._colormap = colormap + self.setEnabled(colormap is not None) + + if colormap is None: + self.vmin, self.vmax = None, None + else: + assert colormap.getNormalization() in colors.Colormap.NORMALIZATIONS + self.vmin, self.vmax = self._colormap.getColormapRange(data=data) + self._updateColorGradient() + self.update() + + def getColormap(self): + """Returns the colormap + + :rtype: :class:`.Colormap` + """ + return None if self._colormap is None else self._colormap + + def _updateColorGradient(self): + """Compute the color gradient""" + colormap = self.getColormap() + if colormap is None: + return + + indices = numpy.linspace(0., 1., self._NB_CONTROL_POINTS) + colors = colormap.getNColors(nbColors=self._NB_CONTROL_POINTS) + self._gradient = qt.QLinearGradient(0, 1, 0, 0) + self._gradient.setCoordinateMode(qt.QGradient.StretchToDeviceMode) + self._gradient.setStops( + [(i, qt.QColor(*color)) for i, color in zip(indices, colors)] + ) + + def paintEvent(self, event): + """""" + painter = qt.QPainter(self) + if self.getColormap() is not None: + painter.setBrush(self._gradient) + penColor = self.palette().color(qt.QPalette.Active, + qt.QPalette.WindowText) + else: + penColor = self.palette().color(qt.QPalette.Disabled, + qt.QPalette.WindowText) + painter.setPen(penColor) + + painter.drawRect(qt.QRect( + 0, + self.margin, + self.width() - 1, + self.height() - 2 * self.margin - 1)) + + def mouseMoveEvent(self, event): + tooltip = str(self.getValueFromRelativePosition( + self._getRelativePosition(event.y()))) + qt.QToolTip.showText(event.globalPos(), tooltip, self) + super(_ColorScale, self).mouseMoveEvent(event) + + def _getRelativePosition(self, yPixel): + """yPixel : pixel position into _ColorScale widget reference + """ + # widgets are bottom-top referencial but we display in top-bottom referential + return 1. - (yPixel - self.margin) / float(self.height() - 2 * self.margin) + + def getValueFromRelativePosition(self, value): + """Return the value in the colorMap from a relative position in the + ColorScaleBar (y) + + :param value: float value in [0, 1] + :return: the value in [colormap['vmin'], colormap['vmax']] + """ + colormap = self.getColormap() + if colormap is None: + return + + value = numpy.clip(value, 0., 1.) + normalizer = colormap._getNormalizer() + normMin, normMax = normalizer.apply([self.vmin, self.vmax], self.vmin, self.vmax) + + return normalizer.revert( + normMin + (normMax - normMin) * value, self.vmin, self.vmax) + + def setMargin(self, margin): + """Define the margin to fit with a TickBar object. + This is needed since we can only paint on the viewport of the widget. + Didn't work with a simple setContentsMargins + + :param int margin: the margin to apply on the top and bottom. + """ + self.margin = int(margin) + self.update() + + +class _TickBar(qt.QWidget): + """Bar grouping the ticks displayed + + To run the following sample code, a QApplication must be initialized. + + >>> bar = _TickBar(1, 1000, norm='log', parent=None, displayValues=True) + >>> bar.show() + + .. image:: img/tickbar.png + :width: 40px + :align: center + + :param int vmin: smaller value of the range of values + :param int vmax: higher value of the range of values + :param normalizer: Normalization object. + :param parent: the Qt parent if any + :param bool displayValues: if True display the values close to the tick, + Otherwise only signal it by '-' + :param int nticks: the number of tick we want to display. Should be an + unsigned int ot None. If None, let the Tick bar find the optimal + number of ticks from the tick density. + :param int margin: margin to set on the top and bottom + """ + _WIDTH_DISP_VAL = 45 + """widget width when displayed with ticks labels""" + _WIDTH_NO_DISP_VAL = 10 + """widget width when displayed without ticks labels""" + _FONT_SIZE = 10 + """font size for ticks labels""" + _LINE_WIDTH = 10 + """width of the line to mark a tick""" + + DEFAULT_TICK_DENSITY = 0.015 + + def __init__(self, vmin, vmax, normalizer, parent=None, displayValues=True, + nticks=None, margin=5): + super(_TickBar, self).__init__(parent) + self.margin = margin + self._nticks = None + self.ticks = () + self.subTicks = () + self._forcedDisplayType = None + self.ticksDensity = _TickBar.DEFAULT_TICK_DENSITY + + self._vmin = vmin + self._vmax = vmax + self._normalizer = normalizer + self.displayValues = displayValues + self.setTicksNumber(nticks) + + self.setMargin(margin) + self.setContentsMargins(0, 0, 0, 0) + + self._resetWidth() + + def setTicksValuesVisible(self, val): + self.displayValues = val + self._resetWidth() + + def _resetWidth(self): + width = self._WIDTH_DISP_VAL if self.displayValues else self._WIDTH_NO_DISP_VAL + self.setFixedWidth(width) + + def update(self, vmin, vmax, normalizer): + self._vmin = vmin + self._vmax = vmax + self._normalizer = normalizer + self.computeTicks() + qt.QWidget.update(self) + + def setMargin(self, margin): + """Define the margin to fit with a _ColorScale object. + This is needed since we can only paint on the viewport of the widget + + :param int margin: the margin to apply on the top and bottom. + """ + self.margin = margin + + def setTicksNumber(self, nticks): + """Set the number of ticks to display. + + :param nticks: the number of tick to be display. Should be an + unsigned int ot None. If None, let the :class:`_TickBar` find the + optimal number of ticks from the tick density. + """ + self._nticks = nticks + self.computeTicks() + qt.QWidget.update(self) + + def setTicksDensity(self, density): + """If you let :class:`_TickBar` deal with the number of ticks + (nticks=None) then you can specify a ticks density to be displayed. + """ + if density < 0.0: + raise ValueError('Density should be a positive value') + self.ticksDensity = density + + def computeTicks(self): + """This function compute ticks values labels. It is called at each + update and each resize event. + Deal only with linear and log scale. + """ + nticks = self._nticks + if nticks is None: + nticks = self._getOptimalNbTicks() + + if self._vmin == self._vmax: + # No range: no ticks + self.ticks = () + self.subTicks = () + elif isinstance(self._normalizer, LogarithmicNormalization): + self._computeTicksLog(nticks) + else: # Fallback: use linear + self._computeTicksLin(nticks) + + # update the form + font = qt.QFont() + font.setPixelSize(_TickBar._FONT_SIZE) + + self.form = self._getFormat(font) + + def _computeTicksLog(self, nticks): + logMin = numpy.log10(self._vmin) + logMax = numpy.log10(self._vmax) + lowBound, highBound, spacing, self._nfrac = ticklayout.niceNumbersForLog10(logMin, + logMax, + nticks) + self.ticks = numpy.power(10., numpy.arange(lowBound, highBound, spacing)) + if spacing == 1: + self.subTicks = ticklayout.computeLogSubTicks(ticks=self.ticks, + lowBound=numpy.power(10., lowBound), + highBound=numpy.power(10., highBound)) + else: + self.subTicks = [] + + def resizeEvent(self, event): + qt.QWidget.resizeEvent(self, event) + self.computeTicks() + + def _computeTicksLin(self, nticks): + _min, _max, _spacing, self._nfrac = ticklayout.niceNumbers(self._vmin, + self._vmax, + nticks) + + self.ticks = numpy.arange(_min, _max, _spacing) + self.subTicks = [] + + def _getOptimalNbTicks(self): + return max(2, int(round(self.ticksDensity * self.rect().height()))) + + def paintEvent(self, event): + painter = qt.QPainter(self) + font = painter.font() + font.setPixelSize(_TickBar._FONT_SIZE) + painter.setFont(font) + + # paint ticks + for val in self.ticks: + self._paintTick(val, painter, majorTick=True) + + # paint subticks + for val in self.subTicks: + self._paintTick(val, painter, majorTick=False) + + def _getRelativePosition(self, val): + """Return the relative position of val according to min and max value + """ + if self._normalizer is None: + return 0. + normMin, normMax, normVal = self._normalizer.apply( + [self._vmin, self._vmax, val], + self._vmin, + self._vmax) + + if normMin == normMax: + return 0. + else: + return 1. - (normVal - normMin) / (normMax - normMin) + + def _paintTick(self, val, painter, majorTick=True): + """ + + :param bool majorTick: if False will never draw text and will set a line + with a smaller width + """ + fm = qt.QFontMetrics(painter.font()) + viewportHeight = self.rect().height() - self.margin * 2 - 1 + relativePos = self._getRelativePosition(val) + height = int(viewportHeight * relativePos + self.margin) + lineWidth = _TickBar._LINE_WIDTH + if majorTick is False: + lineWidth /= 2 + + painter.drawLine(qt.QLine(int(self.width() - lineWidth), + height, + self.width(), + height)) + + if self.displayValues and majorTick is True: + painter.drawText(qt.QPoint(0, int(height + fm.height() / 2)), + self.form.format(val)) + + def setDisplayType(self, disType): + """Set the type of display we want to set for ticks labels + + :param str disType: The type of display we want to set. disType values + can be : + + - 'std' for standard, meaning only a formatting on the number of + digits is done + - 'e' for scientific display + - None to let the _TickBar guess the best display for this kind of data. + """ + if disType not in (None, 'std', 'e'): + raise ValueError("display type not recognized, value should be in (None, 'std', 'e'") + self._forcedDisplayType = disType + + def _getStandardFormat(self): + return "{0:.%sf}" % self._nfrac + + def _getFormat(self, font): + if self._forcedDisplayType is None: + return self._guessType(font) + elif self._forcedDisplayType == 'std': + return self._getStandardFormat() + elif self._forcedDisplayType == 'e': + return self._getScientificForm() + else: + err = 'Forced type for display %s is not recognized' % self._forcedDisplayType + raise ValueError(err) + + def _getScientificForm(self): + return "{0:.0e}" + + def _guessType(self, font): + """Try fo find the better format to display the tick's labels + + :param QFont font: the font we want to use during the painting + """ + form = self._getStandardFormat() + + fm = qt.QFontMetrics(font) + width = 0 + for tick in self.ticks: + width = max(fm.boundingRect(form.format(tick)).width(), width) + + # if the length of the string are too long we are moving to scientific + # display + if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH: + return self._getScientificForm() + else: + return form diff --git a/src/silx/gui/plot/Colormap.py b/src/silx/gui/plot/Colormap.py new file mode 100644 index 0000000..22fea7f --- /dev/null +++ b/src/silx/gui/plot/Colormap.py @@ -0,0 +1,42 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Deprecated module providing the Colormap object +""" + +__authors__ = ["T. Vincent", "H.Payno"] +__license__ = "MIT" +__date__ = "27/11/2020" + +import silx.utils.deprecation + +silx.utils.deprecation.deprecated_warning("Module", + name="silx.gui.plot.Colormap", + reason="moved", + replacement="silx.gui.colors.Colormap", + since_version="0.8.0", + only_once=True, + skip_backtrace_count=1) + +from ..colors import * # noqa diff --git a/src/silx/gui/plot/ColormapDialog.py b/src/silx/gui/plot/ColormapDialog.py new file mode 100644 index 0000000..7c66cb8 --- /dev/null +++ b/src/silx/gui/plot/ColormapDialog.py @@ -0,0 +1,43 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Deprecated module providing ColormapDialog.""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent", "H.Payno"] +__license__ = "MIT" +__date__ = "24/04/2018" + +import silx.utils.deprecation + +silx.utils.deprecation.deprecated_warning("Module", + name="silx.gui.plot.ColormapDialog", + reason="moved", + replacement="silx.gui.dialog.ColormapDialog", + since_version="0.8.0", + only_once=True, + skip_backtrace_count=1) + +from ..dialog.ColormapDialog import * # noqa diff --git a/src/silx/gui/plot/Colors.py b/src/silx/gui/plot/Colors.py new file mode 100644 index 0000000..277e104 --- /dev/null +++ b/src/silx/gui/plot/Colors.py @@ -0,0 +1,90 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Color conversion function, color dictionary and colormap tools.""" + +from __future__ import absolute_import + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "14/06/2018" + +import silx.utils.deprecation + +silx.utils.deprecation.deprecated_warning("Module", + name="silx.gui.plot.Colors", + reason="moved", + replacement="silx.gui.colors", + since_version="0.8.0", + only_once=True, + skip_backtrace_count=1) + +from ..colors import * # noqa + + +@silx.utils.deprecation.deprecated(replacement='silx.gui.colors.Colormap.applyColormap') +def applyColormapToData(data, + name='gray', + normalization='linear', + autoscale=True, + vmin=0., + vmax=1., + colors=None): + """Apply a colormap to the data and returns the RGBA image + + This supports data of any dimensions (not only of dimension 2). + The returned array will have one more dimension (with 4 entries) + than the input data to store the RGBA channels + corresponding to each bin in the array. + + :param numpy.ndarray data: The data to convert. + :param str name: Name of the colormap (default: 'gray'). + :param str normalization: Colormap mapping: 'linear' or 'log'. + :param bool autoscale: Whether to use data min/max (True, default) + or [vmin, vmax] range (False). + :param float vmin: The minimum value of the range to use if + 'autoscale' is False. + :param float vmax: The maximum value of the range to use if + 'autoscale' is False. + :param numpy.ndarray colors: Only used if name is None. + Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays + :return: The computed RGBA image + :rtype: numpy.ndarray of uint8 + """ + colormap = Colormap(name=name, + normalization=normalization, + vmin=vmin, + vmax=vmax, + colors=colors) + return colormap.applyToData(data) + + +@silx.utils.deprecation.deprecated(replacement='silx.gui.colors.Colormap.getSupportedColormaps') +def getSupportedColormaps(): + """Get the supported colormap names as a tuple of str. + + The list should at least contain and start by: + ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') + """ + return Colormap.getSupportedColormaps() diff --git a/src/silx/gui/plot/CompareImages.py b/src/silx/gui/plot/CompareImages.py new file mode 100644 index 0000000..857fc79 --- /dev/null +++ b/src/silx/gui/plot/CompareImages.py @@ -0,0 +1,1259 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A widget dedicated to compare 2 images. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "23/07/2018" + + +import enum +import logging +import numpy +import weakref +import collections +import math + +import silx.image.bilinear +from silx.gui import qt +from silx.gui import plot +from silx.gui import icons +from silx.gui.colors import Colormap +from silx.gui.plot import tools +from silx.utils.weakref import WeakMethodProxy + +_logger = logging.getLogger(__name__) + +from silx.opencl import ocl +if ocl is not None: + try: + from silx.opencl import sift + except ImportError: + # sift module is not available (e.g., in official Debian packages) + sift = None +else: # No OpenCL device or no pyopencl + sift = None + + +@enum.unique +class VisualizationMode(enum.Enum): + """Enum for each visualization mode available.""" + ONLY_A = 'a' + ONLY_B = 'b' + VERTICAL_LINE = 'vline' + HORIZONTAL_LINE = 'hline' + COMPOSITE_RED_BLUE_GRAY = "rbgchannel" + COMPOSITE_RED_BLUE_GRAY_NEG = "rbgnegchannel" + COMPOSITE_A_MINUS_B = "aminusb" + + +@enum.unique +class AlignmentMode(enum.Enum): + """Enum for each alignment mode available.""" + ORIGIN = 'origin' + CENTER = 'center' + STRETCH = 'stretch' + AUTO = 'auto' + + +AffineTransformation = collections.namedtuple("AffineTransformation", + ["tx", "ty", "sx", "sy", "rot"]) +"""Contains a 2D affine transformation: translation, scale and rotation""" + + +class CompareImagesToolBar(qt.QToolBar): + """ToolBar containing specific tools to custom the configuration of a + :class:`CompareImages` widget + + Use :meth:`setCompareWidget` to connect this toolbar to a specific + :class:`CompareImages` widget. + + :param Union[qt.QWidget,None] parent: Parent of this widget. + """ + def __init__(self, parent=None): + qt.QToolBar.__init__(self, parent) + + self.__compareWidget = None + + menu = qt.QMenu(self) + self.__visualizationToolButton = qt.QToolButton(self) + self.__visualizationToolButton.setMenu(menu) + self.__visualizationToolButton.setPopupMode(qt.QToolButton.InstantPopup) + self.addWidget(self.__visualizationToolButton) + self.__visualizationGroup = qt.QActionGroup(self) + self.__visualizationGroup.setExclusive(True) + self.__visualizationGroup.triggered.connect(self.__visualizationModeChanged) + + icon = icons.getQIcon("compare-mode-a") + action = qt.QAction(icon, "Display the first image only", self) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + action.setShortcut(qt.QKeySequence(qt.Qt.Key_A)) + action.setProperty("mode", VisualizationMode.ONLY_A) + menu.addAction(action) + self.__aModeAction = action + self.__visualizationGroup.addAction(action) + + icon = icons.getQIcon("compare-mode-b") + action = qt.QAction(icon, "Display the second image only", self) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + action.setShortcut(qt.QKeySequence(qt.Qt.Key_B)) + action.setProperty("mode", VisualizationMode.ONLY_B) + menu.addAction(action) + self.__bModeAction = action + self.__visualizationGroup.addAction(action) + + icon = icons.getQIcon("compare-mode-vline") + action = qt.QAction(icon, "Vertical compare mode", self) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + action.setShortcut(qt.QKeySequence(qt.Qt.Key_V)) + action.setProperty("mode", VisualizationMode.VERTICAL_LINE) + menu.addAction(action) + self.__vlineModeAction = action + self.__visualizationGroup.addAction(action) + + icon = icons.getQIcon("compare-mode-hline") + action = qt.QAction(icon, "Horizontal compare mode", self) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + action.setShortcut(qt.QKeySequence(qt.Qt.Key_H)) + action.setProperty("mode", VisualizationMode.HORIZONTAL_LINE) + menu.addAction(action) + self.__hlineModeAction = action + self.__visualizationGroup.addAction(action) + + icon = icons.getQIcon("compare-mode-rb-channel") + action = qt.QAction(icon, "Blue/red compare mode (additive mode)", self) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + action.setShortcut(qt.QKeySequence(qt.Qt.Key_C)) + action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY) + menu.addAction(action) + self.__brChannelModeAction = action + self.__visualizationGroup.addAction(action) + + icon = icons.getQIcon("compare-mode-rbneg-channel") + action = qt.QAction(icon, "Yellow/cyan compare mode (subtractive mode)", self) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + action.setShortcut(qt.QKeySequence(qt.Qt.Key_W)) + action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG) + menu.addAction(action) + self.__ycChannelModeAction = action + self.__visualizationGroup.addAction(action) + + icon = icons.getQIcon("compare-mode-a-minus-b") + action = qt.QAction(icon, "Raw A minus B compare mode", self) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + action.setShortcut(qt.QKeySequence(qt.Qt.Key_W)) + action.setProperty("mode", VisualizationMode.COMPOSITE_A_MINUS_B) + menu.addAction(action) + self.__ycChannelModeAction = action + self.__visualizationGroup.addAction(action) + + menu = qt.QMenu(self) + self.__alignmentToolButton = qt.QToolButton(self) + self.__alignmentToolButton.setMenu(menu) + self.__alignmentToolButton.setPopupMode(qt.QToolButton.InstantPopup) + self.addWidget(self.__alignmentToolButton) + self.__alignmentGroup = qt.QActionGroup(self) + self.__alignmentGroup.setExclusive(True) + self.__alignmentGroup.triggered.connect(self.__alignmentModeChanged) + + icon = icons.getQIcon("compare-align-origin") + action = qt.QAction(icon, "Align images on their upper-left pixel", self) + action.setProperty("mode", AlignmentMode.ORIGIN) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + self.__originAlignAction = action + menu.addAction(action) + self.__alignmentGroup.addAction(action) + + icon = icons.getQIcon("compare-align-center") + action = qt.QAction(icon, "Center images", self) + action.setProperty("mode", AlignmentMode.CENTER) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + self.__centerAlignAction = action + menu.addAction(action) + self.__alignmentGroup.addAction(action) + + icon = icons.getQIcon("compare-align-stretch") + action = qt.QAction(icon, "Stretch the second image on the first one", self) + action.setProperty("mode", AlignmentMode.STRETCH) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + self.__stretchAlignAction = action + menu.addAction(action) + self.__alignmentGroup.addAction(action) + + icon = icons.getQIcon("compare-align-auto") + action = qt.QAction(icon, "Auto-alignment of the second image", self) + action.setProperty("mode", AlignmentMode.AUTO) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + self.__autoAlignAction = action + menu.addAction(action) + if sift is None: + action.setEnabled(False) + action.setToolTip("Sift module is not available") + self.__alignmentGroup.addAction(action) + + icon = icons.getQIcon("compare-keypoints") + action = qt.QAction(icon, "Display/hide alignment keypoints", self) + action.setCheckable(True) + action.triggered.connect(self.__keypointVisibilityChanged) + self.addAction(action) + self.__displayKeypoints = action + + def setCompareWidget(self, widget): + """ + Connect this tool bar to a specific :class:`CompareImages` widget. + + :param Union[None,CompareImages] widget: The widget to connect with. + """ + compareWidget = self.getCompareWidget() + if compareWidget is not None: + compareWidget.sigConfigurationChanged.disconnect(self.__updateSelectedActions) + compareWidget = widget + if compareWidget is None: + self.__compareWidget = None + else: + self.__compareWidget = weakref.ref(compareWidget) + if compareWidget is not None: + widget.sigConfigurationChanged.connect(self.__updateSelectedActions) + self.__updateSelectedActions() + + def getCompareWidget(self): + """Returns the connected widget. + + :rtype: CompareImages + """ + if self.__compareWidget is None: + return None + else: + return self.__compareWidget() + + def __updateSelectedActions(self): + """ + Update the state of this tool bar according to the state of the + connected :class:`CompareImages` widget. + """ + widget = self.getCompareWidget() + if widget is None: + return + + mode = widget.getVisualizationMode() + action = None + for a in self.__visualizationGroup.actions(): + actionMode = a.property("mode") + if mode == actionMode: + action = a + break + old = self.__visualizationGroup.blockSignals(True) + if action is not None: + # Check this action + action.setChecked(True) + else: + action = self.__visualizationGroup.checkedAction() + if action is not None: + # Uncheck this action + action.setChecked(False) + self.__updateVisualizationMenu() + self.__visualizationGroup.blockSignals(old) + + mode = widget.getAlignmentMode() + action = None + for a in self.__alignmentGroup.actions(): + actionMode = a.property("mode") + if mode == actionMode: + action = a + break + old = self.__alignmentGroup.blockSignals(True) + if action is not None: + # Check this action + action.setChecked(True) + else: + action = self.__alignmentGroup.checkedAction() + if action is not None: + # Uncheck this action + action.setChecked(False) + self.__updateAlignmentMenu() + self.__alignmentGroup.blockSignals(old) + + def __visualizationModeChanged(self, selectedAction): + """Called when user requesting changes of the visualization mode. + """ + self.__updateVisualizationMenu() + widget = self.getCompareWidget() + if widget is not None: + mode = selectedAction.property("mode") + widget.setVisualizationMode(mode) + + def __updateVisualizationMenu(self): + """Update the state of the action containing visualization menu. + """ + selectedAction = self.__visualizationGroup.checkedAction() + if selectedAction is not None: + self.__visualizationToolButton.setText(selectedAction.text()) + self.__visualizationToolButton.setIcon(selectedAction.icon()) + self.__visualizationToolButton.setToolTip(selectedAction.toolTip()) + else: + self.__visualizationToolButton.setText("") + self.__visualizationToolButton.setIcon(qt.QIcon()) + self.__visualizationToolButton.setToolTip("") + + def __alignmentModeChanged(self, selectedAction): + """Called when user requesting changes of the alignment mode. + """ + self.__updateAlignmentMenu() + widget = self.getCompareWidget() + if widget is not None: + mode = selectedAction.property("mode") + widget.setAlignmentMode(mode) + + def __updateAlignmentMenu(self): + """Update the state of the action containing alignment menu. + """ + selectedAction = self.__alignmentGroup.checkedAction() + if selectedAction is not None: + self.__alignmentToolButton.setText(selectedAction.text()) + self.__alignmentToolButton.setIcon(selectedAction.icon()) + self.__alignmentToolButton.setToolTip(selectedAction.toolTip()) + else: + self.__alignmentToolButton.setText("") + self.__alignmentToolButton.setIcon(qt.QIcon()) + self.__alignmentToolButton.setToolTip("") + + def __keypointVisibilityChanged(self): + """Called when action managing keypoints visibility changes""" + widget = self.getCompareWidget() + if widget is not None: + keypointsVisible = self.__displayKeypoints.isChecked() + widget.setKeypointsVisible(keypointsVisible) + + +class CompareImagesStatusBar(qt.QStatusBar): + """StatusBar containing specific information contained in a + :class:`CompareImages` widget + + Use :meth:`setCompareWidget` to connect this toolbar to a specific + :class:`CompareImages` widget. + + :param Union[qt.QWidget,None] parent: Parent of this widget. + """ + def __init__(self, parent=None): + qt.QStatusBar.__init__(self, parent) + self.setSizeGripEnabled(False) + self.layout().setSpacing(0) + self.__compareWidget = None + self._label1 = qt.QLabel(self) + self._label1.setFrameShape(qt.QFrame.WinPanel) + self._label1.setFrameShadow(qt.QFrame.Sunken) + self._label2 = qt.QLabel(self) + self._label2.setFrameShape(qt.QFrame.WinPanel) + self._label2.setFrameShadow(qt.QFrame.Sunken) + self._transform = qt.QLabel(self) + self._transform.setFrameShape(qt.QFrame.WinPanel) + self._transform.setFrameShadow(qt.QFrame.Sunken) + self.addWidget(self._label1) + self.addWidget(self._label2) + self.addWidget(self._transform) + self._pos = None + self._updateStatusBar() + + def setCompareWidget(self, widget): + """ + Connect this tool bar to a specific :class:`CompareImages` widget. + + :param Union[None,CompareImages] widget: The widget to connect with. + """ + compareWidget = self.getCompareWidget() + if compareWidget is not None: + compareWidget.getPlot().sigPlotSignal.disconnect(self.__plotSignalReceived) + compareWidget.sigConfigurationChanged.disconnect(self.__dataChanged) + compareWidget = widget + if compareWidget is None: + self.__compareWidget = None + else: + self.__compareWidget = weakref.ref(compareWidget) + if compareWidget is not None: + compareWidget.getPlot().sigPlotSignal.connect(self.__plotSignalReceived) + compareWidget.sigConfigurationChanged.connect(self.__dataChanged) + + def getCompareWidget(self): + """Returns the connected widget. + + :rtype: CompareImages + """ + if self.__compareWidget is None: + return None + else: + return self.__compareWidget() + + def __plotSignalReceived(self, event): + """Called when old style signals at emmited from the plot.""" + if event["event"] == "mouseMoved": + x, y = event["x"], event["y"] + self.__mouseMoved(x, y) + + def __mouseMoved(self, x, y): + """Called when mouse move over the plot.""" + self._pos = x, y + self._updateStatusBar() + + def __dataChanged(self): + """Called when internal data from the connected widget changes.""" + self._updateStatusBar() + + def _formatData(self, data): + """Format pixel of an image. + + It supports intensity, RGB, and RGBA. + + :param Union[int,float,numpy.ndarray,str]: Value of a pixel + :rtype: str + """ + if data is None: + return "No data" + if isinstance(data, (int, numpy.integer)): + return "%d" % data + if isinstance(data, (float, numpy.floating)): + return "%f" % data + if isinstance(data, numpy.ndarray): + # RGBA value + if data.shape == (3,): + return "R:%d G:%d B:%d" % (data[0], data[1], data[2]) + elif data.shape == (4,): + return "R:%d G:%d B:%d A:%d" % (data[0], data[1], data[2], data[3]) + _logger.debug("Unsupported data format %s. Cast it to string.", type(data)) + return str(data) + + def _updateStatusBar(self): + """Update the content of the status bar""" + widget = self.getCompareWidget() + if widget is None: + self._label1.setText("Image1: NA") + self._label2.setText("Image2: NA") + self._transform.setVisible(False) + else: + transform = widget.getTransformation() + self._transform.setVisible(transform is not None) + if transform is not None: + has_notable_translation = not numpy.isclose(transform.tx, 0.0, atol=0.01) \ + or not numpy.isclose(transform.ty, 0.0, atol=0.01) + has_notable_scale = not numpy.isclose(transform.sx, 1.0, atol=0.01) \ + or not numpy.isclose(transform.sy, 1.0, atol=0.01) + has_notable_rotation = not numpy.isclose(transform.rot, 0.0, atol=0.01) + + strings = [] + if has_notable_translation: + strings.append("Translation") + if has_notable_scale: + strings.append("Scale") + if has_notable_rotation: + strings.append("Rotation") + if strings == []: + has_translation = not numpy.isclose(transform.tx, 0.0) \ + or not numpy.isclose(transform.ty, 0.0) + has_scale = not numpy.isclose(transform.sx, 1.0) \ + or not numpy.isclose(transform.sy, 1.0) + has_rotation = not numpy.isclose(transform.rot, 0.0) + if has_translation or has_scale or has_rotation: + text = "No big changes" + else: + text = "No changes" + else: + text = "+".join(strings) + self._transform.setText("Align: " + text) + + strings = [] + if not numpy.isclose(transform.ty, 0.0): + strings.append("Translation x: %0.3fpx" % transform.tx) + if not numpy.isclose(transform.ty, 0.0): + strings.append("Translation y: %0.3fpx" % transform.ty) + if not numpy.isclose(transform.sx, 1.0): + strings.append("Scale x: %0.3f" % transform.sx) + if not numpy.isclose(transform.sy, 1.0): + strings.append("Scale y: %0.3f" % transform.sy) + if not numpy.isclose(transform.rot, 0.0): + strings.append("Rotation: %0.3fdeg" % (transform.rot * 180 / numpy.pi)) + if strings == []: + text = "No transformation" + else: + text = "\n".join(strings) + self._transform.setToolTip(text) + + if self._pos is None: + self._label1.setText("Image1: NA") + self._label2.setText("Image2: NA") + else: + data1, data2 = widget.getRawPixelData(self._pos[0], self._pos[1]) + if isinstance(data1, str): + self._label1.setToolTip(data1) + text1 = "NA" + else: + self._label1.setToolTip("") + text1 = self._formatData(data1) + if isinstance(data2, str): + self._label2.setToolTip(data2) + text2 = "NA" + else: + self._label2.setToolTip("") + text2 = self._formatData(data2) + self._label1.setText("Image1: %s" % text1) + self._label2.setText("Image2: %s" % text2) + + +class CompareImages(qt.QMainWindow): + """Widget providing tools to compare 2 images. + + .. image:: img/CompareImages.png + + :param Union[qt.QWidget,None] parent: Parent of this widget. + :param backend: The backend to use, in: + 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none' + or a :class:`BackendBase.BackendBase` class + :type backend: str or :class:`BackendBase.BackendBase` + """ + + VisualizationMode = VisualizationMode + """Available visualization modes""" + + AlignmentMode = AlignmentMode + """Available alignment modes""" + + sigConfigurationChanged = qt.Signal() + """Emitted when the configuration of the widget (visualization mode, + alignement mode...) have changed.""" + + def __init__(self, parent=None, backend=None): + qt.QMainWindow.__init__(self, parent) + self._resetZoomActive = True + self._colormap = Colormap() + """Colormap shared by all modes, except the compose images (rgb image)""" + self._colormapKeyPoints = Colormap('spring') + """Colormap used for sift keypoints""" + + if parent is None: + self.setWindowTitle('Compare images') + else: + self.setWindowFlags(qt.Qt.Widget) + + self.__transformation = None + self.__raw1 = None + self.__raw2 = None + self.__data1 = None + self.__data2 = None + self.__previousSeparatorPosition = None + + self.__plot = plot.PlotWidget(parent=self, backend=backend) + self.__plot.setDefaultColormap(self._colormap) + self.__plot.getXAxis().setLabel('Columns') + self.__plot.getYAxis().setLabel('Rows') + if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward': + self.__plot.getYAxis().setInverted(True) + + self.__plot.setKeepDataAspectRatio(True) + self.__plot.sigPlotSignal.connect(self.__plotSlot) + self.__plot.setAxesDisplayed(False) + + self.setCentralWidget(self.__plot) + + legend = VisualizationMode.VERTICAL_LINE.name + self.__plot.addXMarker( + 0, + legend=legend, + text='', + draggable=True, + color='blue', + constraint=WeakMethodProxy(self.__separatorConstraint)) + self.__vline = self.__plot._getMarker(legend) + + legend = VisualizationMode.HORIZONTAL_LINE.name + self.__plot.addYMarker( + 0, + legend=legend, + text='', + draggable=True, + color='blue', + constraint=WeakMethodProxy(self.__separatorConstraint)) + self.__hline = self.__plot._getMarker(legend) + + # default values + self.__visualizationMode = "" + self.__alignmentMode = "" + self.__keypointsVisible = True + + self.setAlignmentMode(AlignmentMode.ORIGIN) + self.setVisualizationMode(VisualizationMode.VERTICAL_LINE) + self.setKeypointsVisible(False) + + # Toolbars + + self._createToolBars(self.__plot) + if self._interactiveModeToolBar is not None: + self.addToolBar(self._interactiveModeToolBar) + if self._imageToolBar is not None: + self.addToolBar(self._imageToolBar) + if self._compareToolBar is not None: + self.addToolBar(self._compareToolBar) + + # Statusbar + + self._createStatusBar(self.__plot) + if self._statusBar is not None: + self.setStatusBar(self._statusBar) + + def _createStatusBar(self, plot): + self._statusBar = CompareImagesStatusBar(self) + self._statusBar.setCompareWidget(self) + + def _createToolBars(self, plot): + """Create tool bars displayed by the widget""" + toolBar = tools.InteractiveModeToolBar(parent=self, plot=plot) + self._interactiveModeToolBar = toolBar + toolBar = tools.ImageToolBar(parent=self, plot=plot) + self._imageToolBar = toolBar + toolBar = CompareImagesToolBar(self) + toolBar.setCompareWidget(self) + self._compareToolBar = toolBar + + def getPlot(self): + """Returns the plot which is used to display the images. + + :rtype: silx.gui.plot.PlotWidget + """ + return self.__plot + + def getColormap(self): + """ + + :return: colormap used for compare image + :rtype: silx.gui.colors.Colormap + """ + return self._colormap + + def getRawPixelData(self, x, y): + """Return the raw pixel of each image data from axes positions. + + If the coordinate is outside of the image it returns None element in + the tuple. + + The pixel is reach from the raw data image without filter or + transformation. But the coordinate x and y are in the reference of the + current displayed mode. + + :param float x: X-coordinate of the pixel in the current displayed plot + :param float y: Y-coordinate of the pixel in the current displayed plot + :return: A tuple of for each images containing pixel information. It + could be a scalar value or an array in case of RGB/RGBA informations. + It also could be a string containing information is some cases. + :rtype: Tuple(Union[int,float,numpy.ndarray,str],Union[int,float,numpy.ndarray,str]) + """ + data2 = None + alignmentMode = self.__alignmentMode + raw1, raw2 = self.__raw1, self.__raw2 + if alignmentMode == AlignmentMode.ORIGIN: + x1 = x + y1 = y + x2 = x + y2 = y + elif alignmentMode == AlignmentMode.CENTER: + yy = max(raw1.shape[0], raw2.shape[0]) + xx = max(raw1.shape[1], raw2.shape[1]) + x1 = x - (xx - raw1.shape[1]) * 0.5 + x2 = x - (xx - raw2.shape[1]) * 0.5 + y1 = y - (yy - raw1.shape[0]) * 0.5 + y2 = y - (yy - raw2.shape[0]) * 0.5 + elif alignmentMode == AlignmentMode.STRETCH: + x1 = x + y1 = y + x2 = x * raw2.shape[1] / raw1.shape[1] + y2 = x * raw2.shape[1] / raw1.shape[1] + elif alignmentMode == AlignmentMode.AUTO: + x1 = x + y1 = y + # Not implemented + data2 = "Not implemented with sift" + else: + assert(False) + + x1, y1 = int(x1), int(y1) + if raw1 is None or y1 < 0 or y1 >= raw1.shape[0] or x1 < 0 or x1 >= raw1.shape[1]: + data1 = None + else: + data1 = raw1[y1, x1] + + if data2 is None: + x2, y2 = int(x2), int(y2) + if raw2 is None or y2 < 0 or y2 >= raw2.shape[0] or x2 < 0 or x2 >= raw2.shape[1]: + data2 = None + else: + data2 = raw2[y2, x2] + + return data1, data2 + + def setVisualizationMode(self, mode): + """Set the visualization mode. + + :param str mode: New visualization to display the image comparison + """ + if self.__visualizationMode == mode: + return + previousMode = self.getVisualizationMode() + self.__visualizationMode = mode + mode = self.getVisualizationMode() + self.__vline.setVisible(mode == VisualizationMode.VERTICAL_LINE) + self.__hline.setVisible(mode == VisualizationMode.HORIZONTAL_LINE) + visModeRawDisplay = (VisualizationMode.ONLY_A, + VisualizationMode.ONLY_B, + VisualizationMode.VERTICAL_LINE, + VisualizationMode.HORIZONTAL_LINE) + updateColormap = not(previousMode in visModeRawDisplay and + mode in visModeRawDisplay) + self.__updateData(updateColormap=updateColormap) + self.sigConfigurationChanged.emit() + + def getVisualizationMode(self): + """Returns the current interaction mode.""" + return self.__visualizationMode + + def setAlignmentMode(self, mode): + """Set the alignment mode. + + :param str mode: New alignement to apply to images + """ + if self.__alignmentMode == mode: + return + self.__alignmentMode = mode + self.__updateData(updateColormap=False) + self.sigConfigurationChanged.emit() + + def getAlignmentMode(self): + """Returns the current selected alignemnt mode.""" + return self.__alignmentMode + + def setKeypointsVisible(self, isVisible): + """Set keypoints visibility. + + :param bool isVisible: If True, keypoints are displayed (if some) + """ + if self.__keypointsVisible == isVisible: + return + self.__keypointsVisible = isVisible + self.__updateKeyPoints() + self.sigConfigurationChanged.emit() + + def __setDefaultAlignmentMode(self): + """Reset the alignemnt mode to the default value""" + self.setAlignmentMode(AlignmentMode.ORIGIN) + + def __plotSlot(self, event): + """Handle events from the plot""" + if event['event'] in ('markerMoving', 'markerMoved'): + mode = self.getVisualizationMode() + legend = mode.name + if event['label'] == legend: + if mode == VisualizationMode.VERTICAL_LINE: + value = int(float(str(event['xdata']))) + elif mode == VisualizationMode.HORIZONTAL_LINE: + value = int(float(str(event['ydata']))) + else: + assert(False) + if self.__previousSeparatorPosition != value: + self.__separatorMoved(value) + self.__previousSeparatorPosition = value + + def __separatorConstraint(self, x, y): + """Manage contains on the separators to clamp them inside the images.""" + if self.__data1 is None: + return 0, 0 + x = int(x) + if x < 0: + x = 0 + elif x > self.__data1.shape[1]: + x = self.__data1.shape[1] + y = int(y) + if y < 0: + y = 0 + elif y > self.__data1.shape[0]: + y = self.__data1.shape[0] + return x, y + + def __updateSeparators(self): + """Redraw images according to the current state of the separators. + """ + mode = self.getVisualizationMode() + if mode == VisualizationMode.VERTICAL_LINE: + pos = self.__vline.getXPosition() + self.__separatorMoved(pos) + self.__previousSeparatorPosition = pos + elif mode == VisualizationMode.HORIZONTAL_LINE: + pos = self.__hline.getYPosition() + self.__separatorMoved(pos) + self.__previousSeparatorPosition = pos + else: + self.__image1.setOrigin((0, 0)) + self.__image2.setOrigin((0, 0)) + + def __separatorMoved(self, pos): + """Called when vertical or horizontal separators have moved. + + Update the displayed images. + """ + if self.__data1 is None: + return + + mode = self.getVisualizationMode() + if mode == VisualizationMode.VERTICAL_LINE: + pos = int(pos) + if pos <= 0: + pos = 0 + elif pos >= self.__data1.shape[1]: + pos = self.__data1.shape[1] + data1 = self.__data1[:, 0:pos] + data2 = self.__data2[:, pos:] + self.__image1.setData(data1, copy=False) + self.__image2.setData(data2, copy=False) + self.__image2.setOrigin((pos, 0)) + elif mode == VisualizationMode.HORIZONTAL_LINE: + pos = int(pos) + if pos <= 0: + pos = 0 + elif pos >= self.__data1.shape[0]: + pos = self.__data1.shape[0] + data1 = self.__data1[0:pos, :] + data2 = self.__data2[pos:, :] + self.__image1.setData(data1, copy=False) + self.__image2.setData(data2, copy=False) + self.__image2.setOrigin((0, pos)) + else: + assert(False) + + def setData(self, image1, image2, updateColormap=True): + """Set images to compare. + + Images can contains floating-point or integer values, or RGB and RGBA + values, but should have comparable intensities. + + RGB and RGBA images are provided as an array as `[width,height,channels]` + of usigned integer 8-bits or floating-points between 0.0 to 1.0. + + :param numpy.ndarray image1: The first image + :param numpy.ndarray image2: The second image + """ + self.__raw1 = image1 + self.__raw2 = image2 + self.__updateData(updateColormap=updateColormap) + if self.isAutoResetZoom(): + self.__plot.resetZoom() + + def setImage1(self, image1, updateColormap=True): + """Set image1 to be compared. + + Images can contains floating-point or integer values, or RGB and RGBA + values, but should have comparable intensities. + + RGB and RGBA images are provided as an array as `[width,height,channels]` + of usigned integer 8-bits or floating-points between 0.0 to 1.0. + + :param numpy.ndarray image1: The first image + """ + self.__raw1 = image1 + self.__updateData(updateColormap=updateColormap) + if self.isAutoResetZoom(): + self.__plot.resetZoom() + + def setImage2(self, image2, updateColormap=True): + """Set image2 to be compared. + + Images can contains floating-point or integer values, or RGB and RGBA + values, but should have comparable intensities. + + RGB and RGBA images are provided as an array as `[width,height,channels]` + of usigned integer 8-bits or floating-points between 0.0 to 1.0. + + :param numpy.ndarray image2: The second image + """ + self.__raw2 = image2 + self.__updateData(updateColormap=updateColormap) + if self.isAutoResetZoom(): + self.__plot.resetZoom() + + def __updateKeyPoints(self): + """Update the displayed keypoints using cached keypoints. + """ + if self.__keypointsVisible: + data = self.__matching_keypoints + else: + data = [], [], [] + self.__plot.addScatter(x=data[0], + y=data[1], + z=1, + value=data[2], + colormap=self._colormapKeyPoints, + legend="keypoints") + + def __updateData(self, updateColormap): + """Compute aligned image when the alignment mode changes. + + This function cache input images which are used when + vertical/horizontal separators moves. + """ + raw1, raw2 = self.__raw1, self.__raw2 + if raw1 is None or raw2 is None: + return + + alignmentMode = self.getAlignmentMode() + self.__transformation = None + + if alignmentMode == AlignmentMode.ORIGIN: + yy = max(raw1.shape[0], raw2.shape[0]) + xx = max(raw1.shape[1], raw2.shape[1]) + size = yy, xx + data1 = self.__createMarginImage(raw1, size, transparent=True) + data2 = self.__createMarginImage(raw2, size, transparent=True) + self.__matching_keypoints = [0.0], [0.0], [1.0] + elif alignmentMode == AlignmentMode.CENTER: + yy = max(raw1.shape[0], raw2.shape[0]) + xx = max(raw1.shape[1], raw2.shape[1]) + size = yy, xx + data1 = self.__createMarginImage(raw1, size, transparent=True, center=True) + data2 = self.__createMarginImage(raw2, size, transparent=True, center=True) + self.__matching_keypoints = ([data1.shape[1] // 2], + [data1.shape[0] // 2], + [1.0]) + elif alignmentMode == AlignmentMode.STRETCH: + data1 = raw1 + data2 = self.__rescaleImage(raw2, data1.shape) + self.__matching_keypoints = ([0, data1.shape[1], data1.shape[1], 0], + [0, 0, data1.shape[0], data1.shape[0]], + [1.0, 1.0, 1.0, 1.0]) + elif alignmentMode == AlignmentMode.AUTO: + # TODO: sift implementation do not support RGBA images + yy = max(raw1.shape[0], raw2.shape[0]) + xx = max(raw1.shape[1], raw2.shape[1]) + size = yy, xx + data1 = self.__createMarginImage(raw1, size) + data2 = self.__createMarginImage(raw2, size) + self.__matching_keypoints = [0.0], [0.0], [1.0] + try: + data1, data2 = self.__createSiftData(data1, data2) + if data2 is None: + raise ValueError("Unexpected None value") + except Exception as e: + # TODO: Display it on the GUI + _logger.error(e) + self.__setDefaultAlignmentMode() + return + else: + assert(False) + + mode = self.getVisualizationMode() + if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG: + data1 = self.__composeImage(data1, data2, mode) + data2 = numpy.empty((0, 0)) + elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY: + data1 = self.__composeImage(data1, data2, mode) + data2 = numpy.empty((0, 0)) + elif mode == VisualizationMode.COMPOSITE_A_MINUS_B: + data1 = self.__composeImage(data1, data2, mode) + data2 = numpy.empty((0, 0)) + elif mode == VisualizationMode.ONLY_A: + data2 = numpy.empty((0, 0)) + elif mode == VisualizationMode.ONLY_B: + data1 = numpy.empty((0, 0)) + + self.__data1, self.__data2 = data1, data2 + self.__plot.addImage(data1, z=0, legend="image1", resetzoom=False) + self.__plot.addImage(data2, z=0, legend="image2", resetzoom=False) + self.__image1 = self.__plot.getImage("image1") + self.__image2 = self.__plot.getImage("image2") + self.__updateKeyPoints() + + # Set the separator into the middle + if self.__previousSeparatorPosition is None: + value = self.__data1.shape[1] // 2 + self.__vline.setPosition(value, 0) + value = self.__data1.shape[0] // 2 + self.__hline.setPosition(0, value) + self.__updateSeparators() + if updateColormap: + self.__updateColormap() + + def __updateColormap(self): + # TODO: The colormap histogram will still be wrong + mode1 = self.__getImageMode(self.__data1) + mode2 = self.__getImageMode(self.__data2) + if mode1 == "intensity" and mode1 == mode2: + if self.__data1.size == 0: + vmin = self.__data2.min() + vmax = self.__data2.max() + elif self.__data2.size == 0: + vmin = self.__data1.min() + vmax = self.__data1.max() + else: + vmin = min(self.__data1.min(), self.__data2.min()) + vmax = max(self.__data1.max(), self.__data2.max()) + colormap = self.getColormap() + colormap.setVRange(vmin=vmin, vmax=vmax) + self.__image1.setColormap(colormap) + self.__image2.setColormap(colormap) + + def __getImageMode(self, image): + """Returns a value identifying the way the image is stored in the + array. + + :param numpy.ndarray image: Image to check + :rtype: str + """ + if len(image.shape) == 2: + return "intensity" + elif len(image.shape) == 3: + if image.shape[2] == 3: + return "rgb" + elif image.shape[2] == 4: + return "rgba" + raise TypeError("'image' argument is not an image.") + + def __rescaleImage(self, image, shape): + """Rescale an image to the requested shape. + + :rtype: numpy.ndarray + """ + mode = self.__getImageMode(image) + if mode == "intensity": + data = self.__rescaleArray(image, shape) + elif mode == "rgb": + data = numpy.empty((shape[0], shape[1], 3), dtype=image.dtype) + for c in range(3): + data[:, :, c] = self.__rescaleArray(image[:, :, c], shape) + elif mode == "rgba": + data = numpy.empty((shape[0], shape[1], 4), dtype=image.dtype) + for c in range(4): + data[:, :, c] = self.__rescaleArray(image[:, :, c], shape) + return data + + def __composeImage(self, data1, data2, mode): + """Returns an RBG image containing composition of data1 and data2 in 2 + different channels + + :param numpy.ndarray data1: First image + :param numpy.ndarray data1: Second image + :param VisualizationMode mode: Composition mode. + :rtype: numpy.ndarray + """ + assert(data1.shape[0:2] == data2.shape[0:2]) + if mode == VisualizationMode.COMPOSITE_A_MINUS_B: + # TODO: this calculation has no interest of generating a 'composed' + # rgb image, this could be moved in an other function or doc + # should be modified + _type = data1.dtype + result = data1.astype(numpy.float64) - data2.astype(numpy.float64) + return result + mode1 = self.__getImageMode(data1) + if mode1 in ["rgb", "rgba"]: + intensity1 = self.__luminosityImage(data1) + vmin1, vmax1 = 0.0, 1.0 + else: + intensity1 = data1 + vmin1, vmax1 = data1.min(), data1.max() + + mode2 = self.__getImageMode(data2) + if mode2 in ["rgb", "rgba"]: + intensity2 = self.__luminosityImage(data2) + vmin2, vmax2 = 0.0, 1.0 + else: + intensity2 = data2 + vmin2, vmax2 = data2.min(), data2.max() + + vmin, vmax = min(vmin1, vmin2) * 1.0, max(vmax1, vmax2) * 1.0 + shape = data1.shape + result = numpy.empty((shape[0], shape[1], 3), dtype=numpy.uint8) + a = (intensity1 - vmin) * (1.0 / (vmax - vmin)) * 255.0 + b = (intensity2 - vmin) * (1.0 / (vmax - vmin)) * 255.0 + if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY: + result[:, :, 0] = a + result[:, :, 1] = (a + b) / 2 + result[:, :, 2] = b + elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG: + result[:, :, 0] = 255 - b + result[:, :, 1] = 255 - (a + b) / 2 + result[:, :, 2] = 255 - a + return result + + def __luminosityImage(self, image): + """Returns the luminosity channel from an RBG(A) image. + The alpha channel is ignored. + + :rtype: numpy.ndarray + """ + mode = self.__getImageMode(image) + assert(mode in ["rgb", "rgba"]) + is_uint8 = image.dtype.type == numpy.uint8 + # luminosity + image = 0.21 * image[..., 0] + 0.72 * image[..., 1] + 0.07 * image[..., 2] + if is_uint8: + image = image / 255.0 + return image + + def __rescaleArray(self, image, shape): + """Rescale a 2D array to the requested shape. + + :rtype: numpy.ndarray + """ + y, x = numpy.ogrid[:shape[0], :shape[1]] + y, x = y * 1.0 * (image.shape[0] - 1) / (shape[0] - 1), x * 1.0 * (image.shape[1] - 1) / (shape[1] - 1) + b = silx.image.bilinear.BilinearImage(image) + # TODO: could be optimized using strides + x2d = numpy.zeros_like(y) + x + y2d = numpy.zeros_like(x) + y + result = b.map_coordinates((y2d, x2d)) + return result + + def __createMarginImage(self, image, size, transparent=False, center=False): + """Returns a new image with margin to respect the requested size. + + :rtype: numpy.ndarray + """ + assert(image.shape[0] <= size[0]) + assert(image.shape[1] <= size[1]) + if image.shape == size: + return image + mode = self.__getImageMode(image) + + if center: + pos0 = size[0] // 2 - image.shape[0] // 2 + pos1 = size[1] // 2 - image.shape[1] // 2 + else: + pos0, pos1 = 0, 0 + + if mode == "intensity": + data = numpy.zeros(size, dtype=image.dtype) + data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1]] = image + # TODO: It is maybe possible to put NaN on the margin + else: + if transparent: + data = numpy.zeros((size[0], size[1], 4), dtype=numpy.uint8) + else: + data = numpy.zeros((size[0], size[1], 3), dtype=numpy.uint8) + depth = min(data.shape[2], image.shape[2]) + data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 0:depth] = image[:, :, 0:depth] + if transparent and depth == 3: + data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 3] = 255 + return data + + def __toAffineTransformation(self, sift_result): + """Returns an affine transformation from the sift result. + + :param dict sift_result: Result of sift when using `all_result=True` + :rtype: AffineTransformation + """ + offset = sift_result["offset"] + matrix = sift_result["matrix"] + + tx = offset[0] + ty = offset[1] + a = matrix[0, 0] + b = matrix[0, 1] + c = matrix[1, 0] + d = matrix[1, 1] + rot = math.atan2(-b, a) + sx = (-1.0 if a < 0 else 1.0) * math.sqrt(a**2 + b**2) + sy = (-1.0 if d < 0 else 1.0) * math.sqrt(c**2 + d**2) + return AffineTransformation(tx, ty, sx, sy, rot) + + def getTransformation(self): + """Retuns the affine transformation applied to the second image to align + it to the first image. + + This result is only valid for sift alignment. + + :rtype: Union[None,AffineTransformation] + """ + return self.__transformation + + def __createSiftData(self, image, second_image): + """Generate key points and aligned images from 2 images. + + If no keypoints matches, unaligned data are anyway returns. + + :rtype: Tuple(numpy.ndarray,numpy.ndarray) + """ + devicetype = "GPU" + + # Compute base image + sift_ocl = sift.SiftPlan(template=image, devicetype=devicetype) + keypoints = sift_ocl(image) + + # Check image compatibility + second_keypoints = sift_ocl(second_image) + mp = sift.MatchPlan() + match = mp(keypoints, second_keypoints) + _logger.info("Number of Keypoints within image 1: %i" % keypoints.size) + _logger.info(" within image 2: %i" % second_keypoints.size) + + self.__matching_keypoints = (match[:].x[:, 0], + match[:].y[:, 0], + match[:].scale[:, 0]) + matching_keypoints = match.shape[0] + _logger.info("Matching keypoints: %i" % matching_keypoints) + if matching_keypoints == 0: + return image, second_image + + # TODO: Problem here is we have to compute 2 time sift + # The first time to extract matching keypoints, second time + # to extract the aligned image. + + # Normalize the second image + sa = sift.LinearAlign(image, devicetype=devicetype) + data1 = image + # TODO: Create a sift issue: if data1 is RGB and data2 intensity + # it returns None, while extracting manually keypoints (above) works + result = sa.align(second_image, return_all=True) + data2 = result["result"] + self.__transformation = self.__toAffineTransformation(result) + return data1, data2 + + def setAutoResetZoom(self, activate=True): + """ + + :param bool activate: True if we want to activate the automatic + plot reset zoom when setting images. + """ + self._resetZoomActive = activate + + def isAutoResetZoom(self): + """ + + :return: True if the automatic call to resetzoom is activated + :rtype: bool + """ + return self._resetZoomActive diff --git a/src/silx/gui/plot/ComplexImageView.py b/src/silx/gui/plot/ComplexImageView.py new file mode 100644 index 0000000..4eee3b0 --- /dev/null +++ b/src/silx/gui/plot/ComplexImageView.py @@ -0,0 +1,518 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a widget to view 2D complex data. + +The :class:`ComplexImageView` widget is dedicated to visualize a single 2D dataset +of complex data. +""" + +from __future__ import absolute_import + +__authors__ = ["Vincent Favre-Nicolin", "T. Vincent"] +__license__ = "MIT" +__date__ = "24/04/2018" + + +import logging +import collections +import numpy + +from ...utils.deprecation import deprecated +from .. import qt, icons +from .PlotWindow import Plot2D +from . import items +from .items import ImageComplexData +from silx.gui.widgets.FloatEdit import FloatEdit + +_logger = logging.getLogger(__name__) + + +# Widgets + +class _AmplitudeRangeDialog(qt.QDialog): + """QDialog asking for the amplitude range to display.""" + + sigRangeChanged = qt.Signal(tuple) + """Signal emitted when the range has changed. + + It provides the new range as a 2-tuple: (max, delta) + """ + + def __init__(self, + parent=None, + amplitudeRange=None, + displayedRange=(None, 2)): + super(_AmplitudeRangeDialog, self).__init__(parent) + self.setWindowTitle('Set Displayed Amplitude Range') + + if amplitudeRange is not None: + amplitudeRange = min(amplitudeRange), max(amplitudeRange) + self._amplitudeRange = amplitudeRange + self._defaultDisplayedRange = displayedRange + + layout = qt.QFormLayout() + self.setLayout(layout) + + if self._amplitudeRange is not None: + min_, max_ = self._amplitudeRange + layout.addRow( + qt.QLabel('Data Amplitude Range: [%g, %g]' % (min_, max_))) + + self._maxLineEdit = FloatEdit(parent=self) + self._maxLineEdit.validator().setBottom(0.) + self._maxLineEdit.setAlignment(qt.Qt.AlignRight) + + self._maxLineEdit.editingFinished.connect(self._rangeUpdated) + layout.addRow('Displayed Max.:', self._maxLineEdit) + + self._autoscale = qt.QCheckBox('autoscale') + self._autoscale.toggled.connect(self._autoscaleCheckBoxToggled) + layout.addRow('', self._autoscale) + + self._deltaLineEdit = FloatEdit(parent=self) + self._deltaLineEdit.validator().setBottom(1.) + self._deltaLineEdit.setAlignment(qt.Qt.AlignRight) + self._deltaLineEdit.editingFinished.connect(self._rangeUpdated) + layout.addRow('Displayed delta (log10 unit):', self._deltaLineEdit) + + buttons = qt.QDialogButtonBox(self) + buttons.addButton(qt.QDialogButtonBox.Ok) + buttons.addButton(qt.QDialogButtonBox.Cancel) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addRow(buttons) + + # Set dialog from default values + self._resetDialogToDefault() + + self.rejected.connect(self._handleRejected) + + def _resetDialogToDefault(self): + """Set Widgets of the dialog from range information + """ + max_, delta = self._defaultDisplayedRange + + if max_ is not None: # Not in autoscale + displayedMax = max_ + elif self._amplitudeRange is not None: # Autoscale with data + displayedMax = self._amplitudeRange[1] + else: # Autoscale without data + displayedMax = '' + if displayedMax == "": + self._maxLineEdit.setText("") + else: + self._maxLineEdit.setValue(displayedMax) + self._maxLineEdit.setEnabled(max_ is not None) + + self._deltaLineEdit.setValue(delta) + + self._autoscale.setChecked(self._defaultDisplayedRange[0] is None) + + def getRangeInfo(self): + """Returns the current range as a 2-tuple (max, delta (in log10))""" + if self._autoscale.isChecked(): + max_ = None + else: + maxStr = self._maxLineEdit.text() + max_ = self._maxLineEdit.value() if maxStr else None + return max_, self._deltaLineEdit.value() if self._deltaLineEdit.text() else 2 + + def _handleRejected(self): + """Reset range info to default when rejected""" + self._resetDialogToDefault() + self._rangeUpdated() + + def _rangeUpdated(self): + """Handle QLineEdit editing finised""" + self.sigRangeChanged.emit(self.getRangeInfo()) + + def _autoscaleCheckBoxToggled(self, checked): + """Handle autoscale checkbox state changes""" + if checked: # Use default values + if self._amplitudeRange is None: + max_ = '' + else: + max_ = self._amplitudeRange[1] + if max_ == "": + self._maxLineEdit.setText("") + else: + self._maxLineEdit.setValue(max_) + self._maxLineEdit.setEnabled(not checked) + self._rangeUpdated() + + +class _ComplexDataToolButton(qt.QToolButton): + """QToolButton providing choices of complex data visualization modes + + :param parent: See :class:`QToolButton` + :param plot: The :class:`ComplexImageView` to control + """ + + _MODES = collections.OrderedDict([ + (ImageComplexData.ComplexMode.ABSOLUTE, ('math-amplitude', 'Amplitude')), + (ImageComplexData.ComplexMode.SQUARE_AMPLITUDE, + ('math-square-amplitude', 'Square amplitude')), + (ImageComplexData.ComplexMode.PHASE, ('math-phase', 'Phase')), + (ImageComplexData.ComplexMode.REAL, ('math-real', 'Real part')), + (ImageComplexData.ComplexMode.IMAGINARY, + ('math-imaginary', 'Imaginary part')), + (ImageComplexData.ComplexMode.AMPLITUDE_PHASE, + ('math-phase-color', 'Amplitude and Phase')), + (ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE, + ('math-phase-color-log', 'Log10(Amp.) and Phase')) + ]) + + _RANGE_DIALOG_TEXT = 'Set Amplitude Range...' + + def __init__(self, parent=None, plot=None): + super(_ComplexDataToolButton, self).__init__(parent=parent) + + assert plot is not None + self._plot2DComplex = plot + + menu = qt.QMenu(self) + menu.triggered.connect(self._triggered) + self.setMenu(menu) + + for mode, info in self._MODES.items(): + icon, text = info + action = qt.QAction(icons.getQIcon(icon), text, self) + action.setData(mode) + action.setIconVisibleInMenu(True) + menu.addAction(action) + + self._rangeDialogAction = qt.QAction(self) + self._rangeDialogAction.setText(self._RANGE_DIALOG_TEXT) + menu.addAction(self._rangeDialogAction) + + self.setPopupMode(qt.QToolButton.InstantPopup) + + self._modeChanged(self._plot2DComplex.getComplexMode()) + self._plot2DComplex.sigVisualizationModeChanged.connect( + self._modeChanged) + + def _modeChanged(self, mode): + """Handle change of visualization modes""" + icon, text = self._MODES[mode] + self.setIcon(icons.getQIcon(icon)) + self.setToolTip('Display the ' + text.lower()) + self._rangeDialogAction.setEnabled( + mode == ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE) + + def _triggered(self, action): + """Handle triggering of menu actions""" + actionText = action.text() + + if actionText == self._RANGE_DIALOG_TEXT: # Show dialog + # Get amplitude range + data = self._plot2DComplex.getData(copy=False) + + if data.size > 0: + absolute = numpy.absolute(data) + dataRange = (numpy.nanmin(absolute), numpy.nanmax(absolute)) + else: + dataRange = None + + # Show dialog + dialog = _AmplitudeRangeDialog( + parent=self, + amplitudeRange=dataRange, + displayedRange=self._plot2DComplex._getAmplitudeRangeInfo()) + dialog.sigRangeChanged.connect(self._rangeChanged) + dialog.exec() + dialog.sigRangeChanged.disconnect(self._rangeChanged) + + else: # update mode + mode = action.data() + if isinstance(mode, ImageComplexData.ComplexMode): + self._plot2DComplex.setComplexMode(mode) + + def _rangeChanged(self, range_): + """Handle updates of range in the dialog""" + self._plot2DComplex._setAmplitudeRangeInfo(*range_) + + +class ComplexImageView(qt.QWidget): + """Display an image of complex data and allow to choose the visualization. + + :param parent: See :class:`QMainWindow` + """ + + ComplexMode = ImageComplexData.ComplexMode + """Complex Modes enumeration""" + + sigDataChanged = qt.Signal() + """Signal emitted when data has changed.""" + + sigVisualizationModeChanged = qt.Signal(object) + """Signal emitted when the visualization mode has changed. + + It provides the new visualization mode. + """ + + def __init__(self, parent=None): + super(ComplexImageView, self).__init__(parent) + if parent is None: + self.setWindowTitle('ComplexImageView') + + self._plot2D = Plot2D(self) + + layout = qt.QHBoxLayout(self) + layout.setSpacing(0) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._plot2D) + self.setLayout(layout) + + # Create and add image to the plot + self._plotImage = ImageComplexData() + self._plotImage.setName('__ComplexImageView__complex_image__') + self._plotImage.sigItemChanged.connect(self._itemChanged) + self._plot2D.addItem(self._plotImage) + self._plot2D.setActiveImage(self._plotImage.getName()) + + toolBar = qt.QToolBar('Complex', self) + toolBar.addWidget( + _ComplexDataToolButton(parent=self, plot=self)) + + self._plot2D.insertToolBar(self._plot2D.getProfileToolbar(), toolBar) + + def _itemChanged(self, event): + """Handle item changed signal""" + if event is items.ItemChangedType.DATA: + self.sigDataChanged.emit() + elif event is items.ItemChangedType.VISUALIZATION_MODE: + mode = self.getComplexMode() + self.sigVisualizationModeChanged.emit(mode) + + def getPlot(self): + """Return the PlotWidget displaying the data""" + return self._plot2D + + def setData(self, data=None, copy=True): + """Set the complex data to display. + + :param numpy.ndarray data: 2D complex data + :param bool copy: True (default) to copy the data, + False to use provided data (do not modify!). + """ + if data is None: + data = numpy.zeros((0, 0), dtype=numpy.complex64) + + previousData = self._plotImage.getComplexData(copy=False) + + self._plotImage.setData(data, copy=copy) + + if previousData.shape != data.shape: + self.getPlot().resetZoom() + + def getData(self, copy=True): + """Get the currently displayed complex data. + + :param bool copy: True (default) to return a copy of the data, + False to return internal data (do not modify!). + :return: The complex data array. + :rtype: numpy.ndarray of complex with 2 dimensions + """ + return self._plotImage.getComplexData(copy=copy) + + def getDisplayedData(self, copy=True): + """Returns the displayed data depending on the visualization mode + + WARNING: The returned data can be a uint8 RGBA image + + :param bool copy: True (default) to return a copy of the data, + False to return internal data (do not modify!) + :rtype: numpy.ndarray of float with 2 dims or RGBA image (uint8). + """ + mode = self.getComplexMode() + if mode in (self.ComplexMode.AMPLITUDE_PHASE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE): + return self._plotImage.getRgbaImageData(copy=copy) + else: + return self._plotImage.getData(copy=copy) + + # Backward compatibility + + Mode = ComplexMode + + @classmethod + @deprecated(replacement='supportedComplexModes', since_version='0.11.0') + def getSupportedVisualizationModes(cls): + return cls.supportedComplexModes() + + @deprecated(replacement='setComplexMode', since_version='0.11.0') + def setVisualizationMode(self, mode): + return self.setComplexMode(mode) + + @deprecated(replacement='getComplexMode', since_version='0.11.0') + def getVisualizationMode(self): + return self.getComplexMode() + + # Image item proxy + + @staticmethod + def supportedComplexModes(): + """Returns the supported visualization modes. + + Supported visualization modes are: + + - amplitude: The absolute value provided by numpy.absolute + - phase: The phase (or argument) provided by numpy.angle + - real: Real part + - imaginary: Imaginary part + - amplitude_phase: Color-coded phase with amplitude as alpha. + - log10_amplitude_phase: + Color-coded phase with log10(amplitude) as alpha. + + :rtype: List[ComplexMode] + """ + return ImageComplexData.supportedComplexModes() + + def setComplexMode(self, mode): + """Set the mode of visualization of the complex data. + + See :meth:`supportedComplexModes` for the list of + supported modes. + + How-to change visualization mode:: + + widget = ComplexImageView() + widget.setComplexMode(ComplexImageView.ComplexMode.PHASE) + # or + widget.setComplexMode('phase') + + :param Unions[ComplexMode,str] mode: The mode to use. + """ + self._plotImage.setComplexMode(mode) + + def getComplexMode(self): + """Get the current visualization mode of the complex data. + + :rtype: ComplexMode + """ + return self._plotImage.getComplexMode() + + def _setAmplitudeRangeInfo(self, max_=None, delta=2): + """Set the amplitude range to display for 'log10_amplitude_phase' mode. + + :param max_: Max of the amplitude range. + If None it autoscales to data max. + :param float delta: Delta range in log10 to display + """ + self._plotImage._setAmplitudeRangeInfo(max_, delta) + + def _getAmplitudeRangeInfo(self): + """Returns the amplitude range to use for 'log10_amplitude_phase' mode. + + :return: (max, delta), if max is None, then it autoscales to data max + :rtype: 2-tuple""" + return self._plotImage._getAmplitudeRangeInfo() + + def setColormap(self, colormap, mode=None): + """Set the colormap to use for amplitude, phase, real or imaginary. + + WARNING: This colormap is not used when displaying both + amplitude and phase. + + :param ~silx.gui.colors.Colormap colormap: The colormap + :param ComplexMode mode: If specified, set the colormap of this specific mode + """ + self._plotImage.setColormap(colormap, mode) + + def getColormap(self, mode=None): + """Returns the colormap used to display the data. + + :param ComplexMode mode: If specified, set the colormap of this specific mode + :rtype: ~silx.gui.colors.Colormap + """ + return self._plotImage.getColormap(mode=mode) + + def getOrigin(self): + """Returns the offset from origin at which to display the image. + + :rtype: 2-tuple of float + """ + return self._plotImage.getOrigin() + + def setOrigin(self, origin): + """Set the offset from origin at which to display the image. + + :param origin: (ox, oy) Offset from origin + :type origin: float or 2-tuple of float + """ + self._plotImage.setOrigin(origin) + + def getScale(self): + """Returns the scale of the image in data coordinates. + + :rtype: 2-tuple of float + """ + return self._plotImage.getScale() + + def setScale(self, scale): + """Set the scale of the image + + :param scale: (sx, sy) Scale of the image + :type scale: float or 2-tuple of float + """ + self._plotImage.setScale(scale) + + # PlotWidget API proxy + + def getXAxis(self): + """Returns the X axis + + :rtype: :class:`.items.Axis` + """ + return self.getPlot().getXAxis() + + def getYAxis(self): + """Returns an Y axis + + :rtype: :class:`.items.Axis` + """ + return self.getPlot().getYAxis(axis='left') + + def getGraphTitle(self): + """Return the plot main title as a str.""" + return self.getPlot().getGraphTitle() + + def setGraphTitle(self, title=""): + """Set the plot main title. + + :param str title: Main title of the plot (default: '') + """ + self.getPlot().setGraphTitle(title) + + def setKeepDataAspectRatio(self, flag): + """Set whether the plot keeps data aspect ratio or not. + + :param bool flag: True to respect data aspect ratio + """ + self.getPlot().setKeepDataAspectRatio(flag) + + def isKeepDataAspectRatio(self): + """Returns whether the plot is keeping data aspect ratio or not.""" + return self.getPlot().isKeepDataAspectRatio() diff --git a/src/silx/gui/plot/CurvesROIWidget.py b/src/silx/gui/plot/CurvesROIWidget.py new file mode 100644 index 0000000..132d398 --- /dev/null +++ b/src/silx/gui/plot/CurvesROIWidget.py @@ -0,0 +1,1581 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +Widget to handle regions of interest (:class:`ROI`) on curves displayed in a +:class:`PlotWindow`. + +This widget is meant to work with :class:`PlotWindow`. +""" + +__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"] +__license__ = "MIT" +__date__ = "13/03/2018" + +from collections import OrderedDict +import logging +import os +import sys +import functools +import numpy +from silx.io import dictdump +from silx.utils import deprecation +from silx.utils.weakref import WeakMethodProxy +from silx.utils.proxy import docstring +from .. import icons, qt +from silx.math.combo import min_max +import weakref +from silx.gui.widgets.TableWidget import TableWidget +from . import items +from .items.roi import _RegionOfInterestBase + + +_logger = logging.getLogger(__name__) + + +class CurvesROIWidget(qt.QWidget): + """ + Widget displaying a table of ROI information. + + Implements also the following behavior: + + * if the roiTable has no ROI when showing create the default ICR one + + :param parent: See :class:`QWidget` + :param str name: The title of this widget + """ + + sigROIWidgetSignal = qt.Signal(object) + """Signal of ROIs modifications. + + Modification information if given as a dict with an 'event' key + providing the type of events. + + Type of events: + + - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict' + - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader', + 'rowheader' + """ + + sigROISignal = qt.Signal(object) + + def __init__(self, parent=None, name=None, plot=None): + super(CurvesROIWidget, self).__init__(parent) + if name is not None: + self.setWindowTitle(name) + self.__lastSigROISignal = None + """Store the last value emitted for the sigRoiSignal. In the case the + active curve change we need to add this extra step in order to make + sure we won't send twice the sigROISignal. + This come from the fact sigROISignal is connected to the + activeROIChanged signal which is emitted when raw and net counts + values are changing but are not embed in the sigROISignal. + """ + assert plot is not None + self._plotRef = weakref.ref(plot) + self._showAllMarkers = False + self.currentROI = None + + layout = qt.QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + self.headerLabel = qt.QLabel(self) + self.headerLabel.setAlignment(qt.Qt.AlignHCenter) + self.setHeader() + layout.addWidget(self.headerLabel) + + widgetAllCheckbox = qt.QWidget(parent=self) + self._showAllCheckBox = qt.QCheckBox("show all ROI", + parent=widgetAllCheckbox) + widgetAllCheckbox.setLayout(qt.QHBoxLayout()) + spacer = qt.QWidget(parent=widgetAllCheckbox) + spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed) + widgetAllCheckbox.layout().addWidget(spacer) + widgetAllCheckbox.layout().addWidget(self._showAllCheckBox) + layout.addWidget(widgetAllCheckbox) + + self.roiTable = ROITable(self, plot=plot) + rheight = self.roiTable.horizontalHeader().sizeHint().height() + self.roiTable.setMinimumHeight(4 * rheight) + layout.addWidget(self.roiTable) + self._roiFileDir = qt.QDir.home().absolutePath() + self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers) + + hbox = qt.QWidget(self) + hboxlayout = qt.QHBoxLayout(hbox) + hboxlayout.setContentsMargins(0, 0, 0, 0) + hboxlayout.setSpacing(0) + + hboxlayout.addStretch(0) + + self.addButton = qt.QPushButton(hbox) + self.addButton.setText("Add ROI") + self.addButton.setToolTip('Create a new ROI') + self.delButton = qt.QPushButton(hbox) + self.delButton.setText("Delete ROI") + self.addButton.setToolTip('Remove the selected ROI') + self.resetButton = qt.QPushButton(hbox) + self.resetButton.setText("Reset") + self.addButton.setToolTip('Clear all created ROIs. We only let the ' + 'default ROI') + + hboxlayout.addWidget(self.addButton) + hboxlayout.addWidget(self.delButton) + hboxlayout.addWidget(self.resetButton) + + hboxlayout.addStretch(0) + + self.loadButton = qt.QPushButton(hbox) + self.loadButton.setText("Load") + self.loadButton.setToolTip('Load ROIs from a .ini file') + self.saveButton = qt.QPushButton(hbox) + self.saveButton.setText("Save") + self.loadButton.setToolTip('Save ROIs to a .ini file') + hboxlayout.addWidget(self.loadButton) + hboxlayout.addWidget(self.saveButton) + layout.setStretchFactor(self.headerLabel, 0) + layout.setStretchFactor(self.roiTable, 1) + layout.setStretchFactor(hbox, 0) + + layout.addWidget(hbox) + + # Signal / Slot connections + self.addButton.clicked.connect(self._add) + self.delButton.clicked.connect(self._del) + self.resetButton.clicked.connect(self._reset) + + self.loadButton.clicked.connect(self._load) + self.saveButton.clicked.connect(self._save) + + self.roiTable.activeROIChanged.connect(self._emitCurrentROISignal) + + self._isConnected = False # True if connected to plot signals + self._isInit = False + + # expose API + self.getROIListAndDict = self.roiTable.getROIListAndDict + + def getPlotWidget(self): + """Returns the associated PlotWidget or None + + :rtype: Union[~silx.gui.plot.PlotWidget,None] + """ + return None if self._plotRef is None else self._plotRef() + + def showEvent(self, event): + self._visibilityChangedHandler(visible=True) + qt.QWidget.showEvent(self, event) + + @property + def roiFileDir(self): + """The directory from which to load/save ROI from/to files.""" + if not os.path.isdir(self._roiFileDir): + self._roiFileDir = qt.QDir.home().absolutePath() + return self._roiFileDir + + @roiFileDir.setter + def roiFileDir(self, roiFileDir): + self._roiFileDir = str(roiFileDir) + + def setRois(self, rois, order=None): + return self.roiTable.setRois(rois, order) + + def getRois(self, order=None): + return self.roiTable.getRois(order) + + def setMiddleROIMarkerFlag(self, flag=True): + return self.roiTable.setMiddleROIMarkerFlag(flag) + + def _add(self): + """Add button clicked handler""" + def getNextRoiName(): + rois = self.roiTable.getRois(order=None) + roisNames = [] + [roisNames.append(roiName) for roiName in rois] + nrois = len(rois) + if nrois == 0: + return "ICR" + else: + i = 1 + newroi = "newroi %d" % i + while newroi in roisNames: + i += 1 + newroi = "newroi %d" % i + return newroi + roi = ROI(name=getNextRoiName()) + + if roi.getName() == "ICR": + roi.setType("Default") + else: + roi.setType(self.getPlotWidget().getXAxis().getLabel()) + + xmin, xmax = self.getPlotWidget().getXAxis().getLimits() + fromdata = xmin + 0.25 * (xmax - xmin) + todata = xmin + 0.75 * (xmax - xmin) + if roi.isICR(): + fromdata, dummy0, todata, dummy1 = self._getAllLimits() + roi.setFrom(fromdata) + roi.setTo(todata) + self.roiTable.addRoi(roi) + + # back compatibility pymca roi signals + ddict = {} + ddict['event'] = "AddROI" + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict + self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals + + def _del(self): + """Delete button clicked handler""" + self.roiTable.deleteActiveRoi() + + # back compatibility pymca roi signals + ddict = {} + ddict['event'] = "DelROI" + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict + self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals + + def _reset(self): + """Reset button clicked handler""" + self.roiTable.clear() + old = self.blockSignals(True) # avoid several sigROISignal emission + self._add() + self.blockSignals(old) + + # back compatibility pymca roi signals + ddict = {} + ddict['event'] = "ResetROI" + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict + self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals + + def _load(self): + """Load button clicked handler""" + dialog = qt.QFileDialog(self) + dialog.setNameFilters( + ['INI File *.ini', 'JSON File *.json', 'All *.*']) + dialog.setFileMode(qt.QFileDialog.ExistingFile) + dialog.setDirectory(self.roiFileDir) + if not dialog.exec(): + dialog.close() + return + + # pyflakes bug http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=666494 + outputFile = dialog.selectedFiles()[0] + dialog.close() + + self.roiFileDir = os.path.dirname(outputFile) + self.roiTable.load(outputFile) + + # back compatibility pymca roi signals + ddict = {} + ddict['event'] = "LoadROI" + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict + self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals + + def load(self, filename): + """Load ROI widget information from a file storing a dict of ROI. + + :param str filename: The file from which to load ROI + """ + self.roiTable.load(filename) + + def _save(self): + """Save button clicked handler""" + dialog = qt.QFileDialog(self) + dialog.setNameFilters(['INI File *.ini', 'JSON File *.json']) + dialog.setFileMode(qt.QFileDialog.AnyFile) + dialog.setAcceptMode(qt.QFileDialog.AcceptSave) + dialog.setDirectory(self.roiFileDir) + if not dialog.exec(): + dialog.close() + return + + outputFile = dialog.selectedFiles()[0] + extension = '.' + dialog.selectedNameFilter().split('.')[-1] + dialog.close() + + if not outputFile.endswith(extension): + outputFile += extension + + if os.path.exists(outputFile): + try: + os.remove(outputFile) + except IOError: + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Input Output Error: %s" % (sys.exc_info()[1])) + msg.exec() + return + self.roiFileDir = os.path.dirname(outputFile) + self.save(outputFile) + + def save(self, filename): + """Save current ROIs of the widget as a dict of ROI to a file. + + :param str filename: The file to which to save the ROIs + """ + self.roiTable.save(filename) + + def setHeader(self, text='ROIs'): + """Set the header text of this widget""" + self.headerLabel.setText("<b>%s<\b>" % text) + + @deprecation.deprecated(replacement="calculateRois", + reason="CamelCase convention", + since_version="0.7") + def calculateROIs(self, *args, **kw): + self.calculateRois(*args, **kw) + + def calculateRois(self, roiList=None, roiDict=None): + """Compute ROI information""" + return self.roiTable.calculateRois() + + def showAllMarkers(self, _show=True): + self.roiTable.showAllMarkers(_show) + + def _getAllLimits(self): + """Retrieve the limits based on the curves.""" + plot = self.getPlotWidget() + curves = () if plot is None else plot.getAllCurves() + if not curves: + return 1.0, 1.0, 100., 100. + + xmin, ymin = None, None + xmax, ymax = None, None + + for curve in curves: + x = curve.getXData(copy=False) + y = curve.getYData(copy=False) + if xmin is None: + xmin = x.min() + else: + xmin = min(xmin, x.min()) + if xmax is None: + xmax = x.max() + else: + xmax = max(xmax, x.max()) + if ymin is None: + ymin = y.min() + else: + ymin = min(ymin, y.min()) + if ymax is None: + ymax = y.max() + else: + ymax = max(ymax, y.max()) + + return xmin, ymin, xmax, ymax + + def showEvent(self, event): + self._visibilityChangedHandler(visible=True) + qt.QWidget.showEvent(self, event) + + def hideEvent(self, event): + self._visibilityChangedHandler(visible=False) + qt.QWidget.hideEvent(self, event) + + def _visibilityChangedHandler(self, visible): + """Handle widget's visibility updates. + + It is connected to plot signals only when visible. + """ + if visible: + # if no ROI existing yet, add the default one + if self.roiTable.rowCount() == 0: + old = self.blockSignals(True) # avoid several sigROISignal emission + self._add() + self.blockSignals(old) + self.calculateRois() + + def fillFromROIDict(self, *args, **kwargs): + self.roiTable.fillFromROIDict(*args, **kwargs) + + def _emitCurrentROISignal(self): + ddict = {} + ddict['event'] = "currentROISignal" + if self.roiTable.activeRoi is not None: + ddict['ROI'] = self.roiTable.activeRoi.toDict() + ddict['current'] = self.roiTable.activeRoi.getName() + else: + ddict['current'] = None + + if self.__lastSigROISignal != ddict: + self.__lastSigROISignal = ddict + self.sigROISignal.emit(ddict) + + @property + def currentRoi(self): + return self.roiTable.activeRoi + + +class _FloatItem(qt.QTableWidgetItem): + """ + Simple QTableWidgetItem overloading the < operator to deal with ordering + """ + def __init__(self): + qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type) + + def __lt__(self, other): + if self.text() in ('', ROITable.INFO_NOT_FOUND): + return False + if other.text() in ('', ROITable.INFO_NOT_FOUND): + return True + return float(self.text()) < float(other.text()) + + +class ROITable(TableWidget): + """Table widget displaying ROI information. + + See :class:`QTableWidget` for constructor arguments. + + Behavior: listen at the active curve changed only when the widget is + visible. Otherwise won't compute the row and net counts... + """ + + activeROIChanged = qt.Signal() + """Signal emitted when the active roi changed or when the value of the + active roi are changing""" + + COLUMNS_INDEX = OrderedDict([ + ('ID', 0), + ('ROI', 1), + ('Type', 2), + ('From', 3), + ('To', 4), + ('Raw Counts', 5), + ('Net Counts', 6), + ('Raw Area', 7), + ('Net Area', 8), + ]) + + COLUMNS = list(COLUMNS_INDEX.keys()) + + INFO_NOT_FOUND = '????????' + + def __init__(self, parent=None, plot=None, rois=None): + super(ROITable, self).__init__(parent) + self._showAllMarkers = False + self._userIsEditingRoi = False + """bool used to avoid conflict when editing the ROI object""" + self._isConnected = False + self._roiToItems = {} + self._roiDict = {} + """dict of ROI object. Key is ROi id, value is the ROI object""" + self._markersHandler = _RoiMarkerManager() + + """ + Associate for each marker legend used when the `_showAllMarkers` option + is active a roi. + """ + self.setColumnCount(len(self.COLUMNS)) + self.setPlot(plot) + self.__setTooltip() + self.setSortingEnabled(True) + self.itemChanged.connect(self._itemChanged) + + @property + def roidict(self): + return self._getRoiDict() + + @property + def activeRoi(self): + return self._markersHandler._activeRoi + + def _getRoiDict(self): + ddict = {} + for id in self._roiDict: + ddict[self._roiDict[id].getName()] = self._roiDict[id] + return ddict + + def clear(self): + """ + .. note:: clear the interface only. keep the roidict... + """ + self._markersHandler.clear() + self._roiToItems = {} + self._roiDict = {} + + qt.QTableWidget.clear(self) + self.setRowCount(0) + self.setHorizontalHeaderLabels(self.COLUMNS) + header = self.horizontalHeader() + header.setSectionResizeMode(qt.QHeaderView.ResizeToContents) + self.sortByColumn(0, qt.Qt.AscendingOrder) + self.hideColumn(self.COLUMNS_INDEX['ID']) + + def setPlot(self, plot): + self.clear() + self.plot = plot + + def __setTooltip(self): + self.horizontalHeaderItem(self.COLUMNS_INDEX['ROI']).setToolTip( + 'Region of interest identifier') + self.horizontalHeaderItem(self.COLUMNS_INDEX['Type']).setToolTip( + 'Type of the ROI') + self.horizontalHeaderItem(self.COLUMNS_INDEX['From']).setToolTip( + 'X-value of the min point') + self.horizontalHeaderItem(self.COLUMNS_INDEX['To']).setToolTip( + 'X-value of the max point') + self.horizontalHeaderItem(self.COLUMNS_INDEX['Raw Counts']).setToolTip( + 'Estimation of the integral between y=0 and the selected curve') + self.horizontalHeaderItem(self.COLUMNS_INDEX['Net Counts']).setToolTip( + 'Estimation of the integral between the segment [maxPt, minPt] ' + 'and the selected curve') + + def setRois(self, rois, order=None): + """Set the ROIs by providing a dictionary of ROI information. + + The dictionary keys are the ROI names. + Each value is a sub-dictionary of ROI info with the following fields: + + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + + + :param roidict: Dictionary of ROIs + :param str order: Field used for ordering the ROIs. + One of "from", "to", "type". + None (default) for no ordering, or same order as specified + in parameter ``roidict`` if provided as an OrderedDict. + """ + assert order in [None, "from", "to", "type"] + self.clear() + + # backward compatibility since 0.10.0 + if isinstance(rois, dict): + for roiName, roi in rois.items(): + if isinstance(roi, ROI): + _roi = roi + else: + roi['name'] = roiName + _roi = ROI._fromDict(roi) + self.addRoi(_roi) + else: + for roi in rois: + assert isinstance(roi, ROI) + self.addRoi(roi) + self._updateMarkers() + + def addRoi(self, roi): + """ + + :param :class:`ROI` roi: roi to add to the table + """ + assert isinstance(roi, ROI) + self._getItem(name='ID', row=None, roi=roi) + self._roiDict[roi.getID()] = roi + self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot)) + self._updateRoiInfo(roi.getID()) + callback = functools.partial(WeakMethodProxy(self._updateRoiInfo), + roi.getID()) + roi.sigChanged.connect(callback) + # set it as the active one + self.setActiveRoi(roi) + + def _getItem(self, name, row, roi): + if row: + item = self.item(row, self.COLUMNS_INDEX[name]) + else: + item = None + if item: + return item + else: + if name == 'ID': + assert roi + if roi.getID() in self._roiToItems: + return self._roiToItems[roi.getID()] + else: + # create a new row + row = self.rowCount() + self.setRowCount(self.rowCount() + 1) + item = qt.QTableWidgetItem(str(roi.getID()), + type=qt.QTableWidgetItem.Type) + self._roiToItems[roi.getID()] = item + elif name == 'ROI': + item = qt.QTableWidgetItem(roi.getName() if roi else '', + type=qt.QTableWidgetItem.Type) + if roi.getName().upper() in ('ICR', 'DEFAULT'): + item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled) + else: + item.setFlags(qt.Qt.ItemIsSelectable | + qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsEditable) + elif name == 'Type': + item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type) + item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)) + elif name in ('To', 'From'): + item = _FloatItem() + if roi.getName().upper() in ('ICR', 'DEFAULT'): + item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled) + else: + item.setFlags(qt.Qt.ItemIsSelectable | + qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsEditable) + elif name in ('Raw Counts', 'Net Counts', 'Raw Area', 'Net Area'): + item = _FloatItem() + item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)) + else: + raise ValueError('item type not recognized') + + self.setItem(row, self.COLUMNS_INDEX[name], item) + return item + + def _itemChanged(self, item): + def getRoi(): + IDItem = self.item(item.row(), self.COLUMNS_INDEX['ID']) + assert IDItem + id = int(IDItem.text()) + assert id in self._roiDict + roi = self._roiDict[id] + return roi + + def signalChanged(roi): + if self.activeRoi and roi.getID() == self.activeRoi.getID(): + self.activeROIChanged.emit() + + self._userIsEditingRoi = True + if item.column() in (self.COLUMNS_INDEX['To'], self.COLUMNS_INDEX['From']): + roi = getRoi() + + if item.text() not in ('', self.INFO_NOT_FOUND): + try: + value = float(item.text()) + except ValueError: + value = 0 + changed = False + if item.column() == self.COLUMNS_INDEX['To']: + if value != roi.getTo(): + roi.setTo(value) + changed = True + else: + assert(item.column() == self.COLUMNS_INDEX['From']) + if value != roi.getFrom(): + roi.setFrom(value) + changed = True + if changed: + self._updateMarker(roi.getName()) + signalChanged(roi) + + if item.column() is self.COLUMNS_INDEX['ROI']: + roi = getRoi() + if roi.getName() != item.text(): + roi.setName(item.text()) + self._markersHandler.getMarkerHandler(roi.getID()).updateTexts() + signalChanged(roi) + + self._userIsEditingRoi = False + + def deleteActiveRoi(self): + """ + remove the current active roi + """ + activeItems = self.selectedItems() + if len(activeItems) == 0: + return + old = self.blockSignals(True) # avoid several emission of sigROISignal + roiToRm = set() + for item in activeItems: + row = item.row() + itemID = self.item(row, self.COLUMNS_INDEX['ID']) + roiToRm.add(self._roiDict[int(itemID.text())]) + [self.removeROI(roi) for roi in roiToRm] + self.blockSignals(old) + self.setActiveRoi(None) + + def removeROI(self, roi): + """ + remove the requested roi + + :param str name: the name of the roi to remove from the table + """ + if roi and roi.getID() in self._roiToItems: + item = self._roiToItems[roi.getID()] + self.removeRow(item.row()) + del self._roiToItems[roi.getID()] + + assert roi.getID() in self._roiDict + del self._roiDict[roi.getID()] + self._markersHandler.remove(roi) + + callback = functools.partial(WeakMethodProxy(self._updateRoiInfo), + roi.getID()) + roi.sigChanged.connect(callback) + + def setActiveRoi(self, roi): + """ + Define the given roi as the active one. + + .. warning:: this roi should already be registred / added to the table + + :param :class:`ROI` roi: the roi to defined as active + """ + if roi is None: + self.clearSelection() + self._markersHandler.setActiveRoi(None) + self.activeROIChanged.emit() + else: + assert isinstance(roi, ROI) + if roi and roi.getID() in self._roiToItems.keys(): + # avoid several call back to setActiveROI + old = self.blockSignals(True) + self.selectRow(self._roiToItems[roi.getID()].row()) + self.blockSignals(old) + self._markersHandler.setActiveRoi(roi) + self.activeROIChanged.emit() + + def _updateRoiInfo(self, roiID): + if self._userIsEditingRoi is True: + return + if roiID not in self._roiDict: + return + roi = self._roiDict[roiID] + if roi.isICR(): + activeCurve = self.plot.getActiveCurve() + if activeCurve: + xData = activeCurve.getXData() + if len(xData) > 0: + min, max = min_max(xData) + roi.blockSignals(True) + roi.setFrom(min) + roi.setTo(max) + roi.blockSignals(False) + + itemID = self._getItem(name='ID', roi=roi, row=None) + itemName = self._getItem(name='ROI', row=itemID.row(), roi=roi) + itemName.setText(roi.getName()) + + itemType = self._getItem(name='Type', row=itemID.row(), roi=roi) + itemType.setText(roi.getType() or self.INFO_NOT_FOUND) + + itemFrom = self._getItem(name='From', row=itemID.row(), roi=roi) + fromdata = str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND + itemFrom.setText(fromdata) + + itemTo = self._getItem(name='To', row=itemID.row(), roi=roi) + todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND + itemTo.setText(todata) + + rawCounts, netCounts = roi.computeRawAndNetCounts( + curve=self.plot.getActiveCurve(just_legend=False)) + itemRawCounts = self._getItem(name='Raw Counts', row=itemID.row(), + roi=roi) + rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND + itemRawCounts.setText(rawCounts) + + itemNetCounts = self._getItem(name='Net Counts', row=itemID.row(), + roi=roi) + netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND + itemNetCounts.setText(netCounts) + + rawArea, netArea = roi.computeRawAndNetArea( + curve=self.plot.getActiveCurve(just_legend=False)) + itemRawArea = self._getItem(name='Raw Area', row=itemID.row(), + roi=roi) + rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND + itemRawArea.setText(rawArea) + + itemNetArea = self._getItem(name='Net Area', row=itemID.row(), + roi=roi) + netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND + itemNetArea.setText(netArea) + + if self.activeRoi and roi.getID() == self.activeRoi.getID(): + self.activeROIChanged.emit() + + def currentChanged(self, current, previous): + if previous and current.row() != previous.row() and current.row() >= 0: + roiItem = self.item(current.row(), + self.COLUMNS_INDEX['ID']) + + assert roiItem + self.setActiveRoi(self._roiDict[int(roiItem.text())]) + self._markersHandler.updateAllMarkers() + qt.QTableWidget.currentChanged(self, current, previous) + + @deprecation.deprecated(reason="Removed", + replacement="roidict and roidict.values()", + since_version="0.10.0") + def getROIListAndDict(self): + """ + + :return: the list of roi objects and the dictionary of roi name to roi + object. + """ + roidict = self._roiDict + return list(roidict.values()), roidict + + def calculateRois(self, roiList=None, roiDict=None): + """ + Update values of all registred rois (raw and net counts in particular) + + :param roiList: deprecated parameter + :param roiDict: deprecated parameter + """ + if roiDict: + deprecation.deprecated_warning(name='roiDict', type_='Parameter', + reason='Unused parameter', + since_version="0.10.0") + if roiList: + deprecation.deprecated_warning(name='roiList', type_='Parameter', + reason='Unused parameter', + since_version="0.10.0") + + for roiID in self._roiDict: + self._updateRoiInfo(roiID) + + def _updateMarker(self, roiID): + """Make sure the marker of the given roi name is updated""" + if self._showAllMarkers or (self.activeRoi + and self.activeRoi.getName() == roiID): + self._updateMarkers() + + def _updateMarkers(self): + if self._showAllMarkers is True: + self._markersHandler.updateMarkers() + else: + if not self.activeRoi or not self.plot: + return + assert isinstance(self.activeRoi, ROI) + markerHandler = self._markersHandler.getMarkerHandler(self.activeRoi.getID()) + if markerHandler is not None: + markerHandler.updateMarkers() + + def getRois(self, order): + """ + Return the currently defined ROIs, as an ordered dict. + + The dictionary keys are the ROI names. + Each value is a :class:`ROI` object.. + + :param order: Field used for ordering the ROIs. + One of "from", "to", "type", "netcounts", "rawcounts". + None (default) to get the same order as displayed in the widget. + :return: Ordered dictionary of ROI information + """ + + if order is None or order.lower() == "none": + ordered_roilist = list(self._roiDict.values()) + res = OrderedDict([(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist]) + else: + assert order in ["from", "to", "type", "netcounts", "rawcounts"] + ordered_roilist = sorted(self._roiDict.keys(), + key=lambda roi_id: self._roiDict[roi_id].get(order)) + res = OrderedDict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist]) + + return res + + def save(self, filename): + """ + Save current ROIs of the widget as a dict of ROI to a file. + + :param str filename: The file to which to save the ROIs + """ + roilist = [] + roidict = {} + for roiID, roi in self._roiDict.items(): + roilist.append(roi.toDict()) + roidict[roi.getName()] = roi.toDict() + datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}} + dictdump.dump(datadict, filename) + + def load(self, filename): + """ + Load ROI widget information from a file storing a dict of ROI. + + :param str filename: The file from which to load ROI + """ + roisDict = dictdump.load(filename) + rois = [] + + # Remove rawcounts and netcounts from ROIs + for roiDict in roisDict['ROI']['roidict'].values(): + roiDict.pop('rawcounts', None) + roiDict.pop('netcounts', None) + rois.append(ROI._fromDict(roiDict)) + + self.setRois(rois) + + def showAllMarkers(self, _show=True): + """ + + :param bool _show: if true show all the markers of all the ROIs + boundaries otherwise will only show the one of + the active ROI. + """ + self._markersHandler.setShowAllMarkers(_show) + + def setMiddleROIMarkerFlag(self, flag=True): + """ + Activate or deactivate middle marker. + + This allows shifting both min and max limits at once, by dragging + a marker located in the middle. + + :param bool flag: True to activate middle ROI marker + """ + self._markersHandler._middleROIMarkerFlag = flag + + def _handleROIMarkerEvent(self, ddict): + """Handle plot signals related to marker events.""" + if ddict['event'] == 'markerMoved': + label = ddict['label'] + roiID = self._markersHandler.getRoiID(markerID=label) + if roiID is not None: + # avoid several emission of sigROISignal + old = self.blockSignals(True) + self._markersHandler.changePosition(markerID=label, + x=ddict['x']) + self.blockSignals(old) + self._updateRoiInfo(roiID) + + def showEvent(self, event): + self._visibilityChangedHandler(visible=True) + qt.QWidget.showEvent(self, event) + + def hideEvent(self, event): + self._visibilityChangedHandler(visible=False) + qt.QWidget.hideEvent(self, event) + + def _visibilityChangedHandler(self, visible): + """Handle widget's visibility updates. + + It is connected to plot signals only when visible. + """ + if visible: + assert self.plot + if self._isConnected is False: + self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent) + self.plot.sigActiveCurveChanged.connect(self._activeCurveChanged) + self._isConnected = True + self.calculateRois() + else: + if self._isConnected: + self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent) + self.plot.sigActiveCurveChanged.disconnect(self._activeCurveChanged) + self._isConnected = False + + def _activeCurveChanged(self, curve): + self.calculateRois() + + def setCountsVisible(self, visible): + """ + Display the columns relative to areas or not + + :param bool visible: True if the columns 'Raw Area' and 'Net Area' + should be visible. + """ + if visible is True: + self.showColumn(self.COLUMNS_INDEX['Raw Counts']) + self.showColumn(self.COLUMNS_INDEX['Net Counts']) + else: + self.hideColumn(self.COLUMNS_INDEX['Raw Counts']) + self.hideColumn(self.COLUMNS_INDEX['Net Counts']) + + def setAreaVisible(self, visible): + """ + Display the columns relative to areas or not + + :param bool visible: True if the columns 'Raw Area' and 'Net Area' + should be visible. + """ + if visible is True: + self.showColumn(self.COLUMNS_INDEX['Raw Area']) + self.showColumn(self.COLUMNS_INDEX['Net Area']) + else: + self.hideColumn(self.COLUMNS_INDEX['Raw Area']) + self.hideColumn(self.COLUMNS_INDEX['Net Area']) + + def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None): + """ + This function API is kept for compatibility. + But `setRois` should be preferred. + + Set the ROIs by providing a list of ROI names and a dictionary + of ROI information for each ROI. + The ROI names must match an existing dictionary key. + The name list is used to provide an order for the ROIs. + The dictionary's values are sub-dictionaries containing 3 + mandatory fields: + + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + + :param roilist: List of ROI names (keys of roidict) + :type roilist: List + :param dict roidict: Dict of ROI information + :param currentroi: Name of the selected ROI or None (no selection) + """ + if roidict is not None: + self.setRois(roidict) + else: + self.setRois(roilist) + if currentroi: + self.setActiveRoi(currentroi) + + +_indexNextROI = 0 + + +class ROI(_RegionOfInterestBase): + """The Region Of Interest is defined by: + + - A name + - A type. The type is the label of the x axis. This can be used to apply or + not some ROI to a curve and do some post processing. + - The x coordinate of the left limit (fromdata) + - The x coordinate of the right limit (todata) + + :param str: name of the ROI + :param fromdata: left limit of the roi + :param todata: right limit of the roi + :param type: type of the ROI + """ + + sigChanged = qt.Signal() + """Signal emitted when the ROI is edited""" + + def __init__(self, name, fromdata=None, todata=None, type_=None): + _RegionOfInterestBase.__init__(self) + self.setName(name) + global _indexNextROI + self._id = _indexNextROI + _indexNextROI += 1 + + self._fromdata = fromdata + self._todata = todata + self._type = type_ or 'Default' + + self.sigItemChanged.connect(self.__itemChanged) + + def __itemChanged(self, event): + """Handle name change""" + if event == items.ItemChangedType.NAME: + self.sigChanged.emit() + + def getID(self): + """ + + :return int: the unique ID of the ROI + """ + return self._id + + def setType(self, type_): + """ + + :param str type_: + """ + if self._type != type_: + self._type = type_ + self.sigChanged.emit() + + def getType(self): + """ + + :return str: the type of the ROI. + """ + return self._type + + def setFrom(self, frm): + """ + + :param frm: set x coordinate of the left limit + """ + if self._fromdata != frm: + self._fromdata = frm + self.sigChanged.emit() + + def getFrom(self): + """ + + :return: x coordinate of the left limit + """ + return self._fromdata + + def setTo(self, to): + """ + + :param to: x coordinate of the right limit + """ + if self._todata != to: + self._todata = to + self.sigChanged.emit() + + def getTo(self): + """ + + :return: x coordinate of the right limit + """ + return self._todata + + def getMiddle(self): + """ + + :return: middle position between 'from' and 'to' values + """ + return 0.5 * (self.getFrom() + self.getTo()) + + def toDict(self): + """ + + :return: dict containing the roi parameters + """ + ddict = { + 'type': self._type, + 'name': self.getName(), + 'from': self._fromdata, + 'to': self._todata, + } + if hasattr(self, '_extraInfo'): + ddict.update(self._extraInfo) + return ddict + + @staticmethod + def _fromDict(dic): + assert 'name' in dic + roi = ROI(name=dic['name']) + roi._extraInfo = {} + for key in dic: + if key == 'from': + roi.setFrom(dic['from']) + elif key == 'to': + roi.setTo(dic['to']) + elif key == 'type': + roi.setType(dic['type']) + else: + roi._extraInfo[key] = dic[key] + + return roi + + def isICR(self): + """ + + :return: True if the ROI is the `ICR` + """ + return self.getName() == 'ICR' + + def computeRawAndNetCounts(self, curve): + """Compute the Raw and net counts in the ROI for the given curve. + + - Raw count: Points values sum of the curve in the defined Region Of + Interest. + + .. image:: img/rawCounts.png + + - Net count: Raw counts minus background + + .. image:: img/netCounts.png + + :param CurveItem curve: + :return tuple: rawCount, netCount + """ + assert isinstance(curve, items.Curve) or curve is None + + if curve is None: + return None, None + + x = curve.getXData(copy=False) + y = curve.getYData(copy=False) + + idx = numpy.nonzero((self._fromdata <= x) & + (x <= self._todata))[0] + if len(idx): + xw = x[idx] + yw = y[idx] + rawCounts = yw.sum(dtype=numpy.float64) + deltaX = xw[-1] - xw[0] + deltaY = yw[-1] - yw[0] + if deltaX > 0.0: + slope = (deltaY / deltaX) + background = yw[0] + slope * (xw - xw[0]) + netCounts = (rawCounts - + background.sum(dtype=numpy.float64)) + else: + netCounts = 0.0 + else: + rawCounts = 0.0 + netCounts = 0.0 + return rawCounts, netCounts + + def computeRawAndNetArea(self, curve): + """Compute the Raw and net counts in the ROI for the given curve. + + - Raw area: integral of the curve between the min ROI point and the + max ROI point to the y = 0 line. + + .. image:: img/rawArea.png + + - Net area: Raw counts minus background + + .. image:: img/netArea.png + + :param CurveItem curve: + :return tuple: rawArea, netArea + """ + assert isinstance(curve, items.Curve) or curve is None + + if curve is None: + return None, None + + x = curve.getXData(copy=False) + y = curve.getYData(copy=False) + + y = y[(x >= self._fromdata) & (x <= self._todata)] + x = x[(x >= self._fromdata) & (x <= self._todata)] + + if x.size == 0: + return 0.0, 0.0 + + rawArea = numpy.trapz(y, x=x) + # to speed up and avoid an intersection calculation we are taking the + # closest index to the ROI + closestXLeftIndex = (numpy.abs(x - self.getFrom())).argmin() + closestXRightIndex = (numpy.abs(x - self.getTo())).argmin() + yBackground = y[closestXLeftIndex], y[closestXRightIndex] + background = numpy.trapz(yBackground, x=x) + netArea = rawArea - background + return rawArea, netArea + + @docstring(_RegionOfInterestBase) + def contains(self, position): + return self._fromdata <= position[0] <= self._todata + + +class _RoiMarkerManager(object): + """ + Deal with all the ROI markers + """ + def __init__(self): + self._roiMarkerHandlers = {} + self._middleROIMarkerFlag = False + self._showAllMarkers = False + self._activeRoi = None + + def setActiveRoi(self, roi): + self._activeRoi = roi + self.updateAllMarkers() + + def setShowAllMarkers(self, show): + if show != self._showAllMarkers: + self._showAllMarkers = show + self.updateAllMarkers() + + def add(self, roi, markersHandler): + assert isinstance(roi, ROI) + assert isinstance(markersHandler, _RoiMarkerHandler) + if roi.getID() in self._roiMarkerHandlers: + raise ValueError('roi with the same ID already existing') + else: + self._roiMarkerHandlers[roi.getID()] = markersHandler + + def getMarkerHandler(self, roiID): + if roiID in self._roiMarkerHandlers: + return self._roiMarkerHandlers[roiID] + else: + return None + + def clear(self): + roisHandler = list(self._roiMarkerHandlers.values()) + for roiHandler in roisHandler: + self.remove(roiHandler.roi) + + def remove(self, roi): + if roi is None: + return + assert isinstance(roi, ROI) + if roi.getID() in self._roiMarkerHandlers: + self._roiMarkerHandlers[roi.getID()].clear() + del self._roiMarkerHandlers[roi.getID()] + + def hasMarker(self, markerID): + assert type(markerID) is str + return self.getMarker(markerID) is not None + + def changePosition(self, markerID, x): + markerHandler = self.getMarker(markerID) + if markerHandler is None: + raise ValueError('Marker %s not register' % markerID) + markerHandler.changePosition(markerID=markerID, x=x) + + def updateMarker(self, markerID): + markerHandler = self.getMarker(markerID) + if markerHandler is None: + raise ValueError('Marker %s not register' % markerID) + roiID = self.getRoiID(markerID) + visible = (self._activeRoi and self._activeRoi.getID() == roiID) or self._showAllMarkers is True + markerHandler.setVisible(visible) + markerHandler.updateAllMarkers() + + def updateRoiMarkers(self, roiID): + if roiID in self._roiMarkerHandlers: + visible = ((self._activeRoi and self._activeRoi.getID() == roiID) + or self._showAllMarkers is True) + _roi = self._roiMarkerHandlers[roiID]._roi() + if _roi and not _roi.isICR(): + self._roiMarkerHandlers[roiID].showMiddleMarker(self._middleROIMarkerFlag) + self._roiMarkerHandlers[roiID].setVisible(visible) + self._roiMarkerHandlers[roiID].updateMarkers() + + def getMarker(self, markerID): + assert type(markerID) is str + for marker in list(self._roiMarkerHandlers.values()): + if marker.hasMarker(markerID): + return marker + + def updateMarkers(self): + for markerHandler in list(self._roiMarkerHandlers.values()): + markerHandler.updateMarkers() + + def getRoiID(self, markerID): + for roiID, markerHandler in self._roiMarkerHandlers.items(): + if markerHandler.hasMarker(markerID): + return roiID + return None + + def setShowMiddleMarkers(self, show): + self._middleROIMarkerFlag = show + self._roiMarkerHandlers.updateAllMarkers() + + def updateAllMarkers(self): + for roiID in self._roiMarkerHandlers: + self.updateRoiMarkers(roiID) + + def getVisibleRois(self): + res = {} + for roiID, roiHandler in self._roiMarkerHandlers.items(): + markers = (roiHandler.getMarker('min'), roiHandler.getMarker('max'), + roiHandler.getMarker('middle')) + for marker in markers: + if marker.isVisible(): + if roiID not in res: + res[roiID] = [] + res[roiID].append(marker) + return res + + +class _RoiMarkerHandler(object): + """Used to deal with ROI markers used in ROITable""" + def __init__(self, roi, plot): + assert roi and isinstance(roi, ROI) + assert plot + + self._roi = weakref.ref(roi) + self._plot = weakref.ref(plot) + self._draggable = False if roi.isICR() else True + self._color = 'black' if roi.isICR() else 'blue' + self._displayMidMarker = False + self._visible = True + + @property + def draggable(self): + return self._draggable + + @property + def plot(self): + return self._plot() + + def clear(self): + if self.plot and self.roi: + self.plot.removeMarker(self._markerID('min')) + self.plot.removeMarker(self._markerID('max')) + self.plot.removeMarker(self._markerID('middle')) + + @property + def roi(self): + return self._roi() + + def setVisible(self, visible): + if visible != self._visible: + self._visible = visible + self.updateMarkers() + + def showMiddleMarker(self, visible): + if self.draggable is False and visible is True: + _logger.warning("ROI is not draggable. Won't display middle marker") + return + self._displayMidMarker = visible + self.getMarker('middle').setVisible(self._displayMidMarker) + + def updateMarkers(self): + if self.roi is None: + return + self._updateMinMarkerPos() + self._updateMaxMarkerPos() + self._updateMiddleMarkerPos() + + def _updateMinMarkerPos(self): + self.getMarker('min').setPosition(x=self.roi.getFrom(), y=None) + self.getMarker('min').setVisible(self._visible) + + def _updateMaxMarkerPos(self): + self.getMarker('max').setPosition(x=self.roi.getTo(), y=None) + self.getMarker('max').setVisible(self._visible) + + def _updateMiddleMarkerPos(self): + self.getMarker('middle').setPosition(x=self.roi.getMiddle(), y=None) + self.getMarker('middle').setVisible(self._displayMidMarker and self._visible) + + def getMarker(self, markerType): + if self.plot is None: + return None + assert markerType in ('min', 'max', 'middle') + if self.plot._getMarker(self._markerID(markerType)) is None: + assert self.roi + if markerType == 'min': + val = self.roi.getFrom() + elif markerType == 'max': + val = self.roi.getTo() + else: + val = self.roi.getMiddle() + + _color = self._color + if markerType == 'middle': + _color = 'yellow' + self.plot.addXMarker(val, + legend=self._markerID(markerType), + text=self.getMarkerName(markerType), + color=_color, + draggable=self.draggable) + return self.plot._getMarker(self._markerID(markerType)) + + def _markerID(self, markerType): + assert markerType in ('min', 'max', 'middle') + assert self.roi + return '_'.join((str(self.roi.getID()), markerType)) + + def getMarkerName(self, markerType): + assert markerType in ('min', 'max', 'middle') + assert self.roi + return ' '.join((self.roi.getName(), markerType)) + + def updateTexts(self): + self.getMarker('min').setText(self.getMarkerName('min')) + self.getMarker('max').setText(self.getMarkerName('max')) + self.getMarker('middle').setText(self.getMarkerName('middle')) + + def changePosition(self, markerID, x): + assert self.hasMarker(markerID) + markerType = self._getMarkerType(markerID) + assert markerType is not None + if self.roi is None: + return + if markerType == 'min': + self.roi.setFrom(x) + self._updateMiddleMarkerPos() + elif markerType == 'max': + self.roi.setTo(x) + self._updateMiddleMarkerPos() + else: + delta = x - 0.5 * (self.roi.getFrom() + self.roi.getTo()) + self.roi.setFrom(self.roi.getFrom() + delta) + self.roi.setTo(self.roi.getTo() + delta) + self._updateMinMarkerPos() + self._updateMaxMarkerPos() + + def hasMarker(self, marker): + return marker in (self._markerID('min'), + self._markerID('max'), + self._markerID('middle')) + + def _getMarkerType(self, markerID): + if markerID.endswith('_min'): + return 'min' + elif markerID.endswith('_max'): + return 'max' + elif markerID.endswith('_middle'): + return 'middle' + else: + return None + + +class CurvesROIDockWidget(qt.QDockWidget): + """QDockWidget with a :class:`CurvesROIWidget` connected to a PlotWindow. + + It makes the link between the :class:`CurvesROIWidget` and the PlotWindow. + + :param parent: See :class:`QDockWidget` + :param plot: :class:`.PlotWindow` instance on which to operate + :param name: See :class:`QDockWidget` + """ + sigROISignal = qt.Signal(object) + """Deprecated signal for backward compatibility with silx < 0.7. + Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal` + """ + + def __init__(self, parent=None, plot=None, name=None): + super(CurvesROIDockWidget, self).__init__(name, parent) + + assert plot is not None + self.plot = plot + self.roiWidget = CurvesROIWidget(self, name, plot=plot) + """Main widget of type :class:`CurvesROIWidget`""" + + # convenience methods to offer a simpler API allowing to ignore + # the details of the underlying implementation + # (ALLÂ DEPRECATED) + self.calculateROIs = self.calculateRois = self.roiWidget.calculateRois + self.setRois = self.roiWidget.setRois + self.getRois = self.roiWidget.getRois + + self.roiWidget.sigROISignal.connect(self._forwardSigROISignal) + + self.layout().setContentsMargins(0, 0, 0, 0) + self.setWidget(self.roiWidget) + + self.setAreaVisible = self.roiWidget.roiTable.setAreaVisible + self.setCountsVisible = self.roiWidget.roiTable.setCountsVisible + + def _forwardSigROISignal(self, ddict): + # emit deprecated signal for backward compatibility (silx < 0.7) + self.sigROISignal.emit(ddict) + + def toggleViewAction(self): + """Returns a checkable action that shows or closes this widget. + + See :class:`QMainWindow`. + """ + action = super(CurvesROIDockWidget, self).toggleViewAction() + action.setIcon(icons.getQIcon('plot-roi')) + return action + + def showEvent(self, event): + """Make sure this widget is raised when it is shown + (when it is first created as a tab in PlotWindow or when it is shown + again after hiding). + """ + self.raise_() + qt.QDockWidget.showEvent(self, event) + + @property + def currentROI(self): + return self.roiWidget.currentRoi diff --git a/src/silx/gui/plot/ImageStack.py b/src/silx/gui/plot/ImageStack.py new file mode 100644 index 0000000..1588a31 --- /dev/null +++ b/src/silx/gui/plot/ImageStack.py @@ -0,0 +1,640 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2020-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Image stack view with data prefetch capabilty.""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "04/03/2019" + + +from silx.gui import icons, qt +from silx.gui.plot import Plot2D +from silx.gui.utils import concurrent +from silx.io.url import DataUrl +from silx.io.utils import get_data +from collections import OrderedDict +from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser +import time +import threading +import typing +import logging + +_logger = logging.getLogger(__name__) + + +class _PlotWithWaitingLabel(qt.QWidget): + """Image plot widget with an overlay 'waiting' status. + """ + + class AnimationThread(threading.Thread): + def __init__(self, label): + self.running = True + self._label = label + self.animated_icon = icons.getWaitIcon() + self.animated_icon.register(self._label) + super(_PlotWithWaitingLabel.AnimationThread, self).__init__() + + def run(self): + while self.running: + time.sleep(0.05) + icon = self.animated_icon.currentIcon() + self.future_result = concurrent.submitToQtMainThread( + self._label.setPixmap, icon.pixmap(30, state=qt.QIcon.On)) + + def stop(self): + """Stop the update thread""" + if self.running: + self.animated_icon.unregister(self._label) + self.running = False + self.join(2) + + def __init__(self, parent): + super(_PlotWithWaitingLabel, self).__init__(parent=parent) + self._autoResetZoom = True + layout = qt.QStackedLayout(self) + layout.setStackingMode(qt.QStackedLayout.StackAll) + + self._waiting_label = qt.QLabel(parent=self) + self._waiting_label.setAlignment(qt.Qt.AlignHCenter | qt.Qt.AlignVCenter) + layout.addWidget(self._waiting_label) + + self._plot = Plot2D(parent=self) + layout.addWidget(self._plot) + + self.updateThread = _PlotWithWaitingLabel.AnimationThread(self._waiting_label) + self.updateThread.start() + + def close(self) -> bool: + super(_PlotWithWaitingLabel, self).close() + self.stopUpdateThread() + + def stopUpdateThread(self): + self.updateThread.stop() + + def setAutoResetZoom(self, reset): + """ + Should we reset the zoom when adding an image (eq. when browsing) + + :param bool reset: + """ + self._autoResetZoom = reset + if self._autoResetZoom: + self._plot.resetZoom() + + def isAutoResetZoom(self): + """ + + :return: True if a reset is done when the image change + :rtype: bool + """ + return self._autoResetZoom + + def setWaiting(self, activate=True): + if activate is True: + self._plot.clear() + self._waiting_label.show() + else: + self._waiting_label.hide() + + def setData(self, data): + self.setWaiting(activate=False) + self._plot.addImage(data=data, resetzoom=self._autoResetZoom) + + def clear(self): + self._plot.clear() + self.setWaiting(False) + + def getPlotWidget(self): + return self._plot + + +class _HorizontalSlider(HorizontalSliderWithBrowser): + + sigCurrentUrlIndexChanged = qt.Signal(int) + + def __init__(self, parent): + super(_HorizontalSlider, self).__init__(parent=parent) + # connect signal / slot + self.valueChanged.connect(self._urlChanged) + + def setUrlIndex(self, index): + self.setValue(index) + self.sigCurrentUrlIndexChanged.emit(index) + + def _urlChanged(self, value): + self.sigCurrentUrlIndexChanged.emit(value) + + +class UrlList(qt.QWidget): + """List of URLs the user to select an URL""" + + sigCurrentUrlChanged = qt.Signal(str) + """Signal emitted when the active/current url change""" + + def __init__(self, parent=None): + super(UrlList, self).__init__(parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setSpacing(0) + self.layout().setContentsMargins(0, 0, 0, 0) + self._listWidget = qt.QListWidget(parent=self) + self.layout().addWidget(self._listWidget) + + # connect signal / Slot + self._listWidget.currentItemChanged.connect(self._notifyCurrentUrlChanged) + + # expose API + self.currentItem = self._listWidget.currentItem + + def setUrls(self, urls: list) -> None: + url_names = [] + [url_names.append(url.path()) for url in urls] + self._listWidget.addItems(url_names) + + def _notifyCurrentUrlChanged(self, current, previous): + if current is None: + pass + else: + self.sigCurrentUrlChanged.emit(current.text()) + + def setUrl(self, url: DataUrl) -> None: + assert isinstance(url, DataUrl) + sel_items = self._listWidget.findItems(url.path(), qt.Qt.MatchExactly) + if sel_items is None: + _logger.warning(url.path(), ' is not registered in the list.') + elif len(sel_items) > 0: + item = sel_items[0] + self._listWidget.setCurrentItem(item) + self.sigCurrentUrlChanged.emit(item.text()) + + def clear(self): + self._listWidget.clear() + + +class _ToggleableUrlSelectionTable(qt.QWidget): + + _BUTTON_ICON = qt.QStyle.SP_ToolBarHorizontalExtensionButton # noqa + + sigCurrentUrlChanged = qt.Signal(str) + """Signal emitted when the active/current url change""" + + def __init__(self, parent=None) -> None: + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QGridLayout()) + self._toggleButton = qt.QPushButton(parent=self) + self.layout().addWidget(self._toggleButton, 0, 2, 1, 1) + self._toggleButton.setSizePolicy(qt.QSizePolicy.Fixed, + qt.QSizePolicy.Fixed) + + self._urlsTable = UrlList(parent=self) + self.layout().addWidget(self._urlsTable, 1, 1, 1, 2) + + # set up + self._setButtonIcon(show=True) + + # Signal / slot connection + self._toggleButton.clicked.connect(self.toggleUrlSelectionTable) + self._urlsTable.sigCurrentUrlChanged.connect(self._propagateSignal) + + # expose API + self.setUrls = self._urlsTable.setUrls + self.setUrl = self._urlsTable.setUrl + self.currentItem = self._urlsTable.currentItem + + def toggleUrlSelectionTable(self): + visible = not self.urlSelectionTableIsVisible() + self._setButtonIcon(show=visible) + self._urlsTable.setVisible(visible) + + def _setButtonIcon(self, show): + style = qt.QApplication.instance().style() + # return a QIcon + icon = style.standardIcon(self._BUTTON_ICON) + if show is False: + pixmap = icon.pixmap(32, 32).transformed(qt.QTransform().scale(-1, 1)) + icon = qt.QIcon(pixmap) + self._toggleButton.setIcon(icon) + + def urlSelectionTableIsVisible(self): + return self._urlsTable.isVisible() + + def _propagateSignal(self, url): + self.sigCurrentUrlChanged.emit(url) + + def clear(self): + self._urlsTable.clear() + + +class UrlLoader(qt.QThread): + """ + Thread use to load DataUrl + """ + def __init__(self, parent, url): + super(UrlLoader, self).__init__(parent=parent) + assert isinstance(url, DataUrl) + self.url = url + self.data = None + + def run(self): + try: + self.data = get_data(self.url) + except IOError: + self.data = None + + +class ImageStack(qt.QMainWindow): + """Widget loading on the fly images contained the given urls. + + It prefetches images close to the displayed one. + """ + + N_PRELOAD = 10 + + sigLoaded = qt.Signal(str) + """Signal emitted when new data is available""" + + sigCurrentUrlChanged = qt.Signal(str) + """Signal emitted when the current url change""" + + def __init__(self, parent=None) -> None: + super(ImageStack, self).__init__(parent) + self.__n_prefetch = ImageStack.N_PRELOAD + self._loadingThreads = [] + self.setWindowFlags(qt.Qt.Widget) + self._current_url = None + self._url_loader = UrlLoader + "class to instantiate for loading urls" + + # main widget + self._plot = _PlotWithWaitingLabel(parent=self) + self._plot.setAttribute(qt.Qt.WA_DeleteOnClose, True) + self.setWindowTitle("Image stack") + self.setCentralWidget(self._plot) + + # dock widget: url table + self._tableDockWidget = qt.QDockWidget(parent=self) + self._urlsTable = _ToggleableUrlSelectionTable(parent=self) + self._tableDockWidget.setWidget(self._urlsTable) + self._tableDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable) + self.addDockWidget(qt.Qt.RightDockWidgetArea, self._tableDockWidget) + # dock widget: qslider + self._sliderDockWidget = qt.QDockWidget(parent=self) + self._slider = _HorizontalSlider(parent=self) + self._sliderDockWidget.setWidget(self._slider) + self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._sliderDockWidget) + self._sliderDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable) + + self.reset() + + # connect signal / slot + self._urlsTable.sigCurrentUrlChanged.connect(self.setCurrentUrl) + self._slider.sigCurrentUrlIndexChanged.connect(self.setCurrentUrlIndex) + + def close(self) -> bool: + self._freeLoadingThreads() + self._plot.close() + super(ImageStack, self).close() + + def setUrlLoaderClass(self, urlLoader: typing.Type[UrlLoader]) -> None: + """ + + :param urlLoader: define the class to call for loading urls. + warning: this should be a class object and not a + class instance. + """ + assert isinstance(urlLoader, type(UrlLoader)) + self._url_loader = urlLoader + + def getUrlLoaderClass(self): + """ + + :return: class to instantiate for loading urls + :rtype: typing.Type[UrlLoader] + """ + return self._url_loader + + def _freeLoadingThreads(self): + for thread in self._loadingThreads: + thread.blockSignals(True) + thread.wait(5) + self._loadingThreads.clear() + + def getPlotWidget(self) -> Plot2D: + """ + Returns the PlotWidget contained in this window + + :return: PlotWidget contained in this window + :rtype: Plot2D + """ + return self._plot.getPlotWidget() + + def reset(self) -> None: + """Clear the plot and remove any link to url""" + self._freeLoadingThreads() + self._urls = None + self._urlIndexes = None + self._urlData = OrderedDict({}) + self._current_url = None + self._plot.clear() + self._urlsTable.clear() + self._slider.setMaximum(-1) + + def _preFetch(self, urls: list) -> None: + """Pre-fetch the given urls if necessary + + :param urls: list of DataUrl to prefetch + :type: list + """ + for url in urls: + if url.path() not in self._urlData: + self._load(url) + + def _load(self, url): + """ + Launch background load of a DataUrl + + :param url: + :type: DataUrl + """ + assert isinstance(url, DataUrl) + url_path = url.path() + assert url_path in self._urlIndexes + loader = self._url_loader(parent=self, url=url) + loader.finished.connect(self._urlLoaded, qt.Qt.QueuedConnection) + self._loadingThreads.append(loader) + loader.start() + + def _urlLoaded(self) -> None: + """ + + :param url: restul of DataUrl.path() function + :return: + """ + sender = self.sender() + assert isinstance(sender, UrlLoader) + url = sender.url.path() + if url in self._urlIndexes: + self._urlData[url] = sender.data + if self.getCurrentUrl().path() == url: + self._plot.setData(self._urlData[url]) + if sender in self._loadingThreads: + self._loadingThreads.remove(sender) + self.sigLoaded.emit(url) + + def setNPrefetch(self, n: int) -> None: + """ + Define the number of url to prefetch around + + :param int n: number of url to prefetch on left and right sides. + In total n*2 DataUrl will be prefetch + """ + self.__n_prefetch = n + current_url = self.getCurrentUrl() + if current_url is not None: + self.setCurrentUrl(current_url) + + def getNPrefetch(self) -> int: + """ + + :return: number of url to prefetch on left and right sides. In total + will load 2* NPrefetch DataUrls + """ + return self.__n_prefetch + + def setUrls(self, urls: list) -> None: + """list of urls within an index. Warning: urls should contain an image + compatible with the silx.gui.plot.Plot class + + :param urls: urls we want to set in the stack. Key is the index + (position in the stack), value is the DataUrl + :type: list + """ + def createUrlIndexes(): + indexes = OrderedDict() + for index, url in enumerate(urls): + indexes[index] = url + return indexes + + urls_with_indexes = createUrlIndexes() + urlsToIndex = self._urlsToIndex(urls_with_indexes) + self.reset() + self._urls = urls_with_indexes + self._urlIndexes = urlsToIndex + + old_url_table = self._urlsTable.blockSignals(True) + self._urlsTable.setUrls(urls=list(self._urls.values())) + self._urlsTable.blockSignals(old_url_table) + + old_slider = self._slider.blockSignals(True) + self._slider.setMinimum(0) + self._slider.setMaximum(len(self._urls) - 1) + self._slider.blockSignals(old_slider) + + if self.getCurrentUrl() in self._urls: + self.setCurrentUrl(self.getCurrentUrl()) + else: + if len(self._urls.keys()) > 0: + first_url = self._urls[list(self._urls.keys())[0]] + self.setCurrentUrl(first_url) + + def getUrls(self) -> tuple: + """ + + :return: tuple of urls + :rtype: tuple + """ + return tuple(self._urlIndexes.keys()) + + def _getNextUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]: + """ + return the next url in the stack + + :param url: url for which we want the next url + :type: DataUrl + :return: next url in the stack or None if `url` is the last one + :rtype: Union[None, DataUrl] + """ + assert isinstance(url, DataUrl) + if self._urls is None: + return None + else: + index = self._urlIndexes[url.path()] + indexes = list(self._urls.keys()) + res = list(filter(lambda x: x > index, indexes)) + if len(res) == 0: + return None + else: + return self._urls[res[0]] + + def _getPreviousUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]: + """ + return the previous url in the stack + + :param url: url for which we want the previous url + :type: DataUrl + :return: next url in the stack or None if `url` is the last one + :rtype: Union[None, DataUrl] + """ + if self._urls is None: + return None + else: + index = self._urlIndexes[url.path()] + indexes = list(self._urls.keys()) + res = list(filter(lambda x: x < index, indexes)) + if len(res) == 0: + return None + else: + return self._urls[res[-1]] + + def _getNNextUrls(self, n: int, url: DataUrl) -> list: + """ + Deduce the next urls in the stack after `url` + + :param n: the number of url store after `url` + :type: int + :param url: url for which we want n next url + :type: DataUrl + :return: list of next urls. + :rtype: list + """ + res = [] + next_free = self._getNextUrl(url=url) + while len(res) < n and next_free is not None: + assert isinstance(next_free, DataUrl) + res.append(next_free) + next_free = self._getNextUrl(res[-1]) + return res + + def _getNPreviousUrls(self, n: int, url: DataUrl): + """ + Deduce the previous urls in the stack after `url` + + :param n: the number of url store after `url` + :type: int + :param url: url for which we want n previous url + :type: DataUrl + :return: list of previous urls. + :rtype: list + """ + res = [] + next_free = self._getPreviousUrl(url=url) + while len(res) < n and next_free is not None: + res.insert(0, next_free) + next_free = self._getPreviousUrl(res[0]) + return res + + def setCurrentUrlIndex(self, index: int): + """ + Define the url to be displayed + + :param index: url to be displayed + :type: int + """ + if index < 0: + return + if self._urls is None: + return + elif index >= len(self._urls): + raise ValueError('requested index out of bounds') + else: + return self.setCurrentUrl(self._urls[index]) + + def setCurrentUrl(self, url: typing.Union[DataUrl, str]) -> None: + """ + Define the url to be displayed + + :param url: url to be displayed + :type: DataUrl + """ + assert isinstance(url, (DataUrl, str)) + if isinstance(url, str): + url = DataUrl(path=url) + if url != self._current_url: + self._current_url = url + self.sigCurrentUrlChanged.emit(url.path()) + + old_url_table = self._urlsTable.blockSignals(True) + old_slider = self._slider.blockSignals(True) + + self._urlsTable.setUrl(url) + self._slider.setUrlIndex(self._urlIndexes[url.path()]) + if self._current_url is None: + self._plot.clear() + else: + if self._current_url.path() in self._urlData: + self._plot.setData(self._urlData[url.path()]) + else: + self._load(url) + self._notifyLoading() + self._preFetch(self._getNNextUrls(self.__n_prefetch, url)) + self._preFetch(self._getNPreviousUrls(self.__n_prefetch, url)) + self._urlsTable.blockSignals(old_url_table) + self._slider.blockSignals(old_slider) + + def getCurrentUrl(self) -> typing.Union[None, DataUrl]: + """ + + :return: url currently displayed + :rtype: Union[None, DataUrl] + """ + return self._current_url + + def getCurrentUrlIndex(self) -> typing.Union[None, int]: + """ + + :return: index of the url currently displayed + :rtype: Union[None, int] + """ + if self._current_url is None: + return None + else: + return self._urlIndexes[self._current_url.path()] + + @staticmethod + def _urlsToIndex(urls): + """util, return a dictionary with url as key and index as value""" + res = {} + for index, url in urls.items(): + res[url.path()] = index + return res + + def _notifyLoading(self): + """display a simple image of loading...""" + self._plot.setWaiting(activate=True) + + def setAutoResetZoom(self, reset): + """ + Should we reset the zoom when adding an image (eq. when browsing) + + :param bool reset: + """ + self._plot.setAutoResetZoom(reset) + + def isAutoResetZoom(self) -> bool: + """ + + :return: True if a reset is done when the image change + :rtype: bool + """ + return self._plot.isAutoResetZoom() diff --git a/src/silx/gui/plot/ImageView.py b/src/silx/gui/plot/ImageView.py new file mode 100644 index 0000000..f8b830a --- /dev/null +++ b/src/silx/gui/plot/ImageView.py @@ -0,0 +1,1057 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""QWidget displaying a 2D image with histograms on its sides. + +The :class:`ImageView` implements this widget, and +:class:`ImageViewMainWindow` provides a main window with additional toolbar +and status bar. + +Basic usage of :class:`ImageView` is through the following methods: + +- :meth:`ImageView.getColormap`, :meth:`ImageView.setColormap` to update the + default colormap to use and update the currently displayed image. +- :meth:`ImageView.setImage` to update the displayed image. + +For an example of use, see `imageview.py` in :ref:`sample-code`. +""" + +from __future__ import division + + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "26/04/2018" + + +import logging +import numpy +import collections +from typing import Union + +import silx +from .. import qt +from .. import colors +from .. import icons + +from . import items, PlotWindow, PlotWidget, actions +from ..colors import Colormap +from ..colors import cursorColorForColormap +from .tools import LimitsToolBar +from .Profile import ProfileToolBar +from ...utils.proxy import docstring +from ...utils.deprecation import deprecated +from ...utils.enum import Enum +from .tools.RadarView import RadarView +from .utils.axis import SyncAxes +from ..utils import blockSignals +from . import _utils +from .tools.profile import manager +from .tools.profile import rois +from .actions import PlotAction + +_logger = logging.getLogger(__name__) + + +ProfileSumResult = collections.namedtuple("ProfileResult", + ["dataXRange", "dataYRange", + 'histoH', 'histoHRange', + 'histoV', 'histoVRange', + "xCoords", "xData", + "yCoords", "yData"]) + + +def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None): + """ + Compute a full vertical and horizontal profile on an image item using a + a range in the plot referential. + + Optionally takes a previous computed result to be able to skip the + computation. + + :rtype: ProfileSumResult + """ + data = imageItem.getValueData(copy=False) + origin = imageItem.getOrigin() + scale = imageItem.getScale() + height, width = data.shape + + xMin, xMax = xRange + yMin, yMax = yRange + + # Convert plot area limits to image coordinates + # and work in image coordinates (i.e., in pixels) + xMin = int((xMin - origin[0]) / scale[0]) + xMax = int((xMax - origin[0]) / scale[0]) + yMin = int((yMin - origin[1]) / scale[1]) + yMax = int((yMax - origin[1]) / scale[1]) + + if (xMin >= width or xMax < 0 or + yMin >= height or yMax < 0): + return None + + # The image is at least partly in the plot area + # Get the visible bounds in image coords (i.e., in pixels) + subsetXMin = 0 if xMin < 0 else xMin + subsetXMax = (width if xMax >= width else xMax) + 1 + subsetYMin = 0 if yMin < 0 else yMin + subsetYMax = (height if yMax >= height else yMax) + 1 + + if cache is not None: + if ((subsetXMin, subsetXMax) == cache.dataXRange and + (subsetYMin, subsetYMax) == cache.dataYRange): + # The visible area of data is the same + return cache + + # Rebuild histograms for visible area + visibleData = data[subsetYMin:subsetYMax, + subsetXMin:subsetXMax] + histoHVisibleData = numpy.nansum(visibleData, axis=0) + histoVVisibleData = numpy.nansum(visibleData, axis=1) + histoHMin = numpy.nanmin(histoHVisibleData) + histoHMax = numpy.nanmax(histoHVisibleData) + histoVMin = numpy.nanmin(histoVVisibleData) + histoVMax = numpy.nanmax(histoVVisibleData) + + # Convert to histogram curve and update plots + # Taking into account origin and scale + coords = numpy.arange(2 * histoHVisibleData.size) + xCoords = (coords + 1) // 2 + subsetXMin + xCoords = origin[0] + scale[0] * xCoords + xData = numpy.take(histoHVisibleData, coords // 2) + coords = numpy.arange(2 * histoVVisibleData.size) + yCoords = (coords + 1) // 2 + subsetYMin + yCoords = origin[1] + scale[1] * yCoords + yData = numpy.take(histoVVisibleData, coords // 2) + + result = ProfileSumResult( + dataXRange=(subsetXMin, subsetXMax), + dataYRange=(subsetYMin, subsetYMax), + histoH=histoHVisibleData, + histoHRange=(histoHMin, histoHMax), + histoV=histoVVisibleData, + histoVRange=(histoVMin, histoVMax), + xCoords=xCoords, + xData=xData, + yCoords=yCoords, + yData=yData) + + return result + + +class _SideHistogram(PlotWidget): + """ + Widget displaying one of the side profile of the ImageView. + + Implement ProfileWindow + """ + + sigClose = qt.Signal() + + sigMouseMoved = qt.Signal(float, float) + + def __init__(self, parent=None, backend=None, direction=qt.Qt.Horizontal): + super(_SideHistogram, self).__init__(parent=parent, backend=backend) + self._direction = direction + self.sigPlotSignal.connect(self._plotEvents) + self._color = "blue" + self.__profile = None + self.__profileSum = None + + def _plotEvents(self, eventDict): + """Callback for horizontal histogram plot events.""" + if eventDict['event'] == 'mouseMoved': + self.sigMouseMoved.emit(eventDict['x'], eventDict['y']) + + def setProfileColor(self, color): + self._color = color + + def setProfileSum(self, result): + self.__profileSum = result + if self.__profile is None: + self.__drawProfileSum() + + def prepareWidget(self, roi): + """Implements `ProfileWindow`""" + pass + + def setRoiProfile(self, roi): + """Implements `ProfileWindow`""" + if roi is None: + return + self._roiColor = colors.rgba(roi.getColor()) + + def getProfile(self): + """Implements `ProfileWindow`""" + return self.__profile + + def setProfile(self, data): + """Implements `ProfileWindow`""" + self.__profile = data + if data is None: + self.__drawProfileSum() + else: + self.__drawProfile() + + def __drawProfileSum(self): + """Only draw the profile sum on the plot. + + Other elements are removed + """ + profileSum = self.__profileSum + + try: + self.removeCurve('profile') + except Exception: + pass + + if profileSum is None: + try: + self.removeCurve('profilesum') + except Exception: + pass + return + + if self._direction == qt.Qt.Horizontal: + xx, yy = profileSum.xCoords, profileSum.xData + elif self._direction == qt.Qt.Vertical: + xx, yy = profileSum.yData, profileSum.yCoords + else: + assert False + + self.addCurve(xx, yy, + xlabel='', ylabel='', + legend="profilesum", + color=self._color, + linestyle='-', + selectable=False, + resetzoom=False) + + self.__updateLimits() + + def __drawProfile(self): + """Only draw the profile on the plot. + + Other elements are removed + """ + profile = self.__profile + + try: + self.removeCurve('profilesum') + except Exception: + pass + + if profile is None: + try: + self.removeCurve('profile') + except Exception: + pass + self.setProfileSum(self.__profileSum) + return + + if self._direction == qt.Qt.Horizontal: + xx, yy = profile.coords, profile.profile + elif self._direction == qt.Qt.Vertical: + xx, yy = profile.profile, profile.coords + else: + assert False + + self.addCurve(xx, + yy, + legend="profile", + color=self._roiColor, + resetzoom=False) + + self.__updateLimits() + + def __updateLimits(self): + if self.__profile: + data = self.__profile.profile + vMin = numpy.nanmin(data) + vMax = numpy.nanmax(data) + elif self.__profileSum is not None: + if self._direction == qt.Qt.Horizontal: + vMin, vMax = self.__profileSum.histoHRange + elif self._direction == qt.Qt.Vertical: + vMin, vMax = self.__profileSum.histoVRange + else: + assert False + else: + vMin, vMax = 0, 0 + + # Tune the result using the data margins + margins = self.getDataMargins() + if self._direction == qt.Qt.Horizontal: + _, _, vMin, vMax = _utils.addMarginsToLimits(margins, False, False, 0, 0, vMin, vMax) + elif self._direction == qt.Qt.Vertical: + vMin, vMax, _, _ = _utils.addMarginsToLimits(margins, False, False, vMin, vMax, 0, 0) + else: + assert False + + if self._direction == qt.Qt.Horizontal: + dataAxis = self.getYAxis() + elif self._direction == qt.Qt.Vertical: + dataAxis = self.getXAxis() + else: + assert False + + with blockSignals(dataAxis): + dataAxis.setLimits(vMin, vMax) + + +class ShowSideHistogramsAction(PlotAction): + """QAction to change visibility of side histogram of a :class:`.ImageView`. + + :param plot: :class:`.ImageView` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(ShowSideHistogramsAction, self).__init__( + plot, icon='side-histograms', text='Show/hide side histograms', + tooltip='Show/hide side histogram', + triggered=self._actionTriggered, + checkable=True, parent=parent) + + def _actionTriggered(self, checked=False): + if self.plot.isSideHistogramDisplayed() != checked: + self.plot.setSideHistogramDisplayed(checked) + + +class AggregationModeAction(qt.QWidgetAction): + """Action providing few filters to the image""" + + sigAggregationModeChanged = qt.Signal() + + def __init__(self, parent): + qt.QWidgetAction.__init__(self, parent) + + toolButton = qt.QToolButton(parent) + + filterAction = qt.QAction(self) + filterAction.setText("No filter") + filterAction.setCheckable(True) + filterAction.setChecked(True) + filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.NONE) + densityNoFilterAction = filterAction + + filterAction = qt.QAction(self) + filterAction.setText("Max filter") + filterAction.setCheckable(True) + filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MAX) + densityMaxFilterAction = filterAction + + filterAction = qt.QAction(self) + filterAction.setText("Mean filter") + filterAction.setCheckable(True) + filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MEAN) + densityMeanFilterAction = filterAction + + filterAction = qt.QAction(self) + filterAction.setText("Min filter") + filterAction.setCheckable(True) + filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MIN) + densityMinFilterAction = filterAction + + densityGroup = qt.QActionGroup(self) + densityGroup.setExclusive(True) + densityGroup.addAction(densityNoFilterAction) + densityGroup.addAction(densityMaxFilterAction) + densityGroup.addAction(densityMeanFilterAction) + densityGroup.addAction(densityMinFilterAction) + densityGroup.triggered.connect(self._aggregationModeChanged) + self.__densityGroup = densityGroup + + filterMenu = qt.QMenu(toolButton) + filterMenu.addAction(densityNoFilterAction) + filterMenu.addAction(densityMaxFilterAction) + filterMenu.addAction(densityMeanFilterAction) + filterMenu.addAction(densityMinFilterAction) + + toolButton.setPopupMode(qt.QToolButton.InstantPopup) + toolButton.setMenu(filterMenu) + toolButton.setText("Data filters") + toolButton.setToolTip("Enable/disable filter on the image") + icon = icons.getQIcon("aggregation-mode") + toolButton.setIcon(icon) + toolButton.setText("Pixel aggregation filter") + + self.setDefaultWidget(toolButton) + + def _aggregationModeChanged(self): + self.sigAggregationModeChanged.emit() + + def setAggregationMode(self, mode): + """Set an Aggregated enum from ImageDataAggregated""" + for a in self.__densityGroup.actions(): + if a.property("aggregation") is mode: + a.setChecked(True) + + def getAggregationMode(self): + """Returns an Aggregated enum from ImageDataAggregated""" + densityAction = self.__densityGroup.checkedAction() + if densityAction is None: + return items.ImageDataAggregated.Aggregation.NONE + return densityAction.property("aggregation") + + +class ImageView(PlotWindow): + """Display a single image with horizontal and vertical histograms. + + Use :meth:`setImage` to control the displayed image. + This class also provides the :class:`silx.gui.plot.Plot` API. + + The :class:`ImageView` inherits from :class:`.PlotWindow` (which provides + the toolbars) and also exposes :class:`.PlotWidget` API for further + plot control (plot title, axes labels, aspect ratio, ...). + + :param parent: The parent of this widget or None. + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.PlotWidget` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + """ + + HISTOGRAMS_COLOR = 'blue' + """Color to use for the side histograms.""" + + HISTOGRAMS_HEIGHT = 200 + """Height in pixels of the side histograms.""" + + IMAGE_MIN_SIZE = 200 + """Minimum size in pixels of the image area.""" + + # Qt signals + valueChanged = qt.Signal(float, float, object) + """Signals that the data value under the cursor has changed. + + It provides: row, column, data value. + + When the cursor is over an histogram, either row or column is Nan + and the provided data value is the histogram value + (i.e., the sum along the corresponding row/column). + Row and columns are either Nan or integer values. + """ + + class ProfileWindowBehavior(Enum): + """ImageView's profile window behavior options""" + + POPUP = 'popup' + """All profiles are displayed in pop-up windows""" + + EMBEDDED = 'embedded' + """Horizontal, vertical and cross profiles are displayed in + sides widgets, others are displayed in pop-up windows. + """ + + def __init__(self, parent=None, backend=None): + self._imageLegend = '__ImageView__image' + str(id(self)) + self._cache = None # Store currently visible data information + + super(ImageView, self).__init__(parent=parent, backend=backend, + resetzoom=True, autoScale=False, + logScale=False, grid=False, + curveStyle=False, colormap=True, + aspectRatio=True, yInverted=True, + copy=True, save=True, print_=True, + control=False, position=False, + roi=False, mask=True) + + # Enable mask synchronisation to use it in profiles + maskToolsWidget = self.getMaskToolsDockWidget().widget() + maskToolsWidget.setItemMaskUpdated(True) + + self.__showSideHistogramsAction = ShowSideHistogramsAction(self, self) + self.__showSideHistogramsAction.setChecked(True) + + self.__aggregationModeAction = AggregationModeAction(self) + self.__aggregationModeAction.sigAggregationModeChanged.connect(self._aggregationModeChanged) + + if parent is None: + self.setWindowTitle('ImageView') + + if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward': + self.getYAxis().setInverted(True) + + self._initWidgets(backend) + + toolBar = self.toolBar() + toolBar.addAction(self.__showSideHistogramsAction) + toolBar.addAction(self.__aggregationModeAction) + + self.__profileWindowBehavior = self.ProfileWindowBehavior.POPUP + self.__profile = ProfileToolBar(plot=self) + self.addToolBar(self.__profile) + + def _initWidgets(self, backend): + """Set-up layout and plots.""" + self._histoHPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Horizontal) + widgetHandle = self._histoHPlot.getWidgetHandle() + widgetHandle.setMinimumHeight(self.HISTOGRAMS_HEIGHT) + widgetHandle.setMaximumHeight(self.HISTOGRAMS_HEIGHT) + self._histoHPlot.setInteractiveMode('zoom') + self._histoHPlot.setDataMargins(0., 0., 0.1, 0.1) + self._histoHPlot.sigMouseMoved.connect(self._mouseMovedOnHistoH) + self._histoHPlot.setProfileColor(self.HISTOGRAMS_COLOR) + + self._histoVPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Vertical) + widgetHandle = self._histoVPlot.getWidgetHandle() + widgetHandle.setMinimumWidth(self.HISTOGRAMS_HEIGHT) + widgetHandle.setMaximumWidth(self.HISTOGRAMS_HEIGHT) + self._histoVPlot.setInteractiveMode('zoom') + self._histoVPlot.setDataMargins(0.1, 0.1, 0., 0.) + self._histoVPlot.sigMouseMoved.connect(self._mouseMovedOnHistoV) + self._histoVPlot.setProfileColor(self.HISTOGRAMS_COLOR) + + self.setPanWithArrowKeys(True) + self.setInteractiveMode('zoom') # Color set in setColormap + self.sigPlotSignal.connect(self._imagePlotCB) + self.sigActiveImageChanged.connect(self._activeImageChangedSlot) + + self._radarView = RadarView(parent=self) + self._radarView.setPlotWidget(self) + + self.__syncXAxis = SyncAxes([self.getXAxis(), self._histoHPlot.getXAxis()]) + self.__syncYAxis = SyncAxes([self.getYAxis(), self._histoVPlot.getYAxis()]) + + self.__setCentralWidget() + + def __setCentralWidget(self): + """Set central widget with all its content""" + layout = qt.QGridLayout() + layout.addWidget(self.getWidgetHandle(), 0, 0) + layout.addWidget(self._histoVPlot, 0, 1) + layout.addWidget(self._histoHPlot, 1, 0) + layout.addWidget(self._radarView, 1, 1, 1, 2) + layout.addWidget(self.getColorBarWidget(), 0, 2) + + self._radarView.setMinimumWidth(self.IMAGE_MIN_SIZE) + self._radarView.setMinimumHeight(self.HISTOGRAMS_HEIGHT) + self._histoHPlot.setMinimumWidth(self.IMAGE_MIN_SIZE) + self._histoVPlot.setMinimumHeight(self.HISTOGRAMS_HEIGHT) + + layout.setColumnStretch(0, 1) + layout.setColumnStretch(1, 0) + layout.setRowStretch(0, 1) + layout.setRowStretch(1, 0) + + layout.setSpacing(0) + layout.setContentsMargins(0, 0, 0, 0) + + centralWidget = qt.QWidget(self) + centralWidget.setLayout(layout) + self.setCentralWidget(centralWidget) + + @docstring(PlotWidget) + def setBackend(self, backend): + # Use PlotWidget here since we override PlotWindow behavior + PlotWidget.setBackend(self, backend) + self.__setCentralWidget() + + def _dirtyCache(self): + self._cache = None + + def getAggregationModeAction(self): + return self.__aggregationModeAction + + def _aggregationModeChanged(self): + item = self._getItem("image", self._imageLegend) + if item is None: + return + aggregationMode = self.__aggregationModeAction.getAggregationMode() + if aggregationMode is not None and isinstance(item, items.ImageDataAggregated): + item.setAggregationMode(aggregationMode) + else: + # It means the item type have to be changed + self.removeImage(self._imageLegend) + image = item.getData(copy=False) + if image is None: + return + origin = item.getOrigin() + scale = item.getScale() + self.setImage(image, origin, scale, copy=False, resetzoom=False) + + def getShowSideHistogramsAction(self): + return self.__showSideHistogramsAction + + def setSideHistogramDisplayed(self, show): + """Display or not the side histograms""" + if self.isSideHistogramDisplayed() == show: + return + self._histoHPlot.setVisible(show) + self._histoVPlot.setVisible(show) + self._radarView.setVisible(show) + self.__showSideHistogramsAction.setChecked(show) + if show: + # Probably have to be computed + self._updateHistograms() + + def isSideHistogramDisplayed(self): + """True if the side histograms are displayed""" + return self._histoHPlot.isVisible() + + def _updateHistograms(self): + """Update histograms content using current active image.""" + if not self.isSideHistogramDisplayed(): + # The histogram computation can be skipped + return + + activeImage = self.getActiveImage() + if activeImage is not None: + xRange = self.getXAxis().getLimits() + yRange = self.getYAxis().getLimits() + result = computeProfileSumOnRange(activeImage, xRange, yRange, self._cache) + self._cache = result + self._histoHPlot.setProfileSum(result) + self._histoVPlot.setProfileSum(result) + + # Plots event listeners + + def _imagePlotCB(self, eventDict): + """Callback for imageView plot events.""" + if eventDict['event'] == 'mouseMoved': + activeImage = self.getActiveImage() + if activeImage is not None: + data = activeImage.getData(copy=False) + height, width = data.shape[0:2] + + # Get corresponding coordinate in image + origin = activeImage.getOrigin() + scale = activeImage.getScale() + if (eventDict['x'] >= origin[0] and + eventDict['y'] >= origin[1]): + x = int((eventDict['x'] - origin[0]) / scale[0]) + y = int((eventDict['y'] - origin[1]) / scale[1]) + + if x >= 0 and x < width and y >= 0 and y < height: + self.valueChanged.emit(float(x), float(y), + data[y][x]) + + elif eventDict['event'] == 'limitsChanged': + self._updateHistograms() + + def _mouseMovedOnHistoH(self, x, y): + if self._cache is None: + return + activeImage = self.getActiveImage() + if activeImage is None: + return + + xOrigin = activeImage.getOrigin()[0] + xScale = activeImage.getScale()[0] + + minValue = xOrigin + xScale * self._cache.dataXRange[0] + + if x >= minValue: + data = self._cache.histoH + column = int((x - minValue) / xScale) + if column >= 0 and column < data.shape[0]: + self.valueChanged.emit( + float('nan'), + float(column + self._cache.dataXRange[0]), + data[column]) + + def _mouseMovedOnHistoV(self, x, y): + if self._cache is None: + return + activeImage = self.getActiveImage() + if activeImage is None: + return + + yOrigin = activeImage.getOrigin()[1] + yScale = activeImage.getScale()[1] + + minValue = yOrigin + yScale * self._cache.dataYRange[0] + + if y >= minValue: + data = self._cache.histoV + row = int((y - minValue) / yScale) + if row >= 0 and row < data.shape[0]: + self.valueChanged.emit( + float(row + self._cache.dataYRange[0]), + float('nan'), + data[row]) + + def _activeImageChangedSlot(self, previous, legend): + """Handle Plot active image change. + + Resets side histograms cache + """ + self._dirtyCache() + self._updateHistograms() + + def setProfileWindowBehavior(self, behavior: Union[str, ProfileWindowBehavior]): + """Set where profile widgets are displayed. + + :param ProfileWindowBehavior behavior: + - 'popup': All profiles are displayed in pop-up windows + - 'embedded': Horizontal, vertical and cross profiles are displayed in + sides widgets, others are displayed in pop-up windows. + """ + behavior = self.ProfileWindowBehavior.from_value(behavior) + if behavior is not self.getProfileWindowBehavior(): + manager = self.__profile.getProfileManager() + manager.clearProfile() + manager.requestUpdateAllProfile() + + if behavior is self.ProfileWindowBehavior.EMBEDDED: + horizontalProfileWindow = self._histoHPlot + verticalProfileWindow = self._histoVPlot + else: + horizontalProfileWindow = None + verticalProfileWindow = None + + manager.setSpecializedProfileWindow( + rois.ProfileImageHorizontalLineROI, horizontalProfileWindow + ) + manager.setSpecializedProfileWindow( + rois.ProfileImageVerticalLineROI, verticalProfileWindow + ) + self.__profileWindowBehavior = behavior + + def getProfileWindowBehavior(self) -> ProfileWindowBehavior: + """Returns current profile display behavior. + + See :meth:`setProfileWindowBehavior` and :class:`ProfileWindowBehavior` + """ + return self.__profileWindowBehavior + + def getProfileToolBar(self): + """"Returns profile tools attached to this plot. + + :rtype: silx.gui.plot.PlotTools.ProfileToolBar + """ + return self.__profile + + @property + @deprecated(replacement="getProfileToolBar()") + def profile(self): + return self.getProfileToolBar() + + def getHistogram(self, axis): + """Return the histogram and corresponding row or column extent. + + The returned value when an histogram is available is a dict with keys: + + - 'data': numpy array of the histogram values. + - 'extent': (start, end) row or column index. + end index is not included in the histogram. + + :param str axis: 'x' for horizontal, 'y' for vertical + :return: The histogram and its extent as a dict or None. + :rtype: dict + """ + assert axis in ('x', 'y') + if self._cache is None: + return None + else: + if axis == 'x': + return dict( + data=numpy.array(self._cache.histoH, copy=True), + extent=self._cache.dataXRange) + else: + return dict( + data=numpy.array(self._cache.histoV, copy=True), + extent=(self._cache.dataYRange)) + + def radarView(self): + """Get the lower right radarView widget.""" + return self._radarView + + def setRadarView(self, radarView): + """Change the lower right radarView widget. + + :param RadarView radarView: Widget subclassing RadarView to replace + the lower right corner widget. + """ + self._radarView = radarView + self._radarView.setPlotWidget(self) + self.centralWidget().layout().addWidget(self._radarView, 1, 1) + + # High-level API + + def getColormap(self): + """Get the default colormap description. + + :return: A description of the current colormap. + See :meth:`setColormap` for details. + :rtype: dict + """ + return self.getDefaultColormap() + + def setColormap(self, colormap=None, normalization=None, + autoscale=None, vmin=None, vmax=None, colors=None): + """Set the default colormap and update active image. + + Parameters that are not provided are taken from the current colormap. + + The colormap parameter can also be a dict with the following keys: + + - *name*: string. The colormap to use: + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. + - *normalization*: string. The mapping to use for the colormap: + either 'linear' or 'log'. + - *autoscale*: bool. Whether to use autoscale (True) + or range provided by keys 'vmin' and 'vmax' (False). + - *vmin*: float. The minimum value of the range to use if 'autoscale' + is False. + - *vmax*: float. The maximum value of the range to use if 'autoscale' + is False. + - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8. + List of RGB or RGBA colors to use (only if name is None) + + :param colormap: Name of the colormap in + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. + Or the description of the colormap as a dict. + :type colormap: dict or str. + :param str normalization: Colormap mapping: 'linear' or 'log'. + :param bool autoscale: Whether to use autoscale (True) + or [vmin, vmax] range (False). + :param float vmin: The minimum value of the range to use if + 'autoscale' is False. + :param float vmax: The maximum value of the range to use if + 'autoscale' is False. + :param numpy.ndarray colors: Only used if name is None. + Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays + """ + cmap = self.getDefaultColormap() + + if isinstance(colormap, Colormap): + # Replace colormap + cmap = colormap + + self.setDefaultColormap(cmap) + + # Update active image colormap + activeImage = self.getActiveImage() + if isinstance(activeImage, items.ColormapMixIn): + activeImage.setColormap(cmap) + + elif isinstance(colormap, dict): + # Support colormap parameter as a dict + assert normalization is None + assert autoscale is None + assert vmin is None + assert vmax is None + assert colors is None + cmap._setFromDict(colormap) + + else: + if colormap is not None: + cmap.setName(colormap) + if normalization is not None: + cmap.setNormalization(normalization) + if autoscale: + cmap.setVRange(None, None) + else: + if vmin is not None: + cmap.setVMin(vmin) + if vmax is not None: + cmap.setVMax(vmax) + if colors is not None: + cmap.setColormapLUT(colors) + + cursorColor = cursorColorForColormap(cmap.getName()) + self.setInteractiveMode('zoom', color=cursorColor) + + def setImage(self, image, origin=(0, 0), scale=(1., 1.), + copy=True, reset=None, resetzoom=True): + """Set the image to display. + + :param image: A 2D array representing the image or None to empty plot. + :type image: numpy.ndarray-like with 2 dimensions or None. + :param origin: The (x, y) position of the origin of the image. + Default: (0, 0). + The origin is the lower left corner of the image when + the Y axis is not inverted. + :type origin: Tuple of 2 floats: (origin x, origin y). + :param scale: The scale factor to apply to the image on X and Y axes. + Default: (1, 1). + It is the size of a pixel in the coordinates of the axes. + Scales must be positive numbers. + :type scale: Tuple of 2 floats: (scale x, scale y). + :param bool copy: Whether to copy image data (default) or not. + :param bool reset: Deprecated. Alias for `resetzoom`. + :param bool resetzoom: Whether to reset zoom and ROI (default) or not. + """ + self._dirtyCache() + + if reset is not None: + resetzoom = reset + + assert len(origin) == 2 + assert len(scale) == 2 + assert scale[0] > 0 + assert scale[1] > 0 + + if image is None: + self.remove(self._imageLegend, kind='image') + return + + data = numpy.array(image, order='C', copy=copy) + if data.size == 0: + self.remove(self._imageLegend, kind='image') + return + + assert data.ndim == 2 or (data.ndim == 3 and data.shape[2] in (3, 4)) + + aggregation = self.getAggregationModeAction().getAggregationMode() + if data.ndim != 2 and aggregation is not None: + # RGB/A with aggregation is not supported + aggregation = items.ImageDataAggregated.Aggregation.NONE + + if aggregation is items.ImageDataAggregated.Aggregation.NONE: + self.addImage(data, + legend=self._imageLegend, + origin=origin, scale=scale, + colormap=self.getColormap(), + resetzoom=False) + else: + item = self._getItem("image", self._imageLegend) + if isinstance(item, items.ImageDataAggregated): + item.setData(data) + item.setOrigin(origin) + item.setScale(scale) + else: + if isinstance(item, items.ImageDataAggregated): + imageItem = item + wasCreated = False + else: + if item is not None: + self.removeImage(self._imageLegend) + imageItem = items.ImageDataAggregated() + imageItem.setName(self._imageLegend) + imageItem.setColormap(self.getColormap()) + wasCreated = True + imageItem.setData(data) + imageItem.setOrigin(origin) + imageItem.setScale(scale) + imageItem.setAggregationMode(aggregation) + if wasCreated: + self.addItem(imageItem) + + self.setActiveImage(self._imageLegend) + self._updateHistograms() + if resetzoom: + self.resetZoom() + + +# ImageViewMainWindow ######################################################### + +class ImageViewMainWindow(ImageView): + """:class:`ImageView` with additional toolbars + + Adds extra toolbar and a status bar to :class:`ImageView`. + """ + def __init__(self, parent=None, backend=None): + self._dataInfo = None + super(ImageViewMainWindow, self).__init__(parent, backend) + self.setWindowFlags(qt.Qt.Window) + + self.getXAxis().setLabel('X') + self.getYAxis().setLabel('Y') + self.setGraphTitle('Image') + + # Add toolbars and status bar + self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self)) + + menu = self.menuBar().addMenu('File') + menu.addAction(self.getOutputToolBar().getSaveAction()) + menu.addAction(self.getOutputToolBar().getPrintAction()) + menu.addSeparator() + action = menu.addAction('Quit') + action.triggered[bool].connect(qt.QApplication.instance().quit) + + menu = self.menuBar().addMenu('Edit') + menu.addAction(self.getOutputToolBar().getCopyAction()) + menu.addSeparator() + menu.addAction(self.getResetZoomAction()) + menu.addAction(self.getColormapAction()) + menu.addAction(actions.control.KeepAspectRatioAction(self, self)) + menu.addAction(actions.control.YAxisInvertedAction(self, self)) + menu.addAction(self.getShowSideHistogramsAction()) + + self.__profileMenu = self.menuBar().addMenu('Profile') + self.__updateProfileMenu() + + # Connect to ImageView's signal + self.valueChanged.connect(self._statusBarSlot) + + def __updateProfileMenu(self): + """Update actions available in 'Profile' menu""" + profile = self.getProfileToolBar() + self.__profileMenu.clear() + self.__profileMenu.addAction(profile.hLineAction) + self.__profileMenu.addAction(profile.vLineAction) + self.__profileMenu.addAction(profile.crossAction) + self.__profileMenu.addAction(profile.lineAction) + self.__profileMenu.addAction(profile.clearAction) + + def _formatValueToString(self, value): + try: + if isinstance(value, numpy.ndarray): + if len(value) == 4: + return "RGBA: %.3g, %.3g, %.3g, %.3g" % (value[0], value[1], value[2], value[3]) + elif len(value) == 3: + return "RGB: %.3g, %.3g, %.3g" % (value[0], value[1], value[2]) + else: + return "Value: %g" % value + except Exception: + _logger.error("Error while formatting pixel value", exc_info=True) + pass + return "Value: %s" % value + + def _statusBarSlot(self, row, column, value): + """Update status bar with coordinates/value from plots.""" + if numpy.isnan(row): + msg = 'Column: %d, Sum: %g' % (int(column), value) + elif numpy.isnan(column): + msg = 'Row: %d, Sum: %g' % (int(row), value) + else: + msg_value = self._formatValueToString(value) + msg = 'Position: (%d, %d), %s' % (int(row), int(column), msg_value) + if self._dataInfo is not None: + msg = self._dataInfo + ', ' + msg + + self.statusBar().showMessage(msg) + + @docstring(ImageView) + def setProfileWindowBehavior(self, behavior: str): + super().setProfileWindowBehavior(behavior) + self.__updateProfileMenu() + + @docstring(ImageView) + def setImage(self, image, *args, **kwargs): + if hasattr(image, 'dtype') and hasattr(image, 'shape'): + assert image.ndim == 2 or (image.ndim == 3 and image.shape[2] in (3, 4)) + height, width = image.shape[0:2] + dataInfo = 'Data: %dx%d (%s)' % (width, height, str(image.dtype)) + else: + dataInfo = None + + if self._dataInfo != dataInfo: + self._dataInfo = dataInfo + self.statusBar().showMessage(self._dataInfo) + + # Set the new image in ImageView widget + super(ImageViewMainWindow, self).setImage(image, *args, **kwargs) diff --git a/src/silx/gui/plot/Interaction.py b/src/silx/gui/plot/Interaction.py new file mode 100644 index 0000000..6213889 --- /dev/null +++ b/src/silx/gui/plot/Interaction.py @@ -0,0 +1,350 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides an implementation of state machines for interaction. + +Sample code of a state machine with two states ('idle' and 'active') +with transitions on left button press/release: + +.. code-block:: python + + from silx.gui.plot.Interaction import * + + class SampleStateMachine(StateMachine): + + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('active') + + class Active(State): + def enterState(self): + print('Enabled') # Handle enter active state here + + def leaveState(self): + print('Disabled') # Handle leave active state here + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('idle') + + def __init__(self): + # State machine has 2 states + states = { + 'idle': SampleStateMachine.Idle, + 'active': SampleStateMachine.Active + } + super(TwoStates, self).__init__(states, 'idle') + # idle is the initial state + + stateMachine = SampleStateMachine() + + # Triggers a transition to the Active state: + stateMachine.handleEvent('press', 0, 0, LEFT_BTN) + + # Triggers a transition to the Idle state: + stateMachine.handleEvent('release', 0, 0, LEFT_BTN) + +See :class:`ClickOrDrag` for another example of a state machine. + +See `Renaud Blanch, Michel Beaudouin-Lafon. +Programming Rich Interactions using the Hierarchical State Machine Toolkit. +In Proceedings of AVI 2006. p 51-58. +<http://iihm.imag.fr/en/publication/BB06a/>`_ +for a discussion of using (hierarchical) state machines for interaction. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/02/2016" + + +import weakref + + +# state machine ############################################################### + +class State(object): + """Base class for the states of a state machine. + + This class is meant to be subclassed. + """ + + def __init__(self, machine): + """State instances should be created by the :class:`StateMachine`. + + They are not intended to be used outside this context. + + :param machine: The state machine instance this state belongs to. + :type machine: StateMachine + """ + self._machineRef = weakref.ref(machine) # Prevent cyclic reference + + @property + def machine(self): + """The state machine this state belongs to. + + Useful to access data or methods that are shared across states. + """ + machine = self._machineRef() + if machine is not None: + return machine + else: + raise RuntimeError("Associated StateMachine is not valid") + + def goto(self, state, *args, **kwargs): + """Performs a transition to a new state. + + Extra arguments are passed to the :meth:`enterState` method of the + new state. + + :param str state: The name of the state to go to. + """ + self.machine._goto(state, *args, **kwargs) + + def enterState(self, *args, **kwargs): + """Called when the state machine enters this state. + + Arguments are those provided to the :meth:`goto` method that + triggered the transition to this state. + """ + pass + + def leaveState(self): + """Called when the state machine leaves this state + (i.e., when :meth:`goto` is called). + """ + pass + + def validate(self): + """Called externally to validate the current interaction in case of a + creation. + """ + pass + +class StateMachine(object): + """State machine controller. + + This is the entry point of a state machine. + It is in charge of dispatching received event and handling the + current active state. + """ + + def __init__(self, states, initState, *args, **kwargs): + """Create a state machine controller with an initial state. + + Extra arguments are passed to the :meth:`enterState` method + of the initState. + + :param states: All states of the state machine + :type states: dict of: {str name: State subclass} + :param str initState: Key of the initial state in states + """ + self.states = states + + self.state = self.states[initState](self) + self.state.enterState(*args, **kwargs) + + def _goto(self, state, *args, **kwargs): + self.state.leaveState() + self.state = self.states[state](self) + self.state.enterState(*args, **kwargs) + + def handleEvent(self, eventName, *args, **kwargs): + """Process an event with the state machine. + + This method looks up for an event handler in the current state + and then in the :class:`StateMachine` instance. + Handler are looked up as 'onEventName' method. + If a handler is found, it is called with the provided extra + arguments, and this method returns the return value of the + handler. + If no handler is found, this method returns None. + + :param str eventName: Name of the event to handle + :returns: The return value of the handler or None + """ + handlerName = 'on' + eventName[0].upper() + eventName[1:] + try: + handler = getattr(self.state, handlerName) + except AttributeError: + try: + handler = getattr(self, handlerName) + except AttributeError: + handler = None + if handler is not None: + return handler(*args, **kwargs) + + def validate(self): + """Called externally to validate the current interaction in case of a + creation. + """ + self.state.validate() + + +# clickOrDrag ################################################################# + +LEFT_BTN = 'left' +"""Left mouse button.""" + +RIGHT_BTN = 'right' +"""Right mouse button.""" + +MIDDLE_BTN = 'middle' +"""Middle mouse button.""" + + +class ClickOrDrag(StateMachine): + """State machine for left and right click and left drag interaction. + + It is intended to be used through subclassing by overriding + :meth:`click`, :meth:`beginDrag`, :meth:`drag` and :meth:`endDrag`. + + :param Set[str] clickButtons: Set of buttons that provides click interaction + :param Set[str] dragButtons: Set of buttons that provides drag interaction + """ + + DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2 + + class Idle(State): + def onPress(self, x, y, btn): + if btn in self.machine.dragButtons: + self.goto('clickOrDrag', x, y, btn) + return True + elif btn in self.machine.clickButtons: + self.goto('click', x, y, btn) + return True + + class Click(State): + def enterState(self, x, y, btn): + self.initPos = x, y + self.button = btn + + def onMove(self, x, y): + dx2 = (x - self.initPos[0]) ** 2 + dy2 = (y - self.initPos[1]) ** 2 + if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST: + self.goto('idle') + + def onRelease(self, x, y, btn): + if btn == self.button: + self.machine.click(x, y, btn) + self.goto('idle') + + class ClickOrDrag(State): + def enterState(self, x, y, btn): + self.initPos = x, y + self.button = btn + + def onMove(self, x, y): + dx2 = (x - self.initPos[0]) ** 2 + dy2 = (y - self.initPos[1]) ** 2 + if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST: + self.goto('drag', self.initPos, (x, y), self.button) + + def onRelease(self, x, y, btn): + if btn == self.button: + if btn in self.machine.clickButtons: + self.machine.click(x, y, btn) + self.goto('idle') + + class Drag(State): + def enterState(self, initPos, curPos, btn): + self.initPos = initPos + self.button = btn + self.machine.beginDrag(*initPos, btn) + self.machine.drag(*curPos, btn) + + def onMove(self, x, y): + self.machine.drag(x, y, self.button) + + def onRelease(self, x, y, btn): + if btn == self.button: + self.machine.endDrag(self.initPos, (x, y), btn) + self.goto('idle') + + def __init__(self, + clickButtons=(LEFT_BTN, RIGHT_BTN), + dragButtons=(LEFT_BTN,)): + states = { + 'idle': self.Idle, + 'click': self.Click, + 'clickOrDrag': self.ClickOrDrag, + 'drag': self.Drag + } + self.__clickButtons = set(clickButtons) + self.__dragButtons = set(dragButtons) + super(ClickOrDrag, self).__init__(states, 'idle') + + clickButtons = property(lambda self: self.__clickButtons, + doc="Buttons with click interaction (Set[int])") + + dragButtons = property(lambda self: self.__dragButtons, + doc="Buttons with drag interaction (Set[int])") + + def click(self, x, y, btn): + """Called upon a button supporting click. + + Override in subclass. + + :param int x: X mouse position in pixels. + :param int y: Y mouse position in pixels. + :param str btn: The mouse button which was clicked. + """ + pass + + def beginDrag(self, x, y, btn): + """Called at the beginning of a drag gesture with mouse button pressed. + + Override in subclass. + + :param int x: X mouse position in pixels. + :param int y: Y mouse position in pixels. + :param str btn: The mouse button for which a drag is starting. + """ + pass + + def drag(self, x, y, btn): + """Called on mouse moved during a drag gesture. + + Override in subclass. + + :param int x: X mouse position in pixels. + :param int y: Y mouse position in pixels. + :param str btn: The mouse button for which a drag is in progress. + """ + pass + + def endDrag(self, startPoint, endPoint, btn): + """Called at the end of a drag gesture when the mouse button is released. + + Override in subclass. + + :param List[int] startPoint: + (x, y) mouse position in pixels at the beginning of the drag. + :param List[int] endPoint: + (x, y) mouse position in pixels at the end of the drag. + :param str btn: The mouse button for which a drag is done. + """ + pass diff --git a/src/silx/gui/plot/ItemsSelectionDialog.py b/src/silx/gui/plot/ItemsSelectionDialog.py new file mode 100644 index 0000000..c0504b0 --- /dev/null +++ b/src/silx/gui/plot/ItemsSelectionDialog.py @@ -0,0 +1,286 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a dialog widget to select plot items. + +.. autoclass:: ItemsSelectionDialog + +""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "28/06/2017" + +import logging + +from silx.gui import qt +from silx.gui.plot.PlotWidget import PlotWidget + +_logger = logging.getLogger(__name__) + + +class KindsSelector(qt.QListWidget): + """List widget allowing to select plot item kinds + ("curve", "scatter", "image"...) + """ + sigSelectedKindsChanged = qt.Signal(list) + + def __init__(self, parent=None, kinds=None): + """ + + :param parent: Parent QWidget or None + :param tuple(str) kinds: Sequence of kinds. If None, the default + behavior is to provide a checkbox for all possible item kinds. + """ + qt.QListWidget.__init__(self, parent) + + self.plot_item_kinds = [] + + self.setAvailableKinds(kinds if kinds is not None else PlotWidget.ITEM_KINDS) + + self.setSelectionMode(qt.QAbstractItemView.ExtendedSelection) + self.selectAll() + + self.itemSelectionChanged.connect(self.emitSigKindsSelectionChanged) + + def emitSigKindsSelectionChanged(self): + self.sigSelectedKindsChanged.emit(self.selectedKinds) + + @property + def selectedKinds(self): + """Tuple of all selected kinds (as strings).""" + # check for updates when self.itemSelectionChanged + return [item.text() for item in self.selectedItems()] + + def setAvailableKinds(self, kinds): + """Set a list of kinds to be displayed. + + :param list[str] kinds: Sequence of kinds + """ + self.plot_item_kinds = kinds + + self.clear() + for kind in self.plot_item_kinds: + item = qt.QListWidgetItem(self) + item.setText(kind) + self.addItem(item) + + def selectAll(self): + """Select all available kinds.""" + if self.selectionMode() in [qt.QAbstractItemView.SingleSelection, + qt.QAbstractItemView.NoSelection]: + raise RuntimeError("selectAll requires a multiple selection mode") + for i in range(self.count()): + self.item(i).setSelected(True) + + +class PlotItemsSelector(qt.QTableWidget): + """Table widget displaying the legend and kind of all + plot items corresponding to a list of specified kinds. + + Selected plot items are provided as property :attr:`selectedPlotItems`. + You can be warned of selection changes by listening to signal + :attr:`itemSelectionChanged`. + """ + def __init__(self, parent=None, plot=None): + if plot is None or not isinstance(plot, PlotWidget): + raise AttributeError("parameter plot is required") + self.plot = plot + """:class:`PlotWidget` instance""" + + self.plot_item_kinds = None + """List of plot item kinds (strings)""" + + qt.QTableWidget.__init__(self, parent) + + self.setColumnCount(2) + + self.setSelectionBehavior(qt.QTableWidget.SelectRows) + + def _clear(self): + self.clear() + self.setHorizontalHeaderLabels(["legend", "type"]) + + def setAllKindsFilter(self): + """Display all kinds of plot items.""" + self.setKindsFilter(PlotWidget.ITEM_KINDS) + + def setKindsFilter(self, kinds): + """Set list of all kinds of plot items to be displayed. + + :param list[str] kinds: Sequence of kinds + """ + if not set(kinds) <= set(PlotWidget.ITEM_KINDS): + raise KeyError("Illegal plot item kinds: %s" % + set(kinds) - set(PlotWidget.ITEM_KINDS)) + self.plot_item_kinds = kinds + + self.updatePlotItems() + + def updatePlotItems(self): + self._clear() + + # respect order of kinds as set in method setKindsFilter + itemsAndKind = [] + for kind in self.plot_item_kinds: + itemClasses = self.plot._KIND_TO_CLASSES[kind] + for item in self.plot.getItems(): + if isinstance(item, itemClasses) and item.isVisible(): + itemsAndKind.append((item, kind)) + + self.setRowCount(len(itemsAndKind)) + + for index, (item, kind) in enumerate(itemsAndKind): + legend_twitem = qt.QTableWidgetItem(item.getName()) + self.setItem(index, 0, legend_twitem) + + kind_twitem = qt.QTableWidgetItem(kind) + self.setItem(index, 1, kind_twitem) + + @property + def selectedPlotItems(self): + """List of all selected items""" + selection_model = self.selectionModel() + selected_rows_idx = selection_model.selectedRows() + selected_rows = [idx.row() for idx in selected_rows_idx] + + items = [] + for row in selected_rows: + legend = self.item(row, 0).text() + kind = self.item(row, 1).text() + item = self.plot._getItem(kind, legend) + if item is not None: + items.append(item) + + return items + + +class ItemsSelectionDialog(qt.QDialog): + """This widget is a modal dialog allowing to select one or more plot + items, in a table displaying their legend and kind. + + Public methods: + + - :meth:`getSelectedItems` + - :meth:`setAvailableKinds` + - :meth:`setItemsSelectionMode` + + This widget inherits QDialog and therefore implements the usual + dialog methods, e.g. :meth:`exec`. + + A trivial usage example would be:: + + isd = ItemsSelectionDialog(plot=my_plot_widget) + isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection) + result = isd.exec() + if result: + for item in isd.getSelectedItems(): + print(item.getName(), type(item)) + else: + print("Selection cancelled") + """ + def __init__(self, parent=None, plot=None): + if plot is None or not isinstance(plot, PlotWidget): + raise AttributeError("parameter plot is required") + qt.QDialog.__init__(self, parent) + + self.setWindowTitle("Plot items selector") + + kind_selector_label = qt.QLabel("Filter item kinds:", self) + item_selector_label = qt.QLabel("Select items:", self) + + self.kind_selector = KindsSelector(self) + self.kind_selector.setToolTip( + "select one or more item kinds to show them in the item list") + + self.item_selector = PlotItemsSelector(self, plot) + self.item_selector.setToolTip("select items") + + self.item_selector.setKindsFilter(self.kind_selector.selectedKinds) + self.kind_selector.sigSelectedKindsChanged.connect( + self.item_selector.setKindsFilter + ) + + okb = qt.QPushButton("OK", self) + okb.clicked.connect(self.accept) + + cancelb = qt.QPushButton("Cancel", self) + cancelb.clicked.connect(self.reject) + + layout = qt.QGridLayout(self) + layout.addWidget(kind_selector_label, 0, 0) + layout.addWidget(item_selector_label, 0, 1) + layout.addWidget(self.kind_selector, 1, 0) + layout.addWidget(self.item_selector, 1, 1) + layout.addWidget(okb, 2, 0) + layout.addWidget(cancelb, 2, 1) + + self.setLayout(layout) + + def getSelectedItems(self): + """Return a list of selected plot items + + :return: List of selected plot items + :rtype: list[silx.gui.plot.items.Item]""" + return self.item_selector.selectedPlotItems + + def setAvailableKinds(self, kinds): + """Set a list of kinds to be displayed. + + :param list[str] kinds: Sequence of kinds + """ + self.kind_selector.setAvailableKinds(kinds) + + def selectAllKinds(self): + self.kind_selector.selectAll() + + def setItemsSelectionMode(self, mode): + """Set selection mode for plot item (single item selection, + multiple...). + + :param mode: One of :class:`QTableWidget` selection modes + """ + if mode == self.item_selector.SingleSelection: + self.item_selector.setToolTip( + "Select one item by clicking on it.") + elif mode == self.item_selector.MultiSelection: + self.item_selector.setToolTip( + "Select one or more items by clicking with the left mouse" + " button.\nYou can unselect items by clicking them again.\n" + "Multiple items can be toggled by dragging the mouse over them.") + elif mode == self.item_selector.ExtendedSelection: + self.item_selector.setToolTip( + "Select one or more items. You can select multiple items " + "by keeping the Ctrl key pushed when clicking.\nYou can " + "select a range of items by clicking on the first and " + "last while keeping the Shift key pushed.") + elif mode == self.item_selector.ContiguousSelection: + self.item_selector.setToolTip( + "Select one item by clicking on it. If you press the Shift" + " key while clicking on a second item,\nall items between " + "the two will be selected.") + elif mode == self.item_selector.NoSelection: + raise ValueError("The NoSelection mode is not allowed " + "in this context.") + self.item_selector.setSelectionMode(mode) diff --git a/src/silx/gui/plot/LegendSelector.py b/src/silx/gui/plot/LegendSelector.py new file mode 100755 index 0000000..d439387 --- /dev/null +++ b/src/silx/gui/plot/LegendSelector.py @@ -0,0 +1,1039 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Widget displaying curves legends and allowing to operate on curves. + +This widget is meant to work with :class:`PlotWindow`. +""" + +__authors__ = ["V.A. Sole", "T. Rueter", "T. Vincent"] +__license__ = "MIT" +__data__ = "16/10/2017" + + +import logging +import weakref + +import numpy + +from .. import qt, colors +from ..widgets.LegendIconWidget import LegendIconWidget +from . import items + + +_logger = logging.getLogger(__name__) + + +class LegendIcon(LegendIconWidget): + """Object displaying a curve linestyle and symbol. + + :param QWidget parent: See :class:`QWidget` + :param Union[~silx.gui.plot.items.Curve,None] curve: + Curve with which to synchronize + """ + + def __init__(self, parent=None, curve=None): + super(LegendIcon, self).__init__(parent) + self._curveRef = None + self.setCurve(curve) + + def getCurve(self): + """Returns curve associated to this widget + + :rtype: Union[~silx.gui.plot.items.Curve,None] + """ + return None if self._curveRef is None else self._curveRef() + + def setCurve(self, curve): + """Set the curve with which to synchronize this widget. + + :param curve: Union[~silx.gui.plot.items.Curve,None] + """ + assert curve is None or isinstance(curve, items.Curve) + + previousCurve = self.getCurve() + if curve == previousCurve: + return + + if previousCurve is not None: + previousCurve.sigItemChanged.disconnect(self._curveChanged) + + self._curveRef = None if curve is None else weakref.ref(curve) + + if curve is not None: + curve.sigItemChanged.connect(self._curveChanged) + + self._update() + + def _update(self): + """Update widget according to current curve state. + """ + curve = self.getCurve() + if curve is None: + _logger.error('Curve no more exists') + self.setEnabled(False) + return + + style = curve.getCurrentStyle() + + self.setEnabled(curve.isVisible()) + self.setSymbol(style.getSymbol()) + self.setLineWidth(style.getLineWidth()) + self.setLineStyle(style.getLineStyle()) + + color = style.getColor() + if numpy.array(color, copy=False).ndim != 1: + # array of colors, use transparent black + color = 0., 0., 0., 0. + color = colors.rgba(color) # Make sure it is float in [0, 1] + alpha = curve.getAlpha() + color = qt.QColor.fromRgbF( + color[0], color[1], color[2], color[3] * alpha) + self.setLineColor(color) + self.setSymbolColor(color) + self.update() # TODO this should not be needed + + def _curveChanged(self, event): + """Handle update of curve item + + :param event: Kind of change + """ + if event in (items.ItemChangedType.VISIBLE, + items.ItemChangedType.SYMBOL, + items.ItemChangedType.SYMBOL_SIZE, + items.ItemChangedType.LINE_WIDTH, + items.ItemChangedType.LINE_STYLE, + items.ItemChangedType.COLOR, + items.ItemChangedType.ALPHA, + items.ItemChangedType.HIGHLIGHTED, + items.ItemChangedType.HIGHLIGHTED_STYLE): + self._update() + + +class LegendModel(qt.QAbstractListModel): + """Data model of curve legends. + + It holds the information of the curve: + + - color + - line width + - line style + - visibility of the lines + - symbol + - visibility of the symbols + """ + iconColorRole = qt.Qt.UserRole + 0 + iconLineWidthRole = qt.Qt.UserRole + 1 + iconLineStyleRole = qt.Qt.UserRole + 2 + showLineRole = qt.Qt.UserRole + 3 + iconSymbolRole = qt.Qt.UserRole + 4 + showSymbolRole = qt.Qt.UserRole + 5 + + def __init__(self, legendList=None, parent=None): + super(LegendModel, self).__init__(parent) + if legendList is None: + legendList = [] + self.legendList = [] + self.insertLegendList(0, legendList) + self._palette = qt.QPalette() + + def __getitem__(self, idx): + if idx >= len(self.legendList): + raise IndexError('list index out of range') + return self.legendList[idx] + + def rowCount(self, modelIndex=None): + return len(self.legendList) + + def flags(self, index): + return (qt.Qt.ItemIsEditable | + qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsSelectable) + + def data(self, modelIndex, role): + if modelIndex.isValid: + idx = modelIndex.row() + else: + return None + if idx >= len(self.legendList): + raise IndexError('list index out of range') + + item = self.legendList[idx] + isActive = item[1].get("active", False) + if role == qt.Qt.DisplayRole: + # Data to be rendered in the form of text + legend = str(item[0]) + return legend + elif role == qt.Qt.SizeHintRole: + # size = qt.QSize(200,50) + _logger.warning('LegendModel -- size hint role not implemented') + return qt.QSize() + elif role == qt.Qt.TextAlignmentRole: + alignment = qt.Qt.AlignVCenter | qt.Qt.AlignLeft + return alignment + elif role == qt.Qt.BackgroundRole: + # Background color, must be QBrush + if isActive: + brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.Highlight) + elif idx % 2: + brush = qt.QBrush(qt.QColor(240, 240, 240)) + else: + brush = qt.QBrush(qt.Qt.white) + return brush + elif role == qt.Qt.ForegroundRole: + # ForegroundRole color, must be QBrush + if isActive: + brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.HighlightedText) + else: + brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.WindowText) + return brush + elif role == qt.Qt.CheckStateRole: + return bool(item[2]) # item[2] == True + elif role == qt.Qt.ToolTipRole or role == qt.Qt.StatusTipRole: + return '' + elif role == self.iconColorRole: + return item[1]['color'] + elif role == self.iconLineWidthRole: + return item[1]['linewidth'] + elif role == self.iconLineStyleRole: + return item[1]['linestyle'] + elif role == self.iconSymbolRole: + return item[1]['symbol'] + elif role == self.showLineRole: + return item[3] + elif role == self.showSymbolRole: + return item[4] + else: + _logger.info('Unkown role requested: %s', str(role)) + return None + + def setData(self, modelIndex, value, role): + if modelIndex.isValid: + idx = modelIndex.row() + else: + return None + if idx >= len(self.legendList): + # raise IndexError('list index out of range') + _logger.warning( + 'setData -- List index out of range, idx: %d', idx) + return None + + item = self.legendList[idx] + try: + if role == qt.Qt.DisplayRole: + # Set legend + item[0] = str(value) + elif role == self.iconColorRole: + item[1]['color'] = qt.QColor(value) + elif role == self.iconLineWidthRole: + item[1]['linewidth'] = int(value) + elif role == self.iconLineStyleRole: + item[1]['linestyle'] = str(value) + elif role == self.iconSymbolRole: + item[1]['symbol'] = str(value) + elif role == qt.Qt.CheckStateRole: + item[2] = value + elif role == self.showLineRole: + item[3] = value + elif role == self.showSymbolRole: + item[4] = value + except ValueError: + _logger.warning('Conversion failed:\n\tvalue: %s\n\trole: %s', + str(value), str(role)) + # Can that be right? Read docs again.. + self.dataChanged.emit(modelIndex, modelIndex) + return True + + def insertLegendList(self, row, llist): + """ + :param int row: Determines after which row the items are inserted + :param llist: Carries the new legend information + :type llist: List + """ + modelIndex = self.createIndex(row, 0) + count = len(llist) + super(LegendModel, self).beginInsertRows(modelIndex, + row, + row + count) + head = self.legendList[0:row] + tail = self.legendList[row:] + new = [] + for (legend, icon) in llist: + linestyle = icon.get('linestyle', None) + if LegendIconWidget.isEmptyLineStyle(linestyle): + # Curve had no line, give it one and hide it + # So when toggle line, it will display a solid line + showLine = False + icon['linestyle'] = '-' + else: + showLine = True + + symbol = icon.get('symbol', None) + if LegendIconWidget.isEmptySymbol(symbol): + # Curve had no symbol, give it one and hide it + # So when toggle symbol, it will display 'o' + showSymbol = False + icon['symbol'] = 'o' + else: + showSymbol = True + + selected = icon.get('selected', True) + item = [legend, + icon, + selected, + showLine, + showSymbol] + new.append(item) + self.legendList = head + new + tail + super(LegendModel, self).endInsertRows() + return True + + def insertRows(self, row, count, modelIndex=qt.QModelIndex()): + raise NotImplementedError('Use LegendModel.insertLegendList instead') + + def removeRow(self, row): + return self.removeRows(row, 1) + + def removeRows(self, row, count, modelIndex=qt.QModelIndex()): + length = len(self.legendList) + if length == 0: + # Nothing to do.. + return True + if row < 0 or row >= length: + raise IndexError('Index out of range -- ' + + 'idx: %d, len: %d' % (row, length)) + if count == 0: + return False + super(LegendModel, self).beginRemoveRows(modelIndex, + row, + row + count) + del(self.legendList[row:row + count]) + super(LegendModel, self).endRemoveRows() + return True + + def setEditor(self, event, editor): + """ + :param str event: String that identifies the editor + :param editor: Widget used to change data in the underlying model + :type editor: QWidget + """ + if event not in self.eventList: + raise ValueError('setEditor -- Event must be in %s' % + str(self.eventList)) + self.editorDict[event] = editor + + +class LegendListItemWidget(qt.QItemDelegate): + """Object displaying a single item (i.e., a row) in the list.""" + + # Notice: LegendListItem does NOT inherit + # from QObject, it cannot emit signals! + + def __init__(self, parent=None, itemType=0): + super(LegendListItemWidget, self).__init__(parent) + + # Dictionary to render checkboxes + self.cbDict = {} + self.labelDict = {} + self.iconDict = {} + + # Keep checkbox and legend to get sizeHint + self.checkbox = qt.QCheckBox() + self.legend = qt.QLabel() + self.icon = LegendIcon() + + # Context Menu and Editors + self.contextMenu = None + + def paint(self, painter, option, modelIndex): + """ + Here be docs.. + + :param QPainter painter: + :param QStyleOptionViewItem option: + :param QModelIndex modelIndex: + """ + painter.save() + rect = option.rect + + # Calculate the icon rectangle + iconSize = self.icon.sizeHint() + # Calculate icon position + x = rect.left() + 2 + y = rect.top() + int(.5 * (rect.height() - iconSize.height())) + iconRect = qt.QRect(qt.QPoint(x, y), iconSize) + + # Calculate label rectangle + legendSize = qt.QSize(rect.width() - iconSize.width() - 30, + rect.height()) + # Calculate label position + x = rect.left() + iconRect.width() + y = rect.top() + labelRect = qt.QRect(qt.QPoint(x, y), legendSize) + labelRect.translate(qt.QPoint(10, 0)) + + # Calculate the checkbox rectangle + x = rect.right() - 30 + y = rect.top() + chBoxRect = qt.QRect(qt.QPoint(x, y), rect.bottomRight()) + + # Remember the rectangles + idx = modelIndex.row() + self.cbDict[idx] = chBoxRect + self.iconDict[idx] = iconRect + self.labelDict[idx] = labelRect + + # Draw background first! + if option.state & qt.QStyle.State_MouseOver: + backgroundBrush = option.palette.highlight() + else: + backgroundBrush = modelIndex.data(qt.Qt.BackgroundRole) + painter.fillRect(rect, backgroundBrush) + + # Draw label + legendText = modelIndex.data(qt.Qt.DisplayRole) + textBrush = modelIndex.data(qt.Qt.ForegroundRole) + textAlign = modelIndex.data(qt.Qt.TextAlignmentRole) + painter.setBrush(textBrush) + painter.setFont(self.legend.font()) + painter.setPen(textBrush.color()) + painter.drawText(labelRect, textAlign, legendText) + + # Draw icon + iconColor = modelIndex.data(LegendModel.iconColorRole) + iconLineWidth = modelIndex.data(LegendModel.iconLineWidthRole) + iconLineStyle = modelIndex.data(LegendModel.iconLineStyleRole) + iconSymbol = modelIndex.data(LegendModel.iconSymbolRole) + icon = LegendIcon() + icon.resize(iconRect.size()) + icon.move(iconRect.topRight()) + icon.showSymbol = modelIndex.data(LegendModel.showSymbolRole) + icon.showLine = modelIndex.data(LegendModel.showLineRole) + icon.setSymbolColor(iconColor) + icon.setLineColor(iconColor) + icon.setLineWidth(iconLineWidth) + icon.setLineStyle(iconLineStyle) + icon.setSymbol(iconSymbol) + icon.symbolOutlineBrush = backgroundBrush + icon.paint(painter, iconRect, option.palette) + + # Draw the checkbox + if modelIndex.data(qt.Qt.CheckStateRole): + checkState = qt.Qt.Checked + else: + checkState = qt.Qt.Unchecked + + self.drawCheck( + painter, qt.QStyleOptionViewItem(), chBoxRect, checkState) + + painter.restore() + + def editorEvent(self, event, model, option, modelIndex): + # From the docs: + # Mouse events are sent to editorEvent() + # even if they don't start editing of the item. + if event.button() == qt.Qt.RightButton and self.contextMenu: + self.contextMenu.exec(event.globalPos(), modelIndex) + return True + elif event.button() == qt.Qt.LeftButton: + # Check if checkbox was clicked + idx = modelIndex.row() + cbRect = self.cbDict[idx] + if cbRect.contains(event.pos()): + # Toggle checkbox + model.setData(modelIndex, + not modelIndex.data(qt.Qt.CheckStateRole), + qt.Qt.CheckStateRole) + event.ignore() + return True + else: + return super(LegendListItemWidget, self).editorEvent( + event, model, option, modelIndex) + + def createEditor(self, parent, option, idx): + _logger.info('### Editor request ###') + + def sizeHint(self, option, idx): + # return qt.QSize(68,24) + iconSize = self.icon.sizeHint() + legendSize = self.legend.sizeHint() + checkboxSize = self.checkbox.sizeHint() + height = max([iconSize.height(), + legendSize.height(), + checkboxSize.height()]) + 4 + width = iconSize.width() + legendSize.width() + checkboxSize.width() + return qt.QSize(width, height) + + +class LegendListView(qt.QListView): + """Widget displaying a list of curve legends, line style and symbol.""" + + sigLegendSignal = qt.Signal(object) + """Signal emitting a dict when an action is triggered by the user.""" + + __mouseClickedEvent = 'mouseClicked' + __checkBoxClickedEvent = 'checkBoxClicked' + __legendClickedEvent = 'legendClicked' + + def __init__(self, parent=None, model=None, contextMenu=None): + super(LegendListView, self).__init__(parent) + self.__lastButton = None + self.__lastClickPos = None + self.__lastModelIdx = None + # Set default delegate + self.setItemDelegate(LegendListItemWidget()) + # Set default editors + # self.setSizePolicy(qt.QSizePolicy.MinimumExpanding, + # qt.QSizePolicy.MinimumExpanding) + # Set edit triggers by hand using self.edit(QModelIndex) + # in mousePressEvent (better to control than signals) + self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers) + + # Control layout + # self.setBatchSize(2) + # self.setLayoutMode(qt.QListView.Batched) + # self.setFlow(qt.QListView.LeftToRight) + + # Control selection + self.setSelectionMode(qt.QAbstractItemView.NoSelection) + + if model is None: + model = LegendModel(parent=self) + self.setModel(model) + self.setContextMenu(contextMenu) + + def setLegendList(self, legendList, row=None): + if row is not None: + model = self.model() + model.insertLegendList(row, legendList) + elif len(legendList) != self.model().rowCount(): + self.clear() + model = self.model() + model.insertLegendList(0, legendList) + else: + model = self.model() + for i, (new_legend, icon) in enumerate(legendList): + modelIndex = model.index(i) + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + if new_legend != legend: + model.setData(modelIndex, new_legend, qt.Qt.DisplayRole) + + color = modelIndex.data(LegendModel.iconColorRole) + new_color = icon.get('color', None) + if new_color != color: + model.setData(modelIndex, new_color, LegendModel.iconColorRole) + + linewidth = modelIndex.data(LegendModel.iconLineWidthRole) + new_linewidth = icon.get('linewidth', 1.0) + if new_linewidth != linewidth: + model.setData(modelIndex, new_linewidth, LegendModel.iconLineWidthRole) + + linestyle = modelIndex.data(LegendModel.iconLineStyleRole) + new_linestyle = icon.get('linestyle', None) + visible = not LegendIconWidget.isEmptyLineStyle(new_linestyle) + model.setData(modelIndex, visible, LegendModel.showLineRole) + if new_linestyle != linestyle: + model.setData(modelIndex, new_linestyle, LegendModel.iconLineStyleRole) + + symbol = modelIndex.data(LegendModel.iconSymbolRole) + new_symbol = icon.get('symbol', None) + visible = not LegendIconWidget.isEmptySymbol(new_symbol) + model.setData(modelIndex, visible, LegendModel.showSymbolRole) + if new_symbol != symbol: + model.setData(modelIndex, new_symbol, LegendModel.iconSymbolRole) + + selected = modelIndex.data(qt.Qt.CheckStateRole) + new_selected = icon.get('selected', True) + if new_selected != selected: + model.setData(modelIndex, new_selected, qt.Qt.CheckStateRole) + _logger.debug('LegendListView.setLegendList(legendList) finished') + + def clear(self): + model = self.model() + model.removeRows(0, model.rowCount()) + _logger.debug('LegendListView.clear() finished') + + def setContextMenu(self, contextMenu=None): + delegate = self.itemDelegate() + if isinstance(delegate, LegendListItemWidget) and self.model(): + if contextMenu is None: + delegate.contextMenu = LegendListContextMenu(self.model()) + delegate.contextMenu.sigContextMenu.connect( + self._contextMenuSlot) + else: + delegate.contextMenu = contextMenu + + def __getitem__(self, idx): + model = self.model() + try: + item = model[idx] + except ValueError: + item = None + return item + + def _contextMenuSlot(self, ddict): + self.sigLegendSignal.emit(ddict) + + def mousePressEvent(self, event): + self.__lastButton = event.button() + self.__lastPosition = event.pos() + super(LegendListView, self).mousePressEvent(event) + # call _handleMouseClick after editing was handled + # If right click (context menu) is aborted, no + # signal is emitted.. + self._handleMouseClick(self.indexAt(self.__lastPosition)) + + def mouseDoubleClickEvent(self, event): + self.__lastButton = event.button() + self.__lastPosition = event.pos() + super(LegendListView, self).mouseDoubleClickEvent(event) + # call _handleMouseClick after editing was handled + # If right click (context menu) is aborted, no + # signal is emitted.. + self._handleMouseClick(self.indexAt(self.__lastPosition)) + + def mouseMoveEvent(self, event): + # LegendListView.mouseMoveEvent is overwritten + # to suppress unwanted behavior in the delegate. + pass + + def mouseReleaseEvent(self, event): + # LegendListView.mouseReleaseEvent is overwritten + # to subpress unwanted behavior in the delegate. + pass + + def _handleMouseClick(self, modelIndex): + """ + Distinguish between mouse click on Legend + and mouse click on CheckBox by setting the + currentCheckState attribute in LegendListItem. + + Emits signal sigLegendSignal(ddict) + + :param QModelIndex modelIndex: index of the clicked item + """ + _logger.debug('self._handleMouseClick called') + if self.__lastButton not in [qt.Qt.LeftButton, + qt.Qt.RightButton]: + return + if not modelIndex.isValid(): + _logger.debug('_handleMouseClick -- Invalid QModelIndex') + return + # model = self.model() + idx = modelIndex.row() + + delegate = self.itemDelegate() + cbClicked = False + if isinstance(delegate, LegendListItemWidget): + for cbRect in delegate.cbDict.values(): + if cbRect.contains(self.__lastPosition): + cbClicked = True + break + + # TODO: Check for doubleclicks on legend/icon and spawn editors + + ddict = { + 'legend': str(modelIndex.data(qt.Qt.DisplayRole)), + 'icon': { + 'linewidth': str(modelIndex.data( + LegendModel.iconLineWidthRole)), + 'linestyle': str(modelIndex.data( + LegendModel.iconLineStyleRole)), + 'symbol': str(modelIndex.data(LegendModel.iconSymbolRole)) + }, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()) + } + if self.__lastButton == qt.Qt.RightButton: + _logger.debug('Right clicked') + ddict['button'] = "right" + ddict['event'] = self.__mouseClickedEvent + elif cbClicked: + _logger.debug('CheckBox clicked') + ddict['button'] = "left" + ddict['event'] = self.__checkBoxClickedEvent + else: + _logger.debug('Legend clicked') + ddict['button'] = "left" + ddict['event'] = self.__legendClickedEvent + _logger.debug(' idx: %d\n ddict: %s', idx, str(ddict)) + self.sigLegendSignal.emit(ddict) + + +class LegendListContextMenu(qt.QMenu): + """Contextual menu associated to items in a :class:`LegendListView`.""" + + sigContextMenu = qt.Signal(object) + """Signal emitting a dict upon contextual menu actions.""" + + def __init__(self, model): + super(LegendListContextMenu, self).__init__(parent=None) + self.model = model + + self.addAction('Set Active', self.setActiveAction) + self.addAction('Map to left', self.mapToLeftAction) + self.addAction('Map to right', self.mapToRightAction) + + self._pointsAction = self.addAction( + 'Points', self.togglePointsAction) + self._pointsAction.setCheckable(True) + + self._linesAction = self.addAction('Lines', self.toggleLinesAction) + self._linesAction.setCheckable(True) + + self.addAction('Remove curve', self.removeItemAction) + self.addAction('Rename curve', self.renameItemAction) + + def exec(self, pos, idx): + self.__currentIdx = idx + + # Set checkable action state + modelIndex = self.currentIdx() + self._pointsAction.setChecked( + modelIndex.data(LegendModel.showSymbolRole)) + self._linesAction.setChecked( + modelIndex.data(LegendModel.showLineRole)) + + super(LegendListContextMenu, self).popup(pos) + + def exec_(self, pos, idx): # Qt5-like compatibility + return self.exec(pos, idx) + + def currentIdx(self): + return self.__currentIdx + + def mapToLeftAction(self): + _logger.debug('LegendListContextMenu.mapToLeftAction called') + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "mapToLeft" + } + self.sigContextMenu.emit(ddict) + + def mapToRightAction(self): + _logger.debug('LegendListContextMenu.mapToRightAction called') + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "mapToRight" + } + self.sigContextMenu.emit(ddict) + + def removeItemAction(self): + _logger.debug('LegendListContextMenu.removeCurveAction called') + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "removeCurve" + } + self.model.removeRow(modelIndex.row()) + self.sigContextMenu.emit(ddict) + + def renameItemAction(self): + _logger.debug('LegendListContextMenu.renameCurveAction called') + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "renameCurve" + } + self.sigContextMenu.emit(ddict) + + def toggleLinesAction(self): + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + } + linestyle = modelIndex.data(LegendModel.iconLineStyleRole) + visible = not modelIndex.data(LegendModel.showLineRole) + _logger.debug('toggleLinesAction -- lines visible: %s', str(visible)) + ddict['event'] = "toggleLine" + ddict['line'] = visible + ddict['linestyle'] = linestyle if visible else '' + self.model.setData(modelIndex, visible, LegendModel.showLineRole) + self.sigContextMenu.emit(ddict) + + def togglePointsAction(self): + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + } + flag = modelIndex.data(LegendModel.showSymbolRole) + symbol = modelIndex.data(LegendModel.iconSymbolRole) + visible = not flag or LegendIconWidget.isEmptySymbol(symbol) + _logger.debug( + 'togglePointsAction -- Symbols visible: %s', str(visible)) + + ddict['event'] = "togglePoints" + ddict['points'] = visible + ddict['symbol'] = symbol if visible else '' + self.model.setData(modelIndex, visible, LegendModel.showSymbolRole) + self.sigContextMenu.emit(ddict) + + def setActiveAction(self): + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + _logger.debug('setActiveAction -- active curve: %s', legend) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "setActiveCurve", + } + self.sigContextMenu.emit(ddict) + + +class RenameCurveDialog(qt.QDialog): + """Dialog box to input the name of a curve.""" + + def __init__(self, parent=None, current="", curves=()): + super(RenameCurveDialog, self).__init__(parent) + self.setWindowTitle("Rename Curve %s" % current) + self.curves = curves + layout = qt.QVBoxLayout(self) + self.lineEdit = qt.QLineEdit(self) + self.lineEdit.setText(current) + self.hbox = qt.QWidget(self) + self.hboxLayout = qt.QHBoxLayout(self.hbox) + self.hboxLayout.addStretch(1) + self.okButton = qt.QPushButton(self.hbox) + self.okButton.setText('OK') + self.hboxLayout.addWidget(self.okButton) + self.cancelButton = qt.QPushButton(self.hbox) + self.cancelButton.setText('Cancel') + self.hboxLayout.addWidget(self.cancelButton) + self.hboxLayout.addStretch(1) + layout.addWidget(self.lineEdit) + layout.addWidget(self.hbox) + self.okButton.clicked.connect(self.preAccept) + self.cancelButton.clicked.connect(self.reject) + + def preAccept(self): + text = str(self.lineEdit.text()) + addedText = "" + if len(text): + if text not in self.curves: + self.accept() + return + else: + addedText = "Curve already exists." + text = "Invalid Curve Name" + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setWindowTitle(text) + text += "\n%s" % addedText + msg.setText(text) + msg.exec() + + def getText(self): + return str(self.lineEdit.text()) + + +class LegendsDockWidget(qt.QDockWidget): + """QDockWidget with a :class:`LegendSelector` connected to a PlotWindow. + + It makes the link between the LegendListView widget and the PlotWindow. + + :param parent: See :class:`QDockWidget` + :param plot: :class:`.PlotWindow` instance on which to operate + """ + + def __init__(self, parent=None, plot=None): + assert plot is not None + self._plotRef = weakref.ref(plot) + self._isConnected = False # True if widget connected to plot signals + + super(LegendsDockWidget, self).__init__("Legends", parent) + + self._legendWidget = LegendListView() + + self.layout().setContentsMargins(0, 0, 0, 0) + self.setWidget(self._legendWidget) + + self.visibilityChanged.connect( + self._visibilityChangedHandler) + + self._legendWidget.sigLegendSignal.connect(self._legendSignalHandler) + + @property + def plot(self): + """The :class:`.PlotWindow` this widget is attached to.""" + return self._plotRef() + + def renameCurve(self, oldLegend, newLegend): + """Change the name of a curve using remove and addCurve + + :param str oldLegend: The legend of the curve to be changed + :param str newLegend: The new legend of the curve + """ + is_active = self.plot.getActiveCurve(just_legend=True) == oldLegend + curve = self.plot.getCurve(oldLegend) + self.plot.remove(oldLegend, kind='curve') + self.plot.addCurve(curve.getXData(copy=False), + curve.getYData(copy=False), + legend=newLegend, + info=curve.getInfo(), + color=curve.getColor(), + symbol=curve.getSymbol(), + linewidth=curve.getLineWidth(), + linestyle=curve.getLineStyle(), + xlabel=curve.getXLabel(), + ylabel=curve.getYLabel(), + xerror=curve.getXErrorData(copy=False), + yerror=curve.getYErrorData(copy=False), + z=curve.getZValue(), + selectable=curve.isSelectable(), + fill=curve.isFill(), + resetzoom=False) + if is_active: + self.plot.setActiveCurve(newLegend) + + def _legendSignalHandler(self, ddict): + """Handles events from the LegendListView signal""" + _logger.debug("Legend signal ddict = %s", str(ddict)) + + if ddict['event'] == "legendClicked": + if ddict['button'] == "left": + self.plot.setActiveCurve(ddict['legend']) + + elif ddict['event'] == "removeCurve": + self.plot.removeCurve(ddict['legend']) + + elif ddict['event'] == "renameCurve": + curveList = self.plot.getAllCurves(just_legend=True) + oldLegend = ddict['legend'] + dialog = RenameCurveDialog(self.plot, oldLegend, curveList) + ret = dialog.exec() + if ret: + newLegend = dialog.getText() + self.renameCurve(oldLegend, newLegend) + + elif ddict['event'] == "setActiveCurve": + self.plot.setActiveCurve(ddict['legend']) + + elif ddict['event'] == "checkBoxClicked": + self.plot.hideCurve(ddict['legend'], not ddict['selected']) + + elif ddict['event'] in ["mapToRight", "mapToLeft"]: + legend = ddict['legend'] + curve = self.plot.getCurve(legend) + yaxis = 'right' if ddict['event'] == 'mapToRight' else 'left' + self.plot.addCurve(x=curve.getXData(copy=False), + y=curve.getYData(copy=False), + legend=curve.getName(), + info=curve.getInfo(), + yaxis=yaxis) + + elif ddict['event'] == "togglePoints": + legend = ddict['legend'] + curve = self.plot.getCurve(legend) + symbol = ddict['symbol'] if ddict['points'] else '' + self.plot.addCurve(x=curve.getXData(copy=False), + y=curve.getYData(copy=False), + legend=curve.getName(), + info=curve.getInfo(), + symbol=symbol) + + elif ddict['event'] == "toggleLine": + legend = ddict['legend'] + curve = self.plot.getCurve(legend) + linestyle = ddict['linestyle'] if ddict['line'] else '' + self.plot.addCurve(x=curve.getXData(copy=False), + y=curve.getYData(copy=False), + legend=curve.getName(), + info=curve.getInfo(), + linestyle=linestyle) + + else: + _logger.debug("unhandled event %s", str(ddict['event'])) + + def updateLegends(self, *args): + """Sync the LegendSelector widget displayed info with the plot. + """ + legendList = [] + for curve in self.plot.getAllCurves(withhidden=True): + legend = curve.getName() + # Use active color if curve is active + isActive = legend == self.plot.getActiveCurve(just_legend=True) + style = curve.getCurrentStyle() + color = style.getColor() + if numpy.array(color, copy=False).ndim != 1: + # array of colors, use transparent black + color = 0., 0., 0., 0. + + curveInfo = { + 'color': qt.QColor.fromRgbF(*color), + 'linewidth': style.getLineWidth(), + 'linestyle': style.getLineStyle(), + 'symbol': style.getSymbol(), + 'selected': not self.plot.isCurveHidden(legend), + 'active': isActive} + legendList.append((legend, curveInfo)) + + self._legendWidget.setLegendList(legendList) + + def _visibilityChangedHandler(self, visible): + if visible: + self.updateLegends() + if not self._isConnected: + self.plot.sigContentChanged.connect(self.updateLegends) + self.plot.sigActiveCurveChanged.connect(self.updateLegends) + self._isConnected = True + else: + if self._isConnected: + self.plot.sigContentChanged.disconnect(self.updateLegends) + self.plot.sigActiveCurveChanged.disconnect(self.updateLegends) + self._isConnected = False + + def showEvent(self, event): + """Make sure this widget is raised when it is shown + (when it is first created as a tab in PlotWindow or when it is shown + again after hiding). + """ + self.raise_() diff --git a/src/silx/gui/plot/LimitsHistory.py b/src/silx/gui/plot/LimitsHistory.py new file mode 100644 index 0000000..a323548 --- /dev/null +++ b/src/silx/gui/plot/LimitsHistory.py @@ -0,0 +1,83 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides handling of :class:`PlotWidget` limits history. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "19/07/2017" + + +from .. import qt + + +class LimitsHistory(qt.QObject): + """Class handling history of limits of a :class:`PlotWidget`. + + :param PlotWidget parent: The plot widget this object is bound to. + """ + + def __init__(self, parent): + self._history = [] + super(LimitsHistory, self).__init__(parent) + self.setParent(parent) + + def setParent(self, parent): + """See :meth:`QObject.setParent`. + + :param PlotWidget parent: The PlotWidget this object is bound to. + """ + self.clear() # Clear history when changing parent + super(LimitsHistory, self).setParent(parent) + + def push(self): + """Append current limits to the history.""" + plot = self.parent() + xmin, xmax = plot.getXAxis().getLimits() + ymin, ymax = plot.getYAxis(axis='left').getLimits() + y2min, y2max = plot.getYAxis(axis='right').getLimits() + self._history.append((xmin, xmax, ymin, ymax, y2min, y2max)) + + def pop(self): + """Restore previously limits stored in the history. + + :return: True if limits were restored, False if history was empty. + :rtype: bool + """ + plot = self.parent() + if self._history: + limits = self._history.pop(-1) + plot.setLimits(*limits) + return True + else: + plot.resetZoom() + return False + + def clear(self): + """Clear stored limits states.""" + self._history = [] + + def __len__(self): + return len(self._history) diff --git a/src/silx/gui/plot/MaskToolsWidget.py b/src/silx/gui/plot/MaskToolsWidget.py new file mode 100644 index 0000000..522be48 --- /dev/null +++ b/src/silx/gui/plot/MaskToolsWidget.py @@ -0,0 +1,919 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Widget providing a set of tools to draw masks on a PlotWidget. + +This widget is meant to work with :class:`silx.gui.plot.PlotWidget`. + +- :class:`ImageMask`: Handle mask bitmap update and history +- :class:`MaskToolsWidget`: GUI for :class:`Mask` +- :class:`MaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow` +""" +from __future__ import division + +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "08/12/2020" + +import os +import sys +import numpy +import logging +import collections +import h5py + +from silx.image import shapes +from silx.io.utils import NEXUS_HDF5_EXT, is_dataset +from silx.gui.dialog.DatasetDialog import DatasetDialog + +from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget +from . import items +from ..colors import cursorColorForColormap, rgba +from .. import qt +from ..utils import LockReentrant + +from silx.third_party.EdfFile import EdfFile +from silx.third_party.TiffIO import TiffIO + +import fabio + +_logger = logging.getLogger(__name__) + +_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT]) + + +def _selectDataset(filename, mode=DatasetDialog.SaveMode): + """Open a dialog to prompt the user to select a dataset in + a hdf5 file. + + :param str filename: name of an existing HDF5 file + :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode + :rtype: str + :return: Name of selected dataset + """ + dialog = DatasetDialog() + dialog.addFile(filename) + dialog.setWindowTitle("Select a 2D dataset") + dialog.setMode(mode) + if not dialog.exec(): + return None + return dialog.getSelectedDataUrl().data_path() + + +class ImageMask(BaseMask): + """A 2D mask field with update operations. + + Coords follows (row, column) convention and are in mask array coords. + + This is meant for internal use by :class:`MaskToolsWidget`. + """ + + def __init__(self, image=None): + """ + + :param image: :class:`silx.gui.plot.items.ImageBase` instance + """ + BaseMask.__init__(self, image) + self.reset(shape=(0, 0)) # Init the mask with a 2D shape + + def getDataValues(self): + """Return image data as a 2D or 3D array (if it is a RGBA image). + + :rtype: 2D or 3D numpy.ndarray + """ + return self._dataItem.getData(copy=False) + + def save(self, filename, kind): + """Save current mask in a file + + :param str filename: The file where to save to mask + :param str kind: The kind of file to save in 'edf', 'tif', 'npy', 'h5' + or 'msk' (if FabIO is installed) + :raise Exception: Raised if the file writing fail + """ + if kind == 'edf': + edfFile = EdfFile(filename, access="w+") + header = {"program_name": "silx-mask", "masked_value": "nonzero"} + edfFile.WriteImage(header, self.getMask(copy=False), Append=0) + + elif kind == 'tif': + tiffFile = TiffIO(filename, mode='w') + tiffFile.writeImage(self.getMask(copy=False), software='silx') + + elif kind == 'npy': + try: + numpy.save(filename, self.getMask(copy=False)) + except IOError: + raise RuntimeError("Mask file can't be written") + + elif ("." + kind) in NEXUS_HDF5_EXT: + self._saveToHdf5(filename, self.getMask(copy=False)) + + elif kind == 'msk': + try: + data = self.getMask(copy=False) + image = fabio.fabioimage.FabioImage(data=data) + image = image.convert(fabio.fit2dmaskimage.Fit2dMaskImage) + image.save(filename) + except Exception: + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError("Mask file can't be written") + else: + raise ValueError("Format '%s' is not supported" % kind) + + @staticmethod + def _saveToHdf5(filename, mask): + """Save a mask array to a HDF5 file. + + :param str filename: name of an existing HDF5 file + :param numpy.ndarray mask: Mask array. + :returns: True if operation succeeded, False otherwise. + """ + if not os.path.exists(filename): + # create new file + with h5py.File(filename, "w") as _h5f: + pass + dataPath = _selectDataset(filename) + if dataPath is None: + return False + with h5py.File(filename, "a") as h5f: + existing_ds = h5f.get(dataPath) + if existing_ds is not None: + reply = qt.QMessageBox.question( + None, + "Confirm overwrite", + "Do you want to overwrite an existing dataset?", + qt.QMessageBox.Yes | qt.QMessageBox.No) + if reply != qt.QMessageBox.Yes: + return False + del h5f[dataPath] + try: + h5f.create_dataset(dataPath, data=mask) + except Exception: + return False + return True + + # Drawing operations + def updateRectangle(self, level, row, col, height, width, mask=True): + """Mask/Unmask a rectangle of the given mask level. + + :param int level: Mask level to update. + :param int row: Starting row of the rectangle + :param int col: Starting column of the rectangle + :param int height: + :param int width: + :param bool mask: True to mask (default), False to unmask. + """ + assert 0 < level < 256 + if row + height <= 0 or col + width <= 0: + return # Rectangle outside image, avoid negative indices + selection = self._mask[max(0, row):row + height + 1, + max(0, col):col + width + 1] + if mask: + selection[:,:] = level + else: + selection[selection == level] = 0 + self._notify() + + def updatePolygon(self, level, vertices, mask=True): + """Mask/Unmask a polygon of the given mask level. + + :param int level: Mask level to update. + :param vertices: Nx2 array of polygon corners as (row, col) + :param bool mask: True to mask (default), False to unmask. + """ + fill = shapes.polygon_fill_mask(vertices, self._mask.shape) + if mask: + self._mask[fill != 0] = level + else: + self._mask[numpy.logical_and(fill != 0, + self._mask == level)] = 0 + self._notify() + + def updatePoints(self, level, rows, cols, mask=True): + """Mask/Unmask points with given coordinates. + + :param int level: Mask level to update. + :param rows: Rows of selected points + :type rows: 1D numpy.ndarray + :param cols: Columns of selected points + :type cols: 1D numpy.ndarray + :param bool mask: True to mask (default), False to unmask. + """ + valid = numpy.logical_and( + numpy.logical_and(rows >= 0, cols >= 0), + numpy.logical_and(rows < self._mask.shape[0], + cols < self._mask.shape[1])) + rows, cols = rows[valid], cols[valid] + + if mask: + self._mask[rows, cols] = level + else: + inMask = self._mask[rows, cols] == level + self._mask[rows[inMask], cols[inMask]] = 0 + self._notify() + + def updateDisk(self, level, crow, ccol, radius, mask=True): + """Mask/Unmask a disk of the given mask level. + + :param int level: Mask level to update. + :param int crow: Disk center row. + :param int ccol: Disk center column. + :param float radius: Radius of the disk in mask array unit + :param bool mask: True to mask (default), False to unmask. + """ + rows, cols = shapes.circle_fill(crow, ccol, radius) + self.updatePoints(level, rows, cols, mask) + + def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True): + """Mask/Unmask an ellipse of the given mask level. + + :param int level: Mask level to update. + :param int crow: Row of the center of the ellipse + :param int ccol: Column of the center of the ellipse + :param float radius_r: Radius of the ellipse in the row + :param float radius_c: Radius of the ellipse in the column + :param bool mask: True to mask (default), False to unmask. + """ + rows, cols = shapes.ellipse_fill(crow, ccol, radius_r, radius_c) + self.updatePoints(level, rows, cols, mask) + + def updateLine(self, level, row0, col0, row1, col1, width, mask=True): + """Mask/Unmask a line of the given mask level. + + :param int level: Mask level to update. + :param int row0: Row of the starting point. + :param int col0: Column of the starting point. + :param int row1: Row of the end point. + :param int col1: Column of the end point. + :param int width: Width of the line in mask array unit. + :param bool mask: True to mask (default), False to unmask. + """ + rows, cols = shapes.draw_line(row0, col0, row1, col1, width) + self.updatePoints(level, rows, cols, mask) + + +class MaskToolsWidget(BaseMaskToolsWidget): + """Widget with tools for drawing mask on an image in a PlotWidget.""" + + _maxLevelNumber = 255 + + def __init__(self, parent=None, plot=None): + super(MaskToolsWidget, self).__init__(parent, plot, + mask=ImageMask()) + self._origin = (0., 0.) # Mask origin in plot + self._scale = (1., 1.) # Mask scale in plot + self._z = 1 # Mask layer in plot + self._data = numpy.zeros((0, 0), dtype=numpy.uint8) # Store image + + self.__itemMaskUpdatedLock = LockReentrant() + self.__itemMaskUpdated = False + + def __maskStateChanged(self) -> None: + """Handle mask commit to update item mask""" + item = self._mask.getDataItem() + if item is not None: + with self.__itemMaskUpdatedLock: + item.setMaskData(self._mask.getMask(copy=True), copy=False) + + def setItemMaskUpdated(self, enabled: bool) -> None: + """Toggle item mask and mask tool synchronisation. + + :param bool enabled: True to synchronise. Default: False + """ + enabled = bool(enabled) + if enabled != self.__itemMaskUpdated: + if self.__itemMaskUpdated: + self._mask.sigStateChanged.disconnect(self.__maskStateChanged) + self.__itemMaskUpdated = enabled + if self.__itemMaskUpdated: + # Synchronize item and tool mask + self._setMaskedImage(self._mask.getDataItem()) + self._mask.sigStateChanged.connect(self.__maskStateChanged) + + def isItemMaskUpdated(self) -> bool: + """Returns whether or not item and mask tool masks are synchronised. + + :rtype: bool + """ + return self.__itemMaskUpdated + + def setSelectionMask(self, mask, copy=True): + """Set the mask to a new array. + + :param numpy.ndarray mask: + The array to use for the mask or None to reset the mask. + :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous. + Array of other types are converted. + :param bool copy: True (the default) to copy the array, + False to use it as is if possible. + :return: None if failed, shape of mask as 2-tuple if successful. + The mask can be cropped or padded to fit active image, + the returned shape is that of the active image. + """ + if mask is None: + self.resetSelectionMask() + return self._data.shape[:2] + + mask = numpy.array(mask, copy=False, dtype=numpy.uint8) + if len(mask.shape) != 2: + _logger.error('Not an image, shape: %d', len(mask.shape)) + return None + + # Handle mask with single level + if self.multipleMasks() == 'single': + mask = numpy.array(mask != 0, dtype=numpy.uint8) + + # if mask has not changed, do nothing + if numpy.array_equal(mask, self.getSelectionMask()): + return mask.shape + + if self._data.shape[0:2] == (0, 0) or mask.shape == self._data.shape[0:2]: + self._mask.setMask(mask, copy=copy) + self._mask.commit() + return mask.shape + else: + _logger.warning('Mask has not the same size as current image.' + ' Mask will be cropped or padded to fit image' + ' dimensions. %s != %s', + str(mask.shape), str(self._data.shape)) + resizedMask = numpy.zeros(self._data.shape[0:2], + dtype=numpy.uint8) + height = min(self._data.shape[0], mask.shape[0]) + width = min(self._data.shape[1], mask.shape[1]) + resizedMask[:height,:width] = mask[:height,:width] + self._mask.setMask(resizedMask, copy=False) + self._mask.commit() + return resizedMask.shape + + # Handle mask refresh on the plot + def _updatePlotMask(self): + """Update mask image in plot""" + mask = self.getSelectionMask(copy=False) + if mask is not None: + # get the mask from the plot + maskItem = self.plot.getImage(self._maskName) + mustBeAdded = maskItem is None + if mustBeAdded: + maskItem = items.MaskImageData() + maskItem.setName(self._maskName) + # update the items + maskItem.setData(mask, copy=False) + maskItem.setColormap(self._colormap) + maskItem.setOrigin(self._origin) + maskItem.setScale(self._scale) + maskItem.setZValue(self._z) + + if mustBeAdded: + self.plot.addItem(maskItem) + + elif self.plot.getImage(self._maskName): + self.plot.remove(self._maskName, kind='image') + + def showEvent(self, event): + try: + self.plot.sigActiveImageChanged.disconnect( + self._activeImageChangedAfterCare) + except (RuntimeError, TypeError): + pass + + # Sync with current active image + self._setMaskedImage(self.plot.getActiveImage()) + self.plot.sigActiveImageChanged.connect(self._activeImageChanged) + + def hideEvent(self, event): + try: + self.plot.sigActiveImageChanged.disconnect( + self._activeImageChanged) + except (RuntimeError, TypeError): + pass + + image = self.getMaskedItem() + if image is not None: + try: + image.sigItemChanged.disconnect(self.__imageChanged) + except (RuntimeError, TypeError): + pass # TODO should not happen + + if self.isMaskInteractionActivated(): + # Disable drawing tool + self.browseAction.trigger() + + if self.isItemMaskUpdated(): # No "after-care" + self._data = numpy.zeros((0, 0), dtype=numpy.uint8) + self._mask.setDataItem(None) + self._mask.reset() + + if self.plot.getImage(self._maskName): + self.plot.remove(self._maskName, kind='image') + + elif self.getSelectionMask(copy=False) is not None: + self.plot.sigActiveImageChanged.connect( + self._activeImageChangedAfterCare) + + def _activeImageChanged(self, previous, current): + """Reacts upon active image change. + + Only handle change of active image items here. + """ + if previous != current: + image = self.plot.getActiveImage() + if image is not None and image.getName() == self._maskName: + image = None # Active image is the mask + self._setMaskedImage(image) + + def _setOverlayColorForImage(self, image): + """Set the color of overlay adapted to image + + :param image: :class:`.items.ImageBase` object to set color for. + """ + if isinstance(image, items.ColormapMixIn): + colormap = image.getColormap() + self._defaultOverlayColor = rgba( + cursorColorForColormap(colormap['name'])) + else: + self._defaultOverlayColor = rgba('black') + + def _activeImageChangedAfterCare(self, *args): + """Check synchro of active image and mask when mask widget is hidden. + + If active image has no more the same size as the mask, the mask is + removed, otherwise it is adjusted to origin, scale and z. + """ + activeImage = self.plot.getActiveImage() + if activeImage is None or activeImage.getName() == self._maskName: + # No active image or active image is the mask... + self._data = numpy.zeros((0, 0), dtype=numpy.uint8) + self._mask.setDataItem(None) + self._mask.reset() + + if self.plot.getImage(self._maskName): + self.plot.remove(self._maskName, kind='image') + + self.plot.sigActiveImageChanged.disconnect( + self._activeImageChangedAfterCare) + else: + self._setOverlayColorForImage(activeImage) + self._setMaskColors(self.levelSpinBox.value(), + self.transparencySlider.value() / + self.transparencySlider.maximum()) + + self._origin = activeImage.getOrigin() + self._scale = activeImage.getScale() + self._z = activeImage.getZValue() + 1 + self._data = activeImage.getData(copy=False) + if self._data.shape[:2] != self._mask.getMask(copy=False).shape: + # Image has not the same size, remove mask and stop listening + if self.plot.getImage(self._maskName): + self.plot.remove(self._maskName, kind='image') + + self.plot.sigActiveImageChanged.disconnect( + self._activeImageChangedAfterCare) + else: + # Refresh in case origin, scale, z changed + self._mask.setDataItem(activeImage) + self._updatePlotMask() + + def _setMaskedImage(self, image): + """Change the image that is used a reference to author the mask""" + previous = self.getMaskedItem() + if previous is not None and self.isVisible(): + # Disconnect from previous image + try: + previous.sigItemChanged.disconnect(self.__imageChanged) + except (RuntimeError, TypeError): + pass # TODO fixme should not happen + + # Set the image + self._mask.setDataItem(image) + + if image is None: # No image, disable mask + self.setEnabled(False) + + self._data = numpy.zeros((0, 0), dtype=numpy.uint8) + self._mask.reset() + self._mask.commit() + + self._updateInteractiveMode() + + else: # Update and connect to image's sigItemChanged + if self.isItemMaskUpdated(): + if image.getMaskData(copy=False) is None: + # Image item has no mask: use current mask from the tool + image.setMaskData( + self.getSelectionMask(copy=False), copy=True) + else: # Image item has a mask: set it in tool + self.setSelectionMask( + image.getMaskData(copy=False), copy=True) + self._mask.resetHistory() + self.__imageUpdated() + if self.isVisible(): + image.sigItemChanged.connect(self.__imageChanged) + + def __imageChanged(self, event): + """Reacts upon image item changes""" + image = self._mask.getDataItem() + if image is None: + _logger.error("Mask is not attached to an image") + return + + if event in (items.ItemChangedType.COLORMAP, + items.ItemChangedType.DATA, + items.ItemChangedType.POSITION, + items.ItemChangedType.SCALE, + items.ItemChangedType.VISIBLE, + items.ItemChangedType.ZVALUE): + self.__imageUpdated() + + elif (event == items.ItemChangedType.MASK and + self.isItemMaskUpdated() and + not self.__itemMaskUpdatedLock.locked()): + # Update mask from the image item unless mask tool is updating it + self.setSelectionMask(image.getMaskData(copy=False), copy=True) + + def __imageUpdated(self): + """Synchronize mask with current state of the image""" + image = self._mask.getDataItem() + if image is None: + _logger.error("No active image while expecting one") + return + + self._setOverlayColorForImage(image) + + self._setMaskColors(self.levelSpinBox.value(), + self.transparencySlider.value() / + self.transparencySlider.maximum()) + + self._origin = image.getOrigin() + self._scale = image.getScale() + self._z = image.getZValue() + 1 + self._data = image.getData(copy=False) + self._mask.setDataItem(image) + if self._data.shape[:2] != self._mask.getMask(copy=False).shape: + self._mask.reset(self._data.shape[:2]) + self._mask.commit() + else: + # Refresh in case origin, scale, z changed + self._updatePlotMask() + + # Visible and with data + self.setEnabled(image.isVisible() and self._data.size != 0) + + # Threshold tools only available for data with colormap + self.thresholdGroup.setEnabled(self._data.ndim == 2) + + self._updateInteractiveMode() + + # Handle whole mask operations + def load(self, filename): + """Load a mask from an image file. + + :param str filename: File name from which to load the mask + :raise Exception: An exception in case of failure + :raise RuntimeWarning: In case the mask was applied but with some + import changes to notice + """ + _, extension = os.path.splitext(filename) + extension = extension.lower()[1:] + + if extension == "npy": + try: + mask = numpy.load(filename) + except IOError: + _logger.error("Can't load filename '%s'", filename) + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError('File "%s" is not a numpy file.', filename) + elif extension in ["tif", "tiff"]: + try: + image = TiffIO(filename, mode="r") + mask = image.getImage(0) + except Exception as e: + _logger.error("Can't load filename %s", filename) + _logger.debug("Backtrace", exc_info=True) + raise e + elif extension == "edf": + try: + mask = EdfFile(filename, access='r').GetData(0) + except Exception as e: + _logger.error("Can't load filename %s", filename) + _logger.debug("Backtrace", exc_info=True) + raise e + elif extension == "msk": + try: + mask = fabio.open(filename).data + except Exception as e: + _logger.error("Can't load fit2d mask file") + _logger.debug("Backtrace", exc_info=True) + raise e + elif ("." + extension) in NEXUS_HDF5_EXT: + mask = self._loadFromHdf5(filename) + if mask is None: + raise IOError("Could not load mask from HDF5 dataset") + else: + msg = "Extension '%s' is not supported." + raise RuntimeError(msg % extension) + + effectiveMaskShape = self.setSelectionMask(mask, copy=False) + if effectiveMaskShape is None: + return + if mask.shape != effectiveMaskShape: + msg = 'Mask was resized from %s to %s' + msg = msg % (str(mask.shape), str(effectiveMaskShape)) + raise RuntimeWarning(msg) + + def _loadMask(self): + """Open load mask dialog""" + dialog = qt.QFileDialog(self) + dialog.setWindowTitle("Load Mask") + dialog.setModal(1) + + extensions = collections.OrderedDict() + extensions["EDF files"] = "*.edf" + extensions["TIFF files"] = "*.tif *.tiff" + extensions["NumPy binary files"] = "*.npy" + extensions["HDF5 files"] = _HDF5_EXT_STR + # Fit2D mask is displayed anyway fabio is here or not + # to show to the user that the option exists + extensions["Fit2D mask files"] = "*.msk" + + filters = [] + filters.append("All supported files (%s)" % " ".join(extensions.values())) + for name, extension in extensions.items(): + filters.append("%s (%s)" % (name, extension)) + filters.append("All files (*)") + + dialog.setNameFilters(filters) + dialog.setFileMode(qt.QFileDialog.ExistingFile) + dialog.setDirectory(self.maskFileDir) + if not dialog.exec(): + dialog.close() + return + + filename = dialog.selectedFiles()[0] + dialog.close() + + # Update the directory according to the user selection + self.maskFileDir = os.path.dirname(filename) + + try: + self.load(filename) + except RuntimeWarning as e: + message = e.args[0] + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Warning) + msg.setText("Mask loaded but an operation was applied.\n" + message) + msg.exec() + except Exception as e: + message = e.args[0] + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Cannot load mask from file. " + message) + msg.exec() + + @staticmethod + def _loadFromHdf5(filename): + """Load a mask array from a HDF5 file. + + :param str filename: name of an existing HDF5 file + :returns: AÂ mask as a numpy array, or None if the interactive dialog + was cancelled + """ + dataPath = _selectDataset(filename, mode=DatasetDialog.LoadMode) + if dataPath is None: + return None + + with h5py.File(filename, "r") as h5f: + dataset = h5f.get(dataPath) + if not is_dataset(dataset): + raise IOError("%s is not a dataset" % dataPath) + mask = dataset[()] + return mask + + def _saveMask(self): + """Open Save mask dialog""" + dialog = qt.QFileDialog(self) + dialog.setWindowTitle("Save Mask") + dialog.setOption(dialog.DontUseNativeDialog) + dialog.setModal(1) + hdf5Filter = 'HDF5 (%s)' % _HDF5_EXT_STR + filters = [ + 'EDF (*.edf)', + 'TIFF (*.tif)', + 'NumPy binary file (*.npy)', + hdf5Filter, + # Fit2D mask is displayed anyway fabio is here or not + # to show to the user that the option exists + 'Fit2D mask (*.msk)', + ] + dialog.setNameFilters(filters) + dialog.setFileMode(qt.QFileDialog.AnyFile) + dialog.setAcceptMode(qt.QFileDialog.AcceptSave) + dialog.setDirectory(self.maskFileDir) + + def onFilterSelection(filt_): + # disable overwrite confirmation for HDF5, + # because we append the data to existing files + if filt_ == hdf5Filter: + dialog.setOption(dialog.DontConfirmOverwrite) + else: + dialog.setOption(dialog.DontConfirmOverwrite, False) + + dialog.filterSelected.connect(onFilterSelection) + if not dialog.exec(): + dialog.close() + return + + nameFilter = dialog.selectedNameFilter() + filename = dialog.selectedFiles()[0] + dialog.close() + + if "HDF5" in nameFilter: + has_allowed_ext = False + for ext in NEXUS_HDF5_EXT: + if (len(filename) > len(ext) and + filename[-len(ext):].lower() == ext.lower()): + has_allowed_ext = True + extension = ext + if not has_allowed_ext: + extension = ".h5" + filename += ".h5" + else: + # convert filter name to extension name with the . + extension = nameFilter.split()[-1][2:-1] + if not filename.lower().endswith(extension): + filename += extension + + if os.path.exists(filename) and "HDF5" not in nameFilter: + try: + os.remove(filename) + except IOError as e: + msg = qt.QMessageBox(self) + msg.setWindowTitle("Removing existing file") + msg.setIcon(qt.QMessageBox.Critical) + + if hasattr(e, "strerror"): + strerror = e.strerror + else: + strerror = sys.exc_info()[1] + msg.setText("Cannot save.\n" + "Input Output Error: %s" % strerror) + msg.exec() + return + + # Update the directory according to the user selection + self.maskFileDir = os.path.dirname(filename) + + try: + self.save(filename, extension[1:]) + except Exception as e: + msg = qt.QMessageBox(self) + msg.setWindowTitle("Saving mask file") + msg.setIcon(qt.QMessageBox.Critical) + + if hasattr(e, "strerror"): + strerror = e.strerror + else: + strerror = sys.exc_info()[1] + msg.setText("Cannot save file %s\n%s" % (filename, strerror)) + msg.exec() + + def resetSelectionMask(self): + """Reset the mask""" + self._mask.reset(shape=self._data.shape[:2]) + self._mask.commit() + + def _plotDrawEvent(self, event): + """Handle draw events from the plot""" + if (self._drawingMode is None or + event['event'] not in ('drawingProgress', 'drawingFinished')): + return + + if not len(self._data): + return + + level = self.levelSpinBox.value() + + if self._drawingMode == 'rectangle': + if event['event'] == 'drawingFinished': + # Convert from plot to array coords + doMask = self._isMasking() + ox, oy = self._origin + sx, sy = self._scale + + height = int(abs(event['height'] / sy)) + width = int(abs(event['width'] / sx)) + + row = int((event['y'] - oy) / sy) + if sy < 0: + row -= height + + col = int((event['x'] - ox) / sx) + if sx < 0: + col -= width + + self._mask.updateRectangle( + level, + row=row, + col=col, + height=height, + width=width, + mask=doMask) + self._mask.commit() + + elif self._drawingMode == 'ellipse': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + # Convert from plot to array coords + center = (event['points'][0] - self._origin) / self._scale + size = event['points'][1] / self._scale + center = center.astype(numpy.int64) # (row, col) + self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask) + self._mask.commit() + + elif self._drawingMode == 'polygon': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + # Convert from plot to array coords + vertices = (event['points'] - self._origin) / self._scale + vertices = vertices.astype(numpy.int64)[:, (1, 0)] # (row, col) + self._mask.updatePolygon(level, vertices, doMask) + self._mask.commit() + + elif self._drawingMode == 'pencil': + doMask = self._isMasking() + # convert from plot to array coords + col, row = (event['points'][-1] - self._origin) / self._scale + col, row = int(col), int(row) + brushSize = self._getPencilWidth() + + if self._lastPencilPos != (row, col): + if self._lastPencilPos is not None: + # Draw the line + self._mask.updateLine( + level, + self._lastPencilPos[0], self._lastPencilPos[1], + row, col, + brushSize, + doMask) + + # Draw the very first, or last point + self._mask.updateDisk(level, row, col, brushSize / 2., doMask) + + if event['event'] == 'drawingFinished': + self._mask.commit() + self._lastPencilPos = None + else: + self._lastPencilPos = row, col + else: + _logger.error("Drawing mode %s unsupported", self._drawingMode) + + def _loadRangeFromColormapTriggered(self): + """Set range from active image colormap range""" + activeImage = self.plot.getActiveImage() + if (isinstance(activeImage, items.ColormapMixIn) and + activeImage.getName() != self._maskName): + # Update thresholds according to colormap + colormap = activeImage.getColormap() + if colormap['autoscale']: + min_ = numpy.nanmin(activeImage.getData(copy=False)) + max_ = numpy.nanmax(activeImage.getData(copy=False)) + else: + min_, max_ = colormap['vmin'], colormap['vmax'] + self.minLineEdit.setText(str(min_)) + self.maxLineEdit.setText(str(max_)) + + +class MaskToolsDockWidget(BaseMaskToolsDockWidget): + """:class:`MaskToolsWidget` embedded in a QDockWidget. + + For integration in a :class:`PlotWindow`. + + :param parent: See :class:`QDockWidget` + :param plot: The PlotWidget this widget is operating on + :paran str name: The title of this widget + """ + + def __init__(self, parent=None, plot=None, name='Mask'): + widget = MaskToolsWidget(plot=plot) + super(MaskToolsDockWidget, self).__init__(parent, name, widget) diff --git a/src/silx/gui/plot/PlotActions.py b/src/silx/gui/plot/PlotActions.py new file mode 100644 index 0000000..dd16221 --- /dev/null +++ b/src/silx/gui/plot/PlotActions.py @@ -0,0 +1,67 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Depracted module linking old PlotAction with the actions.xxx""" + + +__author__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "01/06/2017" + +from silx.utils.deprecation import deprecated_warning + +deprecated_warning(type_='module', + name=__file__, + reason='PlotActions refactoring', + replacement='plot.actions', + since_version='0.6') + +from .actions import PlotAction + +from .actions.io import CopyAction +from .actions.io import PrintAction +from .actions.io import SaveAction + +from .actions.control import ColormapAction +from .actions.control import CrosshairAction +from .actions.control import CurveStyleAction +from .actions.control import GridAction +from .actions.control import KeepAspectRatioAction +from .actions.control import PanWithArrowKeysAction +from .actions.control import ResetZoomAction +from .actions.control import XAxisAutoScaleAction +from .actions.control import XAxisLogarithmicAction +from .actions.control import YAxisAutoScaleAction +from .actions.control import YAxisLogarithmicAction +from .actions.control import YAxisInvertedAction +from .actions.control import ZoomInAction +from .actions.control import ZoomOutAction + +from .actions.medfilt import MedianFilter1DAction +from .actions.medfilt import MedianFilter2DAction +from .actions.medfilt import MedianFilterAction + +from .actions.histogram import PixelIntensitiesHistoAction + +from .actions.fit import FitAction diff --git a/src/silx/gui/plot/PlotEvents.py b/src/silx/gui/plot/PlotEvents.py new file mode 100644 index 0000000..83f253c --- /dev/null +++ b/src/silx/gui/plot/PlotEvents.py @@ -0,0 +1,166 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Functions to prepare events to be sent to Plot callback.""" + +__author__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "18/02/2016" + + +import numpy as np + + +def prepareDrawingSignal(event, type_, points, parameters=None): + """See Plot documentation for content of events""" + assert event in ('drawingProgress', 'drawingFinished') + + if parameters is None: + parameters = {} + + eventDict = {} + eventDict['event'] = event + eventDict['type'] = type_ + points = np.array(points, dtype=np.float32) + points.shape = -1, 2 + eventDict['points'] = points + eventDict['xdata'] = points[:, 0] + eventDict['ydata'] = points[:, 1] + if type_ in ('rectangle',): + eventDict['x'] = eventDict['xdata'].min() + eventDict['y'] = eventDict['ydata'].min() + eventDict['width'] = eventDict['xdata'].max() - eventDict['x'] + eventDict['height'] = eventDict['ydata'].max() - eventDict['y'] + eventDict['parameters'] = parameters.copy() + return eventDict + + +def prepareMouseSignal(eventType, button, xData, yData, xPixel, yPixel): + """See Plot documentation for content of events""" + assert eventType in ('mouseMoved', 'mouseClicked', 'mouseDoubleClicked') + assert button in (None, 'left', 'middle', 'right') + + return {'event': eventType, + 'x': xData, + 'y': yData, + 'xpixel': xPixel, + 'ypixel': yPixel, + 'button': button} + + +def prepareHoverSignal(label, type_, posData, posPixel, draggable, selectable): + """See Plot documentation for content of events""" + return {'event': 'hover', + 'label': label, + 'type': type_, + 'x': posData[0], + 'y': posData[1], + 'xpixel': posPixel[0], + 'ypixel': posPixel[1], + 'draggable': draggable, + 'selectable': selectable} + + +def prepareMarkerSignal(eventType, button, label, type_, + draggable, selectable, + posDataMarker, + posPixelCursor=None, posDataCursor=None): + """See Plot documentation for content of events""" + if eventType == 'markerClicked': + assert posPixelCursor is not None + assert posDataCursor is None + + posDataCursor = list(posDataMarker) + if hasattr(posDataCursor[0], "__len__"): + posDataCursor[0] = posDataCursor[0][-1] + if hasattr(posDataCursor[1], "__len__"): + posDataCursor[1] = posDataCursor[1][-1] + + elif eventType == 'markerMoving': + assert posPixelCursor is not None + assert posDataCursor is not None + + elif eventType == 'markerMoved': + assert posPixelCursor is None + assert posDataCursor is None + + posDataCursor = posDataMarker + else: + raise NotImplementedError("Unknown event type {0}".format(eventType)) + + eventDict = {'event': eventType, + 'button': button, + 'label': label, + 'type': type_, + 'x': posDataCursor[0], + 'y': posDataCursor[1], + 'xdata': posDataMarker[0], + 'ydata': posDataMarker[1], + 'draggable': draggable, + 'selectable': selectable} + + if eventType in ('markerMoving', 'markerClicked'): + eventDict['xpixel'] = posPixelCursor[0] + eventDict['ypixel'] = posPixelCursor[1] + + return eventDict + + +def prepareImageSignal(button, label, type_, col, row, + x, y, xPixel, yPixel): + """See Plot documentation for content of events""" + return {'event': 'imageClicked', + 'button': button, + 'label': label, + 'type': type_, + 'col': col, + 'row': row, + 'x': x, + 'y': y, + 'xpixel': xPixel, + 'ypixel': yPixel} + + +def prepareCurveSignal(button, label, type_, xData, yData, + x, y, xPixel, yPixel): + """See Plot documentation for content of events""" + return {'event': 'curveClicked', + 'button': button, + 'label': label, + 'type': type_, + 'xdata': xData, + 'ydata': yData, + 'x': x, + 'y': y, + 'xpixel': xPixel, + 'ypixel': yPixel} + + +def prepareLimitsChangedSignal(sourceObj, xRange, yRange, y2Range): + """See Plot documentation for content of events""" + return {'event': 'limitsChanged', + 'source': id(sourceObj), + 'xdata': xRange, + 'ydata': yRange, + 'y2data': y2Range} diff --git a/src/silx/gui/plot/PlotInteraction.py b/src/silx/gui/plot/PlotInteraction.py new file mode 100644 index 0000000..6ebe6b1 --- /dev/null +++ b/src/silx/gui/plot/PlotInteraction.py @@ -0,0 +1,1746 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Implementation of the interaction for the :class:`Plot`.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "15/02/2019" + + +import math +import numpy +import time +import weakref + +from .. import colors +from .. import qt +from . import items +from .Interaction import (ClickOrDrag, LEFT_BTN, RIGHT_BTN, MIDDLE_BTN, + State, StateMachine) +from .PlotEvents import (prepareCurveSignal, prepareDrawingSignal, + prepareHoverSignal, prepareImageSignal, + prepareMarkerSignal, prepareMouseSignal) + +from .backends.BackendBase import (CURSOR_POINTING, CURSOR_SIZE_HOR, + CURSOR_SIZE_VER, CURSOR_SIZE_ALL) + +from ._utils import (FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX, + applyZoomToPlot) + + +# Base class ################################################################## + +class _PlotInteraction(object): + """Base class for interaction handler. + + It provides a weakref to the plot and methods to set/reset overlay. + """ + def __init__(self, plot): + """Init. + + :param plot: The plot to apply modifications to. + """ + self._needReplot = False + self._selectionAreas = set() + self._plot = weakref.ref(plot) # Avoid cyclic-ref + + @property + def plot(self): + plot = self._plot() + assert plot is not None + return plot + + def setSelectionArea(self, points, fill, color, name='', shape='polygon'): + """Set a polygon selection area overlaid on the plot. + Multiple simultaneous areas are supported through the name parameter. + + :param points: The 2D coordinates of the points of the polygon + :type points: An iterable of (x, y) coordinates + :param str fill: The fill mode: 'hatch', 'solid' or 'none' + :param color: RGBA color to use or None to disable display + :type color: list or tuple of 4 float in the range [0, 1] + :param name: The key associated with this selection area + :param str shape: Shape of the area in 'polygon', 'polylines' + """ + assert shape in ('polygon', 'polylines') + + if color is None: + return + + points = numpy.asarray(points) + + # TODO Not very nice, but as is for now + legend = '__SELECTION_AREA__' + name + + fill = fill != 'none' # TODO not very nice either + + greyed = colors.greyed(color)[0] + if greyed < 0.5: + color2 = "white" + else: + color2 = "black" + + self.plot.addShape(points[:, 0], points[:, 1], legend=legend, + replace=False, + shape=shape, fill=fill, + color=color, linebgcolor=color2, linestyle="--", + overlay=True) + + self._selectionAreas.add(legend) + + def resetSelectionArea(self): + """Remove all selection areas set by setSelectionArea.""" + for legend in self._selectionAreas: + self.plot.remove(legend, kind='item') + self._selectionAreas = set() + + +# Zoom/Pan #################################################################### + +class _ZoomOnWheel(ClickOrDrag, _PlotInteraction): + """:class:`ClickOrDrag` state machine with zooming on mouse wheel. + + Base class for :class:`Pan` and :class:`Zoom` + """ + + _DOUBLE_CLICK_TIMEOUT = 0.4 + + class Idle(ClickOrDrag.Idle): + def onWheel(self, x, y, angle): + scaleF = 1.1 if angle > 0 else 1. / 1.1 + applyZoomToPlot(self.machine.plot, scaleF, (x, y)) + + def click(self, x, y, btn): + """Handle clicks by sending events + + :param int x: Mouse X position in pixels + :param int y: Mouse Y position in pixels + :param btn: Clicked mouse button + """ + if btn == LEFT_BTN: + lastClickTime, lastClickPos = self._lastClick + + # Signal mouse double clicked event first + if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT: + # Use position of first click + eventDict = prepareMouseSignal('mouseDoubleClicked', 'left', + *lastClickPos) + self.plot.notify(**eventDict) + + self._lastClick = 0., None + else: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareMouseSignal('mouseClicked', 'left', + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**eventDict) + + self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y) + + elif btn == RIGHT_BTN: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareMouseSignal('mouseClicked', 'right', + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**eventDict) + + def __init__(self, plot, **kwargs): + """Init. + + :param plot: The plot to apply modifications to. + """ + self._lastClick = 0., None + + _PlotInteraction.__init__(self, plot) + ClickOrDrag.__init__(self, **kwargs) + + +# Pan ######################################################################### + +class Pan(_ZoomOnWheel): + """Pan plot content and zoom on wheel state machine.""" + + def _pixelToData(self, x, y): + xData, yData = self.plot.pixelToData(x, y) + _, y2Data = self.plot.pixelToData(x, y, axis='right') + return xData, yData, y2Data + + def beginDrag(self, x, y, btn): + self._previousDataPos = self._pixelToData(x, y) + + def drag(self, x, y, btn): + xData, yData, y2Data = self._pixelToData(x, y) + lastX, lastY, lastY2 = self._previousDataPos + + xMin, xMax = self.plot.getXAxis().getLimits() + yMin, yMax = self.plot.getYAxis().getLimits() + y2Min, y2Max = self.plot.getYAxis(axis='right').getLimits() + + if self.plot.getXAxis()._isLogarithmic(): + try: + dx = math.log10(xData) - math.log10(lastX) + newXMin = pow(10., (math.log10(xMin) - dx)) + newXMax = pow(10., (math.log10(xMax) - dx)) + except (ValueError, OverflowError): + newXMin, newXMax = xMin, xMax + + # Makes sure both values stays in positive float32 range + if newXMin < FLOAT32_MINPOS or newXMax > FLOAT32_SAFE_MAX: + newXMin, newXMax = xMin, xMax + else: + dx = xData - lastX + newXMin, newXMax = xMin - dx, xMax - dx + + # Makes sure both values stays in float32 range + if newXMin < FLOAT32_SAFE_MIN or newXMax > FLOAT32_SAFE_MAX: + newXMin, newXMax = xMin, xMax + + if self.plot.getYAxis()._isLogarithmic(): + try: + dy = math.log10(yData) - math.log10(lastY) + newYMin = pow(10., math.log10(yMin) - dy) + newYMax = pow(10., math.log10(yMax) - dy) + + dy2 = math.log10(y2Data) - math.log10(lastY2) + newY2Min = pow(10., math.log10(y2Min) - dy2) + newY2Max = pow(10., math.log10(y2Max) - dy2) + except (ValueError, OverflowError): + newYMin, newYMax = yMin, yMax + newY2Min, newY2Max = y2Min, y2Max + + # Makes sure y and y2 stays in positive float32 range + if (newYMin < FLOAT32_MINPOS or newYMax > FLOAT32_SAFE_MAX or + newY2Min < FLOAT32_MINPOS or newY2Max > FLOAT32_SAFE_MAX): + newYMin, newYMax = yMin, yMax + newY2Min, newY2Max = y2Min, y2Max + else: + dy = yData - lastY + dy2 = y2Data - lastY2 + newYMin, newYMax = yMin - dy, yMax - dy + newY2Min, newY2Max = y2Min - dy2, y2Max - dy2 + + # Makes sure y and y2 stays in float32 range + if (newYMin < FLOAT32_SAFE_MIN or + newYMax > FLOAT32_SAFE_MAX or + newY2Min < FLOAT32_SAFE_MIN or + newY2Max > FLOAT32_SAFE_MAX): + newYMin, newYMax = yMin, yMax + newY2Min, newY2Max = y2Min, y2Max + + self.plot.setLimits(newXMin, newXMax, + newYMin, newYMax, + newY2Min, newY2Max) + + self._previousDataPos = self._pixelToData(x, y) + + def endDrag(self, startPos, endPos, btn): + del self._previousDataPos + + def cancel(self): + pass + + +# Zoom ######################################################################## + +class Zoom(_ZoomOnWheel): + """Zoom-in/out state machine. + + Zoom-in on selected area, zoom-out on right click, + and zoom on mouse wheel. + """ + + SURFACE_THRESHOLD = 5 + + def __init__(self, plot, color): + self.color = color + + super(Zoom, self).__init__(plot) + self.plot.getLimitsHistory().clear() + + def _areaWithAspectRatio(self, x0, y0, x1, y1): + _plotLeft, _plotTop, plotW, plotH = self.plot.getPlotBoundsInPixels() + + areaX0, areaY0, areaX1, areaY1 = x0, y0, x1, y1 + + if plotH != 0.: + plotRatio = plotW / float(plotH) + width, height = math.fabs(x1 - x0), math.fabs(y1 - y0) + + if height != 0. and width != 0.: + if width / height > plotRatio: + areaHeight = width / plotRatio + areaX0, areaX1 = x0, x1 + center = 0.5 * (y0 + y1) + areaY0 = center - numpy.sign(y1 - y0) * 0.5 * areaHeight + areaY1 = center + numpy.sign(y1 - y0) * 0.5 * areaHeight + else: + areaWidth = height * plotRatio + areaY0, areaY1 = y0, y1 + center = 0.5 * (x0 + x1) + areaX0 = center - numpy.sign(x1 - x0) * 0.5 * areaWidth + areaX1 = center + numpy.sign(x1 - x0) * 0.5 * areaWidth + + return areaX0, areaY0, areaX1, areaY1 + + def beginDrag(self, x, y, btn): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + self.x0, self.y0 = x, y + + def drag(self, x1, y1, btn): + if self.color is None: + return # Do not draw zoom area + + dataPos = self.plot.pixelToData(x1, y1) + assert dataPos is not None + + if self.plot.isKeepDataAspectRatio(): + area = self._areaWithAspectRatio(self.x0, self.y0, x1, y1) + areaX0, areaY0, areaX1, areaY1 = area + areaPoints = ((areaX0, areaY0), + (areaX1, areaY0), + (areaX1, areaY1), + (areaX0, areaY1)) + areaPoints = numpy.array([self.plot.pixelToData( + x, y, check=False) for (x, y) in areaPoints]) + + if self.color != 'video inverted': + areaColor = list(self.color) + areaColor[3] *= 0.25 + else: + areaColor = [1., 1., 1., 1.] + + self.setSelectionArea(areaPoints, + fill='none', + color=areaColor, + name="zoomedArea") + + corners = ((self.x0, self.y0), + (self.x0, y1), + (x1, y1), + (x1, self.y0)) + corners = numpy.array([self.plot.pixelToData(x, y, check=False) + for (x, y) in corners]) + + self.setSelectionArea(corners, fill='none', color=self.color) + + def _zoom(self, x0, y0, x1, y1): + """Zoom to the rectangle view x0,y0 x1,y1. + """ + startPos = x0, y0 + endPos = x1, y1 + + # Store current zoom state in stack + self.plot.getLimitsHistory().push() + + if self.plot.isKeepDataAspectRatio(): + x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1) + + # Convert to data space and set limits + x0, y0 = self.plot.pixelToData(x0, y0, check=False) + + dataPos = self.plot.pixelToData( + startPos[0], startPos[1], axis="right", check=False) + y2_0 = dataPos[1] + + x1, y1 = self.plot.pixelToData(x1, y1, check=False) + + dataPos = self.plot.pixelToData( + endPos[0], endPos[1], axis="right", check=False) + y2_1 = dataPos[1] + + xMin, xMax = min(x0, x1), max(x0, x1) + yMin, yMax = min(y0, y1), max(y0, y1) + y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1) + + self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + + def endDrag(self, startPos, endPos, btn): + x0, y0 = startPos + x1, y1 = endPos + + if abs(x0 - x1) * abs(y0 - y1) >= self.SURFACE_THRESHOLD: + # Avoid empty zoom area + self._zoom(x0, y0, x1, y1) + + self.resetSelectionArea() + + def cancel(self): + if isinstance(self.state, self.states['drag']): + self.resetSelectionArea() + + +# Select ###################################################################### + +class Select(StateMachine, _PlotInteraction): + """Base class for drawing selection areas.""" + + def __init__(self, plot, parameters, states, state): + """Init a state machine. + + :param plot: The plot to apply changes to. + :param dict parameters: A dict of parameters such as color. + :param dict states: The states of the state machine. + :param str state: The name of the initial state. + """ + _PlotInteraction.__init__(self, plot) + self.parameters = parameters + StateMachine.__init__(self, states, state) + + def onWheel(self, x, y, angle): + scaleF = 1.1 if angle > 0 else 1. / 1.1 + applyZoomToPlot(self.plot, scaleF, (x, y)) + + @property + def color(self): + return self.parameters.get('color', None) + + +class SelectPolygon(Select): + """Drawing selection polygon area state machine.""" + + DRAG_THRESHOLD_DIST = 4 + + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('select', x, y) + return True + + class Select(State): + def enterState(self, x, y): + dataPos = self.machine.plot.pixelToData(x, y) + assert dataPos is not None + self._firstPos = dataPos + self.points = [dataPos, dataPos] + + self.updateFirstPoint() + + def updateFirstPoint(self): + """Update drawing first point, using self._firstPos""" + x, y = self.machine.plot.dataToPixel(*self._firstPos, check=False) + + offset = self.machine.getDragThreshold() + points = [(x - offset, y - offset), + (x - offset, y + offset), + (x + offset, y + offset), + (x + offset, y - offset)] + points = [self.machine.plot.pixelToData(xpix, ypix, check=False) + for xpix, ypix in points] + self.machine.setSelectionArea(points, fill=None, + color=self.machine.color, + name='first_point') + + def updateSelectionArea(self): + """Update drawing selection area using self.points""" + self.machine.setSelectionArea(self.points, + fill='hatch', + color=self.machine.color) + eventDict = prepareDrawingSignal('drawingProgress', + 'polygon', + self.points, + self.machine.parameters) + self.machine.plot.notify(**eventDict) + + def validate(self): + if len(self.points) > 2: + self.closePolygon() + else: + # It would be nice to have a cancel event. + # The plot is not aware that the interaction was cancelled + self.machine.cancel() + + def closePolygon(self): + self.machine.resetSelectionArea() + self.points[-1] = self.points[0] + eventDict = prepareDrawingSignal('drawingFinished', + 'polygon', + self.points, + self.machine.parameters) + self.machine.plot.notify(**eventDict) + self.goto('idle') + + def onWheel(self, x, y, angle): + self.machine.onWheel(x, y, angle) + self.updateFirstPoint() + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + # checking if the position is close to the first point + # if yes : closing the "loop" + firstPos = self.machine.plot.dataToPixel(*self._firstPos, + check=False) + dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y) + + threshold = self.machine.getDragThreshold() + + # Only allow to close polygon after first point + if len(self.points) > 2 and dx <= threshold and dy <= threshold: + self.closePolygon() + return False + + # Update polygon last point not too close to previous one + dataPos = self.machine.plot.pixelToData(x, y) + assert dataPos is not None + self.updateSelectionArea() + + # checking that the new points isnt the same (within range) + # of the previous one + # This has to be done because sometimes the mouse release event + # is caught right after entering the Select state (i.e : press + # in Idle state, but with a slightly different position that + # the mouse press. So we had the two first vertices that were + # almost identical. + previousPos = self.machine.plot.dataToPixel(*self.points[-2], + check=False) + dx, dy = abs(previousPos[0] - x), abs(previousPos[1] - y) + if dx >= threshold or dy >= threshold: + self.points.append(dataPos) + else: + self.points[-1] = dataPos + + return True + return False + + def onMove(self, x, y): + firstPos = self.machine.plot.dataToPixel(*self._firstPos, + check=False) + dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y) + threshold = self.machine.getDragThreshold() + + if dx <= threshold and dy <= threshold: + x, y = firstPos # Snap to first point + + dataPos = self.machine.plot.pixelToData(x, y) + assert dataPos is not None + self.points[-1] = dataPos + self.updateSelectionArea() + + def __init__(self, plot, parameters): + states = { + 'idle': SelectPolygon.Idle, + 'select': SelectPolygon.Select + } + super(SelectPolygon, self).__init__(plot, parameters, + states, 'idle') + + def cancel(self): + if isinstance(self.state, self.states['select']): + self.resetSelectionArea() + + def getDragThreshold(self): + """Return dragging ratio with device to pixel ratio applied. + + :rtype: float + """ + ratio = self.plot.window().windowHandle().devicePixelRatio() + return self.DRAG_THRESHOLD_DIST * ratio + + +class Select2Points(Select): + """Base class for drawing selection based on 2 input points.""" + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('start', x, y) + return True + + class Start(State): + def enterState(self, x, y): + self.machine.beginSelect(x, y) + + def onMove(self, x, y): + self.goto('select', x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('select', x, y) + return True + + class Select(State): + def enterState(self, x, y): + self.onMove(x, y) + + def onMove(self, x, y): + self.machine.select(x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.machine.endSelect(x, y) + self.goto('idle') + + def __init__(self, plot, parameters): + states = { + 'idle': Select2Points.Idle, + 'start': Select2Points.Start, + 'select': Select2Points.Select + } + super(Select2Points, self).__init__(plot, parameters, + states, 'idle') + + def beginSelect(self, x, y): + pass + + def select(self, x, y): + pass + + def endSelect(self, x, y): + pass + + def cancelSelect(self): + pass + + def cancel(self): + if isinstance(self.state, self.states['select']): + self.cancelSelect() + + +class SelectEllipse(Select2Points): + """Drawing ellipse selection area state machine.""" + def beginSelect(self, x, y): + self.center = self.plot.pixelToData(x, y) + assert self.center is not None + + def _getEllipseSize(self, pointInEllipse): + """ + Returns the size from the center to the bounding box of the ellipse. + + :param Tuple[float,float] pointInEllipse: A point of the ellipse + :rtype: Tuple[float,float] + """ + x = abs(self.center[0] - pointInEllipse[0]) + y = abs(self.center[1] - pointInEllipse[1]) + if x == 0 or y == 0: + return x, y + # Ellipse definitions + # e: eccentricity + # a: length fron center to bounding box width + # b: length fron center to bounding box height + # Equations + # (1) b < a + # (2) For x,y a point in the ellipse: x^2/a^2 + y^2/b^2 = 1 + # (3) b = a * sqrt(1-e^2) + # (4) e = sqrt(a^2 - b^2) / a + + # The eccentricity of the ellipse defined by a,b=x,y is the same + # as the one we are searching for. + swap = x < y + if swap: + x, y = y, x + e = math.sqrt(x**2 - y**2) / x + # From (2) using (3) to replace b + # a^2 = x^2 + y^2 / (1-e^2) + a = math.sqrt(x**2 + y**2 / (1.0 - e**2)) + b = a * math.sqrt(1 - e**2) + if swap: + a, b = b, a + return a, b + + def select(self, x, y): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + width, height = self._getEllipseSize(dataPos) + + # Circle used for circle preview + nbpoints = 27. + angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints + circleShape = numpy.array((numpy.cos(angles) * width, + numpy.sin(angles) * height)).T + circleShape += numpy.array(self.center) + + self.setSelectionArea(circleShape, + shape="polygon", + fill='hatch', + color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'ellipse', + (self.center, (width, height)), + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + width, height = self._getEllipseSize(dataPos) + + eventDict = prepareDrawingSignal('drawingFinished', + 'ellipse', + (self.center, (width, height)), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + +class SelectRectangle(Select2Points): + """Drawing rectangle selection area state machine.""" + def beginSelect(self, x, y): + self.startPt = self.plot.pixelToData(x, y) + assert self.startPt is not None + + def select(self, x, y): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + self.setSelectionArea((self.startPt, + (self.startPt[0], dataPos[1]), + dataPos, + (dataPos[0], self.startPt[1])), + fill='hatch', + color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'rectangle', + (self.startPt, dataPos), + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + eventDict = prepareDrawingSignal('drawingFinished', + 'rectangle', + (self.startPt, dataPos), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + +class SelectLine(Select2Points): + """Drawing line selection area state machine.""" + def beginSelect(self, x, y): + self.startPt = self.plot.pixelToData(x, y) + assert self.startPt is not None + + def select(self, x, y): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + self.setSelectionArea((self.startPt, dataPos), + fill='hatch', + color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'line', + (self.startPt, dataPos), + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + eventDict = prepareDrawingSignal('drawingFinished', + 'line', + (self.startPt, dataPos), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + +class Select1Point(Select): + """Base class for drawing selection area based on one input point.""" + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('select', x, y) + return True + + class Select(State): + def enterState(self, x, y): + self.onMove(x, y) + + def onMove(self, x, y): + self.machine.select(x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.machine.endSelect(x, y) + self.goto('idle') + + def onWheel(self, x, y, angle): + self.machine.onWheel(x, y, angle) # Call select default wheel + self.machine.select(x, y) + + def __init__(self, plot, parameters): + states = { + 'idle': Select1Point.Idle, + 'select': Select1Point.Select + } + super(Select1Point, self).__init__(plot, parameters, states, 'idle') + + def select(self, x, y): + pass + + def endSelect(self, x, y): + pass + + def cancelSelect(self): + pass + + def cancel(self): + if isinstance(self.state, self.states['select']): + self.cancelSelect() + + +class SelectHLine(Select1Point): + """Drawing a horizontal line selection area state machine.""" + def _hLine(self, y): + """Return points in data coords of the segment visible in the plot. + + Supports non-orthogonal axes. + """ + left, _top, width, _height = self.plot.getPlotBoundsInPixels() + + dataPos1 = self.plot.pixelToData(left, y, check=False) + dataPos2 = self.plot.pixelToData(left + width, y, check=False) + return dataPos1, dataPos2 + + def select(self, x, y): + points = self._hLine(y) + self.setSelectionArea(points, fill='hatch', color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'hline', + points, + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + eventDict = prepareDrawingSignal('drawingFinished', + 'hline', + self._hLine(y), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + +class SelectVLine(Select1Point): + """Drawing a vertical line selection area state machine.""" + def _vLine(self, x): + """Return points in data coords of the segment visible in the plot. + + Supports non-orthogonal axes. + """ + _left, top, _width, height = self.plot.getPlotBoundsInPixels() + + dataPos1 = self.plot.pixelToData(x, top, check=False) + dataPos2 = self.plot.pixelToData(x, top + height, check=False) + return dataPos1, dataPos2 + + def select(self, x, y): + points = self._vLine(x) + self.setSelectionArea(points, fill='hatch', color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'vline', + points, + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + eventDict = prepareDrawingSignal('drawingFinished', + 'vline', + self._vLine(x), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + +class DrawFreeHand(Select): + """Interaction for drawing pencil. It display the preview of the pencil + before pressing the mouse. + """ + + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('select', x, y) + return True + + def onMove(self, x, y): + self.machine.updatePencilShape(x, y) + + def onLeave(self): + self.machine.cancel() + + class Select(State): + def enterState(self, x, y): + self.__isOut = False + self.machine.setFirstPoint(x, y) + + def onMove(self, x, y): + self.machine.updatePencilShape(x, y) + self.machine.select(x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + if self.__isOut: + self.machine.resetSelectionArea() + self.machine.endSelect(x, y) + self.goto('idle') + + def onEnter(self): + self.__isOut = False + + def onLeave(self): + self.__isOut = True + + def __init__(self, plot, parameters): + # Circle used for pencil preview + angle = numpy.arange(13.) * numpy.pi * 2.0 / 13. + size = parameters.get('width', 1.) * 0.5 + self._circle = size * numpy.array((numpy.cos(angle), + numpy.sin(angle))).T + + states = { + 'idle': DrawFreeHand.Idle, + 'select': DrawFreeHand.Select + } + super(DrawFreeHand, self).__init__(plot, parameters, states, 'idle') + + @property + def width(self): + return self.parameters.get('width', None) + + def setFirstPoint(self, x, y): + self._points = [] + self.select(x, y) + + def updatePencilShape(self, x, y): + center = self.plot.pixelToData(x, y, check=False) + assert center is not None + + polygon = center + self._circle + + self.setSelectionArea(polygon, fill='none', color=self.color) + + def select(self, x, y): + pos = self.plot.pixelToData(x, y, check=False) + if len(self._points) > 0: + if self._points[-1] == pos: + # Skip same points + return + self._points.append(pos) + eventDict = prepareDrawingSignal('drawingProgress', + 'polylines', + self._points, + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + pos = self.plot.pixelToData(x, y, check=False) + if len(self._points) > 0: + if self._points[-1] != pos: + # Append if different + self._points.append(pos) + + eventDict = prepareDrawingSignal('drawingFinished', + 'polylines', + self._points, + self.parameters) + self.plot.notify(**eventDict) + self._points = None + + def cancelSelect(self): + self.resetSelectionArea() + + def cancel(self): + self.resetSelectionArea() + + +class SelectFreeLine(ClickOrDrag, _PlotInteraction): + """Base class for drawing free lines with tools such as pencil.""" + + def __init__(self, plot, parameters): + """Init a state machine. + + :param plot: The plot to apply changes to. + :param dict parameters: A dict of parameters such as color. + """ + # self.DRAG_THRESHOLD_SQUARE_DIST = 1 # Disable first move threshold + self._points = [] + ClickOrDrag.__init__(self) + _PlotInteraction.__init__(self, plot) + self.parameters = parameters + + def onWheel(self, x, y, angle): + scaleF = 1.1 if angle > 0 else 1. / 1.1 + applyZoomToPlot(self.plot, scaleF, (x, y)) + + @property + def color(self): + return self.parameters.get('color', None) + + def click(self, x, y, btn): + if btn == LEFT_BTN: + self._processEvent(x, y, isLast=True) + + def beginDrag(self, x, y, btn): + self._processEvent(x, y, isLast=False) + + def drag(self, x, y, btn): + self._processEvent(x, y, isLast=False) + + def endDrag(self, startPos, endPos, btn): + x, y = endPos + self._processEvent(x, y, isLast=True) + + def cancel(self): + self.resetSelectionArea() + self._points = [] + + def _processEvent(self, x, y, isLast): + dataPos = self.plot.pixelToData(x, y, check=False) + isNewPoint = not self._points or dataPos != self._points[-1] + + if isNewPoint: + self._points.append(dataPos) + + if isNewPoint or isLast: + eventDict = prepareDrawingSignal( + 'drawingFinished' if isLast else 'drawingProgress', + 'polylines', + self._points, + self.parameters) + self.plot.notify(**eventDict) + + if not isLast: + self.setSelectionArea(self._points, fill='none', color=self.color, + shape='polylines') + else: + self.cancel() + + +# ItemInteraction ############################################################# + +class ItemsInteraction(ClickOrDrag, _PlotInteraction): + """Interaction with items (markers, curves and images). + + This class provides selection and dragging of plot primitives + that support those interaction. + It is also meant to be combined with the zoom interaction. + """ + + class Idle(ClickOrDrag.Idle): + def __init__(self, *args, **kw): + super(ItemsInteraction.Idle, self).__init__(*args, **kw) + self._hoverMarker = None + + def onWheel(self, x, y, angle): + scaleF = 1.1 if angle > 0 else 1. / 1.1 + applyZoomToPlot(self.machine.plot, scaleF, (x, y)) + + def onMove(self, x, y): + marker = self.machine.plot._getMarkerAt(x, y) + + if marker is not None: + dataPos = self.machine.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareHoverSignal( + marker.getName(), 'marker', + dataPos, (x, y), + marker.isDraggable(), + marker.isSelectable()) + self.machine.plot.notify(**eventDict) + + if marker != self._hoverMarker: + self._hoverMarker = marker + + if marker is None: + self.machine.plot.setGraphCursorShape() + + elif marker.isDraggable(): + if isinstance(marker, items.YMarker): + self.machine.plot.setGraphCursorShape(CURSOR_SIZE_VER) + elif isinstance(marker, items.XMarker): + self.machine.plot.setGraphCursorShape(CURSOR_SIZE_HOR) + else: + self.machine.plot.setGraphCursorShape(CURSOR_SIZE_ALL) + + elif marker.isSelectable(): + self.machine.plot.setGraphCursorShape(CURSOR_POINTING) + else: + self.machine.plot.setGraphCursorShape() + + return True + + def __init__(self, plot): + self._pan = Pan(plot) + + _PlotInteraction.__init__(self, plot) + ClickOrDrag.__init__(self, + clickButtons=(LEFT_BTN, RIGHT_BTN), + dragButtons=(LEFT_BTN, MIDDLE_BTN)) + + def click(self, x, y, btn): + """Handle mouse click + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :param btn: Pressed button id + :return: True if click is catched by an item, False otherwise + """ + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareMouseSignal('mouseClicked', btn, + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**eventDict) + + eventDict = self._handleClick(x, y, btn) + if eventDict is not None: + self.plot.notify(**eventDict) + + def _handleClick(self, x, y, btn): + """Perform picking and prepare event if click is handled here + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :param btn: Pressed button id + :return: event description to send of None if not handling event. + :rtype: dict or None + """ + + if btn == LEFT_BTN: + result = self.plot._pickTopMost(x, y, lambda i: i.isSelectable()) + if result is None: + return None + + item = result.getItem() + + if isinstance(item, items.MarkerBase): + xData, yData = item.getPosition() + if xData is None: + xData = [0, 1] + if yData is None: + yData = [0, 1] + + eventDict = prepareMarkerSignal('markerClicked', + 'left', + item.getName(), + 'marker', + item.isDraggable(), + item.isSelectable(), + (xData, yData), + (x, y), None) + return eventDict + + elif isinstance(item, items.Curve): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + xData = item.getXData(copy=False) + yData = item.getYData(copy=False) + + indices = result.getIndices(copy=False) + eventDict = prepareCurveSignal('left', + item.getName(), + 'curve', + xData[indices], + yData[indices], + dataPos[0], dataPos[1], + x, y) + return eventDict + + elif isinstance(item, items.ImageBase): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + indices = result.getIndices(copy=False) + row, column = indices[0][0], indices[1][0] + eventDict = prepareImageSignal('left', + item.getName(), + 'image', + column, row, + dataPos[0], dataPos[1], + x, y) + return eventDict + + return None + + def _signalMarkerMovingEvent(self, eventType, marker, x, y): + assert marker is not None + + xData, yData = marker.getPosition() + if xData is None: + xData = [0, 1] + if yData is None: + yData = [0, 1] + + posDataCursor = self.plot.pixelToData(x, y) + assert posDataCursor is not None + + eventDict = prepareMarkerSignal(eventType, + 'left', + marker.getName(), + 'marker', + marker.isDraggable(), + marker.isSelectable(), + (xData, yData), + (x, y), + posDataCursor) + self.plot.notify(**eventDict) + + @staticmethod + def __isDraggableItem(item): + return isinstance(item, items.DraggableMixIn) and item.isDraggable() + + def __terminateDrag(self): + """Finalize a drag operation by reseting to initial state""" + self.plot.setGraphCursorShape() + self.draggedItemRef = None + + def beginDrag(self, x, y, btn): + """Handle begining of drag interaction + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :param str btn: The mouse button for which a drag is starting. + :return: True if drag is catched by an item, False otherwise + """ + if btn == LEFT_BTN: + self._lastPos = self.plot.pixelToData(x, y) + assert self._lastPos is not None + + result = self.plot._pickTopMost(x, y, self.__isDraggableItem) + item = result.getItem() if result is not None else None + + self.draggedItemRef = None if item is None else weakref.ref(item) + + if item is None: + self.__terminateDrag() + return False + + if isinstance(item, items.MarkerBase): + self._signalMarkerMovingEvent('markerMoving', item, x, y) + item._startDrag() + + return True + elif btn == MIDDLE_BTN: + self._pan.beginDrag(x, y, btn) + return True + + def drag(self, x, y, btn): + if btn == LEFT_BTN: + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + item = None if self.draggedItemRef is None else self.draggedItemRef() + if item is not None: + item.drag(self._lastPos, dataPos) + + if isinstance(item, items.MarkerBase): + self._signalMarkerMovingEvent('markerMoving', item, x, y) + + self._lastPos = dataPos + elif btn == MIDDLE_BTN: + self._pan.drag(x, y, btn) + + def endDrag(self, startPos, endPos, btn): + if btn == LEFT_BTN: + item = None if self.draggedItemRef is None else self.draggedItemRef() + if isinstance(item, items.MarkerBase): + posData = list(item.getPosition()) + if posData[0] is None: + posData[0] = 1. + if posData[1] is None: + posData[1] = 1. + + eventDict = prepareMarkerSignal( + 'markerMoved', + 'left', + item.getLegend(), + 'marker', + item.isDraggable(), + item.isSelectable(), + posData) + self.plot.notify(**eventDict) + item._endDrag() + + self.__terminateDrag() + elif btn == MIDDLE_BTN: + self._pan.endDrag(startPos, endPos, btn) + + def cancel(self): + self._pan.cancel() + self.__terminateDrag() + + +class ItemsInteractionForCombo(ItemsInteraction): + """Interaction with items to combine through :class:`FocusManager`. + """ + + class Idle(ItemsInteraction.Idle): + @staticmethod + def __isItemSelectableOrDraggable(item): + return (item.isSelectable() or ( + isinstance(item, items.DraggableMixIn) and item.isDraggable())) + + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + result = self.machine.plot._pickTopMost( + x, y, self.__isItemSelectableOrDraggable) + if result is not None: # Request focus and handle interaction + self.goto('clickOrDrag', x, y, btn) + return True + else: # Do not request focus + return False + else: + return super().onPress(x, y, btn) + + +# FocusManager ################################################################ + +class FocusManager(StateMachine): + """Manages focus across multiple event handlers + + On press an event handler can acquire focus. + By default it looses focus when all buttons are released. + """ + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + for eventHandler in self.machine.eventHandlers: + requestFocus = eventHandler.handleEvent('press', x, y, btn) + if requestFocus: + self.goto('focus', eventHandler, btn) + break + + def _processEvent(self, *args): + for eventHandler in self.machine.eventHandlers: + consumeEvent = eventHandler.handleEvent(*args) + if consumeEvent: + break + + def onMove(self, x, y): + self._processEvent('move', x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self._processEvent('release', x, y, btn) + + def onWheel(self, x, y, angle): + self._processEvent('wheel', x, y, angle) + + class Focus(State): + def enterState(self, eventHandler, btn): + self.eventHandler = eventHandler + self.focusBtns = {btn} + + def validate(self): + self.eventHandler.validate() + self.goto('idle') + + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.focusBtns.add(btn) + self.eventHandler.handleEvent('press', x, y, btn) + + def onMove(self, x, y): + self.eventHandler.handleEvent('move', x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.focusBtns.discard(btn) + requestFocus = self.eventHandler.handleEvent('release', x, y, btn) + if len(self.focusBtns) == 0 and not requestFocus: + self.goto('idle') + + def onWheel(self, x, y, angleInDegrees): + self.eventHandler.handleEvent('wheel', x, y, angleInDegrees) + + def __init__(self, eventHandlers=()): + self.eventHandlers = list(eventHandlers) + + states = { + 'idle': FocusManager.Idle, + 'focus': FocusManager.Focus + } + super(FocusManager, self).__init__(states, 'idle') + + def cancel(self): + for handler in self.eventHandlers: + handler.cancel() + + +class ZoomAndSelect(ItemsInteraction): + """Combine Zoom and ItemInteraction state machine. + + :param plot: The Plot to which this interaction is attached + :param color: The color to use for the zoom area bounding box + """ + + def __init__(self, plot, color): + super(ZoomAndSelect, self).__init__(plot) + self._zoom = Zoom(plot, color) + self._doZoom = False + + @property + def color(self): + """Color of the zoom area""" + return self._zoom.color + + def click(self, x, y, btn): + """Handle mouse click + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :param btn: Pressed button id + :return: True if click is catched by an item, False otherwise + """ + eventDict = self._handleClick(x, y, btn) + + if eventDict is not None: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + clickedEventDict = prepareMouseSignal('mouseClicked', btn, + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**clickedEventDict) + + self.plot.notify(**eventDict) + + else: + self._zoom.click(x, y, btn) + + def beginDrag(self, x, y, btn): + """Handle start drag and switching between zoom and item drag. + + :param x: X position in pixels + :param y: Y position in pixels + :param str btn: The mouse button for which a drag is starting. + """ + self._doZoom = not super(ZoomAndSelect, self).beginDrag(x, y, btn) + if self._doZoom: + self._zoom.beginDrag(x, y, btn) + + def drag(self, x, y, btn): + """Handle drag, eventually forwarding to zoom. + + :param x: X position in pixels + :param y: Y position in pixels + :param str btn: The mouse button for which a drag is in progress. + """ + if self._doZoom: + return self._zoom.drag(x, y, btn) + else: + return super(ZoomAndSelect, self).drag(x, y, btn) + + def endDrag(self, startPos, endPos, btn): + """Handle end of drag, eventually forwarding to zoom. + + :param startPos: (x, y) position at the beginning of the drag + :param endPos: (x, y) position at the end of the drag + :param str btn: The mouse button for which a drag is done. + """ + if self._doZoom: + return self._zoom.endDrag(startPos, endPos, btn) + else: + return super(ZoomAndSelect, self).endDrag(startPos, endPos, btn) + + +class PanAndSelect(ItemsInteraction): + """Combine Pan and ItemInteraction state machine. + + :param plot: The Plot to which this interaction is attached + """ + + def __init__(self, plot): + super(PanAndSelect, self).__init__(plot) + self._pan = Pan(plot) + self._doPan = False + + def click(self, x, y, btn): + """Handle mouse click + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :param btn: Pressed button id + :return: True if click is catched by an item, False otherwise + """ + eventDict = self._handleClick(x, y, btn) + + if eventDict is not None: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + clickedEventDict = prepareMouseSignal('mouseClicked', btn, + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**clickedEventDict) + + self.plot.notify(**eventDict) + + else: + self._pan.click(x, y, btn) + + def beginDrag(self, x, y, btn): + """Handle start drag and switching between zoom and item drag. + + :param x: X position in pixels + :param y: Y position in pixels + :param str btn: The mouse button for which a drag is starting. + """ + self._doPan = not super(PanAndSelect, self).beginDrag(x, y, btn) + if self._doPan: + self._pan.beginDrag(x, y, btn) + + def drag(self, x, y, btn): + """Handle drag, eventually forwarding to zoom. + + :param x: X position in pixels + :param y: Y position in pixels + :param str btn: The mouse button for which a drag is in progress. + """ + if self._doPan: + return self._pan.drag(x, y, btn) + else: + return super(PanAndSelect, self).drag(x, y, btn) + + def endDrag(self, startPos, endPos, btn): + """Handle end of drag, eventually forwarding to zoom. + + :param startPos: (x, y) position at the beginning of the drag + :param endPos: (x, y) position at the end of the drag + :param str btn: The mouse button for which a drag is done. + """ + if self._doPan: + return self._pan.endDrag(startPos, endPos, btn) + else: + return super(PanAndSelect, self).endDrag(startPos, endPos, btn) + + +# Interaction mode control #################################################### + +# Mapping of draw modes: event handler +_DRAW_MODES = { + 'polygon': SelectPolygon, + 'rectangle': SelectRectangle, + 'ellipse': SelectEllipse, + 'line': SelectLine, + 'vline': SelectVLine, + 'hline': SelectHLine, + 'polylines': SelectFreeLine, + 'pencil': DrawFreeHand, + } + + +class DrawMode(FocusManager): + """Interactive mode for draw and select""" + + def __init__(self, plot, shape, label, color, width): + eventHandlerClass = _DRAW_MODES[shape] + parameters = { + 'shape': shape, + 'label': label, + 'color': color, + 'width': width, + } + super().__init__(( + Pan(plot, clickButtons=(), dragButtons=(MIDDLE_BTN,)), + eventHandlerClass(plot, parameters))) + + def getDescription(self): + """Returns the dict describing this interactive mode""" + params = self.eventHandlers[1].parameters.copy() + params['mode'] = 'draw' + return params + + +class DrawSelectMode(FocusManager): + """Interactive mode for draw and select""" + + def __init__(self, plot, shape, label, color, width): + eventHandlerClass = _DRAW_MODES[shape] + self._pan = Pan(plot) + self._panStart = None + parameters = { + 'shape': shape, + 'label': label, + 'color': color, + 'width': width, + } + super().__init__(( + ItemsInteractionForCombo(plot), + eventHandlerClass(plot, parameters))) + + def handleEvent(self, eventName, *args, **kwargs): + # Hack to add pan interaction to select-draw + # See issue Refactor PlotWidget interaction #3292 + if eventName == 'press' and args[2] == MIDDLE_BTN: + self._panStart = args[:2] + self._pan.beginDrag(*args) + return # Consume middle click events + elif eventName == 'release' and args[2] == MIDDLE_BTN: + self._panStart = None + self._pan.endDrag(self._panStart, args[:2], MIDDLE_BTN) + return # Consume middle click events + elif self._panStart is not None and eventName == 'move': + x, y = args[:2] + self._pan.drag(x, y, MIDDLE_BTN) + + super().handleEvent(eventName, *args, **kwargs) + + def getDescription(self): + """Returns the dict describing this interactive mode""" + params = self.eventHandlers[1].parameters.copy() + params['mode'] = 'select-draw' + return params + + +class PlotInteraction(object): + """Proxy to currently use state machine for interaction. + + This allows to switch interactive mode. + + :param plot: The :class:`Plot` to apply interaction to + """ + + _DRAW_MODES = { + 'polygon': SelectPolygon, + 'rectangle': SelectRectangle, + 'ellipse': SelectEllipse, + 'line': SelectLine, + 'vline': SelectVLine, + 'hline': SelectHLine, + 'polylines': SelectFreeLine, + 'pencil': DrawFreeHand, + } + + def __init__(self, plot): + self._plot = weakref.ref(plot) # Avoid cyclic-ref + + self.zoomOnWheel = True + """True to enable zoom on wheel, False otherwise.""" + + # Default event handler + self._eventHandler = ItemsInteraction(plot) + + def getInteractiveMode(self): + """Returns the current interactive mode as a dict. + + The returned dict contains at least the key 'mode'. + Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'. + It can also contains extra keys (e.g., 'color') specific to a mode + as provided to :meth:`setInteractiveMode`. + """ + if isinstance(self._eventHandler, ZoomAndSelect): + return {'mode': 'zoom', 'color': self._eventHandler.color} + + elif isinstance(self._eventHandler, (DrawMode, DrawSelectMode)): + return self._eventHandler.getDescription() + + elif isinstance(self._eventHandler, PanAndSelect): + return {'mode': 'pan'} + + else: + return {'mode': 'select'} + + def validate(self): + """Validate the current interaction if possible + + If was designed to close the polygon interaction. + """ + self._eventHandler.validate() + + def setInteractiveMode(self, mode, color='black', + shape='polygon', label=None, width=None): + """Switch the interactive mode. + + :param str mode: The name of the interactive mode. + In 'draw', 'pan', 'select', 'select-draw', 'zoom'. + :param color: Only for 'draw' and 'zoom' modes. + Color to use for drawing selection area. Default black. + If None, selection area is not drawn. + :type color: Color description: The name as a str or + a tuple of 4 floats or None. + :param str shape: Only for 'draw' mode. The kind of shape to draw. + In 'polygon', 'rectangle', 'line', 'vline', 'hline', + 'polylines'. + Default is 'polygon'. + :param str label: Only for 'draw' mode. + :param float width: Width of the pencil. Only for draw pencil mode. + """ + assert mode in ('draw', 'pan', 'select', 'select-draw', 'zoom') + + plot = self._plot() + assert plot is not None + + if isinstance(color, numpy.ndarray) or color not in (None, 'video inverted'): + color = colors.rgba(color) + + if mode in ('draw', 'select-draw'): + self._eventHandler.cancel() + handlerClass = DrawMode if mode == 'draw' else DrawSelectMode + self._eventHandler = handlerClass(plot, shape, label, color, width) + + elif mode == 'pan': + # Ignores color, shape and label + self._eventHandler.cancel() + self._eventHandler = PanAndSelect(plot) + + elif mode == 'zoom': + # Ignores shape and label + self._eventHandler.cancel() + self._eventHandler = ZoomAndSelect(plot, color) + + else: # Default mode: interaction with plot objects + # Ignores color, shape and label + self._eventHandler.cancel() + self._eventHandler = ItemsInteraction(plot) + + def handleEvent(self, event, *args, **kwargs): + """Forward event to current interactive mode state machine.""" + if not self.zoomOnWheel and event == 'wheel': + return # Discard wheel events + self._eventHandler.handleEvent(event, *args, **kwargs) diff --git a/src/silx/gui/plot/PlotToolButtons.py b/src/silx/gui/plot/PlotToolButtons.py new file mode 100644 index 0000000..3970896 --- /dev/null +++ b/src/silx/gui/plot/PlotToolButtons.py @@ -0,0 +1,592 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a set of QToolButton to use with +:class:`~silx.gui.plot.PlotWidget`. + +The following QToolButton are available: + +- :class:`.AspectToolButton` +- :class:`.YAxisOriginToolButton` +- :class:`.ProfileToolButton` +- :class:`.SymbolToolButton` + +""" + +__authors__ = ["V. Valls", "H. Payno"] +__license__ = "MIT" +__date__ = "27/06/2017" + + +import functools +import logging +import weakref + +from .. import icons +from .. import qt +from ... import config + +from .items import SymbolMixIn, Scatter + + +_logger = logging.getLogger(__name__) + + +class PlotToolButton(qt.QToolButton): + """A QToolButton connected to a :class:`~silx.gui.plot.PlotWidget`. + """ + + def __init__(self, parent=None, plot=None): + super(PlotToolButton, self).__init__(parent) + self._plotRef = None + if plot is not None: + self.setPlot(plot) + + def plot(self): + """ + Returns the plot connected to the widget. + """ + return None if self._plotRef is None else self._plotRef() + + def setPlot(self, plot): + """ + Set the plot connected to the widget + + :param plot: :class:`.PlotWidget` instance on which to operate. + """ + previousPlot = self.plot() + + if previousPlot is plot: + return + if previousPlot is not None: + self._disconnectPlot(previousPlot) + + if plot is None: + self._plotRef = None + else: + self._plotRef = weakref.ref(plot) + self._connectPlot(plot) + + def _connectPlot(self, plot): + """ + Called when the plot is connected to the widget + + :param plot: :class:`.PlotWidget` instance + """ + pass + + def _disconnectPlot(self, plot): + """ + Called when the plot is disconnected from the widget + + :param plot: :class:`.PlotWidget` instance + """ + pass + + +class AspectToolButton(PlotToolButton): + """Tool button to switch keep aspect ratio of a plot""" + + STATE = None + """Lazy loaded states used to feed AspectToolButton""" + + def __init__(self, parent=None, plot=None): + if self.STATE is None: + self.STATE = {} + # dont keep ratio + self.STATE[False, "icon"] = icons.getQIcon('shape-ellipse-solid') + self.STATE[False, "state"] = "Aspect ratio is not kept" + self.STATE[False, "action"] = "Do no keep data aspect ratio" + # keep ratio + self.STATE[True, "icon"] = icons.getQIcon('shape-circle-solid') + self.STATE[True, "state"] = "Aspect ratio is kept" + self.STATE[True, "action"] = "Keep data aspect ratio" + + super(AspectToolButton, self).__init__(parent=parent, plot=plot) + + keepAction = self._createAction(True) + keepAction.triggered.connect(self.keepDataAspectRatio) + keepAction.setIconVisibleInMenu(True) + + dontKeepAction = self._createAction(False) + dontKeepAction.triggered.connect(self.dontKeepDataAspectRatio) + dontKeepAction.setIconVisibleInMenu(True) + + menu = qt.QMenu(self) + menu.addAction(keepAction) + menu.addAction(dontKeepAction) + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + def _createAction(self, keepAspectRatio): + icon = self.STATE[keepAspectRatio, "icon"] + text = self.STATE[keepAspectRatio, "action"] + return qt.QAction(icon, text, self) + + def _connectPlot(self, plot): + plot.sigSetKeepDataAspectRatio.connect(self._keepDataAspectRatioChanged) + self._keepDataAspectRatioChanged(plot.isKeepDataAspectRatio()) + + def _disconnectPlot(self, plot): + plot.sigSetKeepDataAspectRatio.disconnect(self._keepDataAspectRatioChanged) + + def keepDataAspectRatio(self): + """Configure the plot to keep the aspect ratio""" + plot = self.plot() + if plot is not None: + # This will trigger _keepDataAspectRatioChanged + plot.setKeepDataAspectRatio(True) + + def dontKeepDataAspectRatio(self): + """Configure the plot to not keep the aspect ratio""" + plot = self.plot() + if plot is not None: + # This will trigger _keepDataAspectRatioChanged + plot.setKeepDataAspectRatio(False) + + def _keepDataAspectRatioChanged(self, aspectRatio): + """Handle Plot set keep aspect ratio signal""" + icon, toolTip = self.STATE[aspectRatio, "icon"], self.STATE[aspectRatio, "state"] + self.setIcon(icon) + self.setToolTip(toolTip) + + +class YAxisOriginToolButton(PlotToolButton): + """Tool button to switch the Y axis orientation of a plot.""" + + STATE = None + """Lazy loaded states used to feed YAxisOriginToolButton""" + + def __init__(self, parent=None, plot=None): + if self.STATE is None: + self.STATE = {} + # is down + self.STATE[False, "icon"] = icons.getQIcon('plot-ydown') + self.STATE[False, "state"] = "Y-axis is oriented downward" + self.STATE[False, "action"] = "Orient Y-axis downward" + # keep ration + self.STATE[True, "icon"] = icons.getQIcon('plot-yup') + self.STATE[True, "state"] = "Y-axis is oriented upward" + self.STATE[True, "action"] = "Orient Y-axis upward" + + super(YAxisOriginToolButton, self).__init__(parent=parent, plot=plot) + + upwardAction = self._createAction(True) + upwardAction.triggered.connect(self.setYAxisUpward) + upwardAction.setIconVisibleInMenu(True) + + downwardAction = self._createAction(False) + downwardAction.triggered.connect(self.setYAxisDownward) + downwardAction.setIconVisibleInMenu(True) + + menu = qt.QMenu(self) + menu.addAction(upwardAction) + menu.addAction(downwardAction) + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + def _createAction(self, isUpward): + icon = self.STATE[isUpward, "icon"] + text = self.STATE[isUpward, "action"] + return qt.QAction(icon, text, self) + + def _connectPlot(self, plot): + yAxis = plot.getYAxis() + yAxis.sigInvertedChanged.connect(self._yAxisInvertedChanged) + self._yAxisInvertedChanged(yAxis.isInverted()) + + def _disconnectPlot(self, plot): + plot.getYAxis().sigInvertedChanged.disconnect(self._yAxisInvertedChanged) + + def setYAxisUpward(self): + """Configure the plot to use y-axis upward""" + plot = self.plot() + if plot is not None: + # This will trigger _yAxisInvertedChanged + plot.getYAxis().setInverted(False) + + def setYAxisDownward(self): + """Configure the plot to use y-axis downward""" + plot = self.plot() + if plot is not None: + # This will trigger _yAxisInvertedChanged + plot.getYAxis().setInverted(True) + + def _yAxisInvertedChanged(self, inverted): + """Handle Plot set y axis inverted signal""" + isUpward = not inverted + icon, toolTip = self.STATE[isUpward, "icon"], self.STATE[isUpward, "state"] + self.setIcon(icon) + self.setToolTip(toolTip) + + +class ProfileOptionToolButton(PlotToolButton): + """Button to define option on the profile""" + sigMethodChanged = qt.Signal(str) + + def __init__(self, parent=None, plot=None): + PlotToolButton.__init__(self, parent=parent, plot=plot) + + self.STATE = {} + # is down + self.STATE['sum', "icon"] = icons.getQIcon('math-sigma') + self.STATE['sum', "state"] = "Compute profile sum" + self.STATE['sum', "action"] = "Compute profile sum" + # keep ration + self.STATE['mean', "icon"] = icons.getQIcon('math-mean') + self.STATE['mean', "state"] = "Compute profile mean" + self.STATE['mean', "action"] = "Compute profile mean" + + self.sumAction = self._createAction('sum') + self.sumAction.triggered.connect(self.setSum) + self.sumAction.setIconVisibleInMenu(True) + self.sumAction.setCheckable(True) + self.sumAction.setChecked(True) + + self.meanAction = self._createAction('mean') + self.meanAction.triggered.connect(self.setMean) + self.meanAction.setIconVisibleInMenu(True) + self.meanAction.setCheckable(True) + + menu = qt.QMenu(self) + menu.addAction(self.sumAction) + menu.addAction(self.meanAction) + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + self._method = 'mean' + self._update() + + def _createAction(self, method): + icon = self.STATE[method, "icon"] + text = self.STATE[method, "action"] + return qt.QAction(icon, text, self) + + def setSum(self): + self.setMethod('sum') + + def _update(self): + icon = self.STATE[self._method, "icon"] + toolTip = self.STATE[self._method, "state"] + self.setIcon(icon) + self.setToolTip(toolTip) + self.sumAction.setChecked(self._method == "sum") + self.meanAction.setChecked(self._method == "mean") + + def setMean(self): + self.setMethod('mean') + + def setMethod(self, method): + """Set the method to use. + + :param str method: Either 'sum' or 'mean' + """ + if method != self._method: + if method in ('sum', 'mean'): + self._method = method + self.sigMethodChanged.emit(self._method) + self._update() + else: + _logger.warning( + "Unsupported method '%s'. Setting ignored.", method) + + def getMethod(self): + """Returns the current method in use (See :meth:`setMethod`). + + :rtype: str + """ + return self._method + + +class ProfileToolButton(PlotToolButton): + """Button used in Profile3DToolbar to switch between 2D profile + and 1D profile.""" + STATE = None + """Lazy loaded states used to feed ProfileToolButton""" + + sigDimensionChanged = qt.Signal(int) + + def __init__(self, parent=None, plot=None): + if self.STATE is None: + self.STATE = { + (1, "icon"): icons.getQIcon('profile1D'), + (1, "state"): "1D profile is computed on visible image", + (1, "action"): "1D profile on visible image", + (2, "icon"): icons.getQIcon('profile2D'), + (2, "state"): "2D profile is computed, one 1D profile for each image in the stack", + (2, "action"): "2D profile on image stack"} + # Compute 1D profile + # Compute 2D profile + + super(ProfileToolButton, self).__init__(parent=parent, plot=plot) + + self._dimension = 1 + + profile1DAction = self._createAction(1) + profile1DAction.triggered.connect(self.computeProfileIn1D) + profile1DAction.setIconVisibleInMenu(True) + profile1DAction.setCheckable(True) + profile1DAction.setChecked(True) + self._profile1DAction = profile1DAction + + profile2DAction = self._createAction(2) + profile2DAction.triggered.connect(self.computeProfileIn2D) + profile2DAction.setIconVisibleInMenu(True) + profile2DAction.setCheckable(True) + self._profile2DAction = profile2DAction + + menu = qt.QMenu(self) + menu.addAction(profile1DAction) + menu.addAction(profile2DAction) + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + menu.setTitle('Select profile dimension') + self.computeProfileIn1D() + + def _createAction(self, profileDimension): + icon = self.STATE[profileDimension, "icon"] + text = self.STATE[profileDimension, "action"] + return qt.QAction(icon, text, self) + + def _profileDimensionChanged(self, profileDimension): + """Update icon in toolbar, emit number of dimensions for profile""" + self.setIcon(self.STATE[profileDimension, "icon"]) + self.setToolTip(self.STATE[profileDimension, "state"]) + self._dimension = profileDimension + self.sigDimensionChanged.emit(profileDimension) + self._profile1DAction.setChecked(profileDimension == 1) + self._profile2DAction.setChecked(profileDimension == 2) + + def computeProfileIn1D(self): + self._profileDimensionChanged(1) + + def computeProfileIn2D(self): + self._profileDimensionChanged(2) + + def setDimension(self, dimension): + """Set the selected dimension""" + assert dimension in [1, 2] + if self._dimension == dimension: + return + if dimension == 1: + self.computeProfileIn1D() + elif dimension == 2: + self.computeProfileIn2D() + else: + _logger.warning("Unsupported dimension '%s'. Setting ignored.", dimension) + + def getDimension(self): + """Get the selected dimension. + + :rtype: int (1 or 2) + """ + return self._dimension + + +class _SymbolToolButtonBase(PlotToolButton): + """Base class for PlotToolButton setting marker and size. + + :param parent: See QWidget + :param plot: The `~silx.gui.plot.PlotWidget` to control + """ + + def __init__(self, parent=None, plot=None): + super(_SymbolToolButtonBase, self).__init__(parent=parent, plot=plot) + + def _addSizeSliderToMenu(self, menu): + """Add a slider to set size to the given menu + + :param QMenu menu: + """ + slider = qt.QSlider(qt.Qt.Horizontal) + slider.setRange(1, 20) + slider.setValue(int(config.DEFAULT_PLOT_SYMBOL_SIZE)) + slider.setTracking(False) + slider.valueChanged.connect(self._sizeChanged) + widgetAction = qt.QWidgetAction(menu) + widgetAction.setDefaultWidget(slider) + menu.addAction(widgetAction) + + def _addSymbolsToMenu(self, menu): + """Add symbols to the given menu + + :param QMenu menu: + """ + for marker, name in zip(SymbolMixIn.getSupportedSymbols(), + SymbolMixIn.getSupportedSymbolNames()): + action = qt.QAction(name, menu) + action.setCheckable(False) + action.triggered.connect( + functools.partial(self._markerChanged, marker)) + menu.addAction(action) + + def _sizeChanged(self, value): + """Manage slider value changed + + :param int value: Marker size + """ + plot = self.plot() + if plot is None: + return + + for item in plot.getItems(): + if isinstance(item, SymbolMixIn): + item.setSymbolSize(value) + + def _markerChanged(self, marker): + """Manage change of marker. + + :param str marker: Letter describing the marker + """ + plot = self.plot() + if plot is None: + return + + for item in plot.getItems(): + if isinstance(item, SymbolMixIn): + item.setSymbol(marker) + + +class SymbolToolButton(_SymbolToolButtonBase): + """A tool button with a drop-down menu to control symbol size and marker. + + :param parent: See QWidget + :param plot: The `~silx.gui.plot.PlotWidget` to control + """ + + def __init__(self, parent=None, plot=None): + super(SymbolToolButton, self).__init__(parent=parent, plot=plot) + + self.setToolTip('Set symbol size and marker') + self.setIcon(icons.getQIcon('plot-symbols')) + + menu = qt.QMenu(self) + self._addSizeSliderToMenu(menu) + menu.addSeparator() + self._addSymbolsToMenu(menu) + + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + +class ScatterVisualizationToolButton(_SymbolToolButtonBase): + """QToolButton to select the visualization mode of scatter plot + + :param parent: See QWidget + :param plot: The `~silx.gui.plot.PlotWidget` to control + """ + + def __init__(self, parent=None, plot=None): + super(ScatterVisualizationToolButton, self).__init__( + parent=parent, plot=plot) + + self.setToolTip( + 'Set scatter visualization mode, symbol marker and size') + self.setIcon(icons.getQIcon('eye')) + + menu = qt.QMenu(self) + + # Add visualization modes + + for mode in Scatter.supportedVisualizations(): + if mode is not Scatter.Visualization.BINNED_STATISTIC: + name = mode.value.capitalize() + action = qt.QAction(name, menu) + action.setCheckable(False) + action.triggered.connect( + functools.partial(self._visualizationChanged, mode, None)) + menu.addAction(action) + + if Scatter.Visualization.BINNED_STATISTIC in Scatter.supportedVisualizations(): + reductions = Scatter.supportedVisualizationParameterValues( + Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION) + if reductions: + submenu = menu.addMenu('Binned Statistic') + for reduction in reductions: + name = reduction.capitalize() + action = qt.QAction(name, menu) + action.setCheckable(False) + action.triggered.connect(functools.partial( + self._visualizationChanged, + Scatter.Visualization.BINNED_STATISTIC, + {Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION: reduction})) + submenu.addAction(action) + + submenu.addSeparator() + binsmenu = submenu.addMenu('N Bins') + + slider = qt.QSlider(qt.Qt.Horizontal) + slider.setRange(10, 1000) + slider.setValue(100) + slider.setTracking(False) + slider.valueChanged.connect(self._binningChanged) + widgetAction = qt.QWidgetAction(binsmenu) + widgetAction.setDefaultWidget(slider) + binsmenu.addAction(widgetAction) + + menu.addSeparator() + + submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol") + self._addSymbolsToMenu(submenu) + + submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol Size") + self._addSizeSliderToMenu(submenu) + + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + def _visualizationChanged(self, mode, parameters=None): + """Handle change of visualization mode. + + :param ScatterVisualizationMixIn.Visualization mode: + The visualization mode to use for scatter + :param Union[dict,None] parameters: + Dict of VisualizationParameter: parameter_value to set + with the visualization. + """ + plot = self.plot() + if plot is None: + return + + for item in plot.getItems(): + if isinstance(item, Scatter): + if parameters: + for parameter, value in parameters.items(): + item.setVisualizationParameter(parameter, value) + item.setVisualization(mode) + + def _binningChanged(self, value): + """Handle change of binning. + + :param int value: The number of bin on each dimension. + """ + plot = self.plot() + if plot is None: + return + + for item in plot.getItems(): + if isinstance(item, Scatter): + item.setVisualizationParameter( + Scatter.VisualizationParameter.BINNED_STATISTIC_SHAPE, + (value, value)) + item.setVisualization(Scatter.Visualization.BINNED_STATISTIC) diff --git a/src/silx/gui/plot/PlotTools.py b/src/silx/gui/plot/PlotTools.py new file mode 100644 index 0000000..5929473 --- /dev/null +++ b/src/silx/gui/plot/PlotTools.py @@ -0,0 +1,43 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Set of widgets to associate with a :class:'PlotWidget'. +""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "01/03/2018" + + +from ...utils.deprecation import deprecated_warning + +deprecated_warning(type_='module', + name=__file__, + reason='Plot tools refactoring', + replacement='silx.gui.plot.tools', + since_version='0.8') + +from .tools import PositionInfo, LimitsToolBar # noqa diff --git a/src/silx/gui/plot/PlotWidget.py b/src/silx/gui/plot/PlotWidget.py new file mode 100755 index 0000000..6cb5ef5 --- /dev/null +++ b/src/silx/gui/plot/PlotWidget.py @@ -0,0 +1,3628 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# ###########################################################################*/ +"""Qt widget providing plot API for 1D and 2D data. + +The :class:`PlotWidget` implements the plot API initially provided in PyMca. +""" + +from __future__ import division + + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "21/12/2018" + +import logging + +_logger = logging.getLogger(__name__) + + +from collections import OrderedDict, namedtuple +from contextlib import contextmanager +import datetime as dt +import itertools +import typing +import warnings + +import numpy + +import silx +from silx.utils.weakref import WeakMethodProxy +from silx.utils.property import classproperty +from silx.utils.deprecation import deprecated, deprecated_warning +try: + # Import matplotlib now to init matplotlib our way + import silx.gui.utils.matplotlib # noqa +except ImportError: + _logger.debug("matplotlib not available") + +from ..colors import Colormap +from .. import colors +from . import PlotInteraction +from . import PlotEvents +from .LimitsHistory import LimitsHistory +from . import _utils + +from . import items +from .items.curve import CurveStyle +from .items.axis import TickMode # noqa + +from .. import qt +from ._utils.panzoom import ViewConstraints +from ...gui.plot._utils.dtime_ticklayout import timestamp + + + +_COLORDICT = colors.COLORDICT +_COLORLIST = silx.config.DEFAULT_PLOT_CURVE_COLORS + +""" +Object returned when requesting the data range. +""" +_PlotDataRange = namedtuple('PlotDataRange', + ['x', 'y', 'yright']) + + +class _PlotWidgetSelection(qt.QObject): + """Object managing a :class:`PlotWidget` selection. + + It is a wrapper over :class:`PlotWidget`'s active items API. + + :param PlotWidget parent: + """ + + sigCurrentItemChanged = qt.Signal(object, object) + """This signal is emitted whenever the current item changes. + + It provides the current and previous items. + """ + + sigSelectedItemsChanged = qt.Signal() + """Signal emitted whenever the list of selected items changes.""" + + def __init__(self, parent): + assert isinstance(parent, PlotWidget) + super(_PlotWidgetSelection, self).__init__(parent=parent) + + # Init history + self.__history = [ # Store active items from most recent to oldest + item for item in (parent.getActiveCurve(), + parent.getActiveImage(), + parent.getActiveScatter()) + if item is not None] + + self.__current = self.__mostRecentActiveItem() + + parent.sigActiveImageChanged.connect(self._activeImageChanged) + parent.sigActiveCurveChanged.connect(self._activeCurveChanged) + parent.sigActiveScatterChanged.connect(self._activeScatterChanged) + + def __mostRecentActiveItem(self) -> typing.Optional[items.Item]: + """Returns most recent active item.""" + return self.__history[0] if len(self.__history) >= 1 else None + + def getSelectedItems(self) -> typing.Tuple[items.Item]: + """Returns the list of currently selected items in the :class:`PlotWidget`. + + The list is given from most recently current item to oldest one.""" + plot = self.parent() + if plot is None: + return () + + active = tuple(self.__history) + + current = self.getCurrentItem() + if current is not None and current not in active: + # Current might not be an active item, if so add it + active = (current,) + active + + return active + + def getCurrentItem(self) -> typing.Optional[items.Item]: + """Returns the current item in the :class:`PlotWidget` or None. """ + return self.__current + + def setCurrentItem(self, item: typing.Optional[items.Item]): + """Set the current item in the :class:`PlotWidget`. + + :param item: + The new item to select or None to clear the selection. + :raise ValueError: If the item is not the :class:`PlotWidget` + """ + previous = self.getCurrentItem() + if previous is item: + return + + previousSelected = self.getSelectedItems() + + if item is None: + self.__current = None + + # Reset all PlotWidget active items + plot = self.parent() + if plot is not None: + for kind in PlotWidget._ACTIVE_ITEM_KINDS: + if plot._getActiveItem(kind) is not None: + plot._setActiveItem(kind, None) + + elif isinstance(item, items.Item): + plot = self.parent() + if plot is None or item.getPlot() is not plot: + raise ValueError( + "Item is not in the PlotWidget: %s" % str(item)) + self.__current = item + + kind = plot._itemKind(item) + + # Clean-up history to be safe + self.__history = [item for item in self.__history + if PlotWidget._itemKind(item) != kind] + + # Sync active item if needed + if (kind in plot._ACTIVE_ITEM_KINDS and + item is not plot._getActiveItem(kind)): + plot._setActiveItem(kind, item.getName()) + else: + raise ValueError("Not an Item: %s" % str(item)) + + self.sigCurrentItemChanged.emit(previous, item) + + if previousSelected != self.getSelectedItems(): + self.sigSelectedItemsChanged.emit() + + def __activeItemChanged(self, + kind: str, + previous: typing.Optional[str], + legend: typing.Optional[str]): + """Set current item from kind and legend""" + if previous == legend: + return # No-op for update of item + + plot = self.parent() + if plot is None: + return + + previousSelected = self.getSelectedItems() + + # Remove items of this kind from the history + self.__history = [item for item in self.__history + if PlotWidget._itemKind(item) != kind] + + # Retrieve current item + if legend is None: # Use most recent active item + currentItem = self.__mostRecentActiveItem() + else: + currentItem = plot._getItem(kind=kind, legend=legend) + if currentItem is None: # Fallback in case something went wrong + currentItem = self.__mostRecentActiveItem() + + # Update history + if currentItem is not None: + while currentItem in self.__history: + self.__history.remove(currentItem) + self.__history.insert(0, currentItem) + + if currentItem != self.__current: + previousItem = self.__current + self.__current = currentItem + self.sigCurrentItemChanged.emit(previousItem, currentItem) + + if previousSelected != self.getSelectedItems(): + self.sigSelectedItemsChanged.emit() + + def _activeImageChanged(self, previous, current): + """Handle active image change""" + self.__activeItemChanged('image', previous, current) + + def _activeCurveChanged(self, previous, current): + """Handle active curve change""" + self.__activeItemChanged('curve', previous, current) + + def _activeScatterChanged(self, previous, current): + """Handle active scatter change""" + self.__activeItemChanged('scatter', previous, current) + + +class PlotWidget(qt.QMainWindow): + """Qt Widget providing a 1D/2D plot. + + This widget is a QMainWindow. + This class implements the plot API initially provided in PyMca. + + Supported backends: + + - 'matplotlib' and 'mpl': Matplotlib with Qt. + - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1) + - 'none': No backend, to run headless for testing purpose. + + :param parent: The parent of this widget or None (default). + :param backend: The backend to use, in: + 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none' + or a :class:`BackendBase.BackendBase` class + :type backend: str or :class:`BackendBase.BackendBase` + """ + + # TODO: Can be removed for silx 0.10 + @classproperty + @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) + def DEFAULT_BACKEND(self): + """Class attribute setting the default backend for all instances.""" + return silx.config.DEFAULT_PLOT_BACKEND + + colorList = _COLORLIST + colorDict = _COLORDICT + + sigPlotSignal = qt.Signal(object) + """Signal for all events of the plot. + + The signal information is provided as a dict. + See the :ref:`plot signal documentation page <plot_signal>` for + information about the content of the dict + """ + + sigSetKeepDataAspectRatio = qt.Signal(bool) + """Signal emitted when plot keep aspect ratio has changed""" + + sigSetGraphGrid = qt.Signal(str) + """Signal emitted when plot grid has changed""" + + sigSetGraphCursor = qt.Signal(bool) + """Signal emitted when plot crosshair cursor has changed""" + + sigSetPanWithArrowKeys = qt.Signal(bool) + """Signal emitted when pan with arrow keys has changed""" + + _sigAxesVisibilityChanged = qt.Signal(bool) + """Signal emitted when the axes visibility changed""" + + sigContentChanged = qt.Signal(str, str, str) + """Signal emitted when the content of the plot is changed. + + It provides the following information: + + - action: The change of the plot: 'add' or 'remove' + - kind: The kind of primitive changed: + 'curve', 'image', 'scatter', 'histogram', 'item' or 'marker' + - legend: The legend of the primitive changed. + """ + + sigActiveCurveChanged = qt.Signal(object, object) + """Signal emitted when the active curve has changed. + + It provides the following information: + + - previous: The legend of the previous active curve or None + - legend: The legend of the new active curve or None if no curve is active + """ + + sigActiveImageChanged = qt.Signal(object, object) + """Signal emitted when the active image has changed. + + It provides the following information: + + - previous: The legend of the previous active image or None + - legend: The legend of the new active image or None if no image is active + """ + + sigActiveScatterChanged = qt.Signal(object, object) + """Signal emitted when the active Scatter has changed. + + It provides the following information: + + - previous: The legend of the previous active scatter or None + - legend: The legend of the new active image or None if no image is active + """ + + sigInteractiveModeChanged = qt.Signal(object) + """Signal emitted when the interactive mode has changed + + It provides the source as passed to :meth:`setInteractiveMode`. + """ + + sigItemAdded = qt.Signal(items.Item) + """Signal emitted when an item was just added to the plot + + It provides the added item. + """ + + sigItemAboutToBeRemoved = qt.Signal(items.Item) + """Signal emitted right before an item is removed from the plot. + + It provides the item that will be removed. + """ + + sigItemRemoved = qt.Signal(items.Item) + """Signal emitted right after an item was removed from the plot. + + It provides the item that was removed. + """ + + sigVisibilityChanged = qt.Signal(bool) + """Signal emitted when the widget becomes visible (or invisible). + This happens when the widget is hidden or shown. + + It provides the visible state. + """ + + _sigDefaultContextMenu = qt.Signal(qt.QMenu) + """Signal emitted when the default context menu of the plot is feed. + + It provides the menu which will be displayed. + """ + + def __init__(self, parent=None, backend=None): + self._autoreplot = False + self._dirty = False + self._cursorInPlot = False + self.__muteActiveItemChanged = False + + self._panWithArrowKeys = True + self._viewConstrains = None + + super(PlotWidget, self).__init__(parent) + if parent is not None: + # behave as a widget + self.setWindowFlags(qt.Qt.Widget) + else: + self.setWindowTitle('PlotWidget') + + # Init the backend + self._backend = self.__getBackendClass(backend)(self, self) + + self.setCallback() # set _callback + + # Items handling + self._content = OrderedDict() + self._contentToUpdate = [] # Used as an OrderedSet + + self._dataRange = None + + # line types + self._styleList = ['-', '--', '-.', ':'] + self._colorIndex = 0 + self._styleIndex = 0 + + self._activeCurveSelectionMode = "atmostone" + self._activeCurveStyle = CurveStyle(color='#000000') + self._activeLegend = {'curve': None, 'image': None, + 'scatter': None} + + # plot colors (updated later to sync backend) + self._foregroundColor = 0., 0., 0., 1. + self._gridColor = .7, .7, .7, 1. + self._backgroundColor = 1., 1., 1., 1. + self._dataBackgroundColor = None + + # default properties + self._cursorConfiguration = None + + self._xAxis = items.XAxis(self) + self._yAxis = items.YAxis(self) + self._yRightAxis = items.YRightAxis(self, self._yAxis) + + self._grid = None + self._graphTitle = '' + self.__graphCursorShape = 'default' + + # Set axes margins + self.__axesDisplayed = True + self.__axesMargins = 0., 0., 0., 0. + self.setAxesMargins(.15, .1, .1, .15) + + self.setGraphTitle() + self.setGraphXLabel() + self.setGraphYLabel() + self.setGraphYLabel('', axis='right') + + self.setDefaultColormap() # Init default colormap + + self.setDefaultPlotPoints(silx.config.DEFAULT_PLOT_CURVE_SYMBOL_MODE) + self.setDefaultPlotLines(True) + + self._limitsHistory = LimitsHistory(self) + + self._eventHandler = PlotInteraction.PlotInteraction(self) + self._eventHandler.setInteractiveMode('zoom', color=(0., 0., 0., 1.)) + self._previousDefaultMode = "zoom", True + + self._pressedButtons = [] # Currently pressed mouse buttons + + self._defaultDataMargins = (0., 0., 0., 0.) + + # Only activate autoreplot at the end + # This avoids errors when loaded in Qt designer + self._dirty = False + self._autoreplot = True + + widget = self.getWidgetHandle() + if widget is not None: + self.setCentralWidget(widget) + else: + _logger.info("PlotWidget backend does not support widget") + + self.setFocusPolicy(qt.Qt.StrongFocus) + self.setFocus(qt.Qt.OtherFocusReason) + + # Set default limits + self.setGraphXLimits(0., 100.) + self.setGraphYLimits(0., 100., axis='right') + self.setGraphYLimits(0., 100., axis='left') + + # Sync backend colors with default ones + self._foregroundColorsUpdated() + self._backgroundColorsUpdated() + + # selection handling + self.__selection = None + + def __getBackendClass(self, backend): + """Returns backend class corresponding to backend. + + If multiple backends are provided, the first available one is used. + + :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend: + The name of the backend or its class or an iterable of those. + :rtype: BackendBase + :raise ValueError: In case the backend is not supported + :raise RuntimeError: If a backend is not available + """ + if backend is None: + backend = silx.config.DEFAULT_PLOT_BACKEND + + if callable(backend): + return backend + + elif isinstance(backend, str): + backend = backend.lower() + if backend in ('matplotlib', 'mpl'): + try: + from .backends.BackendMatplotlib import \ + BackendMatplotlibQt as backendClass + except ImportError: + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError("matplotlib backend is not available") + + elif backend in ('gl', 'opengl'): + from ..utils.glutils import isOpenGLAvailable + checkOpenGL = isOpenGLAvailable(version=(2, 1), runtimeCheck=False) + if not checkOpenGL: + _logger.debug("OpenGL check failed") + raise RuntimeError( + "OpenGL backend is not available: %s" % checkOpenGL.error) + + try: + from .backends.BackendOpenGL import \ + BackendOpenGL as backendClass + except ImportError: + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError("OpenGL backend is not available") + + elif backend == 'none': + from .backends.BackendBase import BackendBase as backendClass + + else: + raise ValueError("Backend not supported %s" % backend) + + return backendClass + + elif isinstance(backend, (tuple, list)): + for b in backend: + try: + return self.__getBackendClass(b) + except RuntimeError: + pass + else: # No backend was found + raise RuntimeError("None of the request backends are available") + + raise ValueError("Backend not supported %s" % str(backend)) + + def selection(self): + """Returns the selection hander""" + if self.__selection is None: # Lazy initialization + self.__selection = _PlotWidgetSelection(parent=self) + return self.__selection + + # TODO: Can be removed for silx 0.10 + @staticmethod + @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) + def setDefaultBackend(backend): + """Set system wide default plot backend. + + .. versionadded:: 0.6 + + :param backend: The backend to use, in: + 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none' + or a :class:`BackendBase.BackendBase` class + """ + silx.config.DEFAULT_PLOT_BACKEND = backend + + def setBackend(self, backend): + """Set the backend to use for rendering. + + Supported backends: + + - 'matplotlib' and 'mpl': Matplotlib with Qt. + - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1) + - 'none': No backend, to run headless for testing purpose. + + :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend: + The backend to use, in: + 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none', + a :class:`BackendBase.BackendBase` class. + If multiple backends are provided, the first available one is used. + :raises ValueError: Unsupported backend descriptor + :raises RuntimeError: Error while loading a backend + """ + backend = self.__getBackendClass(backend)(self, self) + + # First save state that is stored in the backend + xaxis = self.getXAxis() + xmin, xmax = xaxis.getLimits() + ymin, ymax = self.getYAxis(axis='left').getLimits() + y2min, y2max = self.getYAxis(axis='right').getLimits() + isKeepDataAspectRatio = self.isKeepDataAspectRatio() + xTimeZone = xaxis.getTimeZone() + isXAxisTimeSeries = xaxis.getTickMode() == TickMode.TIME_SERIES + + isYAxisInverted = self.getYAxis().isInverted() + + # Remove all items from previous backend + for item in self.getItems(): + item._removeBackendRenderer(self._backend) + + # Switch backend + self._backend = backend + widget = self._backend.getWidgetHandle() + self.setCentralWidget(widget) + if widget is None: + _logger.info("PlotWidget backend does not support widget") + + # Mark as newly dirty + self._dirty = False + self._setDirtyPlot() + + # Synchronize/restore state + self._foregroundColorsUpdated() + self._backgroundColorsUpdated() + + self._backend.setGraphCursorShape(self.getGraphCursorShape()) + crosshairConfig = self.getGraphCursor() + if crosshairConfig is None: + self._backend.setGraphCursor(False, 'black', 1, '-') + else: + self._backend.setGraphCursor(True, *crosshairConfig) + + self._backend.setGraphTitle(self.getGraphTitle()) + self._backend.setGraphGrid(self.getGraphGrid()) + if self.isAxesDisplayed(): + self._backend.setAxesMargins(*self.getAxesMargins()) + else: + self._backend.setAxesMargins(0., 0., 0., 0.) + + # Set axes + xaxis = self.getXAxis() + self._backend.setGraphXLabel(xaxis.getLabel()) + self._backend.setXAxisTimeZone(xTimeZone) + self._backend.setXAxisTimeSeries(isXAxisTimeSeries) + self._backend.setXAxisLogarithmic( + xaxis.getScale() == items.Axis.LOGARITHMIC) + + for axis in ('left', 'right'): + self._backend.setGraphYLabel(self.getYAxis(axis).getLabel(), axis) + self._backend.setYAxisInverted(isYAxisInverted) + self._backend.setYAxisLogarithmic( + self.getYAxis().getScale() == items.Axis.LOGARITHMIC) + + # Finally restore aspect ratio and limits + self._backend.setKeepDataAspectRatio(isKeepDataAspectRatio) + self.setLimits(xmin, xmax, ymin, ymax, y2min, y2max) + + # Mark all items for update with new backend + for item in self.getItems(): + item._updated() + + def getBackend(self): + """Returns the backend currently used by :class:`PlotWidget`. + + :rtype: ~silx.gui.plot.backend.BackendBase.BackendBase + """ + return self._backend + + def _getDirtyPlot(self): + """Return the plot dirty flag. + + If False, the plot has not changed since last replot. + If True, the full plot need to be redrawn. + If 'overlay', only the overlay has changed since last replot. + + It can be accessed by backend to check the dirty state. + + :return: False, True, 'overlay' + """ + return self._dirty + + # Default Qt context menu + + def contextMenuEvent(self, event): + """Override QWidget.contextMenuEvent to implement the context menu""" + menu = qt.QMenu(self) + from .actions.control import ZoomBackAction # Avoid cyclic import + zoomBackAction = ZoomBackAction(plot=self, parent=menu) + menu.addAction(zoomBackAction) + + mode = self.getInteractiveMode() + if "shape" in mode and mode["shape"] == "polygon": + from .actions.control import ClosePolygonInteractionAction # Avoid cyclic import + action = ClosePolygonInteractionAction(plot=self, parent=menu) + menu.addAction(action) + + self._sigDefaultContextMenu.emit(menu) + + # Make sure the plot is updated, especially when the plot is in + # draw interaction mode + menu.aboutToHide.connect(self.__simulateMouseMove) + + menu.exec(event.globalPos()) + + def _setDirtyPlot(self, overlayOnly=False): + """Mark the plot as needing redraw + + :param bool overlayOnly: True to redraw only the overlay, + False to redraw everything + """ + wasDirty = self._dirty + + if not self._dirty and overlayOnly: + self._dirty = 'overlay' + else: + self._dirty = True + + if self._autoreplot and not wasDirty and self.isVisible(): + self._backend.postRedisplay() + + def _foregroundColorsUpdated(self): + """Handle change of foreground/grid color""" + if self._gridColor is None: + gridColor = self._foregroundColor + else: + gridColor = self._gridColor + self._backend.setForegroundColors( + self._foregroundColor, gridColor) + self._setDirtyPlot() + + def getForegroundColor(self): + """Returns the RGBA colors used to display the foreground of this widget + + :rtype: qt.QColor + """ + return qt.QColor.fromRgbF(*self._foregroundColor) + + def setForegroundColor(self, color): + """Set the foreground color of this widget. + + :param Union[List[int],List[float],QColor] color: + The new RGB(A) color. + """ + color = colors.rgba(color) + if self._foregroundColor != color: + self._foregroundColor = color + self._foregroundColorsUpdated() + + def getGridColor(self): + """Returns the RGBA colors used to display the grid lines + + An invalid QColor is returned if there is no grid color, + in which case the foreground color is used. + + :rtype: qt.QColor + """ + if self._gridColor is None: + return qt.QColor() # An invalid color + else: + return qt.QColor.fromRgbF(*self._gridColor) + + def setGridColor(self, color): + """Set the grid lines color + + :param Union[List[int],List[float],QColor,None] color: + The new RGB(A) color. + """ + if isinstance(color, qt.QColor) and not color.isValid(): + color = None + if color is not None: + color = colors.rgba(color) + if self._gridColor != color: + self._gridColor = color + self._foregroundColorsUpdated() + + def _backgroundColorsUpdated(self): + """Handle change of background/data background color""" + if self._dataBackgroundColor is None: + dataBGColor = self._backgroundColor + else: + dataBGColor = self._dataBackgroundColor + self._backend.setBackgroundColors( + self._backgroundColor, dataBGColor) + self._setDirtyPlot() + + def getBackgroundColor(self): + """Returns the RGBA colors used to display the background of this widget. + + :rtype: qt.QColor + """ + return qt.QColor.fromRgbF(*self._backgroundColor) + + def setBackgroundColor(self, color): + """Set the background color of this widget. + + :param Union[List[int],List[float],QColor] color: + The new RGB(A) color. + """ + color = colors.rgba(color) + if self._backgroundColor != color: + self._backgroundColor = color + self._backgroundColorsUpdated() + + def getDataBackgroundColor(self): + """Returns the RGBA colors used to display the background of the plot + view displaying the data. + + An invalid QColor is returned if there is no data background color. + + :rtype: qt.QColor + """ + if self._dataBackgroundColor is None: + # An invalid color + return qt.QColor() + else: + return qt.QColor.fromRgbF(*self._dataBackgroundColor) + + def setDataBackgroundColor(self, color): + """Set the background color of the plot area. + + Set to None or an invalid QColor to use the background color. + + :param Union[List[int],List[float],QColor,None] color: + The new RGB(A) color. + """ + if isinstance(color, qt.QColor) and not color.isValid(): + color = None + if color is not None: + color = colors.rgba(color) + if self._dataBackgroundColor != color: + self._dataBackgroundColor = color + self._backgroundColorsUpdated() + + dataBackgroundColor = qt.Property( + qt.QColor, getDataBackgroundColor, setDataBackgroundColor + ) + + backgroundColor = qt.Property(qt.QColor, getBackgroundColor, setBackgroundColor) + + foregroundColor = qt.Property(qt.QColor, getForegroundColor, setForegroundColor) + + gridColor = qt.Property(qt.QColor, getGridColor, setGridColor) + + def showEvent(self, event): + if self._autoreplot and self._dirty: + self._backend.postRedisplay() + super(PlotWidget, self).showEvent(event) + self.sigVisibilityChanged.emit(True) + + def hideEvent(self, event): + super(PlotWidget, self).hideEvent(event) + self.sigVisibilityChanged.emit(False) + + def _invalidateDataRange(self): + """ + Notifies this PlotWidget instance that the range has changed + and will have to be recomputed. + """ + self._dataRange = None + + def _updateDataRange(self): + """ + Recomputes the range of the data displayed on this PlotWidget. + """ + xMin = yMinLeft = yMinRight = float('nan') + xMax = yMaxLeft = yMaxRight = float('nan') + + for item in self.getItems(): + if item.isVisible(): + bounds = item.getBounds() + if bounds is not None: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + xMin = numpy.nanmin([xMin, bounds[0]]) + xMax = numpy.nanmax([xMax, bounds[1]]) + # Take care of right axis + if (isinstance(item, items.YAxisMixIn) and + item.getYAxis() == 'right'): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + yMinRight = numpy.nanmin([yMinRight, bounds[2]]) + yMaxRight = numpy.nanmax([yMaxRight, bounds[3]]) + else: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + yMinLeft = numpy.nanmin([yMinLeft, bounds[2]]) + yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]]) + + def lGetRange(x, y): + return None if numpy.isnan(x) and numpy.isnan(y) else (x, y) + xRange = lGetRange(xMin, xMax) + yLeftRange = lGetRange(yMinLeft, yMaxLeft) + yRightRange = lGetRange(yMinRight, yMaxRight) + + self._dataRange = _PlotDataRange(x=xRange, + y=yLeftRange, + yright=yRightRange) + + def getDataRange(self): + """ + Returns this PlotWidget's data range. + + :return: a namedtuple with the following members : + x, y (left y axis), yright. Each member is a tuple (min, max) + or None if no data is associated with the axis. + :rtype: namedtuple + """ + if self._dataRange is None: + self._updateDataRange() + return self._dataRange + + # Content management + + _KIND_TO_CLASSES = { + 'curve': (items.Curve,), + 'image': (items.ImageBase,), + 'scatter': (items.Scatter,), + 'marker': (items.MarkerBase,), + 'item': (items.Shape, + items.BoundingRect, + items.XAxisExtent, + items.YAxisExtent), + 'histogram': (items.Histogram,), + } + """Mapping kind to item classes of this kind""" + + @classmethod + def _itemKind(cls, item): + """Returns the "kind" of a given item + + :param Item item: The item get the kind + :rtype: str + """ + for kind, itemClasses in cls._KIND_TO_CLASSES.items(): + if isinstance(item, itemClasses): + return kind + raise ValueError('Unsupported item type %s' % type(item)) + + def _notifyContentChanged(self, item): + self.notify('contentChanged', action='add', + kind=self._itemKind(item), legend=item.getName()) + + def _itemRequiresUpdate(self, item): + """Called by items in the plot for asynchronous update + + :param Item item: The item that required update + """ + assert item.getPlot() == self + # Put item at the end of the list + if item in self._contentToUpdate: + self._contentToUpdate.remove(item) + self._contentToUpdate.append(item) + self._setDirtyPlot(overlayOnly=item.isOverlay()) + + def addItem(self, item=None, *args, **kwargs): + """Add an item to the plot content. + + :param ~silx.gui.plot.items.Item item: The item to add. + :raises ValueError: If item is already in the plot. + """ + if not isinstance(item, items.Item): + deprecated_warning( + 'Function', + 'addItem', + replacement='addShape', + since_version='0.13') + if item is None and not args: # Only kwargs + return self.addShape(**kwargs) + else: + return self.addShape(item, *args, **kwargs) + + assert not args and not kwargs + if item in self.getItems(): + raise ValueError('Item already in the plot') + + # Add item to plot + self._content[(item.getName(), self._itemKind(item))] = item + item._setPlot(self) + self._itemRequiresUpdate(item) + if isinstance(item, items.DATA_ITEMS): + self._invalidateDataRange() # TODO handle this automatically + + self._notifyContentChanged(item) + self.sigItemAdded.emit(item) + + def removeItem(self, item): + """Remove the item from the plot. + + :param ~silx.gui.plot.items.Item item: Item to remove from the plot. + :raises ValueError: If item is not in the plot. + """ + if not isinstance(item, items.Item): # Previous method usage + deprecated_warning( + 'Function', + 'removeItem', + replacement='remove(legend, kind="item")', + since_version='0.13') + if item is None: + return + self.remove(item, kind='item') + return + + if item not in self.getItems(): + raise ValueError('Item not in the plot') + + self.sigItemAboutToBeRemoved.emit(item) + + kind = self._itemKind(item) + + if kind in self._ACTIVE_ITEM_KINDS: + if self._getActiveItem(kind) == item: + # Reset active item + self._setActiveItem(kind, None) + + # Remove item from plot + self._content.pop((item.getName(), kind)) + if item in self._contentToUpdate: + self._contentToUpdate.remove(item) + if item.isVisible(): + self._setDirtyPlot(overlayOnly=item.isOverlay()) + if item.getBounds() is not None: + self._invalidateDataRange() + item._removeBackendRenderer(self._backend) + item._setPlot(None) + + if (kind == 'curve' and not self.getAllCurves(just_legend=True, + withhidden=True)): + self._resetColorAndStyle() + + self.sigItemRemoved.emit(item) + + self.notify('contentChanged', action='remove', + kind=kind, legend=item.getName()) + + def discardItem(self, item) -> bool: + """Remove the item from the plot. + + Same as :meth:`removeItem` but do not raise an exception. + + :param ~silx.gui.plot.items.Item item: Item to remove from the plot. + :returns: True if the item was present, False otherwise. + """ + try: + self.removeItem(item) + except ValueError: + return False + else: + return True + + @deprecated(replacement='addItem', since_version='0.13') + def _add(self, item): + return self.addItem(item) + + @deprecated(replacement='removeItem', since_version='0.13') + def _remove(self, item): + return self.removeItem(item) + + def getItems(self): + """Returns the list of items in the plot + + :rtype: List[silx.gui.plot.items.Item] + """ + return tuple(self._content.values()) + + @contextmanager + def _muteActiveItemChangedSignal(self): + self.__muteActiveItemChanged = True + yield + self.__muteActiveItemChanged = False + + # Add + + # add * input arguments management: + # If an arg is set, then use it. + # Else: + # If a curve with the same legend exists, then use its arg value + # Else, use a default value. + # Store used value. + # This value is used when curve is updated either internally or by user. + + def addCurve(self, x, y, legend=None, info=None, + replace=False, + color=None, symbol=None, + linewidth=None, linestyle=None, + xlabel=None, ylabel=None, yaxis=None, + xerror=None, yerror=None, z=None, selectable=None, + fill=None, resetzoom=True, + histogram=None, copy=True, + baseline=None): + """Add a 1D curve given by x an y to the graph. + + Curves are uniquely identified by their legend. + To add multiple curves, call :meth:`addCurve` multiple times with + different legend argument. + To replace an existing curve, call :meth:`addCurve` with the + existing curve legend. + If you want to display the curve values as an histogram see the + histogram parameter or :meth:`addHistogram`. + + When curve parameters are not provided, if a curve with the + same legend is displayed in the plot, its parameters are used. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + If you attempt to plot an histogram you can set edges values in x. + In this case len(x) = len(y) + 1. + If x contains datetime objects the XAxis tickMode is set to + TickMode.TIME_SERIES. + :param numpy.ndarray y: The data corresponding to the y coordinates + :param str legend: The legend to be associated to the curve (or None) + :param info: User-defined information associated to the curve + :param bool replace: True to delete already existing curves + (the default is False) + :param color: color(s) to be used + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in colors.py + :param str symbol: Symbol to be drawn at each (x, y) position:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + - None (the default) to use default symbol + + :param float linewidth: The width of the curve in pixels (Default: 1). + :param str linestyle: Type of line:: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + - None (the default) to use default line style + + :param str xlabel: Label to show on the X axis when the curve is active + or None to keep default axis label. + :param str ylabel: Label to show on the Y axis when the curve is active + or None to keep default axis label. + :param str yaxis: The Y axis this curve is attached to. + Either 'left' (the default) or 'right' + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param int z: Layer on which to draw the curve (default: 1) + This allows to control the overlay. + :param bool selectable: Indicate if the curve can be selected. + (Default: True) + :param bool fill: True to fill the curve, False otherwise (default). + :param bool resetzoom: True (the default) to reset the zoom. + :param str histogram: if not None then the curve will be draw as an + histogram. The step for each values of the curve can be set to the + left, center or right of the original x curve values. + If histogram is not None and len(x) == len(y)+1 then x is directly + take as edges of the histogram. + Type of histogram:: + + - None (default) + - 'left' + - 'right' + - 'center' + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + :param baseline: curve baseline + :type: Union[None,float,numpy.ndarray] + :returns: The key string identify this curve + """ + # This is an histogram, use addHistogram + if histogram is not None: + histoLegend = self.addHistogram(histogram=y, + edges=x, + legend=legend, + color=color, + fill=fill, + align=histogram, + copy=copy) + histo = self.getHistogram(histoLegend) + + histo.setInfo(info) + if linewidth is not None: + histo.setLineWidth(linewidth) + if linestyle is not None: + histo.setLineStyle(linestyle) + if xlabel is not None: + _logger.warning( + 'addCurve: Histogram does not support xlabel argument') + if ylabel is not None: + _logger.warning( + 'addCurve: Histogram does not support ylabel argument') + if yaxis is not None: + histo.setYAxis(yaxis) + if z is not None: + histo.setZValue(z) + if selectable is not None: + _logger.warning( + 'addCurve: Histogram does not support selectable argument') + + return + + legend = 'Unnamed curve 1.1' if legend is None else str(legend) + + # Check if curve was previously active + wasActive = self.getActiveCurve(just_legend=True) == legend + + if replace: + self._resetColorAndStyle() + + # Create/Update curve object + curve = self.getCurve(legend) + mustBeAdded = curve is None + if curve is None: + # No previous curve, create a default one and add it to the plot + curve = items.Curve() if histogram is None else items.Histogram() + curve.setName(legend) + # Set default color, linestyle and symbol + default_color, default_linestyle = self._getColorAndStyle() + curve.setColor(default_color) + curve.setLineStyle(default_linestyle) + curve.setSymbol(self._defaultPlotPoints) + curve._setBaseline(baseline=baseline) + + # Do not emit sigActiveCurveChanged, + # it will be sent once with _setActiveItem + with self._muteActiveItemChangedSignal(): + # Override previous/default values with provided ones + curve.setInfo(info) + if color is not None: + curve.setColor(color) + if symbol is not None: + curve.setSymbol(symbol) + if linewidth is not None: + curve.setLineWidth(linewidth) + if linestyle is not None: + curve.setLineStyle(linestyle) + if xlabel is not None: + curve._setXLabel(xlabel) + if ylabel is not None: + curve._setYLabel(ylabel) + if yaxis is not None: + curve.setYAxis(yaxis) + if z is not None: + curve.setZValue(z) + if selectable is not None: + curve._setSelectable(selectable) + if fill is not None: + curve.setFill(fill) + + # Set curve data + # If errors not provided, reuse previous ones + # TODO: Issue if size of data change but not that of errors + if xerror is None: + xerror = curve.getXErrorData(copy=False) + if yerror is None: + yerror = curve.getYErrorData(copy=False) + + # Convert x to timestamps so that the internal representation + # remains floating points. The user is expected to set the axis' + # tickMode to TickMode.TIME_SERIES and, if necessary, set the axis + # to the correct time zone. + if len(x) > 0 and isinstance(x[0], dt.datetime): + x = [timestamp(d) for d in x] + + curve.setData(x, y, xerror, yerror, baseline=baseline, copy=copy) + + if replace: # Then remove all other curves + for c in self.getAllCurves(withhidden=True): + if c is not curve: + self.removeItem(c) + + if mustBeAdded: + self.addItem(curve) + else: + self._notifyContentChanged(curve) + + if wasActive: + self.setActiveCurve(curve.getName()) + elif self.getActiveCurveSelectionMode() == "legacy": + if self.getActiveCurve(just_legend=True) is None: + if len(self.getAllCurves(just_legend=True, + withhidden=False)) == 1: + if curve.isVisible(): + self.setActiveCurve(curve.getName()) + + if resetzoom: + # We ask for a zoom reset in order to handle the plot scaling + # if the user does not want that, autoscale of the different + # axes has to be set to off. + self.resetZoom() + + return legend + + def addHistogram(self, + histogram, + edges, + legend=None, + color=None, + fill=None, + align='center', + resetzoom=True, + copy=True, + z=None, + baseline=None): + """Add an histogram to the graph. + + This is NOT computing the histogram, this method takes as parameter + already computed histogram values. + + Histogram are uniquely identified by their legend. + To add multiple histograms, call :meth:`addHistogram` multiple times + with different legend argument. + + When histogram parameters are not provided, if an histogram with the + same legend is displayed in the plot, its parameters are used. + + :param numpy.ndarray histogram: The values of the histogram. + :param numpy.ndarray edges: + The bin edges of the histogram. + If histogram and edges have the same length, the bin edges + are computed according to the align parameter. + :param str legend: + The legend to be associated to the histogram (or None) + :param color: color to be used + :type color: str ("#RRGGBB") or RGB unsigned byte array or + one of the predefined color names defined in colors.py + :param bool fill: True to fill the curve, False otherwise (default). + :param str align: + In case histogram values and edges have the same length N, + the N+1 bin edges are computed according to the alignment in: + 'center' (default), 'left', 'right'. + :param bool resetzoom: True (the default) to reset the zoom. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + :param int z: Layer on which to draw the histogram + :param baseline: histogram baseline + :type: Union[None,float,numpy.ndarray] + :returns: The key string identify this histogram + """ + legend = 'Unnamed histogram' if legend is None else str(legend) + + # Create/Update histogram object + histo = self.getHistogram(legend) + mustBeAdded = histo is None + if histo is None: + # No previous histogram, create a default one and + # add it to the plot + histo = items.Histogram() + histo.setName(legend) + histo.setColor(self._getColorAndStyle()[0]) + + # Override previous/default values with provided ones + if color is not None: + histo.setColor(color) + if fill is not None: + histo.setFill(fill) + if z is not None: + histo.setZValue(z=z) + + # Set histogram data + histo.setData(histogram=histogram, edges=edges, baseline=baseline, + align=align, copy=copy) + + if mustBeAdded: + self.addItem(histo) + else: + self._notifyContentChanged(histo) + + if resetzoom: + # We ask for a zoom reset in order to handle the plot scaling + # if the user does not want that, autoscale of the different + # axes has to be set to off. + self.resetZoom() + + return legend + + def addImage(self, data, legend=None, info=None, + replace=False, + z=None, + selectable=None, draggable=None, + colormap=None, pixmap=None, + xlabel=None, ylabel=None, + origin=None, scale=None, + resetzoom=True, copy=True): + """Add a 2D dataset or an image to the plot. + + It displays either an array of data using a colormap or a RGB(A) image. + + Images are uniquely identified by their legend. + To add multiple images, call :meth:`addImage` multiple times with + different legend argument. + To replace/update an existing image, call :meth:`addImage` with the + existing image legend. + + When image parameters are not provided, if an image with the + same legend is displayed in the plot, its parameters are used. + + :param numpy.ndarray data: + (nrows, ncolumns) data or + (nrows, ncolumns, RGBA) ubyte array + Note: boolean values are converted to int8. + :param str legend: The legend to be associated to the image (or None) + :param info: User-defined information associated to the image + :param bool replace: + True to delete already existing images (Default: False). + :param int z: Layer on which to draw the image (default: 0) + This allows to control the overlay. + :param bool selectable: Indicate if the image can be selected. + (default: False) + :param bool draggable: Indicate if the image can be moved. + (default: False) + :param colormap: Colormap object to use (or None). + This is ignored if data is a RGB(A) image. + :type colormap: Union[~silx.gui.colors.Colormap, dict] + :param pixmap: Pixmap representation of the data (if any) + :type pixmap: (nrows, ncolumns, RGBA) ubyte array or None (default) + :param str xlabel: X axis label to show when this curve is active, + or None to keep default axis label. + :param str ylabel: Y axis label to show when this curve is active, + or None to keep default axis label. + :param origin: (origin X, origin Y) of the data. + It is possible to pass a single float if both + coordinates are equal. + Default: (0., 0.) + :type origin: float or 2-tuple of float + :param scale: (scale X, scale Y) of the data. + It is possible to pass a single float if both + coordinates are equal. + Default: (1., 1.) + :type scale: float or 2-tuple of float + :param bool resetzoom: True (the default) to reset the zoom. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + :returns: The key string identify this image + """ + legend = "Unnamed Image 1.1" if legend is None else str(legend) + + # Check if image was previously active + wasActive = self.getActiveImage(just_legend=True) == legend + + data = numpy.array(data, copy=False) + assert data.ndim in (2, 3) + + image = self.getImage(legend) + if image is not None and image.getData(copy=False).ndim != data.ndim: + # Update a data image with RGBA image or the other way around: + # Remove previous image + # In this case, we don't retrieve defaults from the previous image + self.removeItem(image) + image = None + + mustBeAdded = image is None + if image is None: + # No previous image, create a default one and add it to the plot + if data.ndim == 2: + image = items.ImageData() + image.setColormap(self.getDefaultColormap()) + else: + image = items.ImageRgba() + image.setName(legend) + + # Do not emit sigActiveImageChanged, + # it will be sent once with _setActiveItem + with self._muteActiveItemChangedSignal(): + # Override previous/default values with provided ones + image.setInfo(info) + if origin is not None: + image.setOrigin(origin) + if scale is not None: + image.setScale(scale) + if z is not None: + image.setZValue(z) + if selectable is not None: + image._setSelectable(selectable) + if draggable is not None: + image._setDraggable(draggable) + if colormap is not None and isinstance(image, items.ColormapMixIn): + if isinstance(colormap, dict): + image.setColormap(Colormap._fromDict(colormap)) + else: + assert isinstance(colormap, Colormap) + image.setColormap(colormap) + if xlabel is not None: + image._setXLabel(xlabel) + if ylabel is not None: + image._setYLabel(ylabel) + + if data.ndim == 2: + image.setData(data, alternative=pixmap, copy=copy) + else: # RGB(A) image + if pixmap is not None: + _logger.warning( + 'addImage: pixmap argument ignored when data is RGB(A)') + image.setData(data, copy=copy) + + if replace: + for img in self.getAllImages(): + if img is not image: + self.removeItem(img) + + if mustBeAdded: + self.addItem(image) + else: + self._notifyContentChanged(image) + + if len(self.getAllImages()) == 1 or wasActive: + self.setActiveImage(legend) + + if resetzoom: + # We ask for a zoom reset in order to handle the plot scaling + # if the user does not want that, autoscale of the different + # axes has to be set to off. + self.resetZoom() + + return legend + + def addScatter(self, x, y, value, legend=None, colormap=None, + info=None, symbol=None, xerror=None, yerror=None, + z=None, copy=True): + """Add a (x, y, value) scatter to the graph. + + Scatters are uniquely identified by their legend. + To add multiple scatters, call :meth:`addScatter` multiple times with + different legend argument. + To replace/update an existing scatter, call :meth:`addScatter` with the + existing scatter legend. + + When scatter parameters are not provided, if a scatter with the + same legend is displayed in the plot, its parameters are used. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + :param numpy.ndarray y: The data corresponding to the y coordinates + :param numpy.ndarray value: The data value associated with each point + :param str legend: The legend to be associated to the scatter (or None) + :param ~silx.gui.colors.Colormap colormap: + Colormap object to be used for the scatter (or None) + :param info: User-defined information associated to the curve + :param str symbol: Symbol to be drawn at each (x, y) position:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + - None (the default) to use default symbol + + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param int z: Layer on which to draw the scatter (default: 1) + This allows to control the overlay. + + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + :returns: The key string identify this scatter + """ + legend = 'Unnamed scatter 1.1' if legend is None else str(legend) + + # Check if scatter was previously active + wasActive = self._getActiveItem(kind='scatter', + just_legend=True) == legend + + # Create/Update curve object + scatter = self._getItem(kind='scatter', legend=legend) + mustBeAdded = scatter is None + if scatter is None: + # No previous scatter, create a default one and add it to the plot + scatter = items.Scatter() + scatter.setName(legend) + scatter.setColormap(self.getDefaultColormap()) + + # Do not emit sigActiveScatterChanged, + # it will be sent once with _setActiveItem + with self._muteActiveItemChangedSignal(): + # Override previous/default values with provided ones + scatter.setInfo(info) + if symbol is not None: + scatter.setSymbol(symbol) + if z is not None: + scatter.setZValue(z) + if colormap is not None: + if isinstance(colormap, dict): + scatter.setColormap(Colormap._fromDict(colormap)) + else: + assert isinstance(colormap, Colormap) + scatter.setColormap(colormap) + + # Set scatter data + # If errors not provided, reuse previous ones + if xerror is None: + xerror = scatter.getXErrorData(copy=False) + if xerror is not None and len(xerror) != len(x): + xerror = None + if yerror is None: + yerror = scatter.getYErrorData(copy=False) + if yerror is not None and len(yerror) != len(y): + yerror = None + + scatter.setData(x, y, value, xerror, yerror, copy=copy) + + if mustBeAdded: + self.addItem(scatter) + else: + self._notifyContentChanged(scatter) + + scatters = [item for item in self.getItems() + if isinstance(item, items.Scatter) and item.isVisible()] + if len(scatters) == 1 or wasActive: + self._setActiveItem('scatter', scatter.getName()) + + return legend + + def addShape(self, xdata, ydata, legend=None, info=None, + replace=False, + shape="polygon", color='black', fill=True, + overlay=False, z=None, linestyle="-", linewidth=1.0, + linebgcolor=None): + """Add an item (i.e. a shape) to the plot. + + Items are uniquely identified by their legend. + To add multiple items, call :meth:`addItem` multiple times with + different legend argument. + To replace/update an existing item, call :meth:`addItem` with the + existing item legend. + + :param numpy.ndarray xdata: The X coords of the points of the shape + :param numpy.ndarray ydata: The Y coords of the points of the shape + :param str legend: The legend to be associated to the item + :param info: User-defined information associated to the item + :param bool replace: True (default) to delete already existing images + :param str shape: Type of item to be drawn in + hline, polygon (the default), rectangle, vline, + polylines + :param str color: Color of the item, e.g., 'blue', 'b', '#FF0000' + (Default: 'black') + :param bool fill: True (the default) to fill the shape + :param bool overlay: True if item is an overlay (Default: False). + This allows for rendering optimization if this + item is changed often. + :param int z: Layer on which to draw the item (default: 2) + :param str linestyle: Style of the line. + Only relevant for line markers where X or Y is None. + Value in: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + :param float linewidth: Width of the line. + Only relevant for line markers where X or Y is None. + :param str linebgcolor: Background color of the line, e.g., 'blue', 'b', + '#FF0000'. It is used to draw dotted line using a second color. + :returns: The key string identify this item + """ + # expected to receive the same parameters as the signal + + legend = "Unnamed Item 1.1" if legend is None else str(legend) + + z = int(z) if z is not None else 2 + + if replace: + self.remove(kind='item') + else: + self.remove(legend, kind='item') + + item = items.Shape(shape) + item.setName(legend) + item.setInfo(info) + item.setColor(color) + item.setFill(fill) + item.setOverlay(overlay) + item.setZValue(z) + item.setPoints(numpy.array((xdata, ydata)).T) + item.setLineStyle(linestyle) + item.setLineWidth(linewidth) + item.setLineBgColor(linebgcolor) + + self.addItem(item) + + return legend + + def addXMarker(self, x, legend=None, + text=None, + color=None, + selectable=False, + draggable=False, + constraint=None, + yaxis='left'): + """Add a vertical line marker to the plot. + + Markers are uniquely identified by their legend. + As opposed to curves, images and items, two calls to + :meth:`addXMarker` without legend argument adds two markers with + different identifying legends. + + :param x: Position of the marker on the X axis in data coordinates + :type x: Union[None, float] + :param str legend: Legend associated to the marker to identify it + :param str text: Text to display on the marker. + :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000' + (Default: 'black') + :param bool selectable: Indicate if the marker can be selected. + (default: False) + :param bool draggable: Indicate if the marker can be moved. + (default: False) + :param constraint: A function filtering marker displacement by + dragging operations or None for no filter. + This function is called each time a marker is + moved. + This parameter is only used if draggable is True. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + :param str yaxis: The Y axis this marker belongs to in: 'left', 'right' + :return: The key string identify this marker + """ + return self._addMarker(x=x, y=None, legend=legend, + text=text, color=color, + selectable=selectable, draggable=draggable, + symbol=None, constraint=constraint, + yaxis=yaxis) + + def addYMarker(self, y, + legend=None, + text=None, + color=None, + selectable=False, + draggable=False, + constraint=None, + yaxis='left'): + """Add a horizontal line marker to the plot. + + Markers are uniquely identified by their legend. + As opposed to curves, images and items, two calls to + :meth:`addYMarker` without legend argument adds two markers with + different identifying legends. + + :param float y: Position of the marker on the Y axis in data + coordinates + :param str legend: Legend associated to the marker to identify it + :param str text: Text to display next to the marker. + :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000' + (Default: 'black') + :param bool selectable: Indicate if the marker can be selected. + (default: False) + :param bool draggable: Indicate if the marker can be moved. + (default: False) + :param constraint: A function filtering marker displacement by + dragging operations or None for no filter. + This function is called each time a marker is + moved. + This parameter is only used if draggable is True. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + :param str yaxis: The Y axis this marker belongs to in: 'left', 'right' + :return: The key string identify this marker + """ + return self._addMarker(x=None, y=y, legend=legend, + text=text, color=color, + selectable=selectable, draggable=draggable, + symbol=None, constraint=constraint, + yaxis=yaxis) + + def addMarker(self, x, y, legend=None, + text=None, + color=None, + selectable=False, + draggable=False, + symbol='+', + constraint=None, + yaxis='left'): + """Add a point marker to the plot. + + Markers are uniquely identified by their legend. + As opposed to curves, images and items, two calls to + :meth:`addMarker` without legend argument adds two markers with + different identifying legends. + + :param float x: Position of the marker on the X axis in data + coordinates + :param float y: Position of the marker on the Y axis in data + coordinates + :param str legend: Legend associated to the marker to identify it + :param str text: Text to display next to the marker + :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000' + (Default: 'black') + :param bool selectable: Indicate if the marker can be selected. + (default: False) + :param bool draggable: Indicate if the marker can be moved. + (default: False) + :param str symbol: Symbol representing the marker in:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross (the default) + - 'x' x-cross + - 'd' diamond + - 's' square + + :param constraint: A function filtering marker displacement by + dragging operations or None for no filter. + This function is called each time a marker is + moved. + This parameter is only used if draggable is True. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + :param str yaxis: The Y axis this marker belongs to in: 'left', 'right' + :return: The key string identify this marker + """ + if x is None: + xmin, xmax = self._xAxis.getLimits() + x = 0.5 * (xmax + xmin) + + if y is None: + ymin, ymax = self._yAxis.getLimits() + y = 0.5 * (ymax + ymin) + + return self._addMarker(x=x, y=y, legend=legend, + text=text, color=color, + selectable=selectable, draggable=draggable, + symbol=symbol, constraint=constraint, + yaxis=yaxis) + + def _addMarker(self, x, y, legend, + text, color, + selectable, draggable, + symbol, constraint, + yaxis=None): + """Common method for adding point, vline and hline marker. + + See :meth:`addMarker` for argument documentation. + """ + assert (x, y) != (None, None) + + if legend is None: # Find an unused legend + markerLegends = [item.getName() for item in self.getItems() + if isinstance(item, items.MarkerBase)] + for index in itertools.count(): + legend = "Unnamed Marker %d" % index + if legend not in markerLegends: + break # Keep this legend + legend = str(legend) + + if x is None: + markerClass = items.YMarker + elif y is None: + markerClass = items.XMarker + else: + markerClass = items.Marker + + # Create/Update marker object + marker = self._getMarker(legend) + if marker is not None and not isinstance(marker, markerClass): + _logger.warning('Adding marker with same legend' + ' but different type replaces it') + self.removeItem(marker) + marker = None + + mustBeAdded = marker is None + if marker is None: + # No previous marker, create one + marker = markerClass() + marker.setName(legend) + + if text is not None: + marker.setText(text) + if color is not None: + marker.setColor(color) + if selectable is not None: + marker._setSelectable(selectable) + if draggable is not None: + marker._setDraggable(draggable) + if symbol is not None: + marker.setSymbol(symbol) + marker.setYAxis(yaxis) + + # TODO to improve, but this ensure constraint is applied + marker.setPosition(x, y) + if constraint is not None: + marker._setConstraint(constraint) + marker.setPosition(x, y) + + if mustBeAdded: + self.addItem(marker) + else: + self._notifyContentChanged(marker) + + return legend + + # Hide + + def isCurveHidden(self, legend): + """Returns True if the curve associated to legend is hidden, else False + + :param str legend: The legend key identifying the curve + :return: True if the associated curve is hidden, False otherwise + """ + curve = self._getItem('curve', legend) + return curve is not None and not curve.isVisible() + + def hideCurve(self, legend, flag=True): + """Show/Hide the curve associated to legend. + + Even when hidden, the curve is kept in the list of curves. + + :param str legend: The legend associated to the curve to be hidden + :param bool flag: True (default) to hide the curve, False to show it + """ + curve = self._getItem('curve', legend) + if curve is None: + _logger.warning('Curve not in plot: %s', legend) + return + + isVisible = not flag + if isVisible != curve.isVisible(): + curve.setVisible(isVisible) + + # Remove + + ITEM_KINDS = 'curve', 'image', 'scatter', 'item', 'marker', 'histogram' + """List of supported kind of items in the plot.""" + + _ACTIVE_ITEM_KINDS = 'curve', 'scatter', 'image' + """List of item's kind which have a active item.""" + + def remove(self, legend=None, kind=ITEM_KINDS): + """Remove one or all element(s) of the given legend and kind. + + Examples: + + - ``remove()`` clears the plot + - ``remove(kind='curve')`` removes all curves from the plot + - ``remove('myCurve', kind='curve')`` removes the curve with + legend 'myCurve' from the plot. + - ``remove('myImage, kind='image')`` removes the image with + legend 'myImage' from the plot. + - ``remove('myImage')`` removes elements (for instance curve, image, + item and marker) with legend 'myImage'. + + :param str legend: The legend associated to the element to remove, + or None to remove + :param kind: The kind of elements to remove from the plot. + See :attr:`ITEM_KINDS`. + By default, it removes all kind of elements. + :type kind: str or tuple of str to specify multiple kinds. + """ + if kind == 'all': # Replace all by tuple of all kinds + kind = self.ITEM_KINDS + + if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple + kind = (kind,) + + for aKind in kind: + assert aKind in self.ITEM_KINDS + + if legend is None: # This is a clear + # Clear each given kind + for aKind in kind: + for item in self.getItems(): + if (isinstance(item, self._KIND_TO_CLASSES[aKind]) and + item.getPlot() is self): # Make sure item is still in the plot + self.removeItem(item) + + else: # This is removing a single element + # Remove each given kind + for aKind in kind: + item = self._getItem(aKind, legend) + if item is not None: + self.removeItem(item) + + def removeCurve(self, legend): + """Remove the curve associated to legend from the graph. + + :param str legend: The legend associated to the curve to be deleted + """ + if legend is None: + return + self.remove(legend, kind='curve') + + def removeImage(self, legend): + """Remove the image associated to legend from the graph. + + :param str legend: The legend associated to the image to be deleted + """ + if legend is None: + return + self.remove(legend, kind='image') + + def removeMarker(self, legend): + """Remove the marker associated to legend from the graph. + + :param str legend: The legend associated to the marker to be deleted + """ + if legend is None: + return + self.remove(legend, kind='marker') + + # Clear + + def clear(self): + """Remove everything from the plot.""" + for item in self.getItems(): + if item.getPlot() is self: # Make sure item is still in the plot + self.removeItem(item) + + def clearCurves(self): + """Remove all the curves from the plot.""" + self.remove(kind='curve') + + def clearImages(self): + """Remove all the images from the plot.""" + self.remove(kind='image') + + def clearItems(self): + """Remove all the items from the plot. """ + self.remove(kind='item') + + def clearMarkers(self): + """Remove all the markers from the plot.""" + self.remove(kind='marker') + + # Interaction + + def getGraphCursor(self): + """Returns the state of the crosshair cursor. + + See :meth:`setGraphCursor`. + + :return: None if the crosshair cursor is not active, + else a tuple (color, linewidth, linestyle). + """ + return self._cursorConfiguration + + def setGraphCursor(self, flag=False, color='black', + linewidth=1, linestyle='-'): + """Toggle the display of a crosshair cursor and set its attributes. + + :param bool flag: Toggle the display of a crosshair cursor. + The crosshair cursor is hidden by default. + :param color: The color to use for the crosshair. + :type color: A string (either a predefined color name in colors.py + or "#RRGGBB")) or a 4 columns unsigned byte array + (Default: black). + :param int linewidth: The width of the lines of the crosshair + (Default: 1). + :param str linestyle: Type of line:: + + - ' ' no line + - '-' solid line (the default) + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + """ + if flag: + self._cursorConfiguration = color, linewidth, linestyle + else: + self._cursorConfiguration = None + + self._backend.setGraphCursor(flag=flag, color=color, + linewidth=linewidth, linestyle=linestyle) + self._setDirtyPlot() + self.notify('setGraphCursor', + state=self._cursorConfiguration is not None) + + def pan(self, direction, factor=0.1): + """Pan the graph in the given direction by the given factor. + + Warning: Pan of right Y axis not implemented! + + :param str direction: One of 'up', 'down', 'left', 'right'. + :param float factor: Proportion of the range used to pan the graph. + Must be strictly positive. + """ + assert direction in ('up', 'down', 'left', 'right') + assert factor > 0. + + if direction in ('left', 'right'): + xFactor = factor if direction == 'right' else - factor + xMin, xMax = self._xAxis.getLimits() + + xMin, xMax = _utils.applyPan(xMin, xMax, xFactor, + self._xAxis.getScale() == self._xAxis.LOGARITHMIC) + self._xAxis.setLimits(xMin, xMax) + + else: # direction in ('up', 'down') + sign = -1. if self._yAxis.isInverted() else 1. + yFactor = sign * (factor if direction == 'up' else -factor) + yMin, yMax = self._yAxis.getLimits() + yIsLog = self._yAxis.getScale() == self._yAxis.LOGARITHMIC + + yMin, yMax = _utils.applyPan(yMin, yMax, yFactor, yIsLog) + self._yAxis.setLimits(yMin, yMax) + + y2Min, y2Max = self._yRightAxis.getLimits() + + y2Min, y2Max = _utils.applyPan(y2Min, y2Max, yFactor, yIsLog) + self._yRightAxis.setLimits(y2Min, y2Max) + + # Active Curve/Image + + def isActiveCurveHandling(self): + """Returns True if active curve selection is enabled. + + :rtype: bool + """ + return self.getActiveCurveSelectionMode() != 'none' + + def setActiveCurveHandling(self, flag=True): + """Enable/Disable active curve selection. + + :param bool flag: True to enable 'atmostone' active curve selection, + False to disable active curve selection. + """ + self.setActiveCurveSelectionMode('atmostone' if flag else 'none') + + def getActiveCurveStyle(self): + """Returns the current style applied to active curve + + :rtype: CurveStyle + """ + return self._activeCurveStyle + + def setActiveCurveStyle(self, + color=None, + linewidth=None, + linestyle=None, + symbol=None, + symbolsize=None): + """Set the style of active curve + + :param color: Color + :param Union[str,None] linestyle: Style of the line + :param Union[float,None] linewidth: Width of the line + :param Union[str,None] symbol: Symbol of the markers + :param Union[float,None] symbolsize: Size of the symbols + """ + self._activeCurveStyle = CurveStyle(color=color, + linewidth=linewidth, + linestyle=linestyle, + symbol=symbol, + symbolsize=symbolsize) + curve = self.getActiveCurve() + if curve is not None: + curve.setHighlightedStyle(self.getActiveCurveStyle()) + + @deprecated(replacement="getActiveCurveStyle", since_version="0.9") + def getActiveCurveColor(self): + """Get the color used to display the currently active curve. + + See :meth:`setActiveCurveColor`. + """ + return self._activeCurveStyle.getColor() + + @deprecated(replacement="setActiveCurveStyle", since_version="0.9") + def setActiveCurveColor(self, color="#000000"): + """Set the color to use to display the currently active curve. + + :param str color: Color of the active curve, + e.g., 'blue', 'b', '#FF0000' (Default: 'black') + """ + if color is None: + color = "black" + if color in self.colorDict: + color = self.colorDict[color] + self.setActiveCurveStyle(color=color) + + def getActiveCurve(self, just_legend=False): + """Return the currently active curve. + + It returns None in case of not having an active curve. + + :param bool just_legend: True to get the legend of the curve, + False (the default) to get the curve data + and info. + :return: Active curve's legend or corresponding + :class:`.items.Curve` + :rtype: str or :class:`.items.Curve` or None + """ + if not self.isActiveCurveHandling(): + return None + + return self._getActiveItem(kind='curve', just_legend=just_legend) + + def setActiveCurve(self, legend): + """Make the curve associated to legend the active curve. + + :param legend: The legend associated to the curve + or None to have no active curve. + :type legend: str or None + """ + if not self.isActiveCurveHandling(): + return + if legend is None and self.getActiveCurveSelectionMode() == "legacy": + _logger.info( + 'setActiveCurve(None) ignored due to active curve selection mode') + return + + return self._setActiveItem(kind='curve', legend=legend) + + def setActiveCurveSelectionMode(self, mode): + """Sets the current selection mode. + + :param str mode: The active curve selection mode to use. + It can be: 'legacy', 'atmostone' or 'none'. + """ + assert mode in ('legacy', 'atmostone', 'none') + + if mode != self._activeCurveSelectionMode: + self._activeCurveSelectionMode = mode + if mode == 'none': # reset active curve + self._setActiveItem(kind='curve', legend=None) + + elif mode == 'legacy' and self.getActiveCurve() is None: + # Select an active curve + curves = self.getAllCurves(just_legend=False, + withhidden=False) + if len(curves) == 1: + if curves[0].isVisible(): + self.setActiveCurve(curves[0].getName()) + + def getActiveCurveSelectionMode(self): + """Returns the current selection mode. + + It can be "atmostone", "legacy" or "none". + + :rtype: str + """ + return self._activeCurveSelectionMode + + def getActiveImage(self, just_legend=False): + """Returns the currently active image. + + It returns None in case of not having an active image. + + :param bool just_legend: True to get the legend of the image, + False (the default) to get the image data + and info. + :return: Active image's legend or corresponding image object + :rtype: str, :class:`.items.ImageData`, :class:`.items.ImageRgba` + or None + """ + return self._getActiveItem(kind='image', just_legend=just_legend) + + def setActiveImage(self, legend): + """Make the image associated to legend the active image. + + :param str legend: The legend associated to the image + or None to have no active image. + """ + return self._setActiveItem(kind='image', legend=legend) + + def getActiveScatter(self, just_legend=False): + """Returns the currently active scatter. + + It returns None in case of not having an active scatter. + + :param bool just_legend: True to get the legend of the scatter, + False (the default) to get the scatter data + and info. + :return: Active scatter's legend or corresponding scatter object + :rtype: str, :class:`.items.Scatter` or None + """ + return self._getActiveItem(kind='scatter', just_legend=just_legend) + + def setActiveScatter(self, legend): + """Make the scatter associated to legend the active scatter. + + :param str legend: The legend associated to the scatter + or None to have no active scatter. + """ + return self._setActiveItem(kind='scatter', legend=legend) + + def _getActiveItem(self, kind, just_legend=False): + """Return the currently active item of that kind if any + + :param str kind: Type of item: 'curve', 'scatter' or 'image' + :param bool just_legend: True to get the legend, + False (default) to get the item + :return: legend or item or None if no active item + """ + assert kind in self._ACTIVE_ITEM_KINDS + + if self._activeLegend[kind] is None: + return None + + item = self._getItem(kind, self._activeLegend[kind]) + if item is None: + return None + + return item.getName() if just_legend else item + + def _setActiveItem(self, kind, legend): + """Make the curve associated to legend the active curve. + + :param str kind: Type of item: 'curve' or 'image' + :param legend: The legend associated to the curve + or None to have no active curve. + :type legend: str or None + """ + assert kind in self._ACTIVE_ITEM_KINDS + + xLabel = None + yLabel = None + yRightLabel = None + + oldActiveItem = self._getActiveItem(kind=kind) + + if oldActiveItem is not None: # Stop listening previous active image + oldActiveItem.sigItemChanged.disconnect(self._activeItemChanged) + + # Curve specific: Reset highlight of previous active curve + if kind == 'curve' and oldActiveItem is not None: + oldActiveItem.setHighlighted(False) + + if legend is None: + self._activeLegend[kind] = None + else: + legend = str(legend) + item = self._getItem(kind, legend) + if item is None: + _logger.warning("This %s does not exist: %s", kind, legend) + self._activeLegend[kind] = None + else: + self._activeLegend[kind] = legend + + # Curve specific: handle highlight + if kind == 'curve': + item.setHighlightedStyle(self.getActiveCurveStyle()) + item.setHighlighted(True) + + if isinstance(item, items.LabelsMixIn): + if item.getXLabel() is not None: + xLabel = item.getXLabel() + if item.getYLabel() is not None: + if (isinstance(item, items.YAxisMixIn) and + item.getYAxis() == 'right'): + yRightLabel = item.getYLabel() + else: + yLabel = item.getYLabel() + + # Start listening new active item + item.sigItemChanged.connect(self._activeItemChanged) + + # Store current labels and update plot + self._xAxis._setCurrentLabel(xLabel) + self._yAxis._setCurrentLabel(yLabel) + self._yRightAxis._setCurrentLabel(yRightLabel) + + self._setDirtyPlot() + + activeLegend = self._activeLegend[kind] + if oldActiveItem is not None or activeLegend is not None: + if oldActiveItem is None: + oldActiveLegend = None + else: + oldActiveLegend = oldActiveItem.getName() + self.notify( + 'active' + kind[0].upper() + kind[1:] + 'Changed', + updated=oldActiveLegend != activeLegend, + previous=oldActiveLegend, + legend=activeLegend) + + return activeLegend + + def _activeItemChanged(self, type_): + """Listen for active item changed signal and broadcast signal + + :param item.ItemChangedType type_: The type of item change + """ + if not self.__muteActiveItemChanged: + item = self.sender() + if item is not None: + kind = self._itemKind(item) + self.notify( + 'active' + kind[0].upper() + kind[1:] + 'Changed', + updated=False, + previous=item.getName(), + legend=item.getName()) + + # Getters + + def getAllCurves(self, just_legend=False, withhidden=False): + """Returns all curves legend or info and data. + + It returns an empty list in case of not having any curve. + + If just_legend is False, it returns a list of :class:`items.Curve` + objects describing the curves. + If just_legend is True, it returns a list of curves' legend. + + :param bool just_legend: True to get the legend of the curves, + False (the default) to get the curves' data + and info. + :param bool withhidden: False (default) to skip hidden curves. + :return: list of curves' legend or :class:`.items.Curve` + :rtype: list of str or list of :class:`.items.Curve` + """ + curves = [item for item in self.getItems() if + isinstance(item, items.Curve) and + (withhidden or item.isVisible())] + return [curve.getName() for curve in curves] if just_legend else curves + + def getCurve(self, legend=None): + """Get the object describing a specific curve. + + It returns None in case no matching curve is found. + + :param str legend: + The legend identifying the curve. + If not provided or None (the default), the active curve is returned + or if there is no active curve, the latest updated curve that is + not hidden is returned if there are curves in the plot. + :return: None or :class:`.items.Curve` object + """ + return self._getItem(kind='curve', legend=legend) + + def getAllImages(self, just_legend=False): + """Returns all images legend or objects. + + It returns an empty list in case of not having any image. + + If just_legend is False, it returns a list of :class:`items.ImageBase` + objects describing the images. + If just_legend is True, it returns a list of legends. + + :param bool just_legend: True to get the legend of the images, + False (the default) to get the images' + object. + :return: list of images' legend or :class:`.items.ImageBase` + :rtype: list of str or list of :class:`.items.ImageBase` + """ + images = [item for item in self.getItems() + if isinstance(item, items.ImageBase)] + return [image.getName() for image in images] if just_legend else images + + def getImage(self, legend=None): + """Get the object describing a specific image. + + It returns None in case no matching image is found. + + :param str legend: + The legend identifying the image. + If not provided or None (the default), the active image is returned + or if there is no active image, the latest updated image + is returned if there are images in the plot. + :return: None or :class:`.items.ImageBase` object + """ + return self._getItem(kind='image', legend=legend) + + def getScatter(self, legend=None): + """Get the object describing a specific scatter. + + It returns None in case no matching scatter is found. + + :param str legend: + The legend identifying the scatter. + If not provided or None (the default), the active scatter is + returned or if there is no active scatter, the latest updated + scatter is returned if there are scatters in the plot. + :return: None or :class:`.items.Scatter` object + """ + return self._getItem(kind='scatter', legend=legend) + + def getHistogram(self, legend=None): + """Get the object describing a specific histogram. + + It returns None in case no matching histogram is found. + + :param str legend: + The legend identifying the histogram. + If not provided or None (the default), the latest updated scatter + is returned if there are histograms in the plot. + :return: None or :class:`.items.Histogram` object + """ + return self._getItem(kind='histogram', legend=legend) + + @deprecated(replacement='getItems', since_version='0.13') + def _getItems(self, kind=ITEM_KINDS, just_legend=False, withhidden=False): + """Retrieve all items of a kind in the plot + + :param kind: The kind of elements to retrieve from the plot. + See :attr:`ITEM_KINDS`. + By default, it removes all kind of elements. + :type kind: str or tuple of str to specify multiple kinds. + :param str kind: Type of item: 'curve' or 'image' + :param bool just_legend: True to get the legend of the curves, + False (the default) to get the curves' data + and info. + :param bool withhidden: False (default) to skip hidden curves. + :return: list of legends or item objects + """ + if kind == 'all': # Replace all by tuple of all kinds + kind = self.ITEM_KINDS + + if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple + kind = (kind,) + + for aKind in kind: + assert aKind in self.ITEM_KINDS + + output = [] + for item in self.getItems(): + type_ = self._itemKind(item) + if type_ in kind and (withhidden or item.isVisible()): + output.append(item.getName() if just_legend else item) + return output + + def _getItem(self, kind, legend=None): + """Get an item from the plot: either an image or a curve. + + Returns None if no match found. + + :param str kind: Type of item to retrieve, + see :attr:`ITEM_KINDS`. + :param str legend: Legend of the item or + None to get active or last item + :return: Object describing the item or None + """ + assert kind in self.ITEM_KINDS + + if legend is not None: + return self._content.get((legend, kind), None) + else: + if kind in self._ACTIVE_ITEM_KINDS: + item = self._getActiveItem(kind=kind) + if item is not None: # Return active item if available + return item + # Return last visible item if any + itemClasses = self._KIND_TO_CLASSES[kind] + allItems = [item for item in self.getItems() + if isinstance(item, itemClasses) and item.isVisible()] + return allItems[-1] if allItems else None + + # Limits + + def _notifyLimitsChanged(self, emitSignal=True): + """Send an event when plot area limits are changed.""" + xRange = self._xAxis.getLimits() + yRange = self._yAxis.getLimits() + y2Range = self._yRightAxis.getLimits() + if emitSignal: + axes = self.getXAxis(), self.getYAxis(), self.getYAxis(axis="right") + ranges = xRange, yRange, y2Range + for axis, limits in zip(axes, ranges): + axis.sigLimitsChanged.emit(*limits) + event = PlotEvents.prepareLimitsChangedSignal( + id(self.getWidgetHandle()), xRange, yRange, y2Range) + self.notify(**event) + + def getLimitsHistory(self): + """Returns the object handling the history of limits of the plot""" + return self._limitsHistory + + def getGraphXLimits(self): + """Get the graph X (bottom) limits. + + :return: Minimum and maximum values of the X axis + """ + return self._backend.getGraphXLimits() + + def setGraphXLimits(self, xmin, xmax): + """Set the graph X (bottom) limits. + + :param float xmin: minimum bottom axis value + :param float xmax: maximum bottom axis value + """ + self._xAxis.setLimits(xmin, xmax) + + def getGraphYLimits(self, axis='left'): + """Get the graph Y limits. + + :param str axis: The axis for which to get the limits: + Either 'left' or 'right' + :return: Minimum and maximum values of the X axis + """ + assert axis in ('left', 'right') + yAxis = self._yAxis if axis == 'left' else self._yRightAxis + return yAxis.getLimits() + + def setGraphYLimits(self, ymin, ymax, axis='left'): + """Set the graph Y limits. + + :param float ymin: minimum bottom axis value + :param float ymax: maximum bottom axis value + :param str axis: The axis for which to get the limits: + Either 'left' or 'right' + """ + assert axis in ('left', 'right') + yAxis = self._yAxis if axis == 'left' else self._yRightAxis + return yAxis.setLimits(ymin, ymax) + + def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None): + """Set the limits of the X and Y axes at once. + + If y2min or y2max is None, the right Y axis limits are not updated. + + :param float xmin: minimum bottom axis value + :param float xmax: maximum bottom axis value + :param float ymin: minimum left axis value + :param float ymax: maximum left axis value + :param float y2min: minimum right axis value or None (the default) + :param float y2max: maximum right axis value or None (the default) + """ + # Deal with incorrect values + axis = self.getXAxis() + xmin, xmax = axis._checkLimits(xmin, xmax) + axis = self.getYAxis() + ymin, ymax = axis._checkLimits(ymin, ymax) + + if y2min is None or y2max is None: + # if one limit is None, both are ignored + y2min, y2max = None, None + else: + axis = self.getYAxis(axis="right") + y2min, y2max = axis._checkLimits(y2min, y2max) + + if self._viewConstrains: + view = self._viewConstrains.normalize(xmin, xmax, ymin, ymax) + xmin, xmax, ymin, ymax = view + + self._backend.setLimits(xmin, xmax, ymin, ymax, y2min, y2max) + self._setDirtyPlot() + self._notifyLimitsChanged() + + def _getViewConstraints(self): + """Return the plot object managing constaints on the plot view. + + :rtype: ViewConstraints + """ + if self._viewConstrains is None: + self._viewConstrains = ViewConstraints() + return self._viewConstrains + + # Title and labels + + def getGraphTitle(self): + """Return the plot main title as a str.""" + return self._graphTitle + + def setGraphTitle(self, title=""): + """Set the plot main title. + + :param str title: Main title of the plot (default: '') + """ + self._graphTitle = str(title) + self._backend.setGraphTitle(title) + self._setDirtyPlot() + + def getGraphXLabel(self): + """Return the current X axis label as a str.""" + return self._xAxis.getLabel() + + def setGraphXLabel(self, label="X"): + """Set the plot X axis label. + + The provided label can be temporarily replaced by the X label of the + active curve if any. + + :param str label: The X axis label (default: 'X') + """ + self._xAxis.setLabel(label) + + def getGraphYLabel(self, axis='left'): + """Return the current Y axis label as a str. + + :param str axis: The Y axis for which to get the label (left or right) + """ + assert axis in ('left', 'right') + yAxis = self._yAxis if axis == 'left' else self._yRightAxis + return yAxis.getLabel() + + def setGraphYLabel(self, label="Y", axis='left'): + """Set the plot Y axis label. + + The provided label can be temporarily replaced by the Y label of the + active curve if any. + + :param str label: The Y axis label (default: 'Y') + :param str axis: The Y axis for which to set the label (left or right) + """ + assert axis in ('left', 'right') + yAxis = self._yAxis if axis == 'left' else self._yRightAxis + return yAxis.setLabel(label) + + # Axes + + def getXAxis(self): + """Returns the X axis + + .. versionadded:: 0.6 + + :rtype: :class:`.items.Axis` + """ + return self._xAxis + + def getYAxis(self, axis="left"): + """Returns an Y axis + + .. versionadded:: 0.6 + + :param str axis: The Y axis to return + ('left' or 'right'). + :rtype: :class:`.items.Axis` + """ + assert(axis in ["left", "right"]) + return self._yAxis if axis == "left" else self._yRightAxis + + def setAxesDisplayed(self, displayed: bool): + """Display or not the axes. + + :param bool displayed: If `True` axes are displayed. If `False` axes + are not anymore visible and the margin used for them is removed. + """ + if displayed != self.__axesDisplayed: + self.__axesDisplayed = displayed + if displayed: + self._backend.setAxesMargins(*self.__axesMargins) + else: + self._backend.setAxesMargins(0., 0., 0., 0.) + self._setDirtyPlot() + self._sigAxesVisibilityChanged.emit(displayed) + + def isAxesDisplayed(self) -> bool: + """Returns whether or not axes are currently displayed + + :rtype: bool + """ + return self.__axesDisplayed + + def setAxesMargins( + self, left: float, top: float, right: float, bottom: float): + """Set ratios of margins surrounding data plot area. + + All ratios must be within [0., 1.]. + Sums of ratios of opposed side must be < 1. + + :param float left: Left-side margin ratio. + :param float top: Top margin ratio + :param float right: Right-side margin ratio + :param float bottom: Bottom margin ratio + :raises ValueError: + """ + for value in (left, top, right, bottom): + if value < 0. or value > 1.: + raise ValueError("Margin ratios must be within [0., 1.]") + if left + right >= 1. or top + bottom >= 1.: + raise ValueError("Sum of ratios of opposed sides >= 1") + margins = left, top, right, bottom + + if margins != self.__axesMargins: + self.__axesMargins = margins + if self.isAxesDisplayed(): # Only apply if axes are displayed + self._backend.setAxesMargins(*margins) + self._setDirtyPlot() + + def getAxesMargins(self): + """Returns ratio of margins surrounding data plot area. + + :return: (left, top, right, bottom) + :rtype: List[float] + """ + return self.__axesMargins + + def setYAxisInverted(self, flag=True): + """Set the Y axis orientation. + + :param bool flag: True for Y axis going from top to bottom, + False for Y axis going from bottom to top + """ + self._yAxis.setInverted(flag) + + def isYAxisInverted(self): + """Return True if Y axis goes from top to bottom, False otherwise.""" + return self._yAxis.isInverted() + + def isXAxisLogarithmic(self): + """Return True if X axis scale is logarithmic, False if linear.""" + return self._xAxis._isLogarithmic() + + def setXAxisLogarithmic(self, flag): + """Set the bottom X axis scale (either linear or logarithmic). + + :param bool flag: True to use a logarithmic scale, False for linear. + """ + self._xAxis._setLogarithmic(flag) + + def isYAxisLogarithmic(self): + """Return True if Y axis scale is logarithmic, False if linear.""" + return self._yAxis._isLogarithmic() + + def setYAxisLogarithmic(self, flag): + """Set the Y axes scale (either linear or logarithmic). + + :param bool flag: True to use a logarithmic scale, False for linear. + """ + self._yAxis._setLogarithmic(flag) + + def isXAxisAutoScale(self): + """Return True if X axis is automatically adjusting its limits.""" + return self._xAxis.isAutoScale() + + def setXAxisAutoScale(self, flag=True): + """Set the X axis limits adjusting behavior of :meth:`resetZoom`. + + :param bool flag: True to resize limits automatically, + False to disable it. + """ + self._xAxis.setAutoScale(flag) + + def isYAxisAutoScale(self): + """Return True if Y axes are automatically adjusting its limits.""" + return self._yAxis.isAutoScale() + + def setYAxisAutoScale(self, flag=True): + """Set the Y axis limits adjusting behavior of :meth:`resetZoom`. + + :param bool flag: True to resize limits automatically, + False to disable it. + """ + self._yAxis.setAutoScale(flag) + + def isKeepDataAspectRatio(self): + """Returns whether the plot is keeping data aspect ratio or not.""" + return self._backend.isKeepDataAspectRatio() + + def setKeepDataAspectRatio(self, flag=True): + """Set whether the plot keeps data aspect ratio or not. + + :param bool flag: True to respect data aspect ratio + """ + flag = bool(flag) + if flag == self.isKeepDataAspectRatio(): + return + self._backend.setKeepDataAspectRatio(flag=flag) + self._setDirtyPlot() + self._forceResetZoom() + self.notify('setKeepDataAspectRatio', state=flag) + + def getGraphGrid(self): + """Return the current grid mode, either None, 'major' or 'both'. + + See :meth:`setGraphGrid`. + """ + return self._grid + + def setGraphGrid(self, which=True): + """Set the type of grid to display. + + :param which: None or False to disable the grid, + 'major' or True for grid on major ticks (the default), + 'both' for grid on both major and minor ticks. + :type which: str of bool + """ + assert which in (None, True, False, 'both', 'major') + if not which: + which = None + elif which is True: + which = 'major' + self._grid = which + self._backend.setGraphGrid(which) + self._setDirtyPlot() + self.notify('setGraphGrid', which=str(which)) + + # Defaults + + def isDefaultPlotPoints(self): + """Return True if the default Curve symbol is set and False if not.""" + return self._defaultPlotPoints == silx.config.DEFAULT_PLOT_SYMBOL + + def setDefaultPlotPoints(self, flag): + """Set the default symbol of all curves. + + When called, this reset the symbol of all existing curves. + + :param bool flag: True to use 'o' as the default curve symbol, + False to use no symbol. + """ + self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else '' + + # Reset symbol of all curves + curves = self.getAllCurves(just_legend=False, withhidden=True) + + if curves: + for curve in curves: + curve.setSymbol(self._defaultPlotPoints) + + def isDefaultPlotLines(self): + """Return True for line as default line style, False for no line.""" + return self._plotLines + + def setDefaultPlotLines(self, flag): + """Toggle the use of lines as the default curve line style. + + :param bool flag: True to use a line as the default line style, + False to use no line as the default line style. + """ + self._plotLines = bool(flag) + + linestyle = '-' if self._plotLines else ' ' + + # Reset linestyle of all curves + curves = self.getAllCurves(withhidden=True) + + if curves: + for curve in curves: + curve.setLineStyle(linestyle) + + def getDefaultColormap(self): + """Return the default colormap used by :meth:`addImage`. + + :rtype: ~silx.gui.colors.Colormap + """ + return self._defaultColormap + + def setDefaultColormap(self, colormap=None): + """Set the default colormap used by :meth:`addImage`. + + Setting the default colormap do not change any currently displayed + image. + It only affects future calls to :meth:`addImage` without the colormap + parameter. + + :param ~silx.gui.colors.Colormap colormap: + The description of the default colormap, or + None to set the colormap to a linear + autoscale gray colormap. + """ + if colormap is None: + colormap = Colormap(name=silx.config.DEFAULT_COLORMAP_NAME, + normalization='linear', + vmin=None, + vmax=None) + if isinstance(colormap, dict): + self._defaultColormap = Colormap._fromDict(colormap) + else: + assert isinstance(colormap, Colormap) + self._defaultColormap = colormap + self.notify('defaultColormapChanged') + + @staticmethod + def getSupportedColormaps(): + """Get the supported colormap names as a tuple of str. + + The list contains at least: + ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue', + 'magma', 'inferno', 'plasma', 'viridis') + """ + return Colormap.getSupportedColormaps() + + def _resetColorAndStyle(self): + self._colorIndex = 0 + self._styleIndex = 0 + + def _getColorAndStyle(self): + color = self.colorList[self._colorIndex] + style = self._styleList[self._styleIndex] + + # Loop over color and then styles + self._colorIndex += 1 + if self._colorIndex >= len(self.colorList): + self._colorIndex = 0 + self._styleIndex = (self._styleIndex + 1) % len(self._styleList) + + # If color is the one of active curve, take the next one + if colors.rgba(color) == self.getActiveCurveStyle().getColor(): + color, style = self._getColorAndStyle() + + if not self._plotLines: + style = ' ' + + return color, style + + # Misc. + + def getWidgetHandle(self): + """Return the widget the plot is displayed in. + + This widget is owned by the backend. + """ + return self._backend.getWidgetHandle() + + def notify(self, event, **kwargs): + """Send an event to the listeners and send signals. + + Event are passed to the registered callback as a dict with an 'event' + key for backward compatibility with PyMca. + + :param str event: The type of event + :param kwargs: The information of the event. + """ + eventDict = kwargs.copy() + eventDict['event'] = event + self.sigPlotSignal.emit(eventDict) + + if event == 'setKeepDataAspectRatio': + self.sigSetKeepDataAspectRatio.emit(kwargs['state']) + elif event == 'setGraphGrid': + self.sigSetGraphGrid.emit(kwargs['which']) + elif event == 'setGraphCursor': + self.sigSetGraphCursor.emit(kwargs['state']) + elif event == 'contentChanged': + self.sigContentChanged.emit( + kwargs['action'], kwargs['kind'], kwargs['legend']) + elif event == 'activeCurveChanged': + self.sigActiveCurveChanged.emit( + kwargs['previous'], kwargs['legend']) + elif event == 'activeImageChanged': + self.sigActiveImageChanged.emit( + kwargs['previous'], kwargs['legend']) + elif event == 'activeScatterChanged': + self.sigActiveScatterChanged.emit( + kwargs['previous'], kwargs['legend']) + elif event == 'interactiveModeChanged': + self.sigInteractiveModeChanged.emit(kwargs['source']) + + eventDict = kwargs.copy() + eventDict['event'] = event + self._callback(eventDict) + + def setCallback(self, callbackFunction=None): + """Attach a listener to the backend. + + Limitation: Only one listener at a time. + + :param callbackFunction: function accepting a dictionary as input + to handle the graph events + If None (default), use a default listener. + """ + # TODO allow multiple listeners + # allow register listener by event type + if callbackFunction is None: + callbackFunction = WeakMethodProxy(self.graphCallback) + self._callback = callbackFunction + + def graphCallback(self, ddict=None): + """This callback is going to receive all the events from the plot. + + Those events will consist on a dictionary and among the dictionary + keys the key 'event' is mandatory to describe the type of event. + This default implementation only handles setting the active curve. + """ + + if ddict is None: + ddict = {} + _logger.debug("Received dict keys = %s", str(ddict.keys())) + _logger.debug(str(ddict)) + if ddict['event'] in ["legendClicked", "curveClicked"]: + if ddict['button'] == "left": + self.setActiveCurve(ddict['label']) + qt.QToolTip.showText(self.cursor().pos(), ddict['label']) + elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left': + self.setActiveCurve(None) + + def saveGraph(self, filename, fileFormat=None, dpi=None): + """Save a snapshot of the plot. + + Supported file formats depends on the backend in use. + The following file formats are always supported: "png", "svg". + The matplotlib backend supports more formats: + "pdf", "ps", "eps", "tiff", "jpeg", "jpg". + + :param filename: Destination + :type filename: str, StringIO or BytesIO + :param str fileFormat: String specifying the format + :return: False if cannot save the plot, True otherwise + """ + if fileFormat is None: + if not hasattr(filename, 'lower'): + _logger.warning( + 'saveGraph cancelled, cannot define file format.') + return False + else: + fileFormat = (filename.split(".")[-1]).lower() + + supportedFormats = ("png", "svg", "pdf", "ps", "eps", + "tif", "tiff", "jpeg", "jpg") + + if fileFormat not in supportedFormats: + _logger.warning('Unsupported format %s', fileFormat) + return False + else: + self._backend.saveGraph(filename, + fileFormat=fileFormat, + dpi=dpi) + return True + + def getDataMargins(self): + """Get the default data margin ratios, see :meth:`setDataMargins`. + + :return: The margin ratios for each side (xMin, xMax, yMin, yMax). + :rtype: A 4-tuple of floats. + """ + return self._defaultDataMargins + + def setDataMargins(self, xMinMargin=0., xMaxMargin=0., + yMinMargin=0., yMaxMargin=0.): + """Set the default data margins to use in :meth:`resetZoom`. + + Set the default ratios of margins (as floats) to add around the data + inside the plot area for each side. + """ + self._defaultDataMargins = (xMinMargin, xMaxMargin, + yMinMargin, yMaxMargin) + + def getAutoReplot(self): + """Return True if replot is automatically handled, False otherwise. + + See :meth`setAutoReplot`. + """ + return self._autoreplot + + def setAutoReplot(self, autoreplot=True): + """Set automatic replot mode. + + When enabled, the plot is redrawn automatically when changed. + When disabled, the plot is not redrawn when its content change. + Instead, it :meth:`replot` must be called. + + :param bool autoreplot: True to enable it (default), + False to disable it. + """ + self._autoreplot = bool(autoreplot) + + # If the plot is dirty before enabling autoreplot, + # then _backend.postRedisplay will never be called from _setDirtyPlot + if self._autoreplot and self._getDirtyPlot(): + self._backend.postRedisplay() + + @contextmanager + def _paintContext(self): + """This context MUST surround backend rendering. + + It is in charge of performing required PlotWidget operations + """ + for item in self._contentToUpdate: + item._update(self._backend) + + self._contentToUpdate = [] + yield + self._dirty = False # reset dirty flag + + def replot(self): + """Request to draw the plot.""" + self._backend.replot() + + def _forceResetZoom(self, dataMargins=None): + """Reset the plot limits to the bounds of the data and redraw the plot. + + This method forces a reset zoom and does not check axis autoscale. + + Extra margins can be added around the data inside the plot area + (see :meth:`setDataMargins`). + Margins are given as one ratio of the data range per limit of the + data (xMin, xMax, yMin and yMax limits). + For log scale, extra margins are applied in log10 of the data. + + :param dataMargins: Ratios of margins to add around the data inside + the plot area for each side (default: no margins). + :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax). + """ + if dataMargins is None: + dataMargins = self._defaultDataMargins + + # Get data range + ranges = self.getDataRange() + xmin, xmax = (1., 100.) if ranges.x is None else ranges.x + ymin, ymax = (1., 100.) if ranges.y is None else ranges.y + if ranges.yright is None: + ymin2, ymax2 = ymin, ymax + else: + ymin2, ymax2 = ranges.yright + if ranges.y is None: + ymin, ymax = ranges.yright + + # Add margins around data inside the plot area + newLimits = list(_utils.addMarginsToLimits( + dataMargins, + self._xAxis._isLogarithmic(), + self._yAxis._isLogarithmic(), + xmin, xmax, ymin, ymax, ymin2, ymax2)) + + if self.isKeepDataAspectRatio(): + # Use limits with margins to keep ratio + xmin, xmax, ymin, ymax = newLimits[:4] + + # Compute bbox wth figure aspect ratio + plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + if plotWidth > 0 and plotHeight > 0: + plotRatio = plotHeight / plotWidth + dataRatio = (ymax - ymin) / (xmax - xmin) + if dataRatio < plotRatio: + # Increase y range + ycenter = 0.5 * (ymax + ymin) + yrange = (xmax - xmin) * plotRatio + newLimits[2] = ycenter - 0.5 * yrange + newLimits[3] = ycenter + 0.5 * yrange + + elif dataRatio > plotRatio: + # Increase x range + xcenter = 0.5 * (xmax + xmin) + xrange_ = (ymax - ymin) / plotRatio + newLimits[0] = xcenter - 0.5 * xrange_ + newLimits[1] = xcenter + 0.5 * xrange_ + + self.setLimits(*newLimits) + + def resetZoom(self, dataMargins=None): + """Reset the plot limits to the bounds of the data and redraw the plot. + + It automatically scale limits of axes that are in autoscale mode + (see :meth:`getXAxis`, :meth:`getYAxis` and :meth:`Axis.setAutoScale`). + It keeps current limits on axes that are not in autoscale mode. + + Extra margins can be added around the data inside the plot area + (see :meth:`setDataMargins`). + Margins are given as one ratio of the data range per limit of the + data (xMin, xMax, yMin and yMax limits). + For log scale, extra margins are applied in log10 of the data. + + :param dataMargins: Ratios of margins to add around the data inside + the plot area for each side (default: no margins). + :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax). + """ + xLimits = self._xAxis.getLimits() + yLimits = self._yAxis.getLimits() + y2Limits = self._yRightAxis.getLimits() + + xAuto = self._xAxis.isAutoScale() + yAuto = self._yAxis.isAutoScale() + + # With log axes, autoscale if limits are <= 0 + # This avoids issues with toggling log scale with matplotlib 2.1.0 + if self._xAxis.getScale() == self._xAxis.LOGARITHMIC and xLimits[0] <= 0: + xAuto = True + if self._yAxis.getScale() == self._yAxis.LOGARITHMIC and (yLimits[0] <= 0 or y2Limits[0] <= 0): + yAuto = True + + if not xAuto and not yAuto: + _logger.debug("Nothing to autoscale") + else: # Some axes to autoscale + self._forceResetZoom(dataMargins=dataMargins) + + # Restore limits for axis not in autoscale + if not xAuto and yAuto: + self.setGraphXLimits(*xLimits) + elif xAuto and not yAuto: + if y2Limits is not None: + self.setGraphYLimits( + y2Limits[0], y2Limits[1], axis='right') + if yLimits is not None: + self.setGraphYLimits(yLimits[0], yLimits[1], axis='left') + + if (xLimits != self._xAxis.getLimits() or + yLimits != self._yAxis.getLimits() or + y2Limits != self._yRightAxis.getLimits()): + self._notifyLimitsChanged() + + # Coord conversion + + def dataToPixel(self, x=None, y=None, axis="left", check=True): + """Convert a position in data coordinates to a position in pixels. + + :param float x: The X coordinate in data space. If None (default) + the middle position of the displayed data is used. + :param float y: The Y coordinate in data space. If None (default) + the middle position of the displayed data is used. + :param str axis: The Y axis to use for the conversion + ('left' or 'right'). + :param bool check: True to return None if outside displayed area, + False to convert to pixels anyway + :returns: The corresponding position in pixels or + None if the data position is not in the displayed area and + check is True. + :rtype: A tuple of 2 floats: (xPixel, yPixel) or None. + """ + assert axis in ("left", "right") + + xmin, xmax = self._xAxis.getLimits() + yAxis = self.getYAxis(axis=axis) + ymin, ymax = yAxis.getLimits() + + if x is None: + x = 0.5 * (xmax + xmin) + if y is None: + y = 0.5 * (ymax + ymin) + + if check: + if x > xmax or x < xmin: + return None + + if y > ymax or y < ymin: + return None + + return self._backend.dataToPixel(x, y, axis=axis) + + def pixelToData(self, x, y, axis="left", check=False): + """Convert a position in pixels to a position in data coordinates. + + :param float x: The X coordinate in pixels. If None (default) + the center of the widget is used. + :param float y: The Y coordinate in pixels. If None (default) + the center of the widget is used. + :param str axis: The Y axis to use for the conversion + ('left' or 'right'). + :param bool check: Toggle checking if pixel is in plot area. + If False, this method never returns None. + :returns: The corresponding position in data space or + None if the pixel position is not in the plot area. + :rtype: A tuple of 2 floats: (xData, yData) or None. + """ + assert axis in ("left", "right") + + if x is None: + x = self.width() // 2 + if y is None: + y = self.height() // 2 + + if check: + left, top, width, height = self.getPlotBoundsInPixels() + if not (left <= x <= left + width and top <= y <= top + height): + return None + + return self._backend.pixelToData(x, y, axis) + + def getPlotBoundsInPixels(self): + """Plot area bounds in widget coordinates in pixels. + + :return: bounds as a 4-tuple of int: (left, top, width, height) + """ + return self._backend.getPlotBoundsInPixels() + + # Interaction support + + def getGraphCursorShape(self): + """Returns the current cursor shape. + + :rtype: str + """ + return self.__graphCursorShape + + def setGraphCursorShape(self, cursor=None): + """Set the cursor shape. + + :param str cursor: Name of the cursor shape + """ + self.__graphCursorShape = cursor + self._backend.setGraphCursorShape(cursor) + + @deprecated(replacement='getItems', since_version='0.13') + def _getAllMarkers(self, just_legend=False): + markers = [item for item in self.getItems() if isinstance(item, items.MarkerBase)] + if just_legend: + return [marker.getName() for marker in markers] + else: + return markers + + def _getMarkerAt(self, x, y): + """Return the most interactive marker at a location, else None + + :param float x: X position in pixels + :param float y: Y position in pixels + :rtype: None of marker object + """ + def checkDraggable(item): + return isinstance(item, items.MarkerBase) and item.isDraggable() + def checkSelectable(item): + return isinstance(item, items.MarkerBase) and item.isSelectable() + def check(item): + return isinstance(item, items.MarkerBase) + + result = self._pickTopMost(x, y, checkDraggable) + if not result: + result = self._pickTopMost(x, y, checkSelectable) + if not result: + result = self._pickTopMost(x, y, check) + marker = result.getItem() if result is not None else None + return marker + + def _getMarker(self, legend=None): + """Get the object describing a specific marker. + + It returns None in case no matching marker is found + + :param str legend: The legend of the marker to retrieve + :rtype: None of marker object + """ + return self._getItem(kind='marker', legend=legend) + + def pickItems(self, x, y, condition=None): + """Generator of picked items in the plot at given position. + + Items are returned from front to back. + + :param float x: X position in pixels + :param float y: Y position in pixels + :param callable condition: + Callable taking an item as input and returning False for items + to skip during picking. If None (default) no item is skipped. + :return: Iterable of :class:`PickingResult` objects at picked position. + Items are ordered from front to back. + """ + for item in reversed(self._backend.getItemsFromBackToFront(condition=condition)): + result = item.pick(x, y) + if result is not None: + yield result + + def _pickTopMost(self, x, y, condition=None): + """Returns top-most picked item in the plot at given position. + + Items are checked from front to back. + + :param float x: X position in pixels + :param float y: Y position in pixels + :param callable condition: + Callable taking an item as input and returning False for items + to skip during picking. If None (default) no item is skipped. + :return: :class:`PickingResult` object at picked position. + If no item is picked, it returns None + :rtype: Union[None,PickingResult] + """ + for result in self.pickItems(x, y, condition): + return result + return None + + # User event handling # + + def _isPositionInPlotArea(self, x, y): + """Project position in pixel to the closest point in the plot area + + :param float x: X coordinate in widget coordinate (in pixel) + :param float y: Y coordinate in widget coordinate (in pixel) + :return: (x, y) in widget coord (in pixel) in the plot area + """ + left, top, width, height = self.getPlotBoundsInPixels() + xPlot = numpy.clip(x, left, left + width) + yPlot = numpy.clip(y, top, top + height) + return xPlot, yPlot + + def onMousePress(self, xPixel, yPixel, btn): + """Handle mouse press event. + + :param float xPixel: X mouse position in pixels + :param float yPixel: Y mouse position in pixels + :param str btn: Mouse button in 'left', 'middle', 'right' + """ + if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel): + self._pressedButtons.append(btn) + self._eventHandler.handleEvent('press', xPixel, yPixel, btn) + + def onMouseMove(self, xPixel, yPixel): + """Handle mouse move event. + + :param float xPixel: X mouse position in pixels + :param float yPixel: Y mouse position in pixels + """ + inXPixel, inYPixel = self._isPositionInPlotArea(xPixel, yPixel) + isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel + + if self._cursorInPlot != isCursorInPlot: + self._cursorInPlot = isCursorInPlot + self._eventHandler.handleEvent( + 'enter' if self._cursorInPlot else 'leave') + + if isCursorInPlot: + # Signal mouse move event + dataPos = self.pixelToData(inXPixel, inYPixel) + assert dataPos is not None + + btn = self._pressedButtons[-1] if self._pressedButtons else None + event = PlotEvents.prepareMouseSignal( + 'mouseMoved', btn, dataPos[0], dataPos[1], xPixel, yPixel) + self.notify(**event) + + # Either button was pressed in the plot or cursor is in the plot + if isCursorInPlot or self._pressedButtons: + self._eventHandler.handleEvent('move', inXPixel, inYPixel) + + def onMouseRelease(self, xPixel, yPixel, btn): + """Handle mouse release event. + + :param float xPixel: X mouse position in pixels + :param float yPixel: Y mouse position in pixels + :param str btn: Mouse button in 'left', 'middle', 'right' + """ + try: + self._pressedButtons.remove(btn) + except ValueError: + pass + else: + xPixel, yPixel = self._isPositionInPlotArea(xPixel, yPixel) + self._eventHandler.handleEvent('release', xPixel, yPixel, btn) + + def onMouseWheel(self, xPixel, yPixel, angleInDegrees): + """Handle mouse wheel event. + + :param float xPixel: X mouse position in pixels + :param float yPixel: Y mouse position in pixels + :param float angleInDegrees: Angle corresponding to wheel motion. + Positive for movement away from the user, + negative for movement toward the user. + """ + if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel): + self._eventHandler.handleEvent( + 'wheel', xPixel, yPixel, angleInDegrees) + + def onMouseLeaveWidget(self): + """Handle mouse leave widget event.""" + if self._cursorInPlot: + self._cursorInPlot = False + self._eventHandler.handleEvent('leave') + + # Interaction modes # + + def getInteractiveMode(self): + """Returns the current interactive mode as a dict. + + The returned dict contains at least the key 'mode'. + Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'. + It can also contains extra keys (e.g., 'color') specific to a mode + as provided to :meth:`setInteractiveMode`. + """ + return self._eventHandler.getInteractiveMode() + + def resetInteractiveMode(self): + """Reset the interactive mode to use the previous basic interactive + mode used. + + It can be one of "zoom" or "pan". + """ + mode, zoomOnWheel = self._previousDefaultMode + self.setInteractiveMode(mode=mode, zoomOnWheel=zoomOnWheel) + + def setInteractiveMode(self, mode, color='black', + shape='polygon', label=None, + zoomOnWheel=True, source=None, width=None): + """Switch the interactive mode. + + :param str mode: The name of the interactive mode. + In 'draw', 'pan', 'select', 'select-draw', 'zoom'. + :param color: Only for 'draw' and 'zoom' modes. + Color to use for drawing selection area. Default black. + :type color: Color description: The name as a str or + a tuple of 4 floats. + :param str shape: Only for 'draw' mode. The kind of shape to draw. + In 'polygon', 'rectangle', 'line', 'vline', 'hline', + 'freeline'. + Default is 'polygon'. + :param str label: Only for 'draw' mode, sent in drawing events. + :param bool zoomOnWheel: Toggle zoom on wheel support + :param source: A user-defined object (typically the caller object) + that will be send in the interactiveModeChanged event, + to identify which object required a mode change. + Default: None + :param float width: Width of the pencil. Only for draw pencil mode. + """ + self._eventHandler.setInteractiveMode(mode, color, shape, label, width) + self._eventHandler.zoomOnWheel = zoomOnWheel + if mode in ["pan", "zoom"]: + self._previousDefaultMode = mode, zoomOnWheel + + self.notify( + 'interactiveModeChanged', source=source) + + # Panning with arrow keys + + def isPanWithArrowKeys(self): + """Returns whether or not panning the graph with arrow keys is enabled. + + See :meth:`setPanWithArrowKeys`. + """ + return self._panWithArrowKeys + + def setPanWithArrowKeys(self, pan=False): + """Enable/Disable panning the graph with arrow keys. + + This grabs the keyboard. + + :param bool pan: True to enable panning, False to disable. + """ + pan = bool(pan) + panHasChanged = self._panWithArrowKeys != pan + + self._panWithArrowKeys = pan + if not self._panWithArrowKeys: + self.setFocusPolicy(qt.Qt.NoFocus) + else: + self.setFocusPolicy(qt.Qt.StrongFocus) + self.setFocus(qt.Qt.OtherFocusReason) + + if panHasChanged: + self.sigSetPanWithArrowKeys.emit(pan) + + # Dict to convert Qt arrow key code to direction str. + _ARROWS_TO_PAN_DIRECTION = { + qt.Qt.Key_Left: 'left', + qt.Qt.Key_Right: 'right', + qt.Qt.Key_Up: 'up', + qt.Qt.Key_Down: 'down' + } + + def __simulateMouseMove(self): + qapp = qt.QApplication.instance() + event = qt.QMouseEvent( + qt.QEvent.MouseMove, + self.getWidgetHandle().mapFromGlobal(qt.QCursor.pos()), + qt.Qt.NoButton, + qapp.mouseButtons(), + qapp.keyboardModifiers()) + qapp.sendEvent(self.getWidgetHandle(), event) + + def keyPressEvent(self, event): + """Key event handler handling panning on arrow keys. + + Overrides base class implementation. + """ + key = event.key() + if self._panWithArrowKeys and key in self._ARROWS_TO_PAN_DIRECTION: + self.pan(self._ARROWS_TO_PAN_DIRECTION[key], factor=0.1) + + # Send a mouse move event to the plot widget to take into account + # that even if mouse didn't move on the screen, it moved relative + # to the plotted data. + self.__simulateMouseMove() + else: + # Only call base class implementation when key is not handled. + # See QWidget.keyPressEvent for details. + super(PlotWidget, self).keyPressEvent(event) diff --git a/src/silx/gui/plot/PlotWindow.py b/src/silx/gui/plot/PlotWindow.py new file mode 100644 index 0000000..0349585 --- /dev/null +++ b/src/silx/gui/plot/PlotWindow.py @@ -0,0 +1,993 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A :class:`.PlotWidget` with additional toolbars. + +The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`. +""" + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "12/04/2019" + +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc +import logging +import weakref + +import silx +from silx.utils.weakref import WeakMethodProxy +from silx.utils.deprecation import deprecated +from silx.utils.proxy import docstring + +from . import PlotWidget +from . import actions +from . import items +from .actions import medfilt as actions_medfilt +from .actions import fit as actions_fit +from .actions import control as actions_control +from .actions import histogram as actions_histogram +from . import PlotToolButtons +from . import tools +from .Profile import ProfileToolBar +from .LegendSelector import LegendsDockWidget +from .CurvesROIWidget import CurvesROIDockWidget +from .MaskToolsWidget import MaskToolsDockWidget +from .StatsWidget import BasicStatsWidget +from .ColorBar import ColorBarWidget +try: + from ..console import IPythonDockWidget +except ImportError: + IPythonDockWidget = None + +from .. import qt + + +_logger = logging.getLogger(__name__) + + +class PlotWindow(PlotWidget): + """Qt Widget providing a 1D/2D plot area and additional tools. + + This widgets inherits from :class:`.PlotWidget` and provides its plot API. + + Initialiser parameters: + + :param parent: The parent of this widget or None. + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.PlotWidget` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + :param bool resetzoom: Toggle visibility of reset zoom action. + :param bool autoScale: Toggle visibility of axes autoscale actions. + :param bool logScale: Toggle visibility of axes log scale actions. + :param bool grid: Toggle visibility of grid mode action. + :param bool curveStyle: Toggle visibility of curve style action. + :param bool colormap: Toggle visibility of colormap action. + :param bool aspectRatio: Toggle visibility of aspect ratio button. + :param bool yInverted: Toggle visibility of Y axis direction button. + :param bool copy: Toggle visibility of copy action. + :param bool save: Toggle visibility of save action. + :param bool print_: Toggle visibility of print action. + :param bool control: True to display an Options button with a sub-menu + to show legends, toggle crosshair and pan with arrows. + (Default: False) + :param position: True to display widget with (x, y) mouse position + (Default: False). + It also supports a list of (name, funct(x, y)->value) + to customize the displayed values. + See :class:`~silx.gui.plot.tools.PositionInfo`. + :param bool roi: Toggle visibilty of ROI action. + :param bool mask: Toggle visibilty of mask action. + :param bool fit: Toggle visibilty of fit action. + """ + + def __init__(self, parent=None, backend=None, + resetzoom=True, autoScale=True, logScale=True, grid=True, + curveStyle=True, colormap=True, + aspectRatio=True, yInverted=True, + copy=True, save=True, print_=True, + control=False, position=False, + roi=True, mask=True, fit=False): + super(PlotWindow, self).__init__(parent=parent, backend=backend) + if parent is None: + self.setWindowTitle('PlotWindow') + + self._dockWidgets = [] + + # lazy loaded dock widgets + self._legendsDockWidget = None + self._curvesROIDockWidget = None + self._maskToolsDockWidget = None + self._consoleDockWidget = None + self._statsDockWidget = None + + # Create color bar, hidden by default for backward compatibility + self._colorbar = ColorBarWidget(parent=self, plot=self) + + # Init actions + self.group = qt.QActionGroup(self) + self.group.setExclusive(False) + + self.resetZoomAction = self.group.addAction( + actions.control.ResetZoomAction(self, parent=self)) + self.resetZoomAction.setVisible(resetzoom) + self.addAction(self.resetZoomAction) + + self.zoomInAction = actions.control.ZoomInAction(self, parent=self) + self.addAction(self.zoomInAction) + + self.zoomOutAction = actions.control.ZoomOutAction(self, parent=self) + self.addAction(self.zoomOutAction) + + self.xAxisAutoScaleAction = self.group.addAction( + actions.control.XAxisAutoScaleAction(self, parent=self)) + self.xAxisAutoScaleAction.setVisible(autoScale) + self.addAction(self.xAxisAutoScaleAction) + + self.yAxisAutoScaleAction = self.group.addAction( + actions.control.YAxisAutoScaleAction(self, parent=self)) + self.yAxisAutoScaleAction.setVisible(autoScale) + self.addAction(self.yAxisAutoScaleAction) + + self.xAxisLogarithmicAction = self.group.addAction( + actions.control.XAxisLogarithmicAction(self, parent=self)) + self.xAxisLogarithmicAction.setVisible(logScale) + self.addAction(self.xAxisLogarithmicAction) + + self.yAxisLogarithmicAction = self.group.addAction( + actions.control.YAxisLogarithmicAction(self, parent=self)) + self.yAxisLogarithmicAction.setVisible(logScale) + self.addAction(self.yAxisLogarithmicAction) + + self.gridAction = self.group.addAction( + actions.control.GridAction(self, gridMode='both', parent=self)) + self.gridAction.setVisible(grid) + self.addAction(self.gridAction) + + self.curveStyleAction = self.group.addAction( + actions.control.CurveStyleAction(self, parent=self)) + self.curveStyleAction.setVisible(curveStyle) + self.addAction(self.curveStyleAction) + + self.colormapAction = self.group.addAction( + actions.control.ColormapAction(self, parent=self)) + self.colormapAction.setVisible(colormap) + self.addAction(self.colormapAction) + + self.colorbarAction = self.group.addAction( + actions_control.ColorBarAction(self, parent=self)) + self.colorbarAction.setVisible(False) + self.addAction(self.colorbarAction) + self._colorbar.setVisible(False) + + self.keepDataAspectRatioButton = PlotToolButtons.AspectToolButton( + parent=self, plot=self) + self.keepDataAspectRatioButton.setVisible(aspectRatio) + + self.yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton( + parent=self, plot=self) + self.yAxisInvertedButton.setVisible(yInverted) + + self.group.addAction(self.getRoiAction()) + self.getRoiAction().setVisible(roi) + + self.group.addAction(self.getMaskAction()) + self.getMaskAction().setVisible(mask) + + self._intensityHistoAction = self.group.addAction( + actions_histogram.PixelIntensitiesHistoAction(self, parent=self)) + self._intensityHistoAction.setVisible(False) + + self._medianFilter2DAction = self.group.addAction( + actions_medfilt.MedianFilter2DAction(self, parent=self)) + self._medianFilter2DAction.setVisible(False) + + self._medianFilter1DAction = self.group.addAction( + actions_medfilt.MedianFilter1DAction(self, parent=self)) + self._medianFilter1DAction.setVisible(False) + + self.fitAction = self.group.addAction(actions_fit.FitAction(self, parent=self)) + self.fitAction.setVisible(fit) + self.addAction(self.fitAction) + + # lazy loaded actions needed by the controlButton menu + self._consoleAction = None + self._statsAction = None + self._panWithArrowKeysAction = None + self._crosshairAction = None + + # Make colorbar background white + self._colorbar.setAutoFillBackground(True) + self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground) + self._updateColorBarBackground() + + if control: # Create control button only if requested + self.controlButton = qt.QToolButton() + self.controlButton.setText("Options") + self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon) + self.controlButton.setAutoRaise(True) + self.controlButton.setPopupMode(qt.QToolButton.InstantPopup) + menu = qt.QMenu(self) + menu.aboutToShow.connect(self._customControlButtonMenu) + self.controlButton.setMenu(menu) + + self._positionWidget = None + if position: # Add PositionInfo widget to the bottom of the plot + if isinstance(position, abc.Iterable): + # Use position as a set of converters + converters = position + else: + converters = None + self._positionWidget = tools.PositionInfo( + plot=self, converters=converters) + # Set a snapping mode that is consistent with legacy one + self._positionWidget.setSnappingMode( + tools.PositionInfo.SNAPPING_CROSSHAIR | + tools.PositionInfo.SNAPPING_ACTIVE_ONLY | + tools.PositionInfo.SNAPPING_SYMBOLS_ONLY | + tools.PositionInfo.SNAPPING_CURVE | + tools.PositionInfo.SNAPPING_SCATTER) + + self.__setCentralWidget() + + # Creating the toolbar also create actions for toolbuttons + self._interactiveModeToolBar = tools.InteractiveModeToolBar( + parent=self, plot=self) + self.addToolBar(self._interactiveModeToolBar) + + self._toolbar = self._createToolBar(title='Plot', parent=self) + self.addToolBar(self._toolbar) + + self._outputToolBar = tools.OutputToolBar(parent=self, plot=self) + self._outputToolBar.getCopyAction().setVisible(copy) + self._outputToolBar.getSaveAction().setVisible(save) + self._outputToolBar.getPrintAction().setVisible(print_) + self.addToolBar(self._outputToolBar) + + # Activate shortcuts in PlotWindow widget: + for toolbar in (self._interactiveModeToolBar, self._outputToolBar): + for action in toolbar.actions(): + self.addAction(action) + + def __setCentralWidget(self): + """Set central widget to host plot backend, colorbar, and bottom bar""" + gridLayout = qt.QGridLayout() + gridLayout.setSpacing(0) + gridLayout.setContentsMargins(0, 0, 0, 0) + gridLayout.addWidget(self.getWidgetHandle(), 0, 0) + gridLayout.addWidget(self._colorbar, 0, 1) + gridLayout.setRowStretch(0, 1) + gridLayout.setColumnStretch(0, 1) + centralWidget = qt.QWidget(self) + centralWidget.setLayout(gridLayout) + + if hasattr(self, "controlButton") or self._positionWidget is not None: + hbox = qt.QHBoxLayout() + hbox.setContentsMargins(0, 0, 0, 0) + + if hasattr(self, "controlButton"): + hbox.addWidget(self.controlButton) + + if self._positionWidget is not None: + hbox.addWidget(self._positionWidget) + + hbox.addStretch(1) + bottomBar = qt.QWidget(centralWidget) + bottomBar.setLayout(hbox) + + gridLayout.addWidget(bottomBar, 1, 0, 1, -1) + + self.setCentralWidget(centralWidget) + + @docstring(PlotWidget) + def setBackend(self, backend): + super(PlotWindow, self).setBackend(backend) + self.__setCentralWidget() # Recreate PlotWindow's central widget + + @docstring(PlotWidget) + def setBackgroundColor(self, color): + super(PlotWindow, self).setBackgroundColor(color) + self._updateColorBarBackground() + + @docstring(PlotWidget) + def setDataBackgroundColor(self, color): + super(PlotWindow, self).setDataBackgroundColor(color) + self._updateColorBarBackground() + + @docstring(PlotWidget) + def setForegroundColor(self, color): + super(PlotWindow, self).setForegroundColor(color) + self._updateColorBarBackground() + + def _updateColorBarBackground(self): + """Update the colorbar background according to the state of the plot""" + if self.isAxesDisplayed(): + color = self.getBackgroundColor() + else: + color = self.getDataBackgroundColor() + if not color.isValid(): + # If no color defined, use the background one + color = self.getBackgroundColor() + + foreground = self.getForegroundColor() + + palette = self._colorbar.palette() + palette.setColor(qt.QPalette.Window, color) + palette.setColor(qt.QPalette.WindowText, foreground) + palette.setColor(qt.QPalette.Text, foreground) + self._colorbar.setPalette(palette) + + def getInteractiveModeToolBar(self): + """Returns QToolBar controlling interactive mode. + + :rtype: QToolBar + """ + return self._interactiveModeToolBar + + def getOutputToolBar(self): + """Returns QToolBar containing save, copy and print actions + + :rtype: QToolBar + """ + return self._outputToolBar + + @property + @deprecated(replacement="getPositionInfoWidget()", since_version="0.8.0") + def positionWidget(self): + return self.getPositionInfoWidget() + + def getPositionInfoWidget(self): + """Returns the widget displaying current cursor position information + + :rtype: ~silx.gui.plot.tools.PositionInfo + """ + return self._positionWidget + + def getSelectionMask(self): + """Return the current mask handled by :attr:`maskToolsDockWidget`. + + :return: The array of the mask with dimension of the 'active' image. + If there is no active image, an empty array is returned. + :rtype: 2D numpy.ndarray of uint8 + """ + return self.getMaskToolsDockWidget().getSelectionMask() + + def setSelectionMask(self, mask): + """Set the mask handled by :attr:`maskToolsDockWidget`. + + If the provided mask has not the same dimension as the 'active' + image, it will by cropped or padded. + + :param mask: The array to use for the mask. + :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous. + Array of other types are converted. + :return: True if success, False if failed + """ + return bool(self.getMaskToolsDockWidget().setSelectionMask(mask)) + + def _toggleConsoleVisibility(self, isChecked=False): + """Create IPythonDockWidget if needed, + show it or hide it.""" + # create widget if needed (first call) + if self._consoleDockWidget is None: + available_vars = {"plt": weakref.proxy(self)} + banner = "The variable 'plt' is available. Use the 'whos' " + banner += "and 'help(plt)' commands for more information.\n\n" + self._consoleDockWidget = IPythonDockWidget( + available_vars=available_vars, + custom_banner=banner, + parent=self) + self.addTabbedDockWidget(self._consoleDockWidget) + # self._consoleDockWidget.setVisible(True) + self._consoleDockWidget.toggleViewAction().toggled.connect( + self.getConsoleAction().setChecked) + + self._consoleDockWidget.setVisible(isChecked) + + def _toggleStatsVisibility(self, isChecked=False): + self.getStatsWidget().parent().setVisible(isChecked) + + def _createToolBar(self, title, parent): + """Create a QToolBar from the QAction of the PlotWindow. + + :param str title: The title of the QMenu + :param qt.QWidget parent: See :class:`QToolBar` + """ + toolbar = qt.QToolBar(title, parent) + + # Order widgets with actions + objects = self.group.actions() + + # Add push buttons to list + index = objects.index(self.colormapAction) + objects.insert(index + 1, self.keepDataAspectRatioButton) + objects.insert(index + 2, self.yAxisInvertedButton) + + for obj in objects: + if isinstance(obj, qt.QAction): + toolbar.addAction(obj) + else: + # Add action for toolbutton in order to allow changing + # visibility (see doc QToolBar.addWidget doc) + if obj is self.keepDataAspectRatioButton: + self.keepDataAspectRatioAction = toolbar.addWidget(obj) + elif obj is self.yAxisInvertedButton: + self.yAxisInvertedAction = toolbar.addWidget(obj) + else: + raise RuntimeError() + return toolbar + + def toolBar(self): + """Return a QToolBar from the QAction of the PlotWindow. + """ + return self._toolbar + + def menu(self, title='Plot', parent=None): + """Return a QMenu from the QAction of the PlotWindow. + + :param str title: The title of the QMenu + :param parent: See :class:`QMenu` + """ + menu = qt.QMenu(title, parent) + for action in self.group.actions(): + menu.addAction(action) + return menu + + def _customControlButtonMenu(self): + """Display Options button sub-menu.""" + controlMenu = self.controlButton.menu() + controlMenu.clear() + controlMenu.addAction(self.getLegendsDockWidget().toggleViewAction()) + controlMenu.addAction(self.getRoiAction()) + controlMenu.addAction(self.getStatsAction()) + controlMenu.addAction(self.getMaskAction()) + controlMenu.addAction(self.getConsoleAction()) + + controlMenu.addSeparator() + controlMenu.addAction(self.getCrosshairAction()) + controlMenu.addAction(self.getPanWithArrowKeysAction()) + + def addTabbedDockWidget(self, dock_widget): + """Add a dock widget as a new tab if there are already dock widgets + in the plot. When the first tab is added, the area is chosen + depending on the plot geometry: + if the window is much wider than it is high, the right dock area + is used, else the bottom dock area is used. + + :param dock_widget: Instance of :class:`QDockWidget` to be added. + """ + if dock_widget not in self._dockWidgets: + self._dockWidgets.append(dock_widget) + if len(self._dockWidgets) == 1: + # The first created dock widget must be added to a Widget area + width = self.centralWidget().width() + height = self.centralWidget().height() + if width > (1.25 * height): + area = qt.Qt.RightDockWidgetArea + else: + area = qt.Qt.BottomDockWidgetArea + self.addDockWidget(area, dock_widget) + else: + # Other dock widgets are added as tabs to the same widget area + self.tabifyDockWidget(self._dockWidgets[0], + dock_widget) + + def removeDockWidget(self, dockwidget): + """Removes the *dockwidget* from the main window layout and hides it. + + Note that the *dockwidget* is *not* deleted. + + :param QDockWidget dockwidget: + """ + if dockwidget in self._dockWidgets: + self._dockWidgets.remove(dockwidget) + super(PlotWindow, self).removeDockWidget(dockwidget) + + def _handleFirstDockWidgetShow(self, visible): + """Handle QDockWidget.visibilityChanged + + It calls :meth:`addTabbedDockWidget` for the `sender` widget. + This allows to call `addTabbedDockWidget` lazily. + + It disconnect itself from the signal once done. + + :param bool visible: + """ + if visible: + dockWidget = self.sender() + dockWidget.visibilityChanged.disconnect( + self._handleFirstDockWidgetShow) + self.addTabbedDockWidget(dockWidget) + + def getColorBarWidget(self): + """Returns the embedded :class:`ColorBarWidget` widget. + + :rtype: ColorBarWidget + """ + return self._colorbar + + # getters for dock widgets + + def getLegendsDockWidget(self): + """DockWidget with Legend panel""" + if self._legendsDockWidget is None: + self._legendsDockWidget = LegendsDockWidget(plot=self) + self._legendsDockWidget.hide() + self._legendsDockWidget.visibilityChanged.connect( + self._handleFirstDockWidgetShow) + return self._legendsDockWidget + + def getCurvesRoiDockWidget(self): + # Undocumented for a "soft deprecation" in version 0.7.0 + # (still used internally for lazy loading) + if self._curvesROIDockWidget is None: + self._curvesROIDockWidget = CurvesROIDockWidget( + plot=self, name='Regions Of Interest') + self._curvesROIDockWidget.hide() + self._curvesROIDockWidget.visibilityChanged.connect( + self._handleFirstDockWidgetShow) + return self._curvesROIDockWidget + + def getCurvesRoiWidget(self): + """Return the :class:`CurvesROIWidget`. + + :class:`silx.gui.plot.CurvesROIWidget.CurvesROIWidget` offers a getter + and a setter for the ROI data: + + - :meth:`CurvesROIWidget.getRois` + - :meth:`CurvesROIWidget.setRois` + """ + return self.getCurvesRoiDockWidget().roiWidget + + def getMaskToolsDockWidget(self): + """DockWidget with image mask panel (lazy-loaded).""" + if self._maskToolsDockWidget is None: + self._maskToolsDockWidget = MaskToolsDockWidget( + plot=self, name='Mask') + self._maskToolsDockWidget.hide() + self._maskToolsDockWidget.visibilityChanged.connect( + self._handleFirstDockWidgetShow) + return self._maskToolsDockWidget + + def getStatsWidget(self): + """Returns a BasicStatsWidget connected to this plot + + :rtype: BasicStatsWidget + """ + if self._statsDockWidget is None: + self._statsDockWidget = qt.QDockWidget() + self._statsDockWidget.setWindowTitle("Curves stats") + self._statsDockWidget.layout().setContentsMargins(0, 0, 0, 0) + statsWidget = BasicStatsWidget(parent=self, plot=self) + self._statsDockWidget.setWidget(statsWidget) + statsWidget.sigVisibilityChanged.connect( + self.getStatsAction().setChecked) + self._statsDockWidget.hide() + self._statsDockWidget.visibilityChanged.connect( + self._handleFirstDockWidgetShow) + return self._statsDockWidget.widget() + + # getters for actions + @property + @deprecated(replacement="getInteractiveModeToolBar().getZoomModeAction()", + since_version="0.8.0") + def zoomModeAction(self): + return self.getInteractiveModeToolBar().getZoomModeAction() + + @property + @deprecated(replacement="getInteractiveModeToolBar().getPanModeAction()", + since_version="0.8.0") + def panModeAction(self): + return self.getInteractiveModeToolBar().getPanModeAction() + + def getConsoleAction(self): + """QAction handling the IPython console activation. + + By default, it is connected to a method that initializes the + console widget the first time the user clicks the "Console" menu + button. The following clicks, after initialization is done, + will toggle the visibility of the console widget. + + :rtype: QAction + """ + if self._consoleAction is None: + self._consoleAction = qt.QAction('Console', self) + self._consoleAction.setCheckable(True) + if IPythonDockWidget is not None: + self._consoleAction.toggled.connect(self._toggleConsoleVisibility) + else: + self._consoleAction.setEnabled(False) + return self._consoleAction + + def getCrosshairAction(self): + """Action toggling crosshair cursor mode. + + :rtype: actions.PlotAction + """ + if self._crosshairAction is None: + self._crosshairAction = actions.control.CrosshairAction(self, color='red') + return self._crosshairAction + + def getMaskAction(self): + """QAction toggling image mask dock widget + + :rtype: QAction + """ + return self.getMaskToolsDockWidget().toggleViewAction() + + def getPanWithArrowKeysAction(self): + """Action toggling pan with arrow keys. + + :rtype: actions.PlotAction + """ + if self._panWithArrowKeysAction is None: + self._panWithArrowKeysAction = actions.control.PanWithArrowKeysAction(self) + return self._panWithArrowKeysAction + + def getStatsAction(self): + if self._statsAction is None: + self._statsAction = qt.QAction('Curves stats', self) + self._statsAction.setCheckable(True) + self._statsAction.setChecked(self.getStatsWidget().parent().isVisible()) + self._statsAction.toggled.connect(self._toggleStatsVisibility) + return self._statsAction + + def getRoiAction(self): + """QAction toggling curve ROI dock widget + + :rtype: QAction + """ + return self.getCurvesRoiDockWidget().toggleViewAction() + + def getResetZoomAction(self): + """Action resetting the zoom + + :rtype: actions.PlotAction + """ + return self.resetZoomAction + + def getZoomInAction(self): + """Action to zoom in + + :rtype: actions.PlotAction + """ + return self.zoomInAction + + def getZoomOutAction(self): + """Action to zoom out + + :rtype: actions.PlotAction + """ + return self.zoomOutAction + + def getXAxisAutoScaleAction(self): + """Action to toggle the X axis autoscale on zoom reset + + :rtype: actions.PlotAction + """ + return self.xAxisAutoScaleAction + + def getYAxisAutoScaleAction(self): + """Action to toggle the Y axis autoscale on zoom reset + + :rtype: actions.PlotAction + """ + return self.yAxisAutoScaleAction + + def getXAxisLogarithmicAction(self): + """Action to toggle logarithmic X axis + + :rtype: actions.PlotAction + """ + return self.xAxisLogarithmicAction + + def getYAxisLogarithmicAction(self): + """Action to toggle logarithmic Y axis + + :rtype: actions.PlotAction + """ + return self.yAxisLogarithmicAction + + def getGridAction(self): + """Action to toggle the grid visibility in the plot + + :rtype: actions.PlotAction + """ + return self.gridAction + + def getCurveStyleAction(self): + """Action to change curve line and markers styles + + :rtype: actions.PlotAction + """ + return self.curveStyleAction + + def getColormapAction(self): + """Action open a colormap dialog to change active image + and default colormap. + + :rtype: actions.PlotAction + """ + return self.colormapAction + + def getKeepDataAspectRatioButton(self): + """Button to toggle aspect ratio preservation + + :rtype: PlotToolButtons.AspectToolButton + """ + return self.keepDataAspectRatioButton + + def getKeepDataAspectRatioAction(self): + """Action associated to keepDataAspectRatioButton. + Use this to change the visibility of keepDataAspectRatioButton in the + toolbar (See :meth:`QToolBar.addWidget` documentation). + + :rtype: actions.PlotAction + """ + return self.keepDataAspectRatioAction + + def getYAxisInvertedButton(self): + """Button to switch the Y axis orientation + + :rtype: PlotToolButtons.YAxisOriginToolButton + """ + return self.yAxisInvertedButton + + def getYAxisInvertedAction(self): + """Action associated to yAxisInvertedButton. + Use this to change the visibility yAxisInvertedButton in the toolbar. + (See :meth:`QToolBar.addWidget` documentation). + + :rtype: actions.PlotAction + """ + return self.yAxisInvertedAction + + def getIntensityHistogramAction(self): + """Action toggling the histogram intensity Plot widget + + :rtype: actions.PlotAction + """ + return self._intensityHistoAction + + def getCopyAction(self): + """Action to copy plot snapshot to clipboard + + :rtype: actions.PlotAction + """ + return self.getOutputToolBar().getCopyAction() + + def getSaveAction(self): + """Action to save plot + + :rtype: actions.PlotAction + """ + return self.getOutputToolBar().getSaveAction() + + def getPrintAction(self): + """Action to print plot + + :rtype: actions.PlotAction + """ + return self.getOutputToolBar().getPrintAction() + + def getFitAction(self): + """Action to fit selected curve + + :rtype: actions.PlotAction + """ + return self.fitAction + + def getMedianFilter1DAction(self): + """Action toggling the 1D median filter + + :rtype: actions.PlotAction + """ + return self._medianFilter1DAction + + def getMedianFilter2DAction(self): + """Action toggling the 2D median filter + + :rtype: actions.PlotAction + """ + return self._medianFilter2DAction + + def getColorBarAction(self): + """Action toggling the colorbar show/hide action + + .. warning:: to show/hide the plot colorbar call directly the ColorBar + widget using getColorBarWidget() + + :rtype: actions.PlotAction + """ + return self.colorbarAction + + +class Plot1D(PlotWindow): + """PlotWindow with tools specific for curves. + + This widgets provides the plot API of :class:`.PlotWidget`. + + :param parent: The parent of this widget + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.PlotWidget` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + """ + + def __init__(self, parent=None, backend=None): + super(Plot1D, self).__init__(parent=parent, backend=backend, + resetzoom=True, autoScale=True, + logScale=True, grid=True, + curveStyle=True, colormap=False, + aspectRatio=False, yInverted=False, + copy=True, save=True, print_=True, + control=True, position=True, + roi=True, mask=False, fit=True) + if parent is None: + self.setWindowTitle('Plot1D') + self.getXAxis().setLabel('X') + self.getYAxis().setLabel('Y') + action = self.getFitAction() + action.setXRangeUpdatedOnZoom(True) + action.setFittedItemUpdatedFromActiveCurve(True) + + +class Plot2D(PlotWindow): + """PlotWindow with a toolbar specific for images. + + This widgets provides the plot API of :~:`.PlotWidget`. + + :param parent: The parent of this widget + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.PlotWidget` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + """ + + def __init__(self, parent=None, backend=None): + # List of information to display at the bottom of the plot + posInfo = [ + ('X', lambda x, y: x), + ('Y', lambda x, y: y), + ('Data', WeakMethodProxy(self._getImageValue)), + ('Dims', WeakMethodProxy(self._getImageDims)), + ] + + super(Plot2D, self).__init__(parent=parent, backend=backend, + resetzoom=True, autoScale=False, + logScale=False, grid=False, + curveStyle=False, colormap=True, + aspectRatio=True, yInverted=True, + copy=True, save=True, print_=True, + control=False, position=posInfo, + roi=False, mask=True) + if parent is None: + self.setWindowTitle('Plot2D') + self.getXAxis().setLabel('Columns') + self.getYAxis().setLabel('Rows') + + if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward': + self.getYAxis().setInverted(True) + + self.profile = ProfileToolBar(plot=self) + self.addToolBar(self.profile) + + self.colorbarAction.setVisible(True) + self.getColorBarWidget().setVisible(True) + + # Put colorbar action after colormap action + actions = self.toolBar().actions() + for action in actions: + if action is self.getColormapAction(): + break + + self.sigActiveImageChanged.connect(self.__activeImageChanged) + + def __activeImageChanged(self, previous, legend): + """Handle change of active image + + :param Union[str,None] previous: Legend of previous active image + :param Union[str,None] legend: Legend of current active image + """ + if previous is not None: + item = self.getImage(previous) + if item is not None: + item.sigItemChanged.disconnect(self.__imageChanged) + + if legend is not None: + item = self.getImage(legend) + item.sigItemChanged.connect(self.__imageChanged) + + positionInfo = self.getPositionInfoWidget() + if positionInfo is not None: + positionInfo.updateInfo() + + def __imageChanged(self, event): + """Handle update of active image item + + :param event: Type of changed event + """ + if event == items.ItemChangedType.DATA: + positionInfo = self.getPositionInfoWidget() + if positionInfo is not None: + positionInfo.updateInfo() + + def _getImageValue(self, x, y): + """Get status bar value of top most image at position (x, y) + + :param float x: X position in plot coordinates + :param float y: Y position in plot coordinates + :return: The value at that point or '-' + """ + pickedMask = None + for picked in self.pickItems( + *self.dataToPixel(x, y, check=False), + lambda item: isinstance(item, items.ImageBase)): + if isinstance(picked.getItem(), items.MaskImageData): + if pickedMask is None: # Use top-most if many masks + pickedMask = picked + else: + image = picked.getItem() + + indices = picked.getIndices(copy=False) + if indices is not None: + row, col = indices[0][0], indices[1][0] + value = image.getData(copy=False)[row, col] + + if pickedMask is not None: # Check if masked + maskItem = pickedMask.getItem() + indices = pickedMask.getIndices() + row, col = indices[0][0], indices[1][0] + if maskItem.getData(copy=False)[row, col] != 0: + return value, "Masked" + return value + + return '-' # No image picked + + def _getImageDims(self, *args): + activeImage = self.getActiveImage() + if (activeImage is not None and + activeImage.getData(copy=False) is not None): + dims = activeImage.getData(copy=False).shape[1::-1] + return 'x'.join(str(dim) for dim in dims) + else: + return '-' + + def getProfileToolbar(self): + """Profile tools attached to this plot + + See :class:`silx.gui.plot.Profile.ProfileToolBar` + """ + return self.profile + + @deprecated(replacement="getProfilePlot", since_version="0.5.0") + def getProfileWindow(self): + return self.getProfilePlot() + + def getProfilePlot(self): + """Return plot window used to display profile curve. + + :return: :class:`Plot1D` + """ + return self.profile.getProfilePlot() diff --git a/src/silx/gui/plot/PrintPreviewToolButton.py b/src/silx/gui/plot/PrintPreviewToolButton.py new file mode 100644 index 0000000..30967e4 --- /dev/null +++ b/src/silx/gui/plot/PrintPreviewToolButton.py @@ -0,0 +1,388 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +This modules provides tool buttons to send the content of a plot to a +print preview page. +The plot content can then be moved on the page and resized prior to printing. + +Classes +------- + +- :class:`PrintPreviewToolButton` +- :class:`SingletonPrintPreviewToolButton` + +Examples +-------- + +Simple example +++++++++++++++ + +.. code-block:: python + + from silx.gui import qt + from silx.gui.plot import PlotWidget + from silx.gui.plot.PrintPreviewToolButton import PrintPreviewToolButton + import numpy + + app = qt.QApplication([]) + + pw = PlotWidget() + toolbar = qt.QToolBar(pw) + toolbutton = PrintPreviewToolButton(parent=toolbar, plot=pw) + pw.addToolBar(toolbar) + toolbar.addWidget(toolbutton) + pw.show() + + x = numpy.arange(1000) + y = x / numpy.sin(x) + pw.addCurve(x, y) + + app.exec() + +Singleton example ++++++++++++++++++ + +This example illustrates how to print the content of several different +plots on the same page. The plots all instantiate a +:class:`SingletonPrintPreviewToolButton`, which relies on a singleton widget +(:class:`silx.gui.widgets.PrintPreview.SingletonPrintPreviewDialog`). + +.. image:: img/printPreviewMultiPlot.png + +.. code-block:: python + + from silx.gui import qt + from silx.gui.plot import PlotWidget + from silx.gui.plot.PrintPreviewToolButton import SingletonPrintPreviewToolButton + import numpy + + app = qt.QApplication([]) + + plot_widgets = [] + + for i in range(3): + pw = PlotWidget() + toolbar = qt.QToolBar(pw) + toolbutton = SingletonPrintPreviewToolButton(parent=toolbar, + plot=pw) + pw.addToolBar(toolbar) + toolbar.addWidget(toolbutton) + pw.show() + plot_widgets.append(pw) + + x = numpy.arange(1000) + + plot_widgets[0].addCurve(x, numpy.sin(x * 2 * numpy.pi / 1000)) + plot_widgets[1].addCurve(x, numpy.cos(x * 2 * numpy.pi / 1000)) + plot_widgets[2].addCurve(x, numpy.tan(x * 2 * numpy.pi / 1000)) + + app.exec() + +""" +from __future__ import absolute_import + +import logging +from io import StringIO + +from .. import qt +from .. import icons +from . import PlotWidget +from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog +from ..widgets.PrintGeometryDialog import PrintGeometryDialog +from silx.utils.deprecation import deprecated + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "20/12/2018" + +_logger = logging.getLogger(__name__) +# _logger.setLevel(logging.DEBUG) + + +class PrintPreviewToolButton(qt.QToolButton): + """QToolButton to open a :class:`PrintPreviewDialog` (if not already open) + and add the current plot to its page to be printed. + + :param parent: See :class:`QAction` + :param plot: :class:`.PlotWidget` instance on which to operate + """ + def __init__(self, parent=None, plot=None): + super(PrintPreviewToolButton, self).__init__(parent) + + if not isinstance(plot, PlotWidget): + raise TypeError("plot parameter must be a PlotWidget") + self._plot = plot + + self.setIcon(icons.getQIcon('document-print')) + + printGeomAction = qt.QAction("Print geometry", self) + printGeomAction.setToolTip("Define a print geometry prior to sending " + "the plot to the print preview dialog") + printGeomAction.setIcon(icons.getQIcon('shape-rectangle')) + printGeomAction.triggered.connect(self._setPrintConfiguration) + + printPreviewAction = qt.QAction("Print preview", self) + printPreviewAction.setToolTip("Send plot to the print preview dialog") + printPreviewAction.setIcon(icons.getQIcon('document-print')) + printPreviewAction.triggered.connect(self._plotToPrintPreview) + + menu = qt.QMenu(self) + menu.addAction(printGeomAction) + menu.addAction(printPreviewAction) + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + self._printPreviewDialog = None + self._printConfigurationDialog = None + + self._printGeometry = {"xOffset": 0.1, + "yOffset": 0.1, + "width": 0.9, + "height": 0.9, + "units": "page", + "keepAspectRatio": True} + + @property + def printPreviewDialog(self): + """Lazy loaded :class:`PrintPreviewDialog`""" + # if changes are made here, don't forget making them in + # SingletonPrintPreviewToolButton.printPreviewDialog as well + if self._printPreviewDialog is None: + self._printPreviewDialog = PrintPreviewDialog(self.parent()) + return self._printPreviewDialog + + def getTitle(self): + """Implement this method to fetch the title in the plot. + + :return: Title to be printed above the plot, or None (no title added) + :rtype: str or None + """ + return None + + def getCommentAndPosition(self): + """Implement this method to fetch the legend to be printed below the + figure and its position. + + :return: Legend to be printed below the figure and its position: + "CENTER", "LEFT" or "RIGHT" + :rtype: (str, str) or (None, None) + """ + return None, None + + @property + @deprecated(since_version="0.10", + replacement="getPlot()") + def plot(self): + return self._plot + + def getPlot(self): + """Return the :class:`.PlotWidget` associated with this tool button. + + :rtype: :class:`.PlotWidget` + """ + return self._plot + + def _plotToPrintPreview(self): + """Grab the plot widget and send it to the print preview dialog. + Make sure the print preview dialog is shown and raised.""" + if not self.printPreviewDialog.ensurePrinterIsSet(): + return + + comment, commentPosition = self.getCommentAndPosition() + + if qt.HAS_SVG: + svgRenderer, viewBox = self._getSvgRendererAndViewbox() + self.printPreviewDialog.addSvgItem(svgRenderer, + title=self.getTitle(), + comment=comment, + commentPosition=commentPosition, + viewBox=viewBox, + keepRatio=self._printGeometry["keepAspectRatio"]) + else: + _logger.warning("Missing QtSvg library, using a raster image") + pixmap = self._plot.centralWidget().grab() + self.printPreviewDialog.addPixmap(pixmap, + title=self.getTitle(), + comment=comment, + commentPosition=commentPosition) + self.printPreviewDialog.show() + self.printPreviewDialog.raise_() + + def _getSvgRendererAndViewbox(self): + """Return a SVG renderer displaying the plot and its viewbox + (interactively specified by the user the first time this is called). + + The size of the renderer is adjusted to the printer configuration + and to the geometry configuration (width, height, ratio) specified + by the user.""" + imgData = StringIO() + assert self._plot.saveGraph(imgData, fileFormat="svg"), \ + "Unable to save graph" + imgData.flush() + imgData.seek(0) + svgData = imgData.read() + + svgRenderer = qt.QSvgRenderer() + + viewbox = self._getViewBox() + + svgRenderer.setViewBox(viewbox) + + xml_stream = qt.QXmlStreamReader(svgData.encode(errors="replace")) + + # This is for PyMca compatibility, to share a print preview with PyMca plots + svgRenderer._viewBox = viewbox + svgRenderer._svgRawData = svgData.encode(errors="replace") + svgRenderer._svgRendererData = xml_stream + + if not svgRenderer.load(xml_stream): + raise RuntimeError("Cannot interpret svg data") + + return svgRenderer, viewbox + + def _getViewBox(self): + """ + """ + printer = self.printPreviewDialog.printer + dpix = printer.logicalDpiX() + dpiy = printer.logicalDpiY() + availableWidth = printer.width() + availableHeight = printer.height() + + config = self._printGeometry + width = config['width'] + height = config['height'] + xOffset = config['xOffset'] + yOffset = config['yOffset'] + units = config['units'] + keepAspectRatio = config['keepAspectRatio'] + aspectRatio = self._getPlotAspectRatio() + + # convert the offsets to dots + if units.lower() in ['inch', 'inches']: + xOffset = xOffset * dpix + yOffset = yOffset * dpiy + if width is not None: + width = width * dpix + if height is not None: + height = height * dpiy + elif units.lower() in ['cm', 'centimeters']: + xOffset = (xOffset / 2.54) * dpix + yOffset = (yOffset / 2.54) * dpiy + if width is not None: + width = (width / 2.54) * dpix + if height is not None: + height = (height / 2.54) * dpiy + else: + # page units + xOffset = availableWidth * xOffset + yOffset = availableHeight * yOffset + if width is not None: + width = availableWidth * width + if height is not None: + height = availableHeight * height + + availableWidth -= xOffset + availableHeight -= yOffset + + if width is not None: + if (availableWidth + 0.1) < width: + txt = "Available width %f is less than requested width %f" % \ + (availableWidth, width) + raise ValueError(txt) + if height is not None: + if (availableHeight + 0.1) < height: + txt = "Available height %f is less than requested height %f" % \ + (availableHeight, height) + raise ValueError(txt) + + if keepAspectRatio: + bodyWidth = width or availableWidth + bodyHeight = bodyWidth * aspectRatio + + if bodyHeight > availableHeight: + bodyHeight = availableHeight + bodyWidth = bodyHeight / aspectRatio + + else: + bodyWidth = width or availableWidth + bodyHeight = height or availableHeight + + return qt.QRectF(xOffset, + yOffset, + bodyWidth, + bodyHeight) + + def _setPrintConfiguration(self): + """Open a dialog to prompt the user to adjust print + geometry parameters.""" + self.printPreviewDialog.ensurePrinterIsSet() + if self._printConfigurationDialog is None: + self._printConfigurationDialog = PrintGeometryDialog(self.parent()) + + self._printConfigurationDialog.setPrintGeometry(self._printGeometry) + if self._printConfigurationDialog.exec(): + self._printGeometry = self._printConfigurationDialog.getPrintGeometry() + + def _getPlotAspectRatio(self): + widget = self._plot.centralWidget() + graphWidth = float(widget.width()) + graphHeight = float(widget.height()) + return graphHeight / graphWidth + + +class SingletonPrintPreviewToolButton(PrintPreviewToolButton): + """This class is similar to its parent class :class:`PrintPreviewToolButton` + but it uses a singleton print preview widget. + + This allows for several plots to send their content to the + same print page, and for users to arrange them.""" + def __init__(self, parent=None, plot=None): + PrintPreviewToolButton.__init__(self, parent, plot) + + @property + def printPreviewDialog(self): + if self._printPreviewDialog is None: + self._printPreviewDialog = SingletonPrintPreviewDialog(self.parent()) + return self._printPreviewDialog + + +if __name__ == '__main__': + import numpy + app = qt.QApplication([]) + + pw = PlotWidget() + toolbar = qt.QToolBar(pw) + toolbutton = PrintPreviewToolButton(parent=toolbar, + plot=pw) + pw.addToolBar(toolbar) + toolbar.addWidget(toolbutton) + pw.show() + + x = numpy.arange(1000) + y = x / numpy.sin(x) + pw.addCurve(x, y) + + app.exec() diff --git a/src/silx/gui/plot/Profile.py b/src/silx/gui/plot/Profile.py new file mode 100644 index 0000000..7565155 --- /dev/null +++ b/src/silx/gui/plot/Profile.py @@ -0,0 +1,352 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Utility functions, toolbars and actions to create profile on images +and stacks of images""" + + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno"] +__license__ = "MIT" +__date__ = "12/04/2019" + + +import weakref + +from .. import qt +from . import actions +from .tools.profile import core +from .tools.profile import manager +from .tools.profile import rois +from silx.gui.widgets.MultiModeAction import MultiModeAction + +from silx.utils.deprecation import deprecated +from silx.utils.deprecation import deprecated_warning +from .tools import roi as roi_mdl +from silx.gui.plot import items + + +@deprecated(replacement="silx.gui.plot.tools.profile.createProfile", since_version="0.13.0") +def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): + return core.createProfile(roiInfo, currentData, origin, + scale, lineWidth, method) + + +class _CustomProfileManager(manager.ProfileManager): + """This custom profile manager uses a single predefined profile window + if it is specified. Else the behavior is the same as the default + ProfileManager """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__profileWindow = None + self.__specializedProfileWindows = {} + + def setSpecializedProfileWindow(self, roiClass, profileWindow): + """Set a profile window for a given class or ROI. + + Setting profileWindow to None removes the roiClass from the list. + + :param roiClass: + :param profileWindow: + """ + if profileWindow is None: + self.__specializedProfileWindows.pop(roiClass, None) + else: + self.__specializedProfileWindows[roiClass] = profileWindow + + def setProfileWindow(self, profileWindow): + self.__profileWindow = profileWindow + + def createProfileWindow(self, plot, roi): + for roiClass, specializedProfileWindow in self.__specializedProfileWindows.items(): + if isinstance(roi, roiClass): + return specializedProfileWindow + + if self.__profileWindow is not None: + return self.__profileWindow + else: + return super(_CustomProfileManager, self).createProfileWindow(plot, roi) + + def clearProfileWindow(self, profileWindow): + for specializedProfileWindow in self.__specializedProfileWindows.values(): + if profileWindow is specializedProfileWindow: + profileWindow.setProfile(None) + return + + if self.__profileWindow is not None: + self.__profileWindow.setProfile(None) + else: + return super(_CustomProfileManager, self).clearProfileWindow(profileWindow) + + +class ProfileToolBar(qt.QToolBar): + """QToolBar providing profile tools operating on a :class:`PlotWindow`. + + Attributes: + + - plot: Associated :class:`PlotWindow` on which the profile line is drawn. + - actionGroup: :class:`QActionGroup` of available actions. + + To run the following sample code, a QApplication must be initialized. + First, create a PlotWindow and add a :class:`ProfileToolBar`. + + >>> from silx.gui.plot import PlotWindow + >>> from silx.gui.plot.Profile import ProfileToolBar + + >>> plot = PlotWindow() # Create a PlotWindow + >>> toolBar = ProfileToolBar(plot=plot) # Create a profile toolbar + >>> plot.addToolBar(toolBar) # Add it to plot + >>> plot.show() # To display the PlotWindow with the profile toolbar + + :param plot: :class:`PlotWindow` instance on which to operate. + :param profileWindow: Plot widget instance where to + display the profile curve or None to create one. + :param str title: See :class:`QToolBar`. + :param parent: See :class:`QToolBar`. + """ + + def __init__(self, parent=None, plot=None, profileWindow=None, + title=None): + super(ProfileToolBar, self).__init__(title, parent) + assert plot is not None + + if title is not None: + deprecated_warning("Attribute", + name="title", + reason="removed", + since_version="0.13.0", + only_once=True, + skip_backtrace_count=1) + + self._plotRef = weakref.ref(plot) + + # If a profileWindow is defined, + # It will be used to display all the profiles + self._manager = self.createProfileManager(self, plot) + self._manager.setProfileWindow(profileWindow) + self._manager.setDefaultColorFromCursorColor(True) + self._manager.setItemType(image=True) + self._manager.setActiveItemTracking(True) + + # Actions + self._browseAction = actions.mode.ZoomModeAction(plot, parent=self) + self._browseAction.setVisible(False) + self.freeLineAction = None + self._createProfileActions() + self._editor = self._manager.createEditorAction(self) + + # ActionGroup + self.actionGroup = qt.QActionGroup(self) + self.actionGroup.addAction(self._browseAction) + self.actionGroup.addAction(self.hLineAction) + self.actionGroup.addAction(self.vLineAction) + self.actionGroup.addAction(self.lineAction) + self.actionGroup.addAction(self._editor) + + modes = MultiModeAction(self) + modes.addAction(self.hLineAction) + modes.addAction(self.vLineAction) + modes.addAction(self.lineAction) + if self.freeLineAction is not None: + modes.addAction(self.freeLineAction) + modes.addAction(self.crossAction) + self.__multiAction = modes + + # Add actions to ToolBar + self.addAction(self._browseAction) + self.addAction(modes) + self.addAction(self._editor) + self.addAction(self.clearAction) + + plot.sigActiveImageChanged.connect(self._activeImageChanged) + self._activeImageChanged() + + def createProfileManager(self, parent, plot): + return _CustomProfileManager(parent, plot) + + def _createProfileActions(self): + self.hLineAction = self._manager.createProfileAction(rois.ProfileImageHorizontalLineROI, self) + self.vLineAction = self._manager.createProfileAction(rois.ProfileImageVerticalLineROI, self) + self.lineAction = self._manager.createProfileAction(rois.ProfileImageLineROI, self) + self.freeLineAction = self._manager.createProfileAction(rois.ProfileImageDirectedLineROI, self) + self.crossAction = self._manager.createProfileAction(rois.ProfileImageCrossROI, self) + self.clearAction = self._manager.createClearAction(self) + + def getPlotWidget(self): + """The :class:`.PlotWidget` associated to the toolbar.""" + return self._plotRef() + + @property + @deprecated(since_version="0.13.0", replacement="getPlotWidget()") + def plot(self): + return self.getPlotWidget() + + def _setRoiActionEnabled(self, itemKind, enabled): + for action in self.__multiAction.getMenu().actions(): + if not isinstance(action, roi_mdl.CreateRoiModeAction): + continue + roiClass = action.getRoiClass() + if issubclass(itemKind, roiClass.ITEM_KIND): + action.setEnabled(enabled) + + def _activeImageChanged(self, previous=None, legend=None): + """Handle active image change to toggle actions""" + if legend is None: + self._setRoiActionEnabled(items.ImageStack, False) + self._setRoiActionEnabled(items.ImageBase, False) + else: + plot = self.getPlotWidget() + image = plot.getActiveImage() + # Disable for empty image + enabled = image.getData(copy=False).size > 0 + self._setRoiActionEnabled(type(image), enabled) + + @property + @deprecated(since_version="0.6.0") + def browseAction(self): + return self._browseAction + + @property + @deprecated(replacement="getProfilePlot", since_version="0.5.0") + def profileWindow(self): + return self.getProfilePlot() + + def getProfileManager(self): + """Return the manager of the profiles. + + :rtype: ProfileManager + """ + return self._manager + + @deprecated(since_version="0.13.0") + def getProfilePlot(self): + """Return plot widget in which the profile curve or the + profile image is plotted. + """ + window = self.getProfileMainWindow() + if window is None: + return None + return window.getCurrentPlotWidget() + + @deprecated(replacement="getProfileManager().getCurrentRoi().getProfileWindow()", since_version="0.13.0") + def getProfileMainWindow(self): + """Return window containing the profile curve widget. + + This can return None if no profile was computed. + """ + roi = self._manager.getCurrentRoi() + if roi is None: + return None + return roi.getProfileWindow() + + @property + @deprecated(since_version="0.13.0") + def overlayColor(self): + """This method does nothing anymore. But could be implemented if needed. + + It was used to set color to use for the ROI. + + If set to None (the default), the overlay color is adapted to the + active image colormap and changes if the active image colormap changes. + """ + pass + + @overlayColor.setter + @deprecated(since_version="0.13.0") + def overlayColor(self, color): + """This method does nothing anymore. But could be implemented if needed. + """ + pass + + def clearProfile(self): + """Remove profile curve and profile area.""" + self._manager.clearProfile() + + @deprecated(since_version="0.13.0") + def updateProfile(self): + """This method does nothing anymore. But could be implemented if needed. + + It was used to update the displayed profile and profile ROI. + + This uses the current active image of the plot and the current ROI. + """ + pass + + @deprecated(replacement="clearProfile()", since_version="0.13.0") + def hideProfileWindow(self): + """Hide profile window. + """ + self.clearProfile() + + @deprecated(since_version="0.13.0") + def setProfileMethod(self, method): + assert method in ('sum', 'mean') + roi = self._manager.getCurrentRoi() + if roi is None: + raise RuntimeError("No profile ROI selected") + roi.setProfileMethod(method) + + @deprecated(since_version="0.13.0") + def getProfileMethod(self): + roi = self._manager.getCurrentRoi() + if roi is None: + raise RuntimeError("No profile ROI selected") + return roi.getProfileMethod() + + @deprecated(since_version="0.13.0") + def getProfileOptionToolAction(self): + return self._editor + + +class Profile3DToolBar(ProfileToolBar): + def __init__(self, parent=None, stackview=None, + title=None): + """QToolBar providing profile tools for an image or a stack of images. + + :param parent: the parent QWidget + :param stackview: :class:`StackView` instance on which to operate. + :param str title: See :class:`QToolBar`. + :param parent: See :class:`QToolBar`. + """ + # TODO: add param profileWindow (specify the plot used for profiles) + super(Profile3DToolBar, self).__init__(parent=parent, + plot=stackview.getPlotWidget()) + + if title is not None: + deprecated_warning("Attribute", + name="title", + reason="removed", + since_version="0.13.0", + only_once=True, + skip_backtrace_count=1) + + self.stackView = stackview + """:class:`StackView` instance""" + + def _createProfileActions(self): + self.hLineAction = self._manager.createProfileAction(rois.ProfileImageStackHorizontalLineROI, self) + self.vLineAction = self._manager.createProfileAction(rois.ProfileImageStackVerticalLineROI, self) + self.lineAction = self._manager.createProfileAction(rois.ProfileImageStackLineROI, self) + self.crossAction = self._manager.createProfileAction(rois.ProfileImageStackCrossROI, self) + self.clearAction = self._manager.createClearAction(self) diff --git a/src/silx/gui/plot/ProfileMainWindow.py b/src/silx/gui/plot/ProfileMainWindow.py new file mode 100644 index 0000000..ce56cfd --- /dev/null +++ b/src/silx/gui/plot/ProfileMainWindow.py @@ -0,0 +1,110 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module contains a QMainWindow class used to display profile plots. +""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "21/02/2017" + +import silx.utils.deprecation +from silx.gui import qt +from .tools.profile.manager import ProfileWindow + +silx.utils.deprecation.deprecated_warning("Module", + name="silx.gui.plot.ProfileMainWindow", + reason="moved", + replacement="silx.gui.plot.tools.profile.manager.ProfileWindow", + since_version="0.13.0", + only_once=True, + skip_backtrace_count=1) + +class ProfileMainWindow(ProfileWindow): + """QMainWindow providing 2 plot widgets specialized in + 1D and 2D plotting, with different toolbars. + + Only one of the plots is visible at any given time. + + :param qt.QWidget parent: The parent of this widget or None (default). + :param Union[str,Class] backend: The backend to use, in: + 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none' + or a :class:`BackendBase.BackendBase` class + """ + + sigProfileDimensionsChanged = qt.Signal(int) + """This signal is emitted when :meth:`setProfileDimensions` is called. + It carries the number of dimensions for the profile data (1 or 2). + It can be used to be notified that the profile plot widget has changed. + + Note: This signal should be removed. + """ + + sigProfileMethodChanged = qt.Signal(str) + """Emitted when the method to compute the profile changed (for now can be + sum or mean) + + Note: This signal should be removed. + """ + + def __init__(self, parent=None, backend=None): + ProfileWindow.__init__(self, parent=parent, backend=backend) + # by default, profile is assumed to be a 1D curve + self._profileType = None + + def setProfileType(self, profileType): + """Set which profile plot widget (1D or 2D) is to be used + + Note: This method should be removed. + + :param str profileType: Type of profile data, + "1D" for a curve or "2D" for an image + """ + self._profileType = profileType + if self._profileType == "1D": + self._showPlot1D() + elif self._profileType == "2D": + self._showPlot2D() + else: + raise ValueError("Profile type must be '1D' or '2D'") + self.sigProfileDimensionsChanged.emit(profileType) + + def getPlot(self): + """Return the profile plot widget which is currently in use. + This can be the 2D profile plot or the 1D profile plot. + + Note: This method should be removed. + """ + return self.getCurrentPlotWidget() + + def setProfileMethod(self, method): + """ + Note: This method should be removed. + + :param str method: method to manage the 'width' in the profile + (computing mean or sum). + """ + assert method in ('sum', 'mean') + self._method = method + self.sigProfileMethodChanged.emit(self._method) diff --git a/src/silx/gui/plot/ROIStatsWidget.py b/src/silx/gui/plot/ROIStatsWidget.py new file mode 100644 index 0000000..32a1395 --- /dev/null +++ b/src/silx/gui/plot/ROIStatsWidget.py @@ -0,0 +1,780 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides widget for displaying statistics relative to a +Region of interest and an item +""" + + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "22/07/2019" + + +from contextlib import contextmanager +from silx.gui import qt +from silx.gui import icons +from silx.gui.plot.StatsWidget import _StatsWidgetBase, StatsTable, _Container +from silx.gui.plot.StatsWidget import UpdateModeWidget, UpdateMode +from silx.gui.widgets.TableWidget import TableWidget +from silx.gui.plot.items.roi import RegionOfInterest +from silx.gui.plot import items as plotitems +from silx.gui.plot.items.core import ItemChangedType +from silx.gui.plot3d import items as plot3ditems +from silx.gui.plot.CurvesROIWidget import ROI +from silx.gui.plot import stats as statsmdl +from collections import OrderedDict +from silx.utils.proxy import docstring +import silx.gui.plot.items.marker +import silx.gui.plot.items.shape +import functools +import logging + +_logger = logging.getLogger(__name__) + + +class _GetROIItemCoupleDialog(qt.QDialog): + """ + Dialog used to know which plot item and which roi he wants + """ + _COMPATIBLE_KINDS = ('curve', 'image', 'scatter', 'histogram') + + def __init__(self, parent=None, plot=None, rois=None): + qt.QDialog.__init__(self, parent=parent) + assert plot is not None + assert rois is not None + self._plot = plot + self._rois = rois + + self.setLayout(qt.QVBoxLayout()) + + # define the selection widget + self._selection_widget = qt.QWidget() + self._selection_widget.setLayout(qt.QHBoxLayout()) + self._kindCB = qt.QComboBox(parent=self) + self._selection_widget.layout().addWidget(self._kindCB) + self._itemCB = qt.QComboBox(parent=self) + self._selection_widget.layout().addWidget(self._itemCB) + self._roiCB = qt.QComboBox(parent=self) + self._selection_widget.layout().addWidget(self._roiCB) + self.layout().addWidget(self._selection_widget) + + # define modal buttons + types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel + self._buttonsModal = qt.QDialogButtonBox(parent=self) + self._buttonsModal.setStandardButtons(types) + self.layout().addWidget(self._buttonsModal) + self._buttonsModal.accepted.connect(self.accept) + self._buttonsModal.rejected.connect(self.reject) + + # connect signal / slot + self._kindCB.currentIndexChanged.connect(self._updateValidItemAndRoi) + + def _getCompatibleRois(self, kind): + """Return compatible rois for the given item kind""" + def is_compatible(roi, kind): + if isinstance(roi, RegionOfInterest): + return kind in ('image', 'scatter') + elif isinstance(roi, ROI): + return kind in ('curve', 'histogram') + else: + raise ValueError('kind not managed') + return list(filter(lambda x: is_compatible(x, kind), self._rois)) + + def exec(self): + self._kindCB.clear() + self._itemCB.clear() + # filter kind without any items + self._valid_kinds = {} + # key is item type, value kinds + self._valid_rois = {} + # key is item type, value rois + self._kind_name_to_roi = {} + # key is (kind, roi name) value is roi + self._kind_name_to_item = {} + # key is (kind, legend name) value is item + for kind in _GetROIItemCoupleDialog._COMPATIBLE_KINDS: + def getItems(kind): + output = [] + for item in self._plot.getItems(): + type_ = self._plot._itemKind(item) + if type_ in kind and item.isVisible(): + output.append(item) + return output + + items = getItems(kind=kind) + rois = self._getCompatibleRois(kind=kind) + if len(items) > 0 and len(rois) > 0: + self._valid_kinds[kind] = items + self._valid_rois[kind] = rois + for roi in rois: + name = roi.getName() + self._kind_name_to_roi[(kind, name)] = roi + for item in items: + self._kind_name_to_item[(kind, item.getLegend())] = item + + # filter roi according to kinds + if len(self._valid_kinds) == 0: + _logger.warning('no couple item/roi detected for displaying stats') + return self.reject() + + for kind in self._valid_kinds: + self._kindCB.addItem(kind) + self._updateValidItemAndRoi() + + return qt.QDialog.exec(self) + + def exec_(self): # Qt5 compatibility + return self.exec() + + def _updateValidItemAndRoi(self, *args, **kwargs): + self._itemCB.clear() + self._roiCB.clear() + kind = self._kindCB.currentText() + for roi in self._valid_rois[kind]: + self._roiCB.addItem(roi.getName()) + for item in self._valid_kinds[kind]: + self._itemCB.addItem(item.getLegend()) + + def getROI(self): + kind = self._kindCB.currentText() + roi_name = self._roiCB.currentText() + return self._kind_name_to_roi[(kind, roi_name)] + + def getItem(self): + kind = self._kindCB.currentText() + item_name = self._itemCB.currentText() + return self._kind_name_to_item[(kind, item_name)] + + +class ROIStatsItemHelper(object): + """Item utils to associate a plot item and a roi + + Display on one row statistics regarding the couple + (Item (plot item) / roi). + + :param Item plot_item: item for which we want statistics + :param Union[ROI,RegionOfInterest]: region of interest to use for + statistics. + """ + def __init__(self, plot_item, roi): + self._plot_item = plot_item + self._roi = roi + + @property + def roi(self): + """roi""" + return self._roi + + def roi_name(self): + if isinstance(self._roi, ROI): + return self._roi.getName() + elif isinstance(self._roi, RegionOfInterest): + return self._roi.getName() + else: + raise TypeError('Unmanaged roi type') + + @property + def roi_kind(self): + """roi class""" + return self._roi.__class__ + + # TODO: should call a util function from the wrapper ? + def item_kind(self): + """item kind""" + if isinstance(self._plot_item, plotitems.Curve): + return 'curve' + elif isinstance(self._plot_item, plotitems.ImageData): + return 'image' + elif isinstance(self._plot_item, plotitems.Scatter): + return 'scatter' + elif isinstance(self._plot_item, plotitems.Histogram): + return 'histogram' + elif isinstance(self._plot_item, (plot3ditems.ImageData, + plot3ditems.ScalarField3D)): + return 'image' + elif isinstance(self._plot_item, (plot3ditems.Scatter2D, + plot3ditems.Scatter3D)): + return 'scatter' + + @property + def item_legend(self): + """legend of the plot Item""" + return self._plot_item.getLegend() + + def id_key(self): + """unique key to represent the couple (item, roi)""" + return (self.item_kind(), self.item_legend, self.roi_kind, + self.roi_name()) + + +class _StatsROITable(_StatsWidgetBase, TableWidget): + """ + Table sued to display some statistics regarding a couple (item/roi) + """ + _LEGEND_HEADER_DATA = 'legend' + + _KIND_HEADER_DATA = 'kind' + + _ROI_HEADER_DATA = 'roi' + + sigUpdateModeChanged = qt.Signal(object) + """Signal emitted when the update mode changed""" + + def __init__(self, parent, plot): + TableWidget.__init__(self, parent) + _StatsWidgetBase.__init__(self, statsOnVisibleData=False, + displayOnlyActItem=False) + self.__region_edition_callback = {} + """We need to keep trace of the roi signals connection because + the roi emits the sigChanged during roi edition""" + self._items = {} + self.setRowCount(0) + self.setColumnCount(3) + + # Init headers + headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title()) + headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA) + self.setHorizontalHeaderItem(0, headerItem) + headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title()) + headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA) + self.setHorizontalHeaderItem(1, headerItem) + headerItem = qt.QTableWidgetItem(self._ROI_HEADER_DATA.title()) + headerItem.setData(qt.Qt.UserRole, self._ROI_HEADER_DATA) + self.setHorizontalHeaderItem(2, headerItem) + + self.setSortingEnabled(True) + self.setPlot(plot) + + self.__plotItemToItems = {} + """Key is plotItem, values is list of __RoiStatsItemWidget""" + self.__roiToItems = {} + """Key is roi, values is list of __RoiStatsItemWidget""" + self.__roisKeyToRoi = {} + + def add(self, item): + assert isinstance(item, ROIStatsItemHelper) + if item.id_key() in self._items: + _logger.warning("Item %s is already present", item.id_key()) + return None + self._items[item.id_key()] = item + self._addItem(item) + return item + + def _addItem(self, item): + """ + Add a _RoiStatsItemWidget item to the table. + + :param item: + :return: True if successfully added. + """ + if not isinstance(item, ROIStatsItemHelper): + # skipped because also receive all new plot item (Marker...) that + # we don't want to manage in this case. + return + # plotItem = item.getItem() + # roi = item.getROI() + kind = item.item_kind() + if kind not in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.info("Item has not a supported type: %s", item) + return False + + # register the roi and the kind + self._registerPlotItem(item) + self._registerROI(item) + + # Prepare table items + tableItems = [ + qt.QTableWidgetItem(), # Legend + qt.QTableWidgetItem(), # Kind + qt.QTableWidgetItem()] # roi + + for column in range(3, self.columnCount()): + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + + formatter = self._statsHandler.formatters[name] + if formatter: + tableItem = formatter.tabWidgetItemClass() + else: + tableItem = qt.QTableWidgetItem() + + tooltip = self._statsHandler.stats[name].getToolTip(kind=kind) + if tooltip is not None: + tableItem.setToolTip(tooltip) + + tableItems.append(tableItem) + + # Disable sorting while adding table items + with self._disableSorting(): + # Add a row to the table + self.setRowCount(self.rowCount() + 1) + + # Add table items to the last row + row = self.rowCount() - 1 + for column, tableItem in enumerate(tableItems): + tableItem.setData(qt.Qt.UserRole, _Container(item)) + tableItem.setFlags( + qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable) + self.setItem(row, column, tableItem) + + # Update table items content + self._updateStats(item, data_changed=True) + + # Listen for item changes + # Using queued connection to avoid issue with sender + # being that of the signal calling the signal + item._plot_item.sigItemChanged.connect(self._plotItemChanged, + qt.Qt.QueuedConnection) + return True + + def _removeAllItems(self): + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + # item = self._tableItemToItem(tableItem) + # item.sigItemChanged.disconnect(self._plotItemChanged) + self.clearContents() + self.setRowCount(0) + + def clear(self): + self._removeAllItems() + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. + + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using + """ + self._removeAllItems() + _StatsWidgetBase.setStats(self, statsHandler) + + self.setRowCount(0) + self.setColumnCount(len(self._statsHandler.stats) + 3) # + legend, kind and roi # noqa + + for index, stat in enumerate(self._statsHandler.stats.values()): + headerItem = qt.QTableWidgetItem(stat.name.capitalize()) + headerItem.setData(qt.Qt.UserRole, stat.name) + if stat.description is not None: + headerItem.setToolTip(stat.description) + self.setHorizontalHeaderItem(3 + index, headerItem) + + horizontalHeader = self.horizontalHeader() + horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents) + + self._updateItemObserve() + + def _updateItemObserve(self, *args): + pass + + def _dataChanged(self, item): + pass + + def _updateStats(self, item, data_changed=False, roi_changed=False): + assert isinstance(item, ROIStatsItemHelper) + plotItem = item._plot_item + roi = item._roi + if item is None: + return + plot = self.getPlot() + if plot is None: + _logger.info("Plot not available") + return + + row = self._itemToRow(item) + if row is None: + _logger.error("This item is not in the table: %s", str(item)) + return + + statsHandler = self.getStatsHandler() + if statsHandler is not None: + stats = statsHandler.calculate(plotItem, plot, + onlimits=self._statsOnVisibleData, + roi=roi, data_changed=data_changed, + roi_changed=roi_changed) + else: + stats = {} + + with self._disableSorting(): + for name, tableItem in self._itemToTableItems(item).items(): + if name == self._LEGEND_HEADER_DATA: + text = self._plotWrapper.getLabel(plotItem) + tableItem.setText(text) + elif name == self._KIND_HEADER_DATA: + tableItem.setText(self._plotWrapper.getKind(plotItem)) + elif name == self._ROI_HEADER_DATA: + name = roi.getName() + tableItem.setText(name) + else: + value = stats.get(name) + if value is None: + _logger.error("Value not found for: %s", name) + tableItem.setText('-') + else: + tableItem.setText(str(value)) + + @contextmanager + def _disableSorting(self): + """Context manager that disables table sorting + + Previous state is restored when leaving + """ + sorting = self.isSortingEnabled() + if sorting: + self.setSortingEnabled(False) + yield + if sorting: + self.setSortingEnabled(sorting) + + def _itemToRow(self, item): + """Find the row corresponding to a plot item + + :param item: The plot item + :return: The corresponding row index + :rtype: Union[int,None] + """ + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + if self._tableItemToItem(tableItem) == item: + return row + return None + + def _tableItemToItem(self, tableItem): + """Find the plot item corresponding to a table item + + :param QTableWidgetItem tableItem: + :rtype: QObject + """ + container = tableItem.data(qt.Qt.UserRole) + return container() + + def _itemToTableItems(self, item): + """Find all table items corresponding to a plot item + + :param item: The plot item + :return: An ordered dict of column name to QTableWidgetItem mapping + for the given plot item. + :rtype: OrderedDict + """ + result = OrderedDict() + row = self._itemToRow(item) + if row is not None: + for column in range(self.columnCount()): + tableItem = self.item(row, column) + if self._tableItemToItem(tableItem) != item: + _logger.error("Table item/plot item mismatch") + else: + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + result[name] = tableItem + return result + + def _plotItemToItems(self, plotItem): + """Return all _RoiStatsItemWidget associated to the plotItem + Needed for updating on itemChanged signal + """ + if plotItem in self.__plotItemToItems: + return [] + else: + return self.__plotItemToItems[plotItem] + + def _registerPlotItem(self, item): + if item._plot_item not in self.__plotItemToItems: + self.__plotItemToItems[item._plot_item] = set() + self.__plotItemToItems[item._plot_item].add(item) + + def _roiToItems(self, roi): + """Return all _RoiStatsItemWidget associated to the roi + Needed for updating on roiChanged signal + """ + if roi in self.__roiToItems: + return [] + else: + return self.__roiToItems[roi] + + def _registerROI(self, item): + if item._roi not in self.__roiToItems: + self.__roiToItems[item._roi] = set() + # TODO: normalize also sig name + if isinstance(item._roi, RegionOfInterest): + # item connection within sigRegionChanged should only be + # stopped during the region edition + self.__region_edition_callback[item._roi] = functools.partial( + self._updateAllStats, False, True) + item._roi.sigRegionChanged.connect(self.__region_edition_callback[item._roi]) + item._roi.sigEditingStarted.connect(functools.partial( + self._startFiltering, item._roi)) + item._roi.sigEditingFinished.connect(functools.partial( + self._endFiltering, item._roi)) + else: + item._roi.sigChanged.connect(functools.partial( + self._updateAllStats, False, True)) + self.__roiToItems[item._roi].add(item) + + def _startFiltering(self, roi): + roi.sigRegionChanged.disconnect(self.__region_edition_callback[roi]) + + def _endFiltering(self, roi): + roi.sigRegionChanged.connect(self.__region_edition_callback[roi]) + self._updateAllStats(roi_changed=True) + + def unregisterROI(self, roi): + if roi in self.__roiToItems: + del self.__roiToItems[roi] + if isinstance(roi, RegionOfInterest): + roi.sigRegionEditionStarted.disconnect(functools.partial( + self._startFiltering, roi)) + roi.sigRegionEditionFinished.disconnect(functools.partial( + self._startFiltering, roi)) + try: + roi.sigRegionChanged.disconnect(self._updateAllStats) + except: + pass + else: + roi.sigChanged.disconnect(self._updateAllStats) + + def _plotItemChanged(self, event): + """Handle modifications of the items. + + :param event: + """ + if event is ItemChangedType.DATA: + if self.getUpdateMode() is UpdateMode.MANUAL: + return + if self._skipPlotItemChangedEvent(event) is True: + return + else: + sender = self.sender() + for item in self.__plotItemToItems[sender]: + # TODO: get all concerned items + self._updateStats(item, data_changed=True) + # deal with stat items visibility + if event is ItemChangedType.VISIBLE: + if len(self._itemToTableItems(item).items()) > 0: + item_0 = list(self._itemToTableItems(item).values())[0] + row_index = item_0.row() + self.setRowHidden(row_index, not item.isVisible()) + + def _removeItem(self, itemKey): + if isinstance(itemKey, (silx.gui.plot.items.marker.Marker, + silx.gui.plot.items.shape.Shape)): + return + if itemKey not in self._items: + _logger.warning('key not recognized. Won\'t remove any item') + return + item = self._items[itemKey] + row = self._itemToRow(item) + if row is None: + kind = self._plotWrapper.getKind(item) + if kind in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.error("Removing item that is not in table: %s", str(item)) + return + item._plot_item.sigItemChanged.disconnect(self._plotItemChanged) + self.removeRow(row) + del self._items[itemKey] + + def _updateAllStats(self, is_request=False, roi_changed=False): + """Update stats for all rows in the table + + :param bool is_request: True if come from a manual request + """ + if (self.getUpdateMode() is UpdateMode.MANUAL and + not is_request and not roi_changed): + return + + with self._disableSorting(): + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + item = self._tableItemToItem(tableItem) + self._updateStats(item, roi_changed=roi_changed, + data_changed=is_request) + + def _plotCurrentChanged(self, *args): + pass + + def _getRoi(self, kind, name): + """return the roi fitting the requirement kind, name. This information + is enough to be sure it is unique (in the widget)""" + for roi in self.__roiToItems: + roiName = roi.getName() + if isinstance(roi, kind) and name == roiName: + return roi + return None + + def _getPlotItem(self, kind, legend): + """return the plotItem fitting the requirement kind, legend. + This information is enough to be sure it is unique (in the widget)""" + for plotItem in self.__plotItemToItems: + if legend == plotItem.getLegend() and self._plotWrapper.getKind(plotItem) == kind: + return plotItem + return None + + +class ROIStatsWidget(qt.QMainWindow): + """ + Widget used to define stats item for a couple(roi, plotItem). + Stats will be computing on a given item (curve, image...) in the given + region of interest. + + It also provide an interface for adding and removing items. + + .. snapshotqt:: img/ROIStatsWidget.png + :width: 300px + :align: center + + from silx.gui import qt + from silx.gui.plot import Plot2D + from silx.gui.plot.ROIStatsWidget import ROIStatsWidget + from silx.gui.plot.items.roi import RectangleROI + import numpy + plot = Plot2D() + plot.addImage(numpy.arange(10000).reshape(100, 100), legend='img') + plot.show() + rectangleROI = RectangleROI() + rectangleROI.setGeometry(origin=(0, 100), size=(20, 20)) + rectangleROI.setName('Initial ROI') + widget = ROIStatsWidget(plot=plot) + widget.setStats([('sum', numpy.sum), ('mean', numpy.mean)]) + widget.registerROI(rectangleROI) + widget.addItem(roi=rectangleROI, plotItem=plot.getImage('img')) + widget.show() + + :param Union[qt.QWidget,None] parent: parent qWidget + :param PlotWindow plot: plot widget containing the items + :param stats: stats to display + :param tuple rois: tuple of rois to manage + """ + + def __init__(self, parent=None, plot=None, stats=None, rois=None): + qt.QMainWindow.__init__(self, parent) + + toolbar = qt.QToolBar(self) + icon = icons.getQIcon('add') + self._rois = list(rois) if rois is not None else [] + self._addAction = qt.QAction(icon, 'add item/roi', toolbar) + self._addAction.triggered.connect(self._addRoiStatsItem) + icon = icons.getQIcon('rm') + self._removeAction = qt.QAction(icon, 'remove item/roi', toolbar) + self._removeAction.triggered.connect(self._removeCurrentRow) + + toolbar.addAction(self._addAction) + toolbar.addAction(self._removeAction) + self.addToolBar(toolbar) + + self._plot = plot + self._statsROITable = _StatsROITable(parent=self, plot=self._plot) + self.setStats(stats=stats) + self.setCentralWidget(self._statsROITable) + self.setWindowFlags(qt.Qt.Widget) + + # expose API + self._setUpdateMode = self._statsROITable.setUpdateMode + self._updateAllStats = self._statsROITable._updateAllStats + + # setup + self._statsROITable.setSelectionBehavior(qt.QTableWidget.SelectRows) + + def registerROI(self, roi): + """For now there is no direct link between roi and plot. That is why + we need to add/register them to be able to associate them""" + self._rois.append(roi) + + def setPlot(self, plot): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + self._plot = plot + + def getPlot(self): + return self._plot + + @docstring(_StatsROITable) + def setStats(self, stats): + if stats is not None: + self._statsROITable.setStats(statsHandler=stats) + + @docstring(_StatsROITable) + def getStatsHandler(self): + """ + + :return: + """ + return self._statsROITable.getStatsHandler() + + def _addRoiStatsItem(self): + """Ask the user what couple ROI / item he want to display""" + dialog = _GetROIItemCoupleDialog(parent=self, plot=self._plot, + rois=self._rois) + if dialog.exec(): + self.addItem(roi=dialog.getROI(), plotItem=dialog.getItem()) + + def addItem(self, plotItem, roi): + """ + Add a row of statitstic regarding the couple (plotItem, roi) + + :param Item plotItem: item to use for statistics + :param roi: region of interest to limit the statistic. + :type: Union[ROI, RegionOfInterest] + :return: None of failed to add the item + :rtype: Union[None,ROIStatsItemHelper] + """ + statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem) + return self._statsROITable.add(item=statsItem) + + def removeItem(self, plotItem, roi): + """ + Remove the row associated to the couple (plotItem, roi) + + :param Item plotItem: item to use for statistics + :param roi: region of interest to limit the statistic. + :type: Union[ROI,RegionOfInterest] + """ + statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem) + self._statsROITable._removeItem(itemKey=statsItem.id_key()) + + def _removeCurrentRow(self): + def is1DKind(kind): + if kind in ('curve', 'histogram', 'scatter'): + return True + else: + return False + + currentRow = self._statsROITable.currentRow() + item_kind = self._statsROITable.item(currentRow, 1).text() + item_legend = self._statsROITable.item(currentRow, 0).text() + + roi_name = self._statsROITable.item(currentRow, 2).text() + roi_kind = ROI if is1DKind(item_kind) else RegionOfInterest + roi = self._statsROITable._getRoi(kind=roi_kind, name=roi_name) + if roi is None: + _logger.warning('failed to retrieve the roi you want to remove') + return False + plot_item = self._statsROITable._getPlotItem(kind=item_kind, + legend=item_legend) + if plot_item is None: + _logger.warning('failed to retrieve the plot item you want to' + 'remove') + return False + return self.removeItem(plotItem=plot_item, roi=roi) diff --git a/src/silx/gui/plot/ScatterMaskToolsWidget.py b/src/silx/gui/plot/ScatterMaskToolsWidget.py new file mode 100644 index 0000000..c242dfc --- /dev/null +++ b/src/silx/gui/plot/ScatterMaskToolsWidget.py @@ -0,0 +1,621 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Widget providing a set of tools to draw masks on a PlotWidget. + +This widget is meant to work with a modified :class:`silx.gui.plot.PlotWidget` + +- :class:`ScatterMask`: Handle scatter mask update and history +- :class:`ScatterMaskToolsWidget`: GUI for :class:`ScatterMask` +- :class:`ScatterMaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow` +""" + +from __future__ import division + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "15/02/2019" + + +import math +import logging +import os +import numpy +import sys + +from .. import qt +from ...math.combo import min_max +from ...image import shapes + +from .items import ItemChangedType, Scatter +from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget +from ..colors import cursorColorForColormap, rgba + + +_logger = logging.getLogger(__name__) + + +class ScatterMask(BaseMask): + """A 1D mask for scatter data. + """ + def __init__(self, scatter=None): + """ + + :param scatter: :class:`silx.gui.plot.items.Scatter` instance + """ + BaseMask.__init__(self, scatter) + + def _getXY(self): + x = self._dataItem.getXData(copy=False) + y = self._dataItem.getYData(copy=False) + return x, y + + def getDataValues(self): + """Return scatter data values as a 1D array. + + :rtype: 1D numpy.ndarray + """ + return self._dataItem.getValueData(copy=False) + + def save(self, filename, kind): + if kind == 'npy': + try: + numpy.save(filename, self.getMask(copy=False)) + except IOError: + raise RuntimeError("Mask file can't be written") + elif kind in ["csv", "txt"]: + try: + numpy.savetxt(filename, self.getMask(copy=False)) + except IOError: + raise RuntimeError("Mask file can't be written") + + def updatePoints(self, level, indices, mask=True): + """Mask/Unmask points with given indices. + + :param int level: Mask level to update. + :param indices: Sequence or 1D array of indices of points to be + updated + :param bool mask: True to mask (default), False to unmask. + """ + if mask: + self._mask[indices] = level + else: + # unmask only where mask level is the specified value + indices_stencil = numpy.zeros_like(self._mask, dtype=bool) + indices_stencil[indices] = True + self._mask[numpy.logical_and(self._mask == level, indices_stencil)] = 0 + self._notify() + + # update shapes + def updatePolygon(self, level, vertices, mask=True): + """Mask/Unmask a polygon of the given mask level. + + :param int level: Mask level to update. + :param vertices: Nx2 array of polygon corners as (y, x) or (row, col) + :param bool mask: True to mask (default), False to unmask. + """ + polygon = shapes.Polygon(vertices) + x, y = self._getXY() + + # TODO: this could be optimized if necessary + indices_in_polygon = [idx for idx in range(len(x)) if + polygon.is_inside(y[idx], x[idx])] + + self.updatePoints(level, indices_in_polygon, mask) + + def updateRectangle(self, level, y, x, height, width, mask=True): + """Mask/Unmask data inside a rectangle + + :param int level: Mask level to update. + :param float y: Y coordinate of bottom left corner of the rectangle + :param float x: X coordinate of bottom left corner of the rectangle + :param float height: + :param float width: + :param bool mask: True to mask (default), False to unmask. + """ + vertices = [(y, x), + (y + height, x), + (y + height, x + width), + (y, x + width)] + self.updatePolygon(level, vertices, mask) + + def updateDisk(self, level, cy, cx, radius, mask=True): + """Mask/Unmask a disk of the given mask level. + + :param int level: Mask level to update. + :param float cy: Disk center (y). + :param float cx: Disk center (x). + :param float radius: Radius of the disk in mask array unit + :param bool mask: True to mask (default), False to unmask. + """ + x, y = self._getXY() + stencil = (y - cy)**2 + (x - cx)**2 < radius**2 + self.updateStencil(level, stencil, mask) + + def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True): + """Mask/Unmask an ellipse of the given mask level. + + :param int level: Mask level to update. + :param int crow: Row of the center of the ellipse + :param int ccol: Column of the center of the ellipse + :param float radius_r: Radius of the ellipse in the row + :param float radius_c: Radius of the ellipse in the column + :param bool mask: True to mask (default), False to unmask. + """ + def is_inside(px, py): + return (px - ccol)**2 / radius_c**2 + (py - crow)**2 / radius_r**2 <= 1.0 + x, y = self._getXY() + indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])] + self.updatePoints(level, indices_inside, mask) + + def updateLine(self, level, y0, x0, y1, x1, width, mask=True): + """Mask/Unmask points inside a rectangle defined by a line (two + end points) and a width. + + :param int level: Mask level to update. + :param float y0: Row of the starting point. + :param float x0: Column of the starting point. + :param float row1: Row of the end point. + :param float col1: Column of the end point. + :param float width: Width of the line. + :param bool mask: True to mask (default), False to unmask. + """ + # theta is the angle between the horizontal and the line + theta = math.atan((y1 - y0) / (x1 - x0)) if x1 - x0 else 0 + w_over_2_sin_theta = width / 2. * math.sin(theta) + w_over_2_cos_theta = width / 2. * math.cos(theta) + + vertices = [(y0 - w_over_2_cos_theta, x0 + w_over_2_sin_theta), + (y0 + w_over_2_cos_theta, x0 - w_over_2_sin_theta), + (y1 + w_over_2_cos_theta, x1 - w_over_2_sin_theta), + (y1 - w_over_2_cos_theta, x1 + w_over_2_sin_theta)] + + self.updatePolygon(level, vertices, mask) + + +class ScatterMaskToolsWidget(BaseMaskToolsWidget): + """Widget with tools for masking data points on a scatter in a + :class:`PlotWidget`.""" + + def __init__(self, parent=None, plot=None): + super(ScatterMaskToolsWidget, self).__init__(parent, plot, + mask=ScatterMask()) + self._z = 2 # Mask layer in plot + self._data_scatter = None + """plot Scatter item for data""" + + self._data_extent = None + """Maximum extent of the data i.e., max(xMax-xMin, yMax-yMin)""" + + self._mask_scatter = None + """plot Scatter item for representing the mask""" + + def setSelectionMask(self, mask, copy=True): + """Set the mask to a new array. + + :param numpy.ndarray mask: + The array to use for the mask or None to reset the mask. + :type mask: numpy.ndarray of uint8, C-contiguous. + Array of other types are converted. + :param bool copy: True (the default) to copy the array, + False to use it as is if possible. + :return: None if failed, shape of mask as 1-tuple if successful. + The mask can be cropped or padded to fit active scatter, + the returned shape is that of the scatter data. + """ + if self._data_scatter is None: + # this can happen if the mask tools widget has never been shown + self._data_scatter = self.plot._getActiveItem(kind="scatter") + if self._data_scatter is None: + return None + self._adjustColorAndBrushSize(self._data_scatter) + + if mask is None: + self.resetSelectionMask() + return self._data_scatter.getXData(copy=False).shape + + mask = numpy.array(mask, copy=False, dtype=numpy.uint8) + + if self._data_scatter.getXData(copy=False).shape == (0,) \ + or mask.shape == self._data_scatter.getXData(copy=False).shape: + self._mask.setMask(mask, copy=copy) + self._mask.commit() + return mask.shape + else: + raise ValueError("Mask does not have the same shape as the data") + + # Handle mask refresh on the plot + + def _updatePlotMask(self): + """Update mask image in plot""" + mask = self.getSelectionMask(copy=False) + if mask is not None: + self.plot.addScatter(self._data_scatter.getXData(), + self._data_scatter.getYData(), + mask, + legend=self._maskName, + colormap=self._colormap, + z=self._z) + self._mask_scatter = self.plot._getItem(kind="scatter", + legend=self._maskName) + self._mask_scatter.setSymbolSize( + self._data_scatter.getSymbolSize() + 2.0) + self._mask_scatter.sigItemChanged.connect(self.__maskScatterChanged) + elif self.plot._getItem(kind="scatter", + legend=self._maskName) is not None: + self.plot.remove(self._maskName, kind='scatter') + + def __maskScatterChanged(self, event): + """Handles update of mask scatter""" + if (event is ItemChangedType.VISUALIZATION_MODE and + self._mask_scatter is not None): + self._mask_scatter.setVisualization(Scatter.Visualization.POINTS) + + # track widget visibility and plot active image changes + + def showEvent(self, event): + try: + self.plot.sigActiveScatterChanged.disconnect( + self._activeScatterChangedAfterCare) + except (RuntimeError, TypeError): + pass + self._activeScatterChanged(None, None) # Init mask + enable/disable widget + self.plot.sigActiveScatterChanged.connect(self._activeScatterChanged) + + def hideEvent(self, event): + try: + # if the method is not connected this raises a TypeError and there is no way + # to know the connected slots + self.plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged) + except (RuntimeError, TypeError): + _logger.info(sys.exc_info()[1]) + if not self.browseAction.isChecked(): + self.browseAction.trigger() # Disable drawing tool + + if self.getSelectionMask(copy=False) is not None: + self.plot.sigActiveScatterChanged.connect( + self._activeScatterChangedAfterCare) + + def _adjustColorAndBrushSize(self, activeScatter): + colormap = activeScatter.getColormap() + self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name'])) + self._setMaskColors(self.levelSpinBox.value(), + self.transparencySlider.value() / + self.transparencySlider.maximum()) + self._z = activeScatter.getZValue() + 1 + self._data_scatter = activeScatter + + # Adjust brush size to data range + xData = self._data_scatter.getXData(copy=False) + yData = self._data_scatter.getYData(copy=False) + # Adjust brush size to data range + if xData.size > 0 and yData.size > 0: + xMin, xMax = min_max(xData) + yMin, yMax = min_max(yData) + self._data_extent = max(xMax - xMin, yMax - yMin) + else: + self._data_extent = None + + def _activeScatterChangedAfterCare(self, previous, next): + """Check synchro of active scatter and mask when mask widget is hidden. + + If active image has no more the same size as the mask, the mask is + removed, otherwise it is adjusted to z. + """ + # check that content changed was the active scatter + activeScatter = self.plot._getActiveItem(kind="scatter") + + if activeScatter is None or activeScatter.getName() == self._maskName: + # No active scatter or active scatter is the mask... + self.plot.sigActiveScatterChanged.disconnect( + self._activeScatterChangedAfterCare) + self._data_extent = None + self._data_scatter = None + + else: + self._adjustColorAndBrushSize(activeScatter) + + if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape: + # scatter has not the same size, remove mask and stop listening + if self.plot._getItem(kind="scatter", legend=self._maskName): + self.plot.remove(self._maskName, kind='scatter') + + self.plot.sigActiveScatterChanged.disconnect( + self._activeScatterChangedAfterCare) + self._data_extent = None + self._data_scatter = None + + else: + # Refresh in case z changed + self._mask.setDataItem(self._data_scatter) + self._updatePlotMask() + + def _activeScatterChanged(self, previous, next): + """Update widget and mask according to active scatter changes""" + activeScatter = self.plot._getActiveItem(kind="scatter") + + if activeScatter is None or activeScatter.getName() == self._maskName: + # No active scatter or active scatter is the mask... + self.setEnabled(False) + + self._data_scatter = None + self._data_extent = None + self._mask.reset() + self._mask.commit() + + else: # There is an active scatter + self.setEnabled(True) + self._adjustColorAndBrushSize(activeScatter) + + self._mask.setDataItem(self._data_scatter) + if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape: + self._mask.reset(self._data_scatter.getXData(copy=False).shape) + self._mask.commit() + else: + # Refresh in case z changed + self._updatePlotMask() + + self._updateInteractiveMode() + + # Handle whole mask operations + + def load(self, filename): + """Load a mask from an image file. + + :param str filename: File name from which to load the mask + :raise Exception: An exception in case of failure + :raise RuntimeWarning: In case the mask was applied but with some + import changes to notice + """ + _, extension = os.path.splitext(filename) + extension = extension.lower()[1:] + if extension == "npy": + try: + mask = numpy.load(filename) + except IOError: + _logger.error("Can't load filename '%s'", filename) + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError('File "%s" is not a numpy file.', + filename) + elif extension in ["txt", "csv"]: + try: + mask = numpy.loadtxt(filename) + except IOError: + _logger.error("Can't load filename '%s'", filename) + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError('File "%s" is not a numpy txt file.', + filename) + else: + msg = "Extension '%s' is not supported." + raise RuntimeError(msg % extension) + + self.setSelectionMask(mask, copy=False) + + def _loadMask(self): + """Open load mask dialog""" + dialog = qt.QFileDialog(self) + dialog.setWindowTitle("Load Mask") + dialog.setModal(1) + filters = [ + 'NumPy binary file (*.npy)', + 'CSV text file (*.csv)', + ] + dialog.setNameFilters(filters) + dialog.setFileMode(qt.QFileDialog.ExistingFile) + dialog.setDirectory(self.maskFileDir) + if not dialog.exec(): + dialog.close() + return + + filename = dialog.selectedFiles()[0] + dialog.close() + + # Update the directory according to the user selection + self.maskFileDir = os.path.dirname(filename) + + try: + self.load(filename) + # except RuntimeWarning as e: + # message = e.args[0] + # msg = qt.QMessageBox(self) + # msg.setIcon(qt.QMessageBox.Warning) + # msg.setText("Mask loaded but an operation was applied.\n" + message) + # msg.exec() + except Exception as e: + message = e.args[0] + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Cannot load mask from file. " + message) + msg.exec() + + def _saveMask(self): + """Open Save mask dialog""" + dialog = qt.QFileDialog(self) + dialog.setWindowTitle("Save Mask") + dialog.setModal(1) + filters = [ + 'NumPy binary file (*.npy)', + 'CSV text file (*.csv)', + ] + dialog.setNameFilters(filters) + dialog.setFileMode(qt.QFileDialog.AnyFile) + dialog.setAcceptMode(qt.QFileDialog.AcceptSave) + dialog.setDirectory(self.maskFileDir) + if not dialog.exec(): + dialog.close() + return + + # convert filter name to extension name with the . + extension = dialog.selectedNameFilter().split()[-1][2:-1] + filename = dialog.selectedFiles()[0] + dialog.close() + + if not filename.lower().endswith(extension): + filename += extension + + if os.path.exists(filename): + try: + os.remove(filename) + except IOError as e: + msg = qt.QMessageBox(self) + msg.setWindowTitle("Removing existing file") + msg.setIcon(qt.QMessageBox.Critical) + + if hasattr(e, "strerror"): + strerror = e.strerror + else: + strerror = sys.exc_info()[1] + msg.setText("Cannot save.\n" + "Input Output Error: %s" % strerror) + msg.exec() + return + + # Update the directory according to the user selection + self.maskFileDir = os.path.dirname(filename) + + try: + self.save(filename, extension[1:]) + except Exception as e: + msg = qt.QMessageBox(self) + msg.setWindowTitle("Saving mask file") + msg.setIcon(qt.QMessageBox.Critical) + + if hasattr(e, "strerror"): + strerror = e.strerror + else: + strerror = sys.exc_info()[1] + msg.setText("Cannot save file %s\n%s" % (filename, strerror)) + msg.exec() + + def resetSelectionMask(self): + """Reset the mask""" + self._mask.reset( + shape=self._data_scatter.getXData(copy=False).shape) + self._mask.commit() + + def _getPencilWidth(self): + """Returns the width of the pencil to use in data coordinates` + + :rtype: float + """ + width = super(ScatterMaskToolsWidget, self)._getPencilWidth() + if self._data_extent is not None: + width *= 0.01 * self._data_extent + return width + + def _plotDrawEvent(self, event): + """Handle draw events from the plot""" + if (self._drawingMode is None or + event['event'] not in ('drawingProgress', 'drawingFinished')): + return + + if not len(self._data_scatter.getXData(copy=False)): + return + + level = self.levelSpinBox.value() + + if self._drawingMode == 'rectangle': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + + self._mask.updateRectangle( + level, + y=event['y'], + x=event['x'], + height=abs(event['height']), + width=abs(event['width']), + mask=doMask) + self._mask.commit() + + elif self._drawingMode == 'ellipse': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + center = event['points'][0] + size = event['points'][1] + self._mask.updateEllipse(level, center[1], center[0], + size[1], size[0], doMask) + self._mask.commit() + + elif self._drawingMode == 'polygon': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + vertices = event['points'] + vertices = vertices[:, (1, 0)] # (y, x) + self._mask.updatePolygon(level, vertices, doMask) + self._mask.commit() + + elif self._drawingMode == 'pencil': + doMask = self._isMasking() + # convert from plot to array coords + x, y = event['points'][-1] + + brushSize = self._getPencilWidth() + + if self._lastPencilPos != (y, x): + if self._lastPencilPos is not None: + # Draw the line + self._mask.updateLine( + level, + self._lastPencilPos[0], self._lastPencilPos[1], + y, x, + brushSize, + doMask) + + # Draw the very first, or last point + self._mask.updateDisk(level, y, x, brushSize / 2., doMask) + + if event['event'] == 'drawingFinished': + self._mask.commit() + self._lastPencilPos = None + else: + self._lastPencilPos = y, x + else: + _logger.error("Drawing mode %s unsupported", self._drawingMode) + + def _loadRangeFromColormapTriggered(self): + """Set range from active scatter colormap range""" + if self._data_scatter is not None: + # Update thresholds according to colormap + colormap = self._data_scatter.getColormap() + if colormap['autoscale']: + min_ = numpy.nanmin(self._data_scatter.getValueData(copy=False)) + max_ = numpy.nanmax(self._data_scatter.getValueData(copy=False)) + else: + min_, max_ = colormap['vmin'], colormap['vmax'] + self.minLineEdit.setText(str(min_)) + self.maxLineEdit.setText(str(max_)) + + +class ScatterMaskToolsDockWidget(BaseMaskToolsDockWidget): + """:class:`ScatterMaskToolsWidget` embedded in a QDockWidget. + + For integration in a :class:`PlotWindow`. + + :param parent: See :class:`QDockWidget` + :param plot: The PlotWidget this widget is operating on + :paran str name: The title of this widget + """ + def __init__(self, parent=None, plot=None, name='Mask'): + widget = ScatterMaskToolsWidget(plot=plot) + super(ScatterMaskToolsDockWidget, self).__init__(parent, name, widget) diff --git a/src/silx/gui/plot/ScatterView.py b/src/silx/gui/plot/ScatterView.py new file mode 100644 index 0000000..d3fd2e0 --- /dev/null +++ b/src/silx/gui/plot/ScatterView.py @@ -0,0 +1,404 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A widget dedicated to display scatter plots + +It is based on a :class:`~silx.gui.plot.PlotWidget` with additional tools +for scatter plots. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "14/06/2018" + + +import logging +import weakref + +import numpy + +from . import items +from . import PlotWidget +from . import tools +from .actions import histogram as actions_histogram +from .tools.profile import ScatterProfileToolBar +from .ColorBar import ColorBarWidget +from .ScatterMaskToolsWidget import ScatterMaskToolsWidget + +from ..widgets.BoxLayoutDockWidget import BoxLayoutDockWidget +from .. import qt, icons +from ...utils.proxy import docstring +from ...utils.weakref import WeakMethodProxy + + +_logger = logging.getLogger(__name__) + + +class ScatterView(qt.QMainWindow): + """Main window with a PlotWidget and tools specific for scatter plots. + + :param parent: The parent of this widget + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`~silx.gui.plot.PlotWidget` for the list of supported backend. + :type backend: Union[str,~silx.gui.plot.backends.BackendBase.BackendBase] + """ + + _SCATTER_LEGEND = ' ' + """Legend used for the scatter item""" + + def __init__(self, parent=None, backend=None): + super(ScatterView, self).__init__(parent=parent) + if parent is not None: + # behave as a widget + self.setWindowFlags(qt.Qt.Widget) + else: + self.setWindowTitle('ScatterView') + + # Create plot widget + plot = PlotWidget(parent=self, backend=backend) + self._plot = weakref.ref(plot) + + # Add an empty scatter + self.__createEmptyScatter() + + # Create colorbar widget with white background + self._colorbar = ColorBarWidget(parent=self, plot=plot) + self._colorbar.setAutoFillBackground(True) + palette = self._colorbar.palette() + palette.setColor(qt.QPalette.Window, qt.Qt.white) + self._colorbar.setPalette(palette) + + # Create PositionInfo widget + self.__lastPickingPos = None + self.__pickingCache = None + self._positionInfo = tools.PositionInfo( + plot=plot, + converters=(('X', WeakMethodProxy(self._getPickedX)), + ('Y', WeakMethodProxy(self._getPickedY)), + ('Data', WeakMethodProxy(self._getPickedValue)), + ('Index', WeakMethodProxy(self._getPickedIndex)))) + + # Combine plot, position info and colorbar into central widget + gridLayout = qt.QGridLayout() + gridLayout.setSpacing(0) + gridLayout.setContentsMargins(0, 0, 0, 0) + gridLayout.addWidget(plot, 0, 0) + gridLayout.addWidget(self._colorbar, 0, 1) + gridLayout.addWidget(self._positionInfo, 1, 0, 1, -1) + gridLayout.setRowStretch(0, 1) + gridLayout.setColumnStretch(0, 1) + centralWidget = qt.QWidget(self) + centralWidget.setLayout(gridLayout) + self.setCentralWidget(centralWidget) + + # Create mask tool dock widget + self._maskToolsWidget = ScatterMaskToolsWidget(parent=self, plot=plot) + self._maskDock = BoxLayoutDockWidget() + self._maskDock.setWindowTitle('Scatter Mask') + self._maskDock.setWidget(self._maskToolsWidget) + self._maskDock.setVisible(False) + self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._maskDock) + + self._maskAction = self._maskDock.toggleViewAction() + self._maskAction.setIcon(icons.getQIcon('image-mask')) + self._maskAction.setToolTip("Display/hide mask tools") + + self._intensityHistoAction = actions_histogram.PixelIntensitiesHistoAction(plot=plot, parent=self) + + # Create toolbars + self._interactiveModeToolBar = tools.InteractiveModeToolBar( + parent=self, plot=plot) + + self._scatterToolBar = tools.ScatterToolBar( + parent=self, plot=plot) + self._scatterToolBar.addAction(self._maskAction) + self._scatterToolBar.addAction(self._intensityHistoAction) + + self._profileToolBar = ScatterProfileToolBar(parent=self, plot=plot) + + self._outputToolBar = tools.OutputToolBar(parent=self, plot=plot) + + # Activate shortcuts in PlotWindow widget: + for toolbar in (self._interactiveModeToolBar, + self._scatterToolBar, + self._profileToolBar, + self._outputToolBar): + self.addToolBar(toolbar) + for action in toolbar.actions(): + self.addAction(action) + + + def __createEmptyScatter(self): + """Create an empty scatter item that is used to display the data + + :rtype: ~silx.gui.plot.items.Scatter + """ + plot = self.getPlotWidget() + plot.addScatter(x=(), y=(), value=(), legend=self._SCATTER_LEGEND) + scatter = plot._getItem( + kind='scatter', legend=self._SCATTER_LEGEND) + # Profile is not selectable, + # so it does not interfere with profile interaction + scatter._setSelectable(False) + return scatter + + def _pickScatterData(self, x, y): + """Get data and index and value of top most scatter plot at position (x, y) + + :param float x: X position in plot coordinates + :param float y: Y position in plot coordinates + :return: The data index and value at that point or None + """ + pickingPos = x, y + if self.__lastPickingPos != pickingPos: + self.__pickingCache = None + self.__lastPickingPos = pickingPos + + plot = self.getPlotWidget() + if plot is not None: + pixelPos = plot.dataToPixel(x, y) + if pixelPos is not None: + # Start from top-most item + result = plot._pickTopMost( + pixelPos[0], pixelPos[1], + lambda item: isinstance(item, items.Scatter)) + if result is not None: + item = result.getItem() + if item.getVisualization() is items.Scatter.Visualization.BINNED_STATISTIC: + # Get highest index of closest points + selected = result.getIndices(copy=False)[::-1] + dataIndex = selected[numpy.argmin( + (item.getXData(copy=False)[selected] - x)**2 + + (item.getYData(copy=False)[selected] - y)**2)] + else: + # Get last index + # with matplotlib it should be the top-most point + dataIndex = result.getIndices(copy=False)[-1] + self.__pickingCache = ( + dataIndex, + item.getXData(copy=False)[dataIndex], + item.getYData(copy=False)[dataIndex], + item.getValueData(copy=False)[dataIndex]) + + return self.__pickingCache + + def _getPickedIndex(self, x, y): + """Get data index of top most scatter plot at position (x, y) + + :param float x: X position in plot coordinates + :param float y: Y position in plot coordinates + :return: The data index at that point or '-' + """ + picking = self._pickScatterData(x, y) + return '-' if picking is None else picking[0] + + def _getPickedX(self, x, y): + """Returns X position snapped to scatter plot when close enough + + :param float x: + :param float y: + :rtype: float + """ + picking = self._pickScatterData(x, y) + return x if picking is None else picking[1] + + def _getPickedY(self, x, y): + """Returns Y position snapped to scatter plot when close enough + + :param float x: + :param float y: + :rtype: float + """ + picking = self._pickScatterData(x, y) + return y if picking is None else picking[2] + + def _getPickedValue(self, x, y): + """Get data value of top most scatter plot at position (x, y) + + :param float x: X position in plot coordinates + :param float y: Y position in plot coordinates + :return: The data value at that point or '-' + """ + picking = self._pickScatterData(x, y) + return '-' if picking is None else picking[3] + + def _mouseInPlotArea(self, x, y): + """Clip mouse coordinates to plot area coordinates + + :param float x: X position in pixels + :param float y: Y position in pixels + :return: (x, y) in data coordinates + """ + plot = self.getPlotWidget() + left, top, width, height = plot.getPlotBoundsInPixels() + xPlot = numpy.clip(x, left, left + width - 1) + yPlot = numpy.clip(y, top, top + height - 1) + return xPlot, yPlot + + def getPlotWidget(self): + """Returns the :class:`~silx.gui.plot.PlotWidget` this window is based on. + + :rtype: ~silx.gui.plot.PlotWidget + """ + return self._plot() + + def getPositionInfoWidget(self): + """Returns the widget display mouse coordinates information. + + :rtype: ~silx.gui.plot.tools.PositionInfo + """ + return self._positionInfo + + def getMaskToolsWidget(self): + """Returns the widget controlling mask drawing + + :rtype: ~silx.gui.plot.ScatterMaskToolsWidget + """ + return self._maskToolsWidget + + def getInteractiveModeToolBar(self): + """Returns QToolBar controlling interactive mode. + + :rtype: ~silx.gui.plot.tools.InteractiveModeToolBar + """ + return self._interactiveModeToolBar + + def getScatterToolBar(self): + """Returns QToolBar providing scatter plot tools. + + :rtype: ~silx.gui.plot.tools.ScatterToolBar + """ + return self._scatterToolBar + + def getScatterProfileToolBar(self): + """Returns QToolBar providing scatter profile tools. + + :rtype: ~silx.gui.plot.tools.profile.ScatterProfileToolBar + """ + return self._profileToolBar + + def getOutputToolBar(self): + """Returns QToolBar containing save, copy and print actions + + :rtype: ~silx.gui.plot.tools.OutputToolBar + """ + return self._outputToolBar + + def setColormap(self, colormap=None): + """Set the colormap for the displayed scatter and the + default plot colormap. + + :param ~silx.gui.colors.Colormap colormap: + The description of the colormap. + """ + self.getScatterItem().setColormap(colormap) + # Resilient to call to PlotWidget API (e.g., clear) + self.getPlotWidget().setDefaultColormap(colormap) + + def getColormap(self): + """Return the colormap object in use. + + :return: Colormap currently in use + :rtype: ~silx.gui.colors.Colormap + """ + return self.getScatterItem().getColormap() + + # Control displayed scatter plot + + def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True): + """Set the data of the scatter plot. + + To reset the scatter plot, set x, y and value to None. + + :param Union[numpy.ndarray,None] x: X coordinates. + :param Union[numpy.ndarray,None] y: Y coordinates. + :param Union[numpy.ndarray,None] value: + The data corresponding to the value of the data points. + :param xerror: Values with the uncertainties on the x values. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :type xerror: A float, or a numpy.ndarray of float32. + + :param yerror: Values with the uncertainties on the y values + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param alpha: Values with the transparency (between 0 and 1) + :type alpha: A float, or a numpy.ndarray of float32 + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + x = () if x is None else x + y = () if y is None else y + value = () if value is None else value + + self.getScatterItem().setData( + x=x, y=y, value=value, xerror=xerror, yerror=yerror, alpha=alpha, copy=copy) + + @docstring(items.Scatter) + def getData(self, *args, **kwargs): + return self.getScatterItem().getData(*args, **kwargs) + + def getScatterItem(self): + """Returns the plot item displaying the scatter data. + + This allows to set the style of the displayed scatter. + + :rtype: ~silx.gui.plot.items.Scatter + """ + plot = self.getPlotWidget() + scatter = plot._getItem(kind='scatter', legend=self._SCATTER_LEGEND) + if scatter is None: # Resilient to call to PlotWidget API (e.g., clear) + scatter = self.__createEmptyScatter() + return scatter + + # Convenient proxies + + @docstring(PlotWidget) + def getXAxis(self, *args, **kwargs): + return self.getPlotWidget().getXAxis(*args, **kwargs) + + @docstring(PlotWidget) + def getYAxis(self, *args, **kwargs): + return self.getPlotWidget().getYAxis(*args, **kwargs) + + @docstring(PlotWidget) + def setGraphTitle(self, *args, **kwargs): + return self.getPlotWidget().setGraphTitle(*args, **kwargs) + + @docstring(PlotWidget) + def getGraphTitle(self, *args, **kwargs): + return self.getPlotWidget().getGraphTitle(*args, **kwargs) + + @docstring(PlotWidget) + def resetZoom(self, *args, **kwargs): + return self.getPlotWidget().resetZoom(*args, **kwargs) + + @docstring(ScatterMaskToolsWidget) + def getSelectionMask(self, *args, **kwargs): + return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs) + + @docstring(ScatterMaskToolsWidget) + def setSelectionMask(self, *args, **kwargs): + return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs) diff --git a/src/silx/gui/plot/StackView.py b/src/silx/gui/plot/StackView.py new file mode 100644 index 0000000..56793d7 --- /dev/null +++ b/src/silx/gui/plot/StackView.py @@ -0,0 +1,1254 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""QWidget displaying a 3D volume as a stack of 2D images. + +The :class:`StackView` class implements this widget. + +Basic usage of :class:`StackView` is through the following methods: + +- :meth:`StackView.getColormap`, :meth:`StackView.setColormap` to update the + default colormap to use and update the currently displayed image. +- :meth:`StackView.setStack` to update the displayed image. + +The :class:`StackView` uses :class:`PlotWindow` and also +exposes a subset of the :class:`silx.gui.plot.Plot` API for further control +(plot title, axes labels, ...). + +The :class:`StackViewMainWindow` class implements a widget that adds a status +bar displaying the 3D index and the value under the mouse cursor. + +Example:: + + import numpy + import sys + from silx.gui import qt + from silx.gui.plot.StackView import StackViewMainWindow + + + app = qt.QApplication(sys.argv[1:]) + + # synthetic data, stack of 100 images of size 200x300 + mystack = numpy.fromfunction( + lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.), + (100, 200, 300) + ) + + + sv = StackViewMainWindow() + sv.setColormap("jet", autoscale=True) + sv.setStack(mystack) + sv.setLabels(["1st dim (0-99)", "2nd dim (0-199)", + "3rd dim (0-299)"]) + sv.show() + + app.exec() + +""" + +__authors__ = ["P. Knobel", "H. Payno"] +__license__ = "MIT" +__date__ = "10/10/2018" + +import numpy +import logging + +import silx +from silx.gui import qt +from .. import icons +from . import items, PlotWindow, actions +from .items.image import ImageStack +from ..colors import Colormap +from ..colors import cursorColorForColormap +from .tools import LimitsToolBar +from .Profile import Profile3DToolBar +from ..widgets.FrameBrowser import HorizontalSliderWithBrowser + +from silx.gui.plot.actions import control as actions_control +from silx.gui.plot.actions import io as silx_io +from silx.io.nxdata import save_NXdata +from silx.utils.array_like import DatasetView, ListOfImages +from silx.math import calibration +from silx.utils.deprecation import deprecated_warning +from silx.utils.deprecation import deprecated + +import h5py +from silx.io.utils import is_dataset + +_logger = logging.getLogger(__name__) + + +class StackView(qt.QMainWindow): + """Stack view widget, to display and browse through stack of + images. + + The profile tool can be switched to "3D" mode, to compute the profile + on each image of the stack (not only the active image currently displayed) + and display the result as a slice. + + :param QWidget parent: the Qt parent, or None + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.PlotWidget` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + :param bool resetzoom: Toggle visibility of reset zoom action. + :param bool autoScale: Toggle visibility of axes autoscale actions. + :param bool logScale: Toggle visibility of axes log scale actions. + :param bool grid: Toggle visibility of grid mode action. + :param bool colormap: Toggle visibility of colormap action. + :param bool aspectRatio: Toggle visibility of aspect ratio button. + :param bool yInverted: Toggle visibility of Y axis direction button. + :param bool copy: Toggle visibility of copy action. + :param bool save: Toggle visibility of save action. + :param bool print_: Toggle visibility of print action. + :param bool control: True to display an Options button with a sub-menu + to show legends, toggle crosshair and pan with arrows. + (Default: False) + :param position: True to display widget with (x, y) mouse position + (Default: False). + It also supports a list of (name, funct(x, y)->value) + to customize the displayed values. + See :class:`silx.gui.plot.PlotTools.PositionInfo`. + :param bool mask: Toggle visibilty of mask action. + """ + # Qt signals + valueChanged = qt.Signal(object, object, object) + """Signals that the data value under the cursor has changed. + + It provides: row, column, data value. + """ + + sigPlaneSelectionChanged = qt.Signal(int) + """Signal emitted when there is a change is perspective/displayed axes. + + It provides the perspective as an integer, with the following meaning: + + - 0: axis Y is the 2nd dimension, axis XÂ is the 3rd dimension + - 1: axis Y is the 1st dimension, axis XÂ is the 3rd dimension + - 2: axis Y is the 1st dimension, axis XÂ is the 2nd dimension + """ + + sigStackChanged = qt.Signal(int) + """Signal emitted when the stack is changed. + This happens when a new volume is loaded, or when the current volume + is transposed (change in perspective). + + The signal provides the size (number of pixels) of the stack. + This will be 0 if the stack is cleared, else it will be a positive + integer. + """ + + sigFrameChanged = qt.Signal(int) + """Signal emitter when the frame number has changed. + + This signal provides the current frame number. + """ + + IMAGE_STACK_FILTER_NXDATA = 'Stack of images as NXdata (%s)' % silx_io._NEXUS_HDF5_EXT_STR + + + def __init__(self, parent=None, resetzoom=True, backend=None, + autoScale=False, logScale=False, grid=False, + colormap=True, aspectRatio=True, yinverted=True, + copy=True, save=True, print_=True, control=False, + position=None, mask=True): + qt.QMainWindow.__init__(self, parent) + if parent is not None: + # behave as a widget + self.setWindowFlags(qt.Qt.Widget) + else: + self.setWindowTitle('StackView') + + self._stack = None + """Loaded stack, as a 3D array, a 3D dataset or a list of 2D arrays.""" + self.__transposed_view = None + """View on :attr:`_stack` with the axes sorted, to have + the orthogonal dimension first""" + self._perspective = 0 + """Orthogonal dimension (depth) in :attr:`_stack`""" + + self._stackItem = ImageStack() + """Hold the item displaying the stack""" + imageLegend = '__StackView__image' + str(id(self)) + self._stackItem.setName(imageLegend) + + self.__autoscaleCmap = False + """Flag to disable/enable colormap auto-scaling + based on the min/max values of the entire 3D volume""" + self.__dimensionsLabels = ["Dimension 0", "Dimension 1", + "Dimension 2"] + """These labels are displayed on the X and Y axes. + :meth:`setLabels` updates this attribute.""" + + self._first_stack_dimension = 0 + """Used for dimension labels and combobox""" + + self._titleCallback = self._defaultTitleCallback + """Function returning the plot title based on the frame index. + It can be set to a custom function using :meth:`setTitleCallback`""" + + self.calibrations3D = (calibration.NoCalibration(), + calibration.NoCalibration(), + calibration.NoCalibration()) + + central_widget = qt.QWidget(self) + + self._plot = PlotWindow(parent=central_widget, backend=backend, + resetzoom=resetzoom, autoScale=autoScale, + logScale=logScale, grid=grid, + curveStyle=False, colormap=colormap, + aspectRatio=aspectRatio, yInverted=yinverted, + copy=copy, save=save, print_=print_, + control=control, position=position, + roi=False, mask=mask) + self._plot.addItem(self._stackItem) + self._plot.getIntensityHistogramAction().setVisible(True) + self.sigInteractiveModeChanged = self._plot.sigInteractiveModeChanged + self.sigActiveImageChanged = self._plot.sigActiveImageChanged + self.sigPlotSignal = self._plot.sigPlotSignal + + if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward': + self._plot.getYAxis().setInverted(True) + + self._addColorBarAction() + + self._profileToolBar = Profile3DToolBar(parent=self._plot, + stackview=self) + self._plot.addToolBar(self._profileToolBar) + self._plot.getXAxis().setLabel('Columns') + self._plot.getYAxis().setLabel('Rows') + self._plot.sigPlotSignal.connect(self._plotCallback) + self._plot.getSaveAction().setFileFilter('image', self.IMAGE_STACK_FILTER_NXDATA, func=self._saveImageStack, appendToFile=True) + + self.__planeSelection = PlanesWidget(self._plot) + self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective) + + self._browser_label = qt.QLabel("Image index (Dim0):") + + self._browser = HorizontalSliderWithBrowser(central_widget) + self._browser.setRange(0, 0) + self._browser.valueChanged[int].connect(self.__updateFrameNumber) + self._browser.setEnabled(False) + + layout = qt.QGridLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._plot, 0, 0, 1, 3) + layout.addWidget(self.__planeSelection, 1, 0) + layout.addWidget(self._browser_label, 1, 1) + layout.addWidget(self._browser, 1, 2) + + central_widget.setLayout(layout) + self.setCentralWidget(central_widget) + + # clear profile lines when the perspective changes (plane browsed changed) + self.__planeSelection.sigPlaneSelectionChanged.connect( + self._profileToolBar.clearProfile) + + def _saveImageStack(self, plot, filename, nameFilter): + """Save all images from the stack into a volume. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + :raises: ValueError if nameFilter is invalid + """ + if not nameFilter == self.IMAGE_STACK_FILTER_NXDATA: + raise ValueError('Wrong callback') + entryPath = silx_io.SaveAction._selectWriteableOutputGroup(filename, parent=self) + if entryPath is None: + return False + return save_NXdata(filename, + nxentry_name=entryPath, + signal=self.getStack(copy=False, returnNumpyArray=True)[0], + signal_name="image_stack") + + def _addColorBarAction(self): + self._plot.getColorBarWidget().setVisible(True) + actions = self._plot.toolBar().actions() + for index, action in enumerate(actions): + if action is self._plot.getColormapAction(): + break + self._colorbarAction = actions_control.ColorBarAction(self._plot, self._plot) + self._plot.toolBar().insertAction(actions[index + 1], self._colorbarAction) + + def _plotCallback(self, eventDict): + """Callback for plot events. + + Emit :attr:`valueChanged` signal, with (x, y, value) tuple of the + cursor location in the plot.""" + if eventDict['event'] == 'mouseMoved': + activeImage = self.getActiveImage() + if activeImage is not None: + data = activeImage.getData() + height, width = data.shape + + # Get corresponding coordinate in image + origin = activeImage.getOrigin() + scale = activeImage.getScale() + x = int((eventDict['x'] - origin[0]) / scale[0]) + y = int((eventDict['y'] - origin[1]) / scale[1]) + + if 0 <= x < width and 0 <= y < height: + self.valueChanged.emit(float(x), float(y), + data[y][x]) + else: + self.valueChanged.emit(float(x), float(y), + None) + + def getPerspective(self): + """Returns the index of the dimension the stack is browsed with + + Possible values are: 0, 1, or 2. + + :rtype: int + """ + return self._perspective + + def setPerspective(self, perspective): + """Set the index of the dimension the stack is browsed with: + + - slice plane Dim1-Dim2: perspective 0 + - slice plane Dim0-Dim2: perspective 1 + - slice plane Dim0-Dim1: perspective 2 + + :param int perspective: Orthogonal dimension number (0, 1, or 2) + """ + if perspective == self._perspective: + return + else: + if perspective > 2 or perspective < 0: + raise ValueError( + "Perspective must be 0, 1 or 2, not %s" % perspective) + + self._perspective = int(perspective) + self.__createTransposedView() + self.__updateFrameNumber(self._browser.value()) + self._plot.resetZoom() + self.__updatePlotLabels() + self._updateTitle() + self._browser_label.setText("Image index (Dim%d):" % + (self._first_stack_dimension + perspective)) + + self.sigPlaneSelectionChanged.emit(perspective) + self.sigStackChanged.emit(self._stack.size if + self._stack is not None else 0) + self.__planeSelection.sigPlaneSelectionChanged.disconnect(self.setPerspective) + self.__planeSelection.setPerspective(self._perspective) + self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective) + + def __updatePlotLabels(self): + """Update plot axes labels depending on perspective""" + y, x = (1, 2) if self._perspective == 0 else \ + (0, 2) if self._perspective == 1 else (0, 1) + self.setGraphXLabel(self.__dimensionsLabels[x]) + self.setGraphYLabel(self.__dimensionsLabels[y]) + + def __createTransposedView(self): + """Create the new view on the stack depending on the perspective + (set orthogonal axis browsed on the viewer as first dimension) + """ + assert self._stack is not None + assert 0 <= self._perspective < 3 + + # ensure we have the stack encapsulated in an array-like object + # having a transpose() method + if isinstance(self._stack, numpy.ndarray): + self.__transposed_view = self._stack + + elif is_dataset(self._stack) or isinstance(self._stack, DatasetView): + self.__transposed_view = DatasetView(self._stack) + + elif isinstance(self._stack, ListOfImages): + self.__transposed_view = ListOfImages(self._stack) + + # transpose the array-like object if necessary + if self._perspective == 1: + self.__transposed_view = self.__transposed_view.transpose((1, 0, 2)) + elif self._perspective == 2: + self.__transposed_view = self.__transposed_view.transpose((2, 0, 1)) + + self._browser.setRange(0, self.__transposed_view.shape[0] - 1) + self._browser.setValue(0) + + # Update the item structure + self._stackItem.setStackData(self.__transposed_view, 0, copy=False) + self._stackItem.setColormap(self.getColormap()) + self._stackItem.setOrigin(self._getImageOrigin()) + self._stackItem.setScale(self._getImageScale()) + + def __updateFrameNumber(self, index): + """Update the current image. + + :param index: index of the frame to be displayed + """ + if self.__transposed_view is None: + # no data set + return + + self._stackItem.setStackPosition(index) + + self._updateTitle() + self.sigFrameChanged.emit(index) + + def _set3DScaleAndOrigin(self, calibrations): + """Set scale and origin for all 3 axes, to be used when plotting + an image. + + See setStack for parameter documentation + """ + if calibrations is None: + self.calibrations3D = (calibration.NoCalibration(), + calibration.NoCalibration(), + calibration.NoCalibration()) + else: + self.calibrations3D = [] + for i, calib in enumerate(calibrations): + if hasattr(calib, "__len__") and len(calib) == 2: + calib = calibration.LinearCalibration(calib[0], calib[1]) + elif calib is None: + calib = calibration.NoCalibration() + elif not isinstance(calib, calibration.AbstractCalibration): + raise TypeError("calibration must be a 2-tuple, None or" + + " an instance of an AbstractCalibration " + + "subclass") + elif not calib.is_affine(): + _logger.warning( + "Calibration for dimension %d is not linear, " + "it will be ignored for scaling the graph axes.", + i) + self.calibrations3D.append(calib) + + def getCalibrations(self, order='array'): + """Returns currently used calibrations for each axis + + Returned calibrations might differ from the ones that were set as + non-linear calibrations used for image axes are temporarily ignored. + + :param str order: + 'array' to sort calibrations as data array (dim0, dim1, dim2), + 'axes' to sort calibrations as currently selected x, y and z axes. + :return: Calibrations ordered depending on order + :rtype: List[~silx.math.calibration.AbstractCalibration] + """ + assert order in ('array', 'axes') + calibs = [] + + # filter out non-linear calibration for graph axes + for index, calib in enumerate(self.calibrations3D): + if index != self._perspective and not calib.is_affine(): + calib = calibration.NoCalibration() + calibs.append(calib) + + if order == 'axes': # Move 'z' axis to the end + xy_dims = [d for d in (0, 1, 2) if d != self._perspective] + calibs = [calibs[max(xy_dims)], + calibs[min(xy_dims)], + calibs[self._perspective]] + + return tuple(calibs) + + def _getImageScale(self): + """ + :return: 2-tuple (XScale, YScale) for current image view + """ + xcalib, ycalib, _zcalib = self.getCalibrations(order='axes') + return xcalib.get_slope(), ycalib.get_slope() + + def _getImageOrigin(self): + """ + :return: 2-tuple (XOrigin, YOrigin) for current image view + """ + xcalib, ycalib, _zcalib = self.getCalibrations(order='axes') + return xcalib(0), ycalib(0) + + def _getImageZ(self, index): + """ + :param idx: 0-based image index in the stack + :return: calibrated Z value corresponding to the image idx + """ + _xcalib, _ycalib, zcalib = self.getCalibrations(order='axes') + return zcalib(index) + + def _updateTitle(self): + frame_idx = self._browser.value() + self._plot.setGraphTitle(self._titleCallback(frame_idx)) + + def _defaultTitleCallback(self, index): + return "Image z=%g" % self._getImageZ(index) + + # public API, stack specific methods + def setStack(self, stack, perspective=None, reset=True, calibrations=None): + """Set the 3D stack. + + The perspective parameter is used to define which dimension of the 3D + array is to be used as frame index. The lowest remaining dimension + number is the row index of the displayed image (Y axis), and the highest + remaining dimension is the column index (X axis). + + :param stack: 3D stack, or `None` to clear plot. + :type stack: 3D numpy.ndarray, or 3D h5py.Dataset, or list/tuple of 2D + numpy arrays, or None. + :param int perspective: Dimension for the frame index: 0, 1 or 2. + Use ``None`` to keep the current perspective (default). + :param bool reset: Whether to reset zoom or not. + :param calibrations: Sequence of 3 calibration objects for each axis. + These objects can be a subclass of :class:`AbstractCalibration`, + or 2-tuples *(a, b)* where *a* is the y-intercept and *b* is the + slope of a linear calibration (:math:`x \\mapsto a + b x`) + """ + if stack is None: + self.clear() + self.sigStackChanged.emit(0) + return + + self._set3DScaleAndOrigin(calibrations) + + # stack as list of 2D arrays: must be converted into an array_like + if not isinstance(stack, numpy.ndarray): + if not is_dataset(stack): + try: + assert hasattr(stack, "__len__") + for img in stack: + assert hasattr(img, "shape") + assert len(img.shape) == 2 + except AssertionError: + raise ValueError( + "Stack must be a 3D array/dataset or a list of " + + "2D arrays.") + stack = ListOfImages(stack) + + assert len(stack.shape) == 3, "data must be 3D" + + self._stack = stack + self.__createTransposedView() + + perspective_changed = False + if perspective not in [None, self._perspective]: + perspective_changed = True + self.setPerspective(perspective) + + if self.__autoscaleCmap: + self.scaleColormapRangeToStack() + + # init plot + self._stackItem.setStackData(self.__transposed_view, 0, copy=False) + self._stackItem.setColormap(self.getColormap()) + self._stackItem.setOrigin(self._getImageOrigin()) + self._stackItem.setScale(self._getImageScale()) + self._stackItem.setVisible(True) + + # Put back the item in the plot in case it was cleared + exists = self._plot.getImage(self._stackItem.getName()) + if exists is None: + self._plot.addItem(self._stackItem) + + self._plot.setActiveImage(self._stackItem.getName()) + self.__updatePlotLabels() + self._updateTitle() + + if reset: + self._plot.resetZoom() + + # enable and init browser + self._browser.setEnabled(True) + + if not perspective_changed: # avoid double signal (see self.setPerspective) + self.sigStackChanged.emit(stack.size) + + def getStack(self, copy=True, returnNumpyArray=False): + """Get the original stack, as a 3D array or dataset. + + The output has the form: [data, params] + where params is a dictionary containing display parameters. + + :param bool copy: If True (default), then the object is copied + and returned as a numpy array. + Else, a reference to original data is returned, if possible. + If the original data is not a numpy array and parameter + returnNumpyArray is True, a copy will be made anyway. + :param bool returnNumpyArray: If True, the returned object is + guaranteed to be a numpy array. + :return: 3D stack and parameters. + :rtype: (numpy.ndarray, dict) + """ + if self._stack is None: + return None + + image = self._stackItem + colormap = image.getColormap() + + params = { + 'info': image.getInfo(), + 'origin': image.getOrigin(), + 'scale': image.getScale(), + 'z': image.getZValue(), + 'selectable': image.isSelectable(), + 'draggable': image.isDraggable(), + 'colormap': colormap, + 'xlabel': image.getXLabel(), + 'ylabel': image.getYLabel(), + } + if returnNumpyArray or copy: + return numpy.array(self._stack, copy=copy), params + + # if a list of 2D arrays was cast into a ListOfImages, + # return the original list + if isinstance(self._stack, ListOfImages): + return self._stack.images, params + + return self._stack, params + + def getCurrentView(self, copy=True, returnNumpyArray=False): + """Get the stack, as it is currently displayed. + + The first index of the returned stack is always the frame + index. If the perspective has been changed in the widget since the + data was first loaded, this will be reflected in the order of the + dimensions of the returned object. + + The output has the form: [data, params] + where params is a dictionary containing display parameters. + + :param bool copy: If True (default), then the object is copied + and returned as a numpy array. + Else, a reference to original data is returned, if possible. + If the original data is not a numpy array and parameter + `returnNumpyArray` is `True`, a copy will be made anyway. + :param bool returnNumpyArray: If `True`, the returned object is + guaranteed to be a numpy array. + :return: 3D stack and parameters. + :rtype: (numpy.ndarray, dict) + """ + image = self.getActiveImage() + if image is None: + return None + + if isinstance(image, items.ColormapMixIn): + colormap = image.getColormap() + else: + colormap = None + + params = { + 'info': image.getInfo(), + 'origin': image.getOrigin(), + 'scale': image.getScale(), + 'z': image.getZValue(), + 'selectable': image.isSelectable(), + 'draggable': image.isDraggable(), + 'colormap': colormap, + 'xlabel': image.getXLabel(), + 'ylabel': image.getYLabel(), + } + if returnNumpyArray or copy: + return numpy.array(self.__transposed_view, copy=copy), params + return self.__transposed_view, params + + def setFrameNumber(self, number): + """Set the frame selection to a specific value + + :param int number: Number of the frame + """ + self._browser.setValue(number) + + def getFrameNumber(self): + """Set the frame selection to a specific value + + :return: Index of currently displayed frame + :rtype: int + """ + return self._browser.value() + + def setFirstStackDimension(self, first_stack_dimension): + """When viewing the last 3 dimensions of an n-D array (n>3), you can + use this method to change the text in the combobox. + + For instance, for a 7-D array, first stack dim is 4, so the default + "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions + numbers are 0-based). + + :param int first_stack_dim: First stack dimension (n-3) when viewing the + last 3 dimensions of an n-D array. + """ + old_state = self.__planeSelection.blockSignals(True) + self.__planeSelection.setFirstStackDimension(first_stack_dimension) + self.__planeSelection.blockSignals(old_state) + self._first_stack_dimension = first_stack_dimension + self._browser_label.setText("Image index (Dim%d):" % first_stack_dimension) + + def setTitleCallback(self, callback): + """Set a user defined function to generate the plot title based on the + image/frame index. + + The callback function must accept an integer as a its first positional + parameter and must not require any other mandatory parameter. + It must return a string. + + To switch back the default behavior, you can pass ``None``:: + + mystackview.setTitleCallback(None) + + To have no title, pass a function that returns an empty string:: + + mystackview.setTitleCallback(lambda idx: "") + + :param callback: Callback function generating the stack title based + on the frame number. + """ + + if callback is None: + self._titleCallback = self._defaultTitleCallback + elif callable(callback): + self._titleCallback = callback + else: + raise TypeError("Provided callback is not callable") + self._updateTitle() + + def clear(self): + """Clear the widget: + + - clear the plot + - clear the loaded data volume + """ + self._stack = None + self.__transposed_view = None + self._perspective = 0 + self._browser.setEnabled(False) + # reset browser range + self._browser.setRange(0, 0) + self._plot.clear() + + def setLabels(self, labels=None): + """Set the labels to be displayed on the plot axes. + + You must provide a sequence of 3 strings, corresponding to the 3 + dimensions of the original data volume. + The proper label will automatically be selected for each plot axis + when the volume is rotated (when different axes are selected as the + X and Y axes). + + :param List[str] labels: 3 labels corresponding to the 3 dimensions + of the data volumes. + """ + + default_labels = ["Dimension %d" % self._first_stack_dimension, + "Dimension %d" % (self._first_stack_dimension + 1), + "Dimension %d" % (self._first_stack_dimension + 2)] + if labels is None: + new_labels = default_labels + else: + # filter-out None + new_labels = [] + for i, label in enumerate(labels): + new_labels.append(label or default_labels[i]) + + self.__dimensionsLabels = new_labels + self.__updatePlotLabels() + + def getLabels(self): + """Return dimension labels displayed on the plot axes + + :return: List of three strings corresponding to the 3 dimensions + of the stack: (name_dim0, name_dim1, name_dim2) + """ + return self.__dimensionsLabels + + def getColormap(self): + """Get the current colormap description. + + :return: A description of the current colormap. + See :meth:`setColormap` for details. + :rtype: dict + """ + # "default" colormap used by addImage when image is added without + # specifying a special colormap + return self._plot.getDefaultColormap() + + def scaleColormapRangeToStack(self): + """Scale colormap range according to current stack data. + + If no stack has been set through :meth:`setStack`, this has no effect. + + The range scaling mode is given by current :class:`Colormap`'s + :meth:`Colormap.getAutoscaleMode`. + """ + stack = self.getStack(copy=False, returnNumpyArray=True) + if stack is None: + return # No-op + + colormap = self.getColormap() + vmin, vmax = colormap.getColormapRange(data=stack[0]) + colormap.setVRange(vmin=vmin, vmax=vmax) + + def setColormap(self, colormap=None, normalization=None, + autoscale=None, vmin=None, vmax=None, colors=None): + """Set the colormap and update active image. + + Parameters that are not provided are taken from the current colormap. + + The colormap parameter can also be a dict with the following keys: + + - *name*: string. The colormap to use: + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. + - *normalization*: string. The mapping to use for the colormap: + either 'linear' or 'log'. + - *autoscale*: bool. Whether to use autoscale (True) or range + provided by keys + 'vmin' and 'vmax' (False). + - *vmin*: float. The minimum value of the range to use if 'autoscale' + is False. + - *vmax*: float. The maximum value of the range to use if 'autoscale' + is False. + - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8. + List of RGB or RGBA colors to use (only if name is None) + + :param colormap: Name of the colormap in + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. + Or a :class`.Colormap` object. + :type colormap: dict or str. + :param str normalization: Colormap mapping: 'linear' or 'log'. + :param bool autoscale: Whether to use autoscale or [vmin, vmax] range. + Default value of autoscale is False. This option is not compatible + with h5py datasets. + :param float vmin: The minimum value of the range to use if + 'autoscale' is False. + :param float vmax: The maximum value of the range to use if + 'autoscale' is False. + :param numpy.ndarray colors: Only used if name is None. + Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays + """ + # if is a colormap object or a dictionary + if isinstance(colormap, Colormap) or isinstance(colormap, dict): + # Support colormap parameter as a dict + errmsg = "If colormap is provided as a Colormap object, all other parameters" + errmsg += " must not be specified when calling setColormap" + assert normalization is None, errmsg + assert autoscale is None, errmsg + assert vmin is None, errmsg + assert vmax is None, errmsg + assert colors is None, errmsg + + if isinstance(colormap, dict): + reason = 'colormap parameter should now be an object' + replacement = 'Colormap()' + since_version = '0.6' + deprecated_warning(type_='function', + name='setColormap', + reason=reason, + replacement=replacement, + since_version=since_version) + _colormap = Colormap._fromDict(colormap) + else: + _colormap = colormap + else: + norm = normalization if normalization is not None else 'linear' + name = colormap if colormap is not None else 'gray' + _colormap = Colormap(name=name, + normalization=norm, + vmin=vmin, + vmax=vmax, + colors=colors) + + if autoscale is not None: + deprecated_warning( + type_='function', + name='setColormap', + reason='autoscale argument is replaced by a method', + replacement='scaleColormapRangeToStack', + since_version='0.14') + self.__autoscaleCmap = bool(autoscale) + + cursorColor = cursorColorForColormap(_colormap.getName()) + self._plot.setInteractiveMode('zoom', color=cursorColor) + + self._plot.setDefaultColormap(_colormap) + + # Update active image colormap + activeImage = self.getActiveImage() + if isinstance(activeImage, items.ColormapMixIn): + activeImage.setColormap(self.getColormap()) + + if self.__autoscaleCmap: + # scaleColormapRangeToStack needs to be called **after** + # setDefaultColormap so getColormap returns the right colormap + self.scaleColormapRangeToStack() + + + @deprecated(replacement="getPlotWidget", since_version="0.13") + def getPlot(self): + return self.getPlotWidget() + + def getPlotWidget(self): + """Return the :class:`PlotWidget`. + + This gives access to advanced plot configuration options. + Be warned that modifying the plot can cause issues, and some changes + you make to the plot could be overwritten by the :class:`StackView` + widget's internal methods and callbacks. + + :return: instance of :class:`PlotWidget` used in widget + """ + return self._plot + + def setOptionVisible(self, isVisible): + """ + Set the visibility of the browsing options. + + :param bool isVisible: True to have the options visible, else False + """ + self._browser.setVisible(isVisible) + self.__planeSelection.setVisible(isVisible) + + # proxies to PlotWidget or PlotWindow methods + def getProfileToolbar(self): + """Profile tools attached to this plot + """ + return self._profileToolBar + + def getGraphTitle(self): + """Return the plot main title as a str. + """ + return self._plot.getGraphTitle() + + def setGraphTitle(self, title=""): + """Set the plot main title. + + :param str title: Main title of the plot (default: '') + """ + return self._plot.setGraphTitle(title) + + def getGraphXLabel(self): + """Return the current horizontal axis label as a str. + """ + return self._plot.getXAxis().getLabel() + + def setGraphXLabel(self, label=None): + """Set the plot horizontal axis label. + + :param str label: The horizontal axis label + """ + if label is None: + label = self.__dimensionsLabels[1 if self._perspective == 2 else 2] + self._plot.getXAxis().setLabel(label) + + def getGraphYLabel(self, axis='left'): + """Return the current vertical axis label as a str. + + :param str axis: The Y axis for which to get the label (left or right) + """ + return self._plot.getYAxis().getLabel(axis) + + def setGraphYLabel(self, label=None, axis='left'): + """Set the vertical axis label on the plot. + + :param str label: The Y axis label + :param str axis: The Y axis for which to set the label (left or right) + """ + if label is None: + label = self.__dimensionsLabels[1 if self._perspective == 0 else 0] + self._plot.getYAxis(axis=axis).setLabel(label) + + def resetZoom(self): + """Reset the plot limits to the bounds of the data and redraw the plot. + + This method is a simple proxy to the legacy :class:`PlotWidget` method + of the same name. Using the object oriented approach is now + preferred:: + + stackview.getPlot().resetZoom() + """ + self._plot.resetZoom() + + def setYAxisInverted(self, flag=True): + """Set the Y axis orientation. + + This method is a simple proxy to the legacy :class:`PlotWidget` method + of the same name. Using the object oriented approach is now + preferred:: + + stackview.getPlot().setYAxisInverted(flag) + + :param bool flag: True for Y axis going from top to bottom, + False for Y axis going from bottom to top + """ + self._plot.setYAxisInverted(flag) + + def isYAxisInverted(self): + """Return True if Y axis goes from top to bottom, False otherwise. + + This method is a simple proxy to the legacy :class:`PlotWidget` method + of the same name. Using the object oriented approach is now + preferred:: + + stackview.getPlot().isYAxisInverted()""" + return self._plot.isYAxisInverted() + + def getSupportedColormaps(self): + """Get the supported colormap names as a tuple of str. + + The list should at least contain and start by: + ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') + + This method is a simple proxy to the legacy :class:`PlotWidget` method + of the same name. Using the object oriented approach is now + preferred:: + + stackview.getPlot().getSupportedColormaps() + """ + return self._plot.getSupportedColormaps() + + def isKeepDataAspectRatio(self): + """Returns whether the plot is keeping data aspect ratio or not. + + This method is a simple proxy to the legacy :class:`PlotWidget` method + of the same name. Using the object oriented approach is now + preferred:: + + stackview.getPlot().isKeepDataAspectRatio()""" + return self._plot.isKeepDataAspectRatio() + + def setKeepDataAspectRatio(self, flag=True): + """Set whether the plot keeps data aspect ratio or not. + + This method is a simple proxy to the legacy :class:`PlotWidget` method + of the same name. Using the object oriented approach is now + preferred:: + + stackview.getPlot().setKeepDataAspectRatio(flag) + + :param bool flag: True to respect data aspect ratio + """ + self._plot.setKeepDataAspectRatio(flag) + + # kind of private methods, but needed by Profile + def getActiveImage(self, just_legend=False): + """Returns the stack image object. + """ + if just_legend: + return self._stackItem.getName() + return self._stackItem + + def getColorBarAction(self): + """Returns the action managing the visibility of the colorbar. + + .. warning:: to show/hide the plot colorbar call directly the ColorBar + widget using getColorBarWidget() + + :rtype: QAction + """ + return self._colorbarAction + + def remove(self, legend=None, + kind=('curve', 'image', 'item', 'marker')): + """See :meth:`Plot.Plot.remove`""" + self._plot.remove(legend, kind) + + def setInteractiveMode(self, *args, **kwargs): + """ + See :meth:`Plot.Plot.setInteractiveMode` + """ + self._plot.setInteractiveMode(*args, **kwargs) + + @deprecated(replacement="addShape", since_version="0.13") + def addItem(self, *args, **kwargs): + self.addShape(*args, **kwargs) + + def addShape(self, *args, **kwargs): + """ + See :meth:`Plot.Plot.addShape` + """ + self._plot.addShape(*args, **kwargs) + + +class PlanesWidget(qt.QWidget): + """Widget for the plane/perspective selection + + :param parent: the parent QWidget + """ + sigPlaneSelectionChanged = qt.Signal(int) + + def __init__(self, parent): + super(PlanesWidget, self).__init__(parent) + + self.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum) + layout0 = qt.QHBoxLayout() + self.setLayout(layout0) + layout0.setContentsMargins(0, 0, 0, 0) + + layout0.addWidget(qt.QLabel("Axes selection:")) + + # By default, the first dimension (dim0) is the frame index/depth/z, + # the second dimension is the image row number/y axis + # and the third dimension is the image column index/x axis + + # 1 + # | 0 + # |/__2 + self.qcbAxisSelection = qt.QComboBox(self) + self._setCBChoices(first_stack_dimension=0) + self.qcbAxisSelection.currentIndexChanged[int].connect( + self.__planeSelectionChanged) + + layout0.addWidget(self.qcbAxisSelection) + + def __planeSelectionChanged(self, idx): + """Callback function when the combobox selection changes + + idx is the dimension number orthogonal to the slice plane, + following the convention: + + - slice plane Dim1-Dim2: perspective 0 + - slice plane Dim0-Dim2: perspective 1 + - slice plane Dim0-Dim1: perspective 2 + """ + self.sigPlaneSelectionChanged.emit(idx) + + def _setCBChoices(self, first_stack_dimension): + self.qcbAxisSelection.clear() + + dim1dim2 = 'Dim%d-Dim%d' % (first_stack_dimension + 1, + first_stack_dimension + 2) + dim0dim2 = 'Dim%d-Dim%d' % (first_stack_dimension, + first_stack_dimension + 2) + dim0dim1 = 'Dim%d-Dim%d' % (first_stack_dimension, + first_stack_dimension + 1) + + self.qcbAxisSelection.addItem(icons.getQIcon("cube-front"), dim1dim2) + self.qcbAxisSelection.addItem(icons.getQIcon("cube-bottom"), dim0dim2) + self.qcbAxisSelection.addItem(icons.getQIcon("cube-left"), dim0dim1) + + def setFirstStackDimension(self, first_stack_dim): + """When viewing the last 3 dimensions of an n-D array (n>3), you can + use this method to change the text in the combobox. + + For instance, for a 7-D array, first stack dim is 4, so the default + "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions + numbers are 0-based). + + :param int first_stack_dim: First stack dimension (n-3) when viewing the + last 3 dimensions of an n-D array. + """ + self._setCBChoices(first_stack_dim) + + def setPerspective(self, perspective): + """Update the combobox selection. + + - slice plane Dim1-Dim2: perspective 0 + - slice plane Dim0-Dim2: perspective 1 + - slice plane Dim0-Dim1: perspective 2 + + :param perspective: Orthogonal dimension number (0, 1, or 2) + """ + self.qcbAxisSelection.setCurrentIndex(perspective) + + +class StackViewMainWindow(StackView): + """This class is a :class:`StackView` with a menu, an additional toolbar + to set the plot limits, and a status bar to display the value and 3D + index of the data samples hovered by the mouse cursor. + + :param QWidget parent: Parent widget, or None + """ + def __init__(self, parent=None): + self._dataInfo = None + super(StackViewMainWindow, self).__init__(parent) + self.setWindowFlags(qt.Qt.Window) + + # Add toolbars and status bar + self.addToolBar(qt.Qt.BottomToolBarArea, + LimitsToolBar(plot=self._plot)) + + self.statusBar() + + menu = self.menuBar().addMenu('File') + menu.addAction(self._plot.getOutputToolBar().getSaveAction()) + menu.addAction(self._plot.getOutputToolBar().getPrintAction()) + menu.addSeparator() + action = menu.addAction('Quit') + action.triggered[bool].connect(qt.QApplication.instance().quit) + + menu = self.menuBar().addMenu('Edit') + menu.addAction(self._plot.getOutputToolBar().getCopyAction()) + menu.addSeparator() + menu.addAction(self._plot.getResetZoomAction()) + menu.addAction(self._plot.getColormapAction()) + menu.addAction(self.getColorBarAction()) + + menu.addAction(actions.control.KeepAspectRatioAction(self._plot, self)) + menu.addAction(actions.control.YAxisInvertedAction(self._plot, self)) + + menu = self.menuBar().addMenu('Profile') + profileToolBar = self._profileToolBar + menu.addAction(profileToolBar.hLineAction) + menu.addAction(profileToolBar.vLineAction) + menu.addAction(profileToolBar.lineAction) + menu.addAction(profileToolBar.crossAction) + menu.addSeparator() + menu.addAction(profileToolBar._editor) + menu.addSeparator() + menu.addAction(profileToolBar.clearAction) + + # Connect to StackView's signal + self.valueChanged.connect(self._statusBarSlot) + + def _statusBarSlot(self, x, y, value): + """Update status bar with coordinates/value from plots.""" + # todo (after implementing calibration): + # - use floats for (x, y, z) + # - display both indices (dim0, dim1, dim2) and (x, y, z) + msg = "Cursor out of range" + if x is not None and y is not None: + img_idx = self._browser.value() + + if self._perspective == 0: + dim0, dim1, dim2 = img_idx, int(y), int(x) + elif self._perspective == 1: + dim0, dim1, dim2 = int(y), img_idx, int(x) + elif self._perspective == 2: + dim0, dim1, dim2 = int(y), int(x), img_idx + + msg = 'Position: (%d, %d, %d)' % (dim0, dim1, dim2) + if value is not None: + msg += ', Value: %g' % value + if self._dataInfo is not None: + msg = self._dataInfo + ', ' + msg + + self.statusBar().showMessage(msg) + + def setStack(self, stack, *args, **kwargs): + """Set the displayed stack. + + See :meth:`StackView.setStack` for details. + """ + if hasattr(stack, 'dtype') and hasattr(stack, 'shape'): + assert len(stack.shape) == 3 + nframes, height, width = stack.shape + self._dataInfo = 'Data: %dx%dx%d (%s)' % (nframes, height, width, + str(stack.dtype)) + self.statusBar().showMessage(self._dataInfo) + else: + self._dataInfo = None + + # Set the new stack in StackView widget + super(StackViewMainWindow, self).setStack(stack, *args, **kwargs) + self.setStatusBar(None) diff --git a/src/silx/gui/plot/StatsWidget.py b/src/silx/gui/plot/StatsWidget.py new file mode 100644 index 0000000..00f78d0 --- /dev/null +++ b/src/silx/gui/plot/StatsWidget.py @@ -0,0 +1,1658 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +Module containing widgets displaying stats from items of a plot. +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "24/07/2018" + + +from collections import OrderedDict +from contextlib import contextmanager +import logging +import weakref +import functools +import numpy +import enum +from silx.utils.proxy import docstring +from silx.utils.enum import Enum as _Enum +from silx.gui import qt +from silx.gui import icons +from silx.gui.plot import stats as statsmdl +from silx.gui.widgets.TableWidget import TableWidget +from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter +from silx.gui.plot.items.core import ItemChangedType +from silx.gui.widgets.FlowLayout import FlowLayout +from . import PlotWidget +from . import items as plotitems + + +_logger = logging.getLogger(__name__) + + +@enum.unique +class UpdateMode(_Enum): + AUTO = 'auto' + MANUAL = 'manual' + + +# Helper class to handle specific calls to PlotWidget and SceneWidget + + +class _Wrapper(qt.QObject): + """Base class for connection with PlotWidget and SceneWidget. + + This class is used when no PlotWidget or SceneWidget is connected. + + :param plot: The plot to be used + """ + + sigItemAdded = qt.Signal(object) + """Signal emitted when a new item is added. + + It provides the added item. + """ + + sigItemRemoved = qt.Signal(object) + """Signal emitted when an item is (about to be) removed. + + It provides the removed item. + """ + + sigCurrentChanged = qt.Signal(object) + """Signal emitted when the current item has changed. + + It provides the current item. + """ + + sigVisibleDataChanged = qt.Signal() + """Signal emitted when the visible data area has changed""" + + def __init__(self, plot=None): + super(_Wrapper, self).__init__(parent=None) + self._plotRef = None if plot is None else weakref.ref(plot) + + def getPlot(self): + """Returns the plot attached to this widget""" + return None if self._plotRef is None else self._plotRef() + + def getItems(self): + """Returns the list of items in the plot + + :rtype: List[object] + """ + return () + + def getSelectedItems(self): + """Returns the list of selected items in the plot + + :rtype: List[object] + """ + return () + + def setCurrentItem(self, item): + """Set the current/active item in the plot + + :param item: The plot item to set as active/current + """ + pass + + def getLabel(self, item): + """Returns the label of the given item. + + :param item: + :rtype: str + """ + return '' + + def getKind(self, item): + """Returns the kind of an item or None if not supported + + :param item: + :rtype: Union[str,None] + """ + return None + + +class _PlotWidgetWrapper(_Wrapper): + """Class handling PlotWidget specific calls and signal connections + + See :class:`._Wrapper` for documentation + + :param PlotWidget plot: + """ + + def __init__(self, plot): + assert isinstance(plot, PlotWidget) + super(_PlotWidgetWrapper, self).__init__(plot) + plot.sigItemAdded.connect(self.sigItemAdded.emit) + plot.sigItemAboutToBeRemoved.connect(self.sigItemRemoved.emit) + plot.sigActiveCurveChanged.connect(self._activeCurveChanged) + plot.sigActiveImageChanged.connect(self._activeImageChanged) + plot.sigActiveScatterChanged.connect(self._activeScatterChanged) + plot.sigPlotSignal.connect(self._limitsChanged) + + def _activeChanged(self, kind): + """Handle change of active curve/image/scatter""" + plot = self.getPlot() + if plot is not None: + item = plot._getActiveItem(kind=kind) + if item is None or self.getKind(item) is not None: + self.sigCurrentChanged.emit(item) + + def _activeCurveChanged(self, previous, current): + self._activeChanged(kind='curve') + + def _activeImageChanged(self, previous, current): + self._activeChanged(kind='image') + + def _activeScatterChanged(self, previous, current): + self._activeChanged(kind='scatter') + + def _limitsChanged(self, event): + """Handle change of plot area limits.""" + if event['event'] == 'limitsChanged': + self.sigVisibleDataChanged.emit() + + def getItems(self): + plot = self.getPlot() + if plot is None: + return () + else: + return [item for item in plot.getItems() if item.isVisible()] + + def getSelectedItems(self): + plot = self.getPlot() + items = [] + if plot is not None: + for kind in plot._ACTIVE_ITEM_KINDS: + item = plot._getActiveItem(kind=kind) + if item is not None: + items.append(item) + return tuple(items) + + def setCurrentItem(self, item): + plot = self.getPlot() + if plot is not None: + kind = self.getKind(item) + if kind in plot._ACTIVE_ITEM_KINDS: + if plot._getActiveItem(kind) != item: + plot._setActiveItem(kind, item.getName()) + + def getLabel(self, item): + return item.getName() + + def getKind(self, item): + if isinstance(item, plotitems.Curve): + return 'curve' + elif isinstance(item, plotitems.ImageData): + return 'image' + elif isinstance(item, plotitems.Scatter): + return 'scatter' + elif isinstance(item, plotitems.Histogram): + return 'histogram' + else: + return None + + +class _SceneWidgetWrapper(_Wrapper): + """Class handling SceneWidget specific calls and signal connections + + See :class:`._Wrapper` for documentation + + :param SceneWidget plot: + """ + + def __init__(self, plot): + # Lazy-import to avoid circular imports + from ..plot3d.SceneWidget import SceneWidget + + assert isinstance(plot, SceneWidget) + super(_SceneWidgetWrapper, self).__init__(plot) + plot.getSceneGroup().sigItemAdded.connect(self.sigItemAdded) + plot.getSceneGroup().sigItemRemoved.connect(self.sigItemRemoved) + plot.selection().sigCurrentChanged.connect(self._currentChanged) + # sigVisibleDataChanged is never emitted + + def _currentChanged(self, current, previous): + self.sigCurrentChanged.emit(current) + + def getItems(self): + plot = self.getPlot() + return () if plot is None else tuple(plot.getSceneGroup().visit()) + + def getSelectedItems(self): + plot = self.getPlot() + return () if plot is None else (plot.selection().getCurrentItem(),) + + def setCurrentItem(self, item): + plot = self.getPlot() + if plot is not None: + plot.selection().setCurrentItem(item) + + def getLabel(self, item): + return item.getLabel() + + def getKind(self, item): + from ..plot3d import items as plot3ditems + + if isinstance(item, (plot3ditems.ImageData, + plot3ditems.ScalarField3D)): + return 'image' + elif isinstance(item, (plot3ditems.Scatter2D, + plot3ditems.Scatter3D)): + return 'scatter' + else: + return None + + +class _ScalarFieldViewWrapper(_Wrapper): + """Class handling ScalarFieldView specific calls and signal connections + + See :class:`._Wrapper` for documentation + + :param SceneWidget plot: + """ + + def __init__(self, plot): + # Lazy-import to avoid circular imports + from ..plot3d.ScalarFieldView import ScalarFieldView + from ..plot3d.items import ScalarField3D + + assert isinstance(plot, ScalarFieldView) + super(_ScalarFieldViewWrapper, self).__init__(plot) + self._item = ScalarField3D() + self._dataChanged() + plot.sigDataChanged.connect(self._dataChanged) + # sigItemAdded, sigItemRemoved, sigVisibleDataChanged are never emitted + + def _dataChanged(self): + plot = self.getPlot() + if plot is not None: + self._item.setData(plot.getData(copy=False), copy=False) + self.sigCurrentChanged.emit(self._item) + + def getItems(self): + plot = self.getPlot() + return () if plot is None else (self._item,) + + def getSelectedItems(self): + return self.getItems() + + def setCurrentItem(self, item): + pass + + def getLabel(self, item): + return 'Data' + + def getKind(self, item): + return 'image' + + +class _Container(object): + """Class to contain a plot item. + + This is apparently needed for compatibility with PySide2, + + :param QObject obj: + """ + def __init__(self, obj): + self._obj = obj + + def __call__(self): + return self._obj + + +class _StatsWidgetBase(object): + """ + Base class for all widgets which want to display statistics + """ + + def __init__(self, statsOnVisibleData, displayOnlyActItem): + self._displayOnlyActItem = displayOnlyActItem + self._statsOnVisibleData = statsOnVisibleData + self._statsHandler = None + self._updateMode = UpdateMode.AUTO + + self.__default_skipped_events = ( + ItemChangedType.ALPHA, + ItemChangedType.COLOR, + ItemChangedType.COLORMAP, + ItemChangedType.SYMBOL, + ItemChangedType.SYMBOL_SIZE, + ItemChangedType.LINE_WIDTH, + ItemChangedType.LINE_STYLE, + ItemChangedType.LINE_BG_COLOR, + ItemChangedType.FILL, + ItemChangedType.HIGHLIGHTED_COLOR, + ItemChangedType.HIGHLIGHTED_STYLE, + ItemChangedType.TEXT, + ItemChangedType.OVERLAY, + ItemChangedType.VISUALIZATION_MODE, + ) + + self._plotWrapper = _Wrapper() + self._dealWithPlotConnection(create=True) + + def setPlot(self, plot): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + try: + import OpenGL + except ImportError: + has_opengl = False + else: + has_opengl = True + from ..plot3d.SceneWidget import SceneWidget # Lazy import + self._dealWithPlotConnection(create=False) + self.clear() + if plot is None: + self._plotWrapper = _Wrapper() + elif isinstance(plot, PlotWidget): + self._plotWrapper = _PlotWidgetWrapper(plot) + else: + if has_opengl is True: + if isinstance(plot, SceneWidget): + self._plotWrapper = _SceneWidgetWrapper(plot) + else: # Expect a ScalarFieldView + self._plotWrapper = _ScalarFieldViewWrapper(plot) + else: + _logger.warning('OpenGL not installed, %s not managed' % ('SceneWidget qnd ScalarFieldView')) + self._dealWithPlotConnection(create=True) + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. + + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using + """ + if statsHandler is None: + statsHandler = StatsHandler(statFormatters=()) + elif isinstance(statsHandler, (list, tuple)): + statsHandler = StatsHandler(statsHandler) + assert isinstance(statsHandler, StatsHandler) + + self._statsHandler = statsHandler + + def getStatsHandler(self): + """Returns the :class:`StatsHandler` in use. + + :rtype: StatsHandler + """ + return self._statsHandler + + def getPlot(self): + """Returns the plot attached to this widget + + :rtype: Union[PlotWidget,SceneWidget,None] + """ + return self._plotWrapper.getPlot() + + def _dealWithPlotConnection(self, create=True): + """Manage connection to plot signals + + Note: connection on Item are managed by _addItem and _removeItem methods + """ + connections = [] # List of (signal, slot) to connect/disconnect + if self._statsOnVisibleData: + connections.append( + (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats)) + + if self._displayOnlyActItem: + connections.append( + (self._plotWrapper.sigCurrentChanged, self._updateCurrentItem)) + else: + connections += [ + (self._plotWrapper.sigItemAdded, self._addItem), + (self._plotWrapper.sigItemRemoved, self._removeItem), + (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged)] + + for signal, slot in connections: + if create: + signal.connect(slot) + else: + signal.disconnect(slot) + + def _updateItemObserve(self, *args): + """Reload table depending on mode""" + raise NotImplementedError('Base class') + + def _updateCurrentItem(self, *args): + """specific callback for the sigCurrentChanged and with the + _displayOnlyActItem option.""" + raise NotImplementedError('Base class') + + def _updateStats(self, item, data_changed=False, roi_changed=False): + """Update displayed information for given plot item + + :param item: The plot item + :param bool data_changed: is the item data changed. + :param bool roi_changed: is the associated roi changed. + """ + raise NotImplementedError('Base class') + + def _updateAllStats(self): + """Update stats for all rows in the table""" + raise NotImplementedError('Base class') + + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + """Toggle display off all items or only the active/selected one + + :param bool displayOnlyActItem: + True if we want to only show active item + """ + self._displayOnlyActItem = displayOnlyActItem + + def setStatsOnVisibleData(self, b): + """Toggle computation of statistics on whole data or only visible ones. + + .. warning:: When visible data is activated we will process to a simple + filtering of visible data by the user. The filtering is a + simple data sub-sampling. No interpolation is made to fit + data to boundaries. + + :param bool b: True if we want to apply statistics only on visible data + """ + if self._statsOnVisibleData != b: + self._dealWithPlotConnection(create=False) + self._statsOnVisibleData = b + self._dealWithPlotConnection(create=True) + self._updateAllStats() + + def _addItem(self, item): + """Add a plot item to the table + + If item is not supported, it is ignored. + + :param item: The plot item + :returns: True if the item is added to the widget. + :rtype: bool + """ + raise NotImplementedError('Base class') + + def _removeItem(self, item): + """Remove table items corresponding to given plot item from the table. + + :param item: The plot item + """ + raise NotImplementedError('Base class') + + def _plotCurrentChanged(self, current): + """Handle change of current item and update selection in table + + :param current: + """ + raise NotImplementedError('Base class') + + def clear(self): + """clear GUI""" + pass + + def _skipPlotItemChangedEvent(self, event): + """ + + :param ItemChangedtype event: event to filter or not + :return: True if we want to ignore this ItemChangedtype + :rtype: bool + """ + return event in self.__default_skipped_events + + def setUpdateMode(self, mode): + """Set the way to update the displayed statistics. + + :param mode: mode requested for update + :type mode: Union[str,UpdateMode] + """ + mode = UpdateMode.from_value(mode) + if mode != self._updateMode: + self._updateMode = mode + self._updateModeHasChanged() + + def getUpdateMode(self): + """Returns update mode (See :meth:`setUpdateMode`). + + :return: update mode + :rtype: UpdateMode + """ + return self._updateMode + + def _updateModeHasChanged(self): + """callback when the update mode has changed""" + pass + + +class StatsTable(_StatsWidgetBase, TableWidget): + """ + TableWidget displaying for each items contained by the Plot some + information: + + * legend + * minimal value + * maximal value + * standard deviation (std) + + :param QWidget parent: The widget's parent. + :param Union[PlotWidget,SceneWidget] plot: + :class:`PlotWidget` or :class:`SceneWidget` instance on which to operate + """ + + _LEGEND_HEADER_DATA = 'legend' + _KIND_HEADER_DATA = 'kind' + + sigUpdateModeChanged = qt.Signal(object) + """Signal emitted when the update mode changed""" + + def __init__(self, parent=None, plot=None): + TableWidget.__init__(self, parent) + _StatsWidgetBase.__init__(self, statsOnVisibleData=False, + displayOnlyActItem=False) + + # Init for _displayOnlyActItem == False + assert self._displayOnlyActItem is False + self.setSelectionBehavior(qt.QAbstractItemView.SelectRows) + self.setSelectionMode(qt.QAbstractItemView.SingleSelection) + self.currentItemChanged.connect(self._currentItemChanged) + + self.setRowCount(0) + self.setColumnCount(2) + + # Init headers + headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title()) + headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA) + self.setHorizontalHeaderItem(0, headerItem) + headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title()) + headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA) + self.setHorizontalHeaderItem(1, headerItem) + + self.setSortingEnabled(True) + self.setPlot(plot) + + @contextmanager + def _disableSorting(self): + """Context manager that disables table sorting + + Previous state is restored when leaving + """ + sorting = self.isSortingEnabled() + if sorting: + self.setSortingEnabled(False) + yield + if sorting: + self.setSortingEnabled(sorting) + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. + + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using + """ + self._removeAllItems() + _StatsWidgetBase.setStats(self, statsHandler) + + self.setRowCount(0) + self.setColumnCount(len(self._statsHandler.stats) + 2) # + legend and kind + + for index, stat in enumerate(self._statsHandler.stats.values()): + headerItem = qt.QTableWidgetItem(stat.name.capitalize()) + headerItem.setData(qt.Qt.UserRole, stat.name) + if stat.description is not None: + headerItem.setToolTip(stat.description) + self.setHorizontalHeaderItem(2 + index, headerItem) + + horizontalHeader = self.horizontalHeader() + horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents) + + self._updateItemObserve() + + def setPlot(self, plot): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + _StatsWidgetBase.setPlot(self, plot) + self._updateItemObserve() + + def clear(self): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + self._removeAllItems() + + def _updateItemObserve(self, *args): + """Reload table depending on mode""" + self._removeAllItems() + + # Get selected or all items from the plot + if self._displayOnlyActItem: # Only selected + items = self._plotWrapper.getSelectedItems() + else: # All items + items = self._plotWrapper.getItems() + + # Add items to the plot + for item in items: + self._addItem(item) + + def _updateCurrentItem(self, *args): + """specific callback for the sigCurrentChanged and with the + _displayOnlyActItem option. + + Behavior: create the tableItems if does not exists. + If exists, update it only when we are in 'auto' mode""" + if self.getUpdateMode() is UpdateMode.MANUAL: + # when sigCurrentChanged is giving the current item + if len(args) > 0 and isinstance(args[0], (plotitems.Curve, plotitems.Histogram, plotitems.ImageData, plotitems.Scatter)): + item = args[0] + tableItems = self._itemToTableItems(item) + # if the table does not exists yet + if len(tableItems) == 0: + self._updateItemObserve() + else: + # in this case no current item + self._updateItemObserve(args) + else: + # auto mode + self._updateItemObserve(args) + + def _plotCurrentChanged(self, current): + """Handle change of current item and update selection in table + + :param current: + """ + row = self._itemToRow(current) + if row is None: + if self.currentRow() >= 0: + self.setCurrentCell(-1, -1) + elif row != self.currentRow(): + self.setCurrentCell(row, 0) + + def _tableItemToItem(self, tableItem): + """Find the plot item corresponding to a table item + + :param QTableWidgetItem tableItem: + :rtype: QObject + """ + container = tableItem.data(qt.Qt.UserRole) + return container() + + def _itemToRow(self, item): + """Find the row corresponding to a plot item + + :param item: The plot item + :return: The corresponding row index + :rtype: Union[int,None] + """ + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + if self._tableItemToItem(tableItem) == item: + return row + return None + + def _itemToTableItems(self, item): + """Find all table items corresponding to a plot item + + :param item: The plot item + :return: An ordered dict of column name to QTableWidgetItem mapping + for the given plot item. + :rtype: OrderedDict + """ + result = OrderedDict() + row = self._itemToRow(item) + if row is not None: + for column in range(self.columnCount()): + tableItem = self.item(row, column) + if self._tableItemToItem(tableItem) != item: + _logger.error("Table item/plot item mismatch") + else: + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + result[name] = tableItem + return result + + def _plotItemChanged(self, event): + """Handle modifications of the items. + + :param event: + """ + if self.getUpdateMode() is UpdateMode.MANUAL: + return + if self._skipPlotItemChangedEvent(event) is True: + return + else: + item = self.sender() + self._updateStats(item, data_changed=True) + # deal with stat items visibility + if event is ItemChangedType.VISIBLE: + if len(self._itemToTableItems(item).items()) > 0: + item_0 = list(self._itemToTableItems(item).values())[0] + row_index = item_0.row() + self.setRowHidden(row_index, not item.isVisible()) + + def _addItem(self, item): + """Add a plot item to the table + + If item is not supported, it is ignored. + + :param item: The plot item + :returns: True if the item is added to the widget. + :rtype: bool + """ + if self._itemToRow(item) is not None: + _logger.info("Item already present in the table") + self._updateStats(item) + return True + + kind = self._plotWrapper.getKind(item) + if kind not in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.info("Item has not a supported type: %s", item) + return False + + # Prepare table items + tableItems = [ + qt.QTableWidgetItem(), # Legend + qt.QTableWidgetItem()] # Kind + + for column in range(2, self.columnCount()): + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + + formatter = self._statsHandler.formatters[name] + if formatter: + tableItem = formatter.tabWidgetItemClass() + else: + tableItem = qt.QTableWidgetItem() + + tooltip = self._statsHandler.stats[name].getToolTip(kind=kind) + if tooltip is not None: + tableItem.setToolTip(tooltip) + + tableItems.append(tableItem) + + # Disable sorting while adding table items + with self._disableSorting(): + # Add a row to the table + self.setRowCount(self.rowCount() + 1) + + # Add table items to the last row + row = self.rowCount() - 1 + for column, tableItem in enumerate(tableItems): + tableItem.setData(qt.Qt.UserRole, _Container(item)) + tableItem.setFlags( + qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable) + self.setItem(row, column, tableItem) + + # Update table items content + self._updateStats(item, data_changed=True) + + # Listen for item changes + # Using queued connection to avoid issue with sender + # being that of the signal calling the signal + item.sigItemChanged.connect(self._plotItemChanged, + qt.Qt.QueuedConnection) + + return True + + def _removeItem(self, item): + """Remove table items corresponding to given plot item from the table. + + :param item: The plot item + """ + row = self._itemToRow(item) + if row is None: + kind = self._plotWrapper.getKind(item) + if kind in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.error("Removing item that is not in table: %s", str(item)) + return + item.sigItemChanged.disconnect(self._plotItemChanged) + self.removeRow(row) + + def _removeAllItems(self): + """Remove content of the table""" + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + item = self._tableItemToItem(tableItem) + item.sigItemChanged.disconnect(self._plotItemChanged) + self.clearContents() + self.setRowCount(0) + + def _updateStats(self, item, data_changed=False, roi_changed=False): + """Update displayed information for given plot item + + :param item: The plot item + :param bool data_changed: is the item data changed. + :param bool roi_changed: is the associated roi changed. + """ + if item is None: + return + plot = self.getPlot() + if plot is None: + _logger.info("Plot not available") + return + + row = self._itemToRow(item) + if row is None: + _logger.error("This item is not in the table: %s", str(item)) + return + + statsHandler = self.getStatsHandler() + if statsHandler is not None: + # _updateStats is call when the plot visible area change. + # to force stats update we consider roi changed + if self._statsOnVisibleData: + roi_changed = True + else: + roi_changed = False + stats = statsHandler.calculate( + item, plot, self._statsOnVisibleData, + data_changed=data_changed, roi_changed=roi_changed) + else: + stats = {} + + with self._disableSorting(): + for name, tableItem in self._itemToTableItems(item).items(): + if name == self._LEGEND_HEADER_DATA: + text = self._plotWrapper.getLabel(item) + tableItem.setText(text) + elif name == self._KIND_HEADER_DATA: + tableItem.setText(self._plotWrapper.getKind(item)) + else: + value = stats.get(name) + if value is None: + _logger.error("Value not found for: %s", name) + tableItem.setText('-') + else: + tableItem.setText(str(value)) + + def _updateAllStats(self, is_request=False): + """Update stats for all rows in the table + + :param bool is_request: True if come from a manual request + """ + if self.getUpdateMode() is UpdateMode.MANUAL and not is_request: + return + with self._disableSorting(): + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + item = self._tableItemToItem(tableItem) + self._updateStats(item, data_changed=is_request) + + def _currentItemChanged(self, current, previous): + """Handle change of selection in table and sync plot selection + + :param QTableWidgetItem current: + :param QTableWidgetItem previous: + """ + if current and current.row() >= 0: + item = self._tableItemToItem(current) + self._plotWrapper.setCurrentItem(item) + + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + """Toggle display off all items or only the active/selected one + + :param bool displayOnlyActItem: + True if we want to only show active item + """ + if self._displayOnlyActItem == displayOnlyActItem: + return + self._dealWithPlotConnection(create=False) + if not self._displayOnlyActItem: + self.currentItemChanged.disconnect(self._currentItemChanged) + + _StatsWidgetBase.setDisplayOnlyActiveItem(self, displayOnlyActItem) + + self._updateItemObserve() + self._dealWithPlotConnection(create=True) + + if not self._displayOnlyActItem: + self.currentItemChanged.connect(self._currentItemChanged) + self.setSelectionMode(qt.QAbstractItemView.SingleSelection) + else: + self.setSelectionMode(qt.QAbstractItemView.NoSelection) + + def _updateModeHasChanged(self): + self.sigUpdateModeChanged.emit(self._updateMode) + + +class UpdateModeWidget(qt.QWidget): + """Widget used to select the mode of update""" + sigUpdateModeChanged = qt.Signal(object) + """signal emitted when the mode for update changed""" + sigUpdateRequested = qt.Signal() + """signal emitted when an manual request for example is activate""" + + def __init__(self, parent=None): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QHBoxLayout()) + self._buttonGrp = qt.QButtonGroup(parent=self) + self._buttonGrp.setExclusive(True) + + spacer = qt.QSpacerItem(20, 20, + qt.QSizePolicy.Expanding, + qt.QSizePolicy.Minimum) + self.layout().addItem(spacer) + + self._autoRB = qt.QRadioButton('auto', parent=self) + self.layout().addWidget(self._autoRB) + self._buttonGrp.addButton(self._autoRB) + + self._manualRB = qt.QRadioButton('manual', parent=self) + self.layout().addWidget(self._manualRB) + self._buttonGrp.addButton(self._manualRB) + self._manualRB.setChecked(True) + + refresh_icon = icons.getQIcon('view-refresh') + self._updatePB = qt.QPushButton(refresh_icon, '', parent=self) + self.layout().addWidget(self._updatePB) + + # connect signal / SLOT + self._updatePB.clicked.connect(self._updateRequested) + self._manualRB.toggled.connect(self._manualButtonToggled) + self._autoRB.toggled.connect(self._autoButtonToggled) + + def _manualButtonToggled(self, checked): + if checked: + self.setUpdateMode(UpdateMode.MANUAL) + self.sigUpdateModeChanged.emit(self.getUpdateMode()) + + def _autoButtonToggled(self, checked): + if checked: + self.setUpdateMode(UpdateMode.AUTO) + self.sigUpdateModeChanged.emit(self.getUpdateMode()) + + def _updateRequested(self): + if self.getUpdateMode() is UpdateMode.MANUAL: + self.sigUpdateRequested.emit() + + def setUpdateMode(self, mode): + """Set the way to update the displayed statistics. + + :param mode: mode requested for update + :type mode: Union[str,UpdateMode] + """ + mode = UpdateMode.from_value(mode) + + if mode is UpdateMode.AUTO: + if not self._autoRB.isChecked(): + self._autoRB.setChecked(True) + elif mode is UpdateMode.MANUAL: + if not self._manualRB.isChecked(): + self._manualRB.setChecked(True) + else: + raise ValueError('mode', mode, 'is not recognized') + + def getUpdateMode(self): + """Returns update mode (See :meth:`setUpdateMode`). + + :return: the active update mode + :rtype: UpdateMode + """ + if self._manualRB.isChecked(): + return UpdateMode.MANUAL + elif self._autoRB.isChecked(): + return UpdateMode.AUTO + else: + raise RuntimeError("No mode selected") + + def showRadioButtons(self, show): + """show / hide the QRadioButtons + + :param bool show: if True make RadioButton visible + """ + self._autoRB.setVisible(show) + self._manualRB.setVisible(show) + + +class _OptionsWidget(qt.QToolBar): + + def __init__(self, parent=None, updateMode=None, displayOnlyActItem=False): + assert updateMode is not None + qt.QToolBar.__init__(self, parent) + self.setIconSize(qt.QSize(16, 16)) + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-active-items")) + action.setText("Active items only") + action.setToolTip("Display stats for active items only.") + action.setCheckable(True) + action.setChecked(displayOnlyActItem) + self.__displayActiveItems = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-whole-items")) + action.setText("All items") + action.setToolTip("Display stats for all available items.") + action.setCheckable(True) + self.__displayWholeItems = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-visible-data")) + action.setText("Use the visible data range") + action.setToolTip("Use the visible data range.<br/>" + "If activated the data is filtered to only use" + "visible data of the plot." + "The filtering is a data sub-sampling." + "No interpolation is made to fit data to" + "boundaries.") + action.setCheckable(True) + self.__useVisibleData = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-whole-data")) + action.setText("Use the full data range") + action.setToolTip("Use the full data range.") + action.setCheckable(True) + action.setChecked(True) + self.__useWholeData = action + + self.addAction(self.__displayWholeItems) + self.addAction(self.__displayActiveItems) + self.addSeparator() + self.addAction(self.__useVisibleData) + self.addAction(self.__useWholeData) + + self.itemSelection = qt.QActionGroup(self) + self.itemSelection.setExclusive(True) + self.itemSelection.addAction(self.__displayActiveItems) + self.itemSelection.addAction(self.__displayWholeItems) + + self.dataRangeSelection = qt.QActionGroup(self) + self.dataRangeSelection.setExclusive(True) + self.dataRangeSelection.addAction(self.__useWholeData) + self.dataRangeSelection.addAction(self.__useVisibleData) + + self.__updateStatsAction = qt.QAction(self) + self.__updateStatsAction.setIcon(icons.getQIcon("view-refresh")) + self.__updateStatsAction.setText("update statistics") + self.__updateStatsAction.setToolTip("update statistics") + self.__updateStatsAction.setCheckable(False) + self._updateStatsSep = self.addSeparator() + self.addAction(self.__updateStatsAction) + + self._setUpdateMode(mode=updateMode) + + # expose API + self.sigUpdateStats = self.__updateStatsAction.triggered + + def isActiveItemMode(self): + return self.itemSelection.checkedAction() is self.__displayActiveItems + + def setDisplayActiveItems(self, only_active): + self.__displayActiveItems.setChecked(only_active) + self.__displayWholeItems.setChecked(not only_active) + + def isVisibleDataRangeMode(self): + return self.dataRangeSelection.checkedAction() is self.__useVisibleData + + def setVisibleDataRangeModeEnabled(self, enabled): + """Enable/Disable the visible data range mode + + :param bool enabled: True to allow user to choose + stats on visible data + """ + self.__useVisibleData.setEnabled(enabled) + if not enabled: + self.__useWholeData.setChecked(True) + + def _setUpdateMode(self, mode): + self.__updateStatsAction.setVisible(mode == UpdateMode.MANUAL) + self._updateStatsSep.setVisible(mode == UpdateMode.MANUAL) + + def getUpdateStatsAction(self): + """ + + :return: the action for the automatic mode + :rtype: QAction + """ + return self.__updateStatsAction + + +class StatsWidget(qt.QWidget): + """ + Widget displaying a set of :class:`Stat` to be displayed on a + :class:`StatsTable` and to be apply on items contained in the :class:`Plot` + Also contains options to: + + * compute statistics on all the data or on visible data only + * show statistics of all items or only the active one + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + """ + + sigVisibilityChanged = qt.Signal(bool) + """Signal emitted when the visibility of this widget changes. + + It Provides the visibility of the widget. + """ + + NUMBER_FORMAT = '{0:.3f}' + + def __init__(self, parent=None, plot=None, stats=None): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + self._options = _OptionsWidget(parent=self, updateMode=UpdateMode.MANUAL) + self.layout().addWidget(self._options) + self._statsTable = StatsTable(parent=self, plot=plot) + self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode()) + self._options._setUpdateMode(mode=self._statsTable.getUpdateMode()) + self.setStats(stats) + + self.layout().addWidget(self._statsTable) + + old = self._statsTable.blockSignals(True) + self._options.itemSelection.triggered.connect( + self._optSelectionChanged) + self._options.dataRangeSelection.triggered.connect( + self._optDataRangeChanged) + self._optDataRangeChanged() + self._statsTable.blockSignals(old) + + self._statsTable.sigUpdateModeChanged.connect(self._options._setUpdateMode) + callback = functools.partial(self._getStatsTable()._updateAllStats, is_request=True) + self._options.sigUpdateStats.connect(callback) + + def _getStatsTable(self): + """Returns the :class:`StatsTable` used by this widget. + + :rtype: StatsTable + """ + return self._statsTable + + def showEvent(self, event): + self.sigVisibilityChanged.emit(True) + qt.QWidget.showEvent(self, event) + + def hideEvent(self, event): + self.sigVisibilityChanged.emit(False) + qt.QWidget.hideEvent(self, event) + + def _optSelectionChanged(self, action=None): + self._getStatsTable().setDisplayOnlyActiveItem( + self._options.isActiveItemMode()) + + def _optDataRangeChanged(self, action=None): + self._getStatsTable().setStatsOnVisibleData( + self._options.isVisibleDataRangeMode()) + + # Proxy methods + + @docstring(StatsTable) + def setStats(self, statsHandler): + return self._getStatsTable().setStats(statsHandler=statsHandler) + + @docstring(StatsTable) + def setPlot(self, plot): + self._options.setVisibleDataRangeModeEnabled( + plot is None or isinstance(plot, PlotWidget)) + return self._getStatsTable().setPlot(plot=plot) + + @docstring(StatsTable) + def getPlot(self): + return self._getStatsTable().getPlot() + + @docstring(StatsTable) + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + old = self._options.blockSignals(True) + # update the options + self._options.setDisplayActiveItems(displayOnlyActItem) + self._options.blockSignals(old) + return self._getStatsTable().setDisplayOnlyActiveItem( + displayOnlyActItem=displayOnlyActItem) + + @docstring(StatsTable) + def setStatsOnVisibleData(self, b): + return self._getStatsTable().setStatsOnVisibleData(b=b) + + @docstring(StatsTable) + def getUpdateMode(self): + return self._statsTable.getUpdateMode() + + @docstring(StatsTable) + def setUpdateMode(self, mode): + self._statsTable.setUpdateMode(mode) + + +DEFAULT_STATS = StatsHandler(( + (statsmdl.StatMin(), StatFormatter()), + statsmdl.StatCoordMin(), + (statsmdl.StatMax(), StatFormatter()), + statsmdl.StatCoordMax(), + statsmdl.StatCOM(), + (('mean', numpy.mean), StatFormatter()), + (('std', numpy.std), StatFormatter()), +)) + + +class BasicStatsWidget(StatsWidget): + """ + Widget defining a simple set of :class:`Stat` to be displayed on a + :class:`StatsWidget`. + + :param QWidget parent: Qt parent + :param PlotWidget plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + + .. snapshotqt:: img/BasicStatsWidget.png + :width: 300px + :align: center + + from silx.gui.plot import Plot1D + from silx.gui.plot.StatsWidget import BasicStatsWidget + + plot = Plot1D() + x = range(100) + y = x + plot.addCurve(x, y, legend='curve_0') + plot.setActiveCurve('curve_0') + + widget = BasicStatsWidget(plot=plot) + widget.show() + """ + def __init__(self, parent=None, plot=None): + StatsWidget.__init__(self, parent=parent, plot=plot, + stats=DEFAULT_STATS) + + +class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): + """ + Widget made to display stats into a QLayout with couple (QLabel, QLineEdit) + created for each stats. + The layout can be defined prior of adding any statistic. + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param str kind: the kind of plotitems we want to display + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + """ + + sigUpdateModeChanged = qt.Signal(object) + """Signal emitted when the update mode changed""" + + def __init__(self, parent=None, plot=None, kind='curve', stats=None, + statsOnVisibleData=False): + self._item_kind = kind + """The item displayed""" + self._statQlineEdit = {} + """list of legends actually displayed""" + self._n_statistics_per_line = 4 + """number of statistics displayed per line in the grid layout""" + qt.QWidget.__init__(self, parent) + _StatsWidgetBase.__init__(self, + statsOnVisibleData=statsOnVisibleData, + displayOnlyActItem=True) + self.setLayout(self._createLayout()) + self.setPlot(plot) + if stats is not None: + self.setStats(stats) + + def _addItemForStatistic(self, statistic): + assert isinstance(statistic, statsmdl.StatBase) + assert statistic.name in self._statsHandler.stats + + self.layout().setSpacing(2) + self.layout().setContentsMargins(2, 2, 2, 2) + + if isinstance(self.layout(), qt.QGridLayout): + parent = self + else: + widget = qt.QWidget(parent=self) + parent = widget + + qLabel = qt.QLabel(statistic.name + ':', parent=parent) + qLineEdit = qt.QLineEdit('', parent=parent) + qLineEdit.setReadOnly(True) + + self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit) + self._statQlineEdit[statistic.name] = qLineEdit + + def setPlot(self, plot): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + _StatsWidgetBase.setPlot(self, plot) + self._updateAllStats() + + def _addStatsWidgetsToLayout(self, qLabel, qLineEdit): + raise NotImplementedError('Base class') + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. + + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using + """ + _StatsWidgetBase.setStats(self, statsHandler) + for statName, stat in list(self._statsHandler.stats.items()): + self._addItemForStatistic(stat) + self._updateAllStats() + + def _activeItemChanged(self, kind, previous, current): + if self.getUpdateMode() is UpdateMode.MANUAL: + return + if kind == self._item_kind: + self._updateAllStats() + + def _updateAllStats(self): + plot = self.getPlot() + if plot is not None: + _items = self._plotWrapper.getSelectedItems() + + def kind_filter(_item): + return self._plotWrapper.getKind(_item) == self.getKind() + items = list(filter(kind_filter, _items)) + assert len(items) in (0, 1) + if len(items) == 1: + self._setItem(items[0]) + + def setKind(self, kind): + """Change the kind of active item to display + :param str kind: kind of item to display information for ('curve' ...) + """ + if self._item_kind != kind: + self._item_kind = kind + self._updateItemObserve() + + def getKind(self): + """ + :return: kind of item we want to compute statistic for + :rtype: str + """ + return self._item_kind + + def _setItem(self, item, data_changed=True): + if item is None: + for stat_name, stat_widget in self._statQlineEdit.items(): + stat_widget.setText('') + elif (self._statsHandler is not None and len( + self._statsHandler.stats) > 0): + plot = self.getPlot() + if plot is not None: + statsValDict = self._statsHandler.calculate(item, + plot, + self._statsOnVisibleData, + data_changed=data_changed) + for statName, statVal in list(statsValDict.items()): + self._statQlineEdit[statName].setText(statVal) + + def _updateItemObserve(self, *argv): + if self.getUpdateMode() is UpdateMode.MANUAL: + return + assert self._displayOnlyActItem + _items = self._plotWrapper.getSelectedItems() + + def kind_filter(_item): + return self._plotWrapper.getKind(_item) == self.getKind() + items = list(filter(kind_filter, _items)) + assert len(items) in (0, 1) + _item = items[0] if len(items) == 1 else None + self._setItem(_item, data_changed=True) + + def _updateCurrentItem(self): + self._updateItemObserve() + + def _createLayout(self): + """create an instance of the main QLayout""" + raise NotImplementedError('Base class') + + def _addItem(self, item): + raise NotImplementedError('Display only the active item') + + def _removeItem(self, item): + raise NotImplementedError('Display only the active item') + + def _plotCurrentChanged(self, current): + raise NotImplementedError('Display only the active item') + + def _updateModeHasChanged(self): + self.sigUpdateModeChanged.emit(self._updateMode) + + +class _BasicLineStatsWidget(_BaseLineStatsWidget): + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False): + _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind, + plot=plot, stats=stats, + statsOnVisibleData=statsOnVisibleData) + + def _createLayout(self): + return FlowLayout() + + def _addStatsWidgetsToLayout(self, qLabel, qLineEdit): + # create a mother widget to make sure both qLabel & qLineEdit will + # always be displayed side by side + widget = qt.QWidget(parent=self) + widget.setLayout(qt.QHBoxLayout()) + widget.layout().setSpacing(0) + widget.layout().setContentsMargins(0, 0, 0, 0) + + widget.layout().addWidget(qLabel) + widget.layout().addWidget(qLineEdit) + + self.layout().addWidget(widget) + + def _addOptionsWidget(self, widget): + self.layout().addWidget(widget) + + +class BasicLineStatsWidget(qt.QWidget): + """ + Widget defining a simple set of :class:`Stat` to be displayed on a + :class:`LineStatsWidget`. + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param str kind: the kind of plotitems we want to display + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + """ + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QHBoxLayout()) + self.layout().setSpacing(0) + self.layout().setContentsMargins(0, 0, 0, 0) + self._lineStatsWidget = _BasicLineStatsWidget(parent=self, plot=plot, + kind=kind, stats=stats, + statsOnVisibleData=statsOnVisibleData) + self.layout().addWidget(self._lineStatsWidget) + + self._options = UpdateModeWidget() + self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode()) + self._options.showRadioButtons(False) + self.layout().addWidget(self._options) + + # connect Signal ? SLOT + self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode) + self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode) + self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats) + + def showControl(self, visible): + self._options.setVisible(visible) + + # Proxy methods + + @docstring(_BasicLineStatsWidget) + def setUpdateMode(self, mode): + self._lineStatsWidget.setUpdateMode(mode=mode) + + @docstring(_BasicLineStatsWidget) + def getUpdateMode(self): + return self._lineStatsWidget.getUpdateMode() + + @docstring(_BasicLineStatsWidget) + def setPlot(self, plot): + self._lineStatsWidget.setPlot(plot=plot) + + @docstring(_BasicLineStatsWidget) + def setStats(self, statsHandler): + self._lineStatsWidget.setStats(statsHandler=statsHandler) + + @docstring(_BasicLineStatsWidget) + def setKind(self, kind): + self._lineStatsWidget.setKind(kind=kind) + + @docstring(_BasicLineStatsWidget) + def getKind(self): + return self._lineStatsWidget.getKind() + + @docstring(_BasicLineStatsWidget) + def setStatsOnVisibleData(self, b): + self._lineStatsWidget.setStatsOnVisibleData(b) + + @docstring(UpdateModeWidget) + def showRadioButtons(self, show): + self._options.showRadioButtons(show=show) + + +class _BasicGridStatsWidget(_BaseLineStatsWidget): + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False, + statsPerLine=4): + _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind, + plot=plot, stats=stats, + statsOnVisibleData=statsOnVisibleData) + self._n_statistics_per_line = statsPerLine + + def _addStatsWidgetsToLayout(self, qLabel, qLineEdit): + column = len(self._statQlineEdit) % self._n_statistics_per_line + row = len(self._statQlineEdit) // self._n_statistics_per_line + self.layout().addWidget(qLabel, row, column * 2) + self.layout().addWidget(qLineEdit, row, column * 2 + 1) + + def _createLayout(self): + return qt.QGridLayout() + + +class BasicGridStatsWidget(qt.QWidget): + """ + pymca design like widget + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param str kind: the kind of plotitems we want to display + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + :param int statsPerLine: number of statistic to be displayed per line + + .. snapshotqt:: img/BasicGridStatsWidget.png + :width: 600px + :align: center + + from silx.gui.plot import Plot1D + from silx.gui.plot.StatsWidget import BasicGridStatsWidget + + plot = Plot1D() + x = range(100) + y = x + plot.addCurve(x, y, legend='curve_0') + plot.setActiveCurve('curve_0') + + widget = BasicGridStatsWidget(plot=plot, kind='curve') + widget.show() + """ + + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setSpacing(0) + self.layout().setContentsMargins(0, 0, 0, 0) + + self._options = UpdateModeWidget() + self._options.showRadioButtons(False) + self.layout().addWidget(self._options) + + self._lineStatsWidget = _BasicGridStatsWidget(parent=self, plot=plot, + kind=kind, stats=stats, + statsOnVisibleData=statsOnVisibleData) + self.layout().addWidget(self._lineStatsWidget) + + # tune options + self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode()) + + # connect Signal ? SLOT + self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode) + self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode) + self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats) + + def showControl(self, visible): + self._options.setVisible(visible) + + @docstring(_BasicGridStatsWidget) + def setUpdateMode(self, mode): + self._lineStatsWidget.setUpdateMode(mode=mode) + + @docstring(_BasicGridStatsWidget) + def getUpdateMode(self): + return self._lineStatsWidget.getUpdateMode() + + @docstring(_BasicGridStatsWidget) + def setPlot(self, plot): + self._lineStatsWidget.setPlot(plot=plot) + + @docstring(_BasicGridStatsWidget) + def setStats(self, statsHandler): + self._lineStatsWidget.setStats(statsHandler=statsHandler) + + @docstring(_BasicGridStatsWidget) + def setKind(self, kind): + self._lineStatsWidget.setKind(kind=kind) + + @docstring(_BasicGridStatsWidget) + def getKind(self): + return self._lineStatsWidget.getKind() + + @docstring(_BasicGridStatsWidget) + def setStatsOnVisibleData(self, b): + self._lineStatsWidget.setStatsOnVisibleData(b) + + @docstring(UpdateModeWidget) + def showRadioButtons(self, show): + self._options.showRadioButtons(show=show) diff --git a/src/silx/gui/plot/_BaseMaskToolsWidget.py b/src/silx/gui/plot/_BaseMaskToolsWidget.py new file mode 100644 index 0000000..407ab11 --- /dev/null +++ b/src/silx/gui/plot/_BaseMaskToolsWidget.py @@ -0,0 +1,1282 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module is a collection of base classes used in modules +:mod:`.MaskToolsWidget` (images) and :mod:`.ScatterMaskToolsWidget` +""" +from __future__ import division + +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "08/12/2020" + +import os +import weakref + +import numpy + +from silx.gui import qt, icons +from silx.gui.widgets.FloatEdit import FloatEdit +from silx.gui.colors import Colormap +from silx.gui.colors import rgba +from .actions.mode import PanModeAction + + +class BaseMask(qt.QObject): + """Base class for :class:`ImageMask` and :class:`ScatterMask` + + A mask field with update operations. + + A mask is an array of the same shape as some underlying data. The mask + array stores integer values in the range 0-255, to allow for 254 levels + of mask (value 0 is reserved for unmasked data). + + The mask is updated using spatial selection methods: data located inside + a selected area is masked with a specified mask level. + + """ + + sigChanged = qt.Signal() + """Signal emitted when the mask has changed""" + + sigStateChanged = qt.Signal() + """Signal emitted for each mask commit/undo/redo operation""" + + sigUndoable = qt.Signal(bool) + """Signal emitted when undo becomes possible/impossible""" + + sigRedoable = qt.Signal(bool) + """Signal emitted when redo becomes possible/impossible""" + + def __init__(self, dataItem=None): + self.historyDepth = 10 + """Maximum number of operation stored in history list for undo""" + # Init lists for undo/redo + self._history = [] + self._redo = [] + + # Store the mask + self._mask = numpy.array((), dtype=numpy.uint8) + + # Store the plot item to be masked + self._dataItem = None + if dataItem is not None: + self.setDataItem(dataItem) + self.reset(self.getDataValues().shape) + super(BaseMask, self).__init__() + + def setDataItem(self, item): + """Set a data item + + :param item: A plot item, subclass of :class:`silx.gui.plot.items.Item` + :return: + """ + self._dataItem = item + + def getDataItem(self): + """Returns current plot item the mask is on. + + :rtype: Union[~silx.gui.plot.items.Item,None] + """ + return self._dataItem + + def getDataValues(self): + """Return data values, as a numpy array with the same shape + as the mask. + + This method must be implemented in a subclass, as the way of + accessing data depends on the data item passed to :meth:`setDataItem` + + :return: Data values associated with the data item. + :rtype: numpy.ndarray + """ + raise NotImplementedError("To be implemented in subclass") + + def _notify(self): + """Notify of mask change.""" + self.sigChanged.emit() + + def getMask(self, copy=True): + """Get the current mask as a numpy array. + + :param bool copy: True (default) to get a copy of the mask. + If False, the returned array MUST not be modified. + :return: The array of the mask with dimension of the data to be masked. + :rtype: numpy.ndarray of uint8 + """ + return numpy.array(self._mask, copy=copy) + + def setMask(self, mask, copy=True): + """Set the mask to a new array. + + :param numpy.ndarray mask: The array to use for the mask. + :type mask: numpy.ndarray of uint8, C-contiguous. + Array of other types are converted. + :param bool copy: True (the default) to copy the array, + False to use it as is if possible. + """ + self._mask = numpy.array(mask, copy=copy, order='C', dtype=numpy.uint8) + self._notify() + + # History control + def resetHistory(self): + """Reset history""" + self._history = [numpy.array(self._mask, copy=True)] + self._redo = [] + self.sigUndoable.emit(False) + self.sigRedoable.emit(False) + + def commit(self): + """Append the current mask to history if changed""" + if (not self._history or self._redo or + not numpy.array_equal(self._mask, self._history[-1])): + if self._redo: + self._redo = [] # Reset redo as a new action as been performed + self.sigRedoable[bool].emit(False) + + while len(self._history) >= self.historyDepth: + self._history.pop(0) + self._history.append(numpy.array(self._mask, copy=True)) + + if len(self._history) == 2: + self.sigUndoable.emit(True) + self.sigStateChanged.emit() + + def undo(self): + """Restore previous mask if any""" + if len(self._history) > 1: + self._redo.append(self._history.pop()) + self._mask = numpy.array(self._history[-1], copy=True) + self._notify() # Do not store this change in history + + if len(self._redo) == 1: # First redo + self.sigRedoable.emit(True) + if len(self._history) == 1: # Last value in history + self.sigUndoable.emit(False) + self.sigStateChanged.emit() + + def redo(self): + """Restore previously undone modification if any""" + if self._redo: + self._mask = self._redo.pop() + self._history.append(numpy.array(self._mask, copy=True)) + self._notify() + + if not self._redo: # No more redo + self.sigRedoable.emit(False) + if len(self._history) == 2: # Something to undo + self.sigUndoable.emit(True) + self.sigStateChanged.emit() + + # Whole mask operations + + def clear(self, level): + """Set all values of the given mask level to 0. + + :param int level: Value of the mask to set to 0. + """ + assert 0 < level < 256 + self._mask[self._mask == level] = 0 + self._notify() + + def invert(self, level): + """Invert mask of the given mask level. + + 0 values become level and level values become 0. + + :param int level: The level to invert. + """ + assert 0 < level < 256 + masked = self._mask == level + self._mask[self._mask == 0] = level + self._mask[masked] = 0 + self._notify() + + def reset(self, shape=None): + """Reset the mask to zero and change its shape. + + :param shape: Shape of the new mask with the correct dimensionality + with regards to the data dimensionality, + or None to have an empty mask + :type shape: tuple of int + """ + if shape is None: + # assume dimensionality never changes + shape = (0,) * len(self._mask.shape) # empty array + shapeChanged = (shape != self._mask.shape) + self._mask = numpy.zeros(shape, dtype=numpy.uint8) + if shapeChanged: + self.resetHistory() + + self._notify() + + # To be implemented + def save(self, filename, kind): + """Save current mask in a file + + :param str filename: The file where to save to mask + :param str kind: The kind of file to save (e.g 'npy') + :raise Exception: Raised if the file writing fail + """ + raise NotImplementedError("To be implemented in subclass") + + # update thresholds + def updateStencil(self, level, stencil, mask=True): + """Mask/Unmask points from boolean mask: all elements that are True + in the boolean mask are set to ``level`` (if ``mask=True``) or 0 + (if ``mask=False``) + + :param int level: Mask level to update. + :param stencil: Boolean mask. + :type stencil: numpy.array of same dimension as the mask + :param bool mask: True to mask (default), False to unmask. + """ + if mask: + self._mask[stencil] = level + else: + self._mask[numpy.logical_and(self._mask == level, stencil)] = 0 + self._notify() + + def updateBelowThreshold(self, level, threshold, mask=True): + """Mask/unmask all points whose values are below a threshold. + + :param int level: + :param float threshold: Threshold + :param bool mask: True to mask (default), False to unmask. + """ + self.updateStencil(level, + self.getDataValues() < threshold, + mask) + + def updateBetweenThresholds(self, level, min_, max_, mask=True): + """Mask/unmask all points whose values are in a range. + + :param int level: + :param float min_: Lower threshold + :param float max_: Upper threshold + :param bool mask: True to mask (default), False to unmask. + """ + stencil = numpy.logical_and(min_ <= self.getDataValues(), + self.getDataValues() <= max_) + self.updateStencil(level, stencil, mask) + + def updateAboveThreshold(self, level, threshold, mask=True): + """Mask/unmask all points whose values are above a threshold. + + :param int level: Mask level to update. + :param float threshold: Threshold. + :param bool mask: True to mask (default), False to unmask. + """ + self.updateStencil(level, + self.getDataValues() > threshold, + mask) + + def updateNotFinite(self, level, mask=True): + """Mask/unmask all points whose values are not finite. + + :param int level: Mask level to update. + :param bool mask: True to mask (default), False to unmask. + """ + self.updateStencil(level, + numpy.logical_not(numpy.isfinite(self.getDataValues())), + mask) + + # Drawing operations: + def updateRectangle(self, level, row, col, height, width, mask=True): + """Mask/Unmask data inside a rectangle, with the given mask level. + + :param int level: Mask level to update, in range 1-255. + :param row: Starting row/y of the rectangle + :param col: Starting column/x of the rectangle + :param height: + :param width: + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + def updatePolygon(self, level, vertices, mask=True): + """Mask/Unmask data inside a polygon, with the given mask level. + + :param int level: Mask level to update. + :param vertices: Nx2 array of polygon corners as (row, col) / (y, x) + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + def updatePoints(self, level, rows, cols, mask=True): + """Mask/Unmask points with given coordinates. + + :param int level: Mask level to update. + :param rows: Rows/ordinates (y) of selected points + :type rows: 1D numpy.ndarray + :param cols: Columns/abscissa (x) of selected points + :type cols: 1D numpy.ndarray + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + def updateDisk(self, level, crow, ccol, radius, mask=True): + """Mask/Unmask data located inside a dick of the given mask level. + + :param int level: Mask level to update. + :param crow: Disk center row/ordinate (y). + :param ccol: Disk center column/abscissa. + :param float radius: Radius of the disk in mask array unit + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True): + """Mask/Unmask a disk of the given mask level. + + :param int level: Mask level to update. + :param int crow: Row of the center of the ellipse + :param int ccol: Column of the center of the ellipse + :param float radius_r: Radius of the ellipse in the row + :param float radius_c: Radius of the ellipse in the column + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + def updateLine(self, level, row0, col0, row1, col1, width, mask=True): + """Mask/Unmask a line of the given mask level. + + :param int level: Mask level to update. + :param row0: Row/y of the starting point. + :param col0: Column/x of the starting point. + :param row1: Row/y of the end point. + :param col1: Column/x of the end point. + :param width: Width of the line in mask array unit. + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + +class BaseMaskToolsWidget(qt.QWidget): + """Base class for :class:`MaskToolsWidget` (image mask) and + :class:`scatterMaskToolsWidget`""" + + sigMaskChanged = qt.Signal() + _maxLevelNumber = 255 + + def __init__(self, parent=None, plot=None, mask=None): + """ + + :param parent: Parent QWidget + :param plot: Plot widget on which to operate + :param mask: Instance of subclass of :class:`BaseMask` + (e.g. :class:`ImageMask`) + """ + super(BaseMaskToolsWidget, self).__init__(parent) + # register if the user as force a color for the corresponding mask level + self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=bool) + # overlays colors set by the user + self._overlayColors = numpy.zeros((self._maxLevelNumber + 1, 3), dtype=numpy.float32) + + # as parent have to be the first argument of the widget to fit + # QtDesigner need but here plot can't be None by default. + assert plot is not None + self._plotRef = weakref.ref(plot) + self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask + + self._colormap = Colormap(normalization='linear', + vmin=0, + vmax=self._maxLevelNumber) + self._defaultOverlayColor = rgba('gray') # Color of the mask + self._setMaskColors(1, 0.5) # Set the colormap LUT + + if not isinstance(mask, BaseMask): + raise TypeError("mask is not an instance of BaseMask") + self._mask = mask + + self._mask.sigChanged.connect(self._updatePlotMask) + self._mask.sigChanged.connect(self._emitSigMaskChanged) + + self._drawingMode = None # Store current drawing mode + self._lastPencilPos = None + self._multipleMasks = 'exclusive' + + self._maskFileDir = qt.QDir.home().absolutePath() + self.plot.sigInteractiveModeChanged.connect( + self._interactiveModeChanged) + + self._initWidgets() + + def _emitSigMaskChanged(self): + """Notify mask changes""" + self.sigMaskChanged.emit() + + def getMaskedItem(self): + """Returns the item that is currently being masked + + :rtype: Union[~silx.gui.plot.items.Item,None] + """ + return self._mask.getDataItem() + + def getSelectionMask(self, copy=True): + """Get the current mask as a numpy array. + + :param bool copy: True (default) to get a copy of the mask. + If False, the returned array MUST not be modified. + :return: The mask (as an array of uint8) with dimension of + the 'active' plot item. + If there is no active image or scatter, it returns None. + :rtype: Union[numpy.ndarray,None] + """ + mask = self._mask.getMask(copy=copy) + return None if mask.size == 0 else mask + + def setSelectionMask(self, mask): + """Set the mask: Must be implemented in subclass""" + raise NotImplementedError() + + def resetSelectionMask(self): + """Reset the mask: Must be implemented in subclass""" + raise NotImplementedError() + + def multipleMasks(self): + """Return the current mode of multiple masks support. + + See :meth:`setMultipleMasks` + """ + return self._multipleMasks + + def setMultipleMasks(self, mode): + """Set the mode of multiple masks support. + + Available modes: + + - 'single': Edit a single level of mask + - 'exclusive': Supports to 256 levels of non overlapping masks + + :param str mode: The mode to use + """ + assert mode in ('exclusive', 'single') + if mode != self._multipleMasks: + self._multipleMasks = mode + self._levelWidget.setVisible(self._multipleMasks != 'single') + self._clearAllBtn.setVisible(self._multipleMasks != 'single') + + def setMaskFileDirectory(self, path): + """Set the default directory to use by load/save GUI tools + + The directory is also updated by the user, if he change the location + of the dialog. + """ + self.maskFileDir = path + + def getMaskFileDirectory(self): + """Get the default directory used by load/save GUI tools""" + return self.maskFileDir + + @property + def maskFileDir(self): + """The directory from which to load/save mask from/to files.""" + if not os.path.isdir(self._maskFileDir): + self._maskFileDir = qt.QDir.home().absolutePath() + return self._maskFileDir + + @maskFileDir.setter + def maskFileDir(self, maskFileDir): + self._maskFileDir = str(maskFileDir) + + @property + def plot(self): + """The :class:`.PlotWindow` this widget is attached to.""" + plot = self._plotRef() + if plot is None: + raise RuntimeError( + 'Mask widget attached to a PlotWidget that no longer exists') + return plot + + def setDirection(self, direction=qt.QBoxLayout.LeftToRight): + """Set the direction of the layout of the widget + + :param direction: QBoxLayout direction + """ + self.layout().setDirection(direction) + + def _initWidgets(self): + """Create widgets""" + layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight) + layout.addWidget(self._initMaskGroupBox()) + layout.addWidget(self._initDrawGroupBox()) + layout.addWidget(self._initThresholdGroupBox()) + layout.addWidget(self._initOtherToolsGroupBox()) + layout.addStretch(1) + self.setLayout(layout) + + @staticmethod + def _hboxWidget(*widgets, **kwargs): + """Place widgets in widget with horizontal layout + + :param widgets: Widgets to position horizontally + :param bool stretch: True for trailing stretch (default), + False for no trailing stretch + :return: A QWidget with a QHBoxLayout + """ + stretch = kwargs.get('stretch', True) + + layout = qt.QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + for widget in widgets: + layout.addWidget(widget) + if stretch: + layout.addStretch(1) + widget = qt.QWidget() + widget.setLayout(layout) + return widget + + def _initTransparencyWidget(self): + """ Init the mask transparency widget """ + transparencyWidget = qt.QWidget(parent=self) + grid = qt.QGridLayout() + grid.setContentsMargins(0, 0, 0, 0) + self.transparencySlider = qt.QSlider(qt.Qt.Horizontal, parent=transparencyWidget) + self.transparencySlider.setRange(3, 10) + self.transparencySlider.setValue(8) + self.transparencySlider.setToolTip( + 'Set the transparency of the mask display') + self.transparencySlider.valueChanged.connect(self._updateColors) + grid.addWidget(qt.QLabel('Display:', parent=transparencyWidget), 0, 0) + grid.addWidget(self.transparencySlider, 0, 1, 1, 3) + grid.addWidget(qt.QLabel('<small><b>Transparent</b></small>', parent=transparencyWidget), 1, 1) + grid.addWidget(qt.QLabel('<small><b>Opaque</b></small>', parent=transparencyWidget), 1, 3) + transparencyWidget.setLayout(grid) + return transparencyWidget + + def _initMaskGroupBox(self): + """Init general mask operation widgets""" + + # Mask level + self.levelSpinBox = qt.QSpinBox() + self.levelSpinBox.setRange(1, self._maxLevelNumber) + self.levelSpinBox.setToolTip( + 'Choose which mask level is edited.\n' + 'A mask can have up to 255 non-overlapping levels.') + self.levelSpinBox.valueChanged[int].connect(self._updateColors) + self._levelWidget = self._hboxWidget(qt.QLabel('Mask level:'), + self.levelSpinBox) + # Transparency + self._transparencyWidget = self._initTransparencyWidget() + + style = qt.QApplication.style() + + def getIcon(*identifiyers): + for i in identifiyers: + if isinstance(i, str): + if qt.QIcon.hasThemeIcon(i): + return qt.QIcon.fromTheme(i) + elif isinstance(i, qt.QIcon): + return i + else: + return style.standardIcon(i) + return qt.QIcon() + + undoAction = qt.QAction(self) + undoAction.setText('Undo') + icon = getIcon("edit-undo", qt.QStyle.SP_ArrowBack) + undoAction.setIcon(icon) + undoAction.setShortcut(qt.QKeySequence.Undo) + undoAction.setToolTip('Undo last mask change <b>%s</b>' % + undoAction.shortcut().toString()) + self._mask.sigUndoable.connect(undoAction.setEnabled) + undoAction.triggered.connect(self._mask.undo) + + redoAction = qt.QAction(self) + redoAction.setText('Redo') + icon = getIcon("edit-redo", qt.QStyle.SP_ArrowForward) + redoAction.setIcon(icon) + redoAction.setShortcut(qt.QKeySequence.Redo) + redoAction.setToolTip('Redo last undone mask change <b>%s</b>' % + redoAction.shortcut().toString()) + self._mask.sigRedoable.connect(redoAction.setEnabled) + redoAction.triggered.connect(self._mask.redo) + + loadAction = qt.QAction(self) + loadAction.setText('Load...') + icon = icons.getQIcon("document-open") + loadAction.setIcon(icon) + loadAction.setToolTip('Load mask from file') + loadAction.triggered.connect(self._loadMask) + + saveAction = qt.QAction(self) + saveAction.setText('Save...') + icon = icons.getQIcon("document-save") + saveAction.setIcon(icon) + saveAction.setToolTip('Save mask to file') + saveAction.triggered.connect(self._saveMask) + + invertAction = qt.QAction(self) + invertAction.setText('Invert') + icon = icons.getQIcon("mask-invert") + invertAction.setIcon(icon) + invertAction.setShortcut(qt.Qt.CTRL + qt.Qt.Key_I) + invertAction.setToolTip('Invert current mask <b>%s</b>' % + invertAction.shortcut().toString()) + invertAction.triggered.connect(self._handleInvertMask) + + clearAction = qt.QAction(self) + clearAction.setText('Clear') + icon = icons.getQIcon("mask-clear") + clearAction.setIcon(icon) + clearAction.setShortcut(qt.QKeySequence.Delete) + clearAction.setToolTip('Clear current mask level <b>%s</b>' % + clearAction.shortcut().toString()) + clearAction.triggered.connect(self._handleClearMask) + + clearAllAction = qt.QAction(self) + clearAllAction.setText('Clear all') + icon = icons.getQIcon("mask-clear-all") + clearAllAction.setIcon(icon) + clearAllAction.setToolTip('Clear all mask levels') + clearAllAction.triggered.connect(self.resetSelectionMask) + + # Buttons group + margin1 = qt.QWidget(self) + margin1.setMinimumWidth(6) + margin2 = qt.QWidget(self) + margin2.setMinimumWidth(6) + + actions = (loadAction, saveAction, margin1, + undoAction, redoAction, margin2, + invertAction, clearAction, clearAllAction) + widgets = [] + for action in actions: + if isinstance(action, qt.QWidget): + widgets.append(action) + continue + btn = qt.QToolButton() + btn.setDefaultAction(action) + widgets.append(btn) + if action is clearAllAction: + self._clearAllBtn = btn + container = self._hboxWidget(*widgets) + container.layout().setSpacing(1) + + layout = qt.QVBoxLayout() + layout.addWidget(container) + layout.addWidget(self._levelWidget) + layout.addWidget(self._transparencyWidget) + layout.addStretch(1) + + maskGroup = qt.QGroupBox('Mask') + maskGroup.setLayout(layout) + return maskGroup + + def isMaskInteractionActivated(self): + """Returns true if any mask interaction is activated""" + return self.drawActionGroup.checkedAction() is not None + + def _initDrawGroupBox(self): + """Init drawing tools widgets""" + layout = qt.QVBoxLayout() + + self.browseAction = PanModeAction(self.plot, self.plot) + self.addAction(self.browseAction) + + # Draw tools + self.rectAction = qt.QAction(icons.getQIcon('shape-rectangle'), + 'Rectangle selection', + self) + self.rectAction.setToolTip( + 'Rectangle selection tool: (Un)Mask a rectangular region <b>R</b>') + self.rectAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R)) + self.rectAction.setCheckable(True) + self.rectAction.triggered.connect(self._activeRectMode) + self.addAction(self.rectAction) + + self.ellipseAction = qt.QAction(icons.getQIcon('shape-ellipse'), + 'Circle selection', + self) + self.ellipseAction.setToolTip( + 'Rectangle selection tool: (Un)Mask a circle region <b>R</b>') + self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R)) + self.ellipseAction.setCheckable(True) + self.ellipseAction.triggered.connect(self._activeEllipseMode) + self.addAction(self.ellipseAction) + + self.polygonAction = qt.QAction(icons.getQIcon('shape-polygon'), + 'Polygon selection', + self) + self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S)) + self.polygonAction.setToolTip( + 'Polygon selection tool: (Un)Mask a polygonal region <b>S</b><br>' + 'Left-click to place new polygon corners<br>' + 'Left-click on first corner to close the polygon') + self.polygonAction.setCheckable(True) + self.polygonAction.triggered.connect(self._activePolygonMode) + self.addAction(self.polygonAction) + + self.pencilAction = qt.QAction(icons.getQIcon('draw-pencil'), + 'Pencil tool', + self) + self.pencilAction.setShortcut(qt.QKeySequence(qt.Qt.Key_P)) + self.pencilAction.setToolTip( + 'Pencil tool: (Un)Mask using a pencil <b>P</b>') + self.pencilAction.setCheckable(True) + self.pencilAction.triggered.connect(self._activePencilMode) + self.addAction(self.pencilAction) + + self.drawActionGroup = qt.QActionGroup(self) + self.drawActionGroup.setExclusive(True) + self.drawActionGroup.addAction(self.rectAction) + self.drawActionGroup.addAction(self.ellipseAction) + self.drawActionGroup.addAction(self.polygonAction) + self.drawActionGroup.addAction(self.pencilAction) + + actions = (self.browseAction, self.rectAction, self.ellipseAction, + self.polygonAction, self.pencilAction) + drawButtons = [] + for action in actions: + btn = qt.QToolButton() + btn.setDefaultAction(action) + drawButtons.append(btn) + container = self._hboxWidget(*drawButtons) + layout.addWidget(container) + + # Mask/Unmask radio buttons + maskRadioBtn = qt.QRadioButton('Mask') + maskRadioBtn.setToolTip( + 'Drawing masks with current level. Press <b>Ctrl</b> to unmask') + maskRadioBtn.setChecked(True) + + unmaskRadioBtn = qt.QRadioButton('Unmask') + unmaskRadioBtn.setToolTip( + 'Drawing unmasks with current level. Press <b>Ctrl</b> to mask') + + self.maskStateGroup = qt.QButtonGroup() + self.maskStateGroup.addButton(maskRadioBtn, 1) + self.maskStateGroup.addButton(unmaskRadioBtn, 0) + + self.maskStateWidget = self._hboxWidget(maskRadioBtn, unmaskRadioBtn) + layout.addWidget(self.maskStateWidget) + + self.maskStateWidget.setHidden(True) + + # Pencil settings + self.pencilSetting = self._createPencilSettings(None) + self.pencilSetting.setVisible(False) + layout.addWidget(self.pencilSetting) + + layout.addStretch(1) + + drawGroup = qt.QGroupBox('Draw tools') + drawGroup.setLayout(layout) + return drawGroup + + def _createPencilSettings(self, parent=None): + pencilSetting = qt.QWidget(parent) + + self.pencilSpinBox = qt.QSpinBox(parent=pencilSetting) + self.pencilSpinBox.setRange(1, 1024) + pencilToolTip = """Set pencil drawing tool size in pixels of the image + on which to make the mask.""" + self.pencilSpinBox.setToolTip(pencilToolTip) + + self.pencilSlider = qt.QSlider(qt.Qt.Horizontal, parent=pencilSetting) + self.pencilSlider.setRange(1, 50) + self.pencilSlider.setToolTip(pencilToolTip) + + pencilLabel = qt.QLabel('Pencil size:', parent=pencilSetting) + + layout = qt.QGridLayout() + layout.addWidget(pencilLabel, 0, 0) + layout.addWidget(self.pencilSpinBox, 0, 1) + layout.addWidget(self.pencilSlider, 1, 1) + pencilSetting.setLayout(layout) + + self.pencilSpinBox.valueChanged.connect(self._pencilWidthChanged) + self.pencilSlider.valueChanged.connect(self._pencilWidthChanged) + + return pencilSetting + + def _initThresholdGroupBox(self): + """Init thresholding widgets""" + + self.belowThresholdAction = qt.QAction(icons.getQIcon('plot-roi-below'), + 'Mask below threshold', + self) + self.belowThresholdAction.setToolTip( + 'Mask image where values are below given threshold') + self.belowThresholdAction.setCheckable(True) + self.belowThresholdAction.setChecked(True) + + self.betweenThresholdAction = qt.QAction(icons.getQIcon('plot-roi-between'), + 'Mask within range', + self) + self.betweenThresholdAction.setToolTip( + 'Mask image where values are within given range') + self.betweenThresholdAction.setCheckable(True) + + self.aboveThresholdAction = qt.QAction(icons.getQIcon('plot-roi-above'), + 'Mask above threshold', + self) + self.aboveThresholdAction.setToolTip( + 'Mask image where values are above given threshold') + self.aboveThresholdAction.setCheckable(True) + + self.thresholdActionGroup = qt.QActionGroup(self) + self.thresholdActionGroup.setExclusive(True) + self.thresholdActionGroup.addAction(self.belowThresholdAction) + self.thresholdActionGroup.addAction(self.betweenThresholdAction) + self.thresholdActionGroup.addAction(self.aboveThresholdAction) + self.thresholdActionGroup.triggered.connect( + self._thresholdActionGroupTriggered) + + self.loadColormapRangeAction = qt.QAction(icons.getQIcon('view-refresh'), + 'Set min-max from colormap', + self) + self.loadColormapRangeAction.setToolTip( + 'Set min and max values from current colormap range') + self.loadColormapRangeAction.setCheckable(False) + self.loadColormapRangeAction.triggered.connect( + self._loadRangeFromColormapTriggered) + + widgets = [] + for action in self.thresholdActionGroup.actions(): + btn = qt.QToolButton() + btn.setDefaultAction(action) + widgets.append(btn) + + spacer = qt.QWidget(parent=self) + spacer.setSizePolicy(qt.QSizePolicy.Expanding, + qt.QSizePolicy.Preferred) + widgets.append(spacer) + + loadColormapRangeBtn = qt.QToolButton() + loadColormapRangeBtn.setDefaultAction(self.loadColormapRangeAction) + widgets.append(loadColormapRangeBtn) + + toolBar = self._hboxWidget(*widgets, stretch=False) + + config = qt.QGridLayout() + config.setContentsMargins(0, 0, 0, 0) + + self.minLineLabel = qt.QLabel("Min:", self) + self.minLineEdit = FloatEdit(self, value=0) + config.addWidget(self.minLineLabel, 0, 0) + config.addWidget(self.minLineEdit, 0, 1) + + self.maxLineLabel = qt.QLabel("Max:", self) + self.maxLineEdit = FloatEdit(self, value=0) + config.addWidget(self.maxLineLabel, 1, 0) + config.addWidget(self.maxLineEdit, 1, 1) + + self.applyMaskBtn = qt.QPushButton('Apply mask') + self.applyMaskBtn.clicked.connect(self._maskBtnClicked) + + layout = qt.QVBoxLayout() + layout.addWidget(toolBar) + layout.addLayout(config) + layout.addWidget(self.applyMaskBtn) + layout.addStretch(1) + + self.thresholdGroup = qt.QGroupBox('Threshold') + self.thresholdGroup.setLayout(layout) + + # Init widget state + self._thresholdActionGroupTriggered(self.belowThresholdAction) + return self.thresholdGroup + + # track widget visibility and plot active image changes + + def _initOtherToolsGroupBox(self): + layout = qt.QVBoxLayout() + + self.maskNanBtn = qt.QPushButton('Mask not finite values') + self.maskNanBtn.setToolTip('Mask Not a Number and infinite values') + self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked) + layout.addWidget(self.maskNanBtn) + layout.addStretch(1) + + self.otherToolGroup = qt.QGroupBox('Other tools') + self.otherToolGroup.setLayout(layout) + return self.otherToolGroup + + def changeEvent(self, event): + """Reset drawing action when disabling widget""" + if (event.type() == qt.QEvent.EnabledChange and + not self.isEnabled() and + self.drawActionGroup.checkedAction()): + # Disable drawing tool by setting interaction to zoom + self.browseAction.trigger() + + def save(self, filename, kind): + """Save current mask in a file + + :param str filename: The file where to save to mask + :param str kind: The kind of file to save in 'edf', 'tif', 'npy' + :raise Exception: Raised if the process fails + """ + self._mask.save(filename, kind) + + def getCurrentMaskColor(self): + """Returns the color of the current selected level. + + :rtype: A tuple or a python array + """ + currentLevel = self.levelSpinBox.value() + if self._defaultColors[currentLevel]: + return self._defaultOverlayColor + else: + return self._overlayColors[currentLevel].tolist() + + def _setMaskColors(self, level, alpha): + """Set-up the mask colormap to highlight current mask level. + + :param int level: The mask level to highlight + :param float alpha: Alpha level of mask in [0., 1.] + """ + assert 0 < level <= self._maxLevelNumber + + colors = numpy.empty((self._maxLevelNumber + 1, 4), dtype=numpy.float32) + + # Set color + colors[:,:3] = self._defaultOverlayColor[:3] + + # check if some colors has been directly set by the user + mask = numpy.equal(self._defaultColors, False) + colors[mask,:3] = self._overlayColors[mask,:3] + + # Set alpha + colors[:, -1] = alpha / 2. + + # Set highlighted level color + colors[level, 3] = alpha + + # Set no mask level + colors[0] = (0., 0., 0., 0.) + + self._colormap.setColormapLUT(colors) + + def resetMaskColors(self, level=None): + """Reset the mask color at the given level to be defaultColors + + :param level: + The index of the mask for which we want to reset the color. + If none we will reset color for all masks. + """ + if level is None: + self._defaultColors[level] = True + else: + self._defaultColors[:] = True + + self._updateColors() + + def setMaskColors(self, rgb, level=None): + """Set the masks color + + :param rgb: The rgb color + :param level: + The index of the mask for which we want to change the color. + If none set this color for all the masks + """ + rgb = rgba(rgb)[0:3] + if level is None: + self._overlayColors[:] = rgb + self._defaultColors[:] = False + else: + self._overlayColors[level] = rgb + self._defaultColors[level] = False + + self._updateColors() + + def getMaskColors(self): + """masks colors getter""" + return self._overlayColors + + def _updateColors(self, *args): + """Rebuild mask colormap when selected level or transparency change""" + self._setMaskColors(self.levelSpinBox.value(), + self.transparencySlider.value() / + self.transparencySlider.maximum()) + self._updatePlotMask() + self._updateInteractiveMode() + + def _pencilWidthChanged(self, width): + + old = self.pencilSpinBox.blockSignals(True) + try: + self.pencilSpinBox.setValue(width) + finally: + self.pencilSpinBox.blockSignals(old) + + old = self.pencilSlider.blockSignals(True) + try: + self.pencilSlider.setValue(width) + finally: + self.pencilSlider.blockSignals(old) + self._updateInteractiveMode() + + def _updateInteractiveMode(self): + """Update the current mode to the same if some cached data have to be + updated. It is the case for the color for example. + """ + if self._drawingMode == 'rectangle': + self._activeRectMode() + elif self._drawingMode == 'ellipse': + self._activeEllipseMode() + elif self._drawingMode == 'polygon': + self._activePolygonMode() + elif self._drawingMode == 'pencil': + self._activePencilMode() + + def _handleClearMask(self): + """Handle clear button clicked: reset current level mask""" + self._mask.clear(self.levelSpinBox.value()) + self._mask.commit() + + def _handleInvertMask(self): + """Invert the current mask level selection.""" + self._mask.invert(self.levelSpinBox.value()) + self._mask.commit() + + # Handle drawing tools UI events + + def _interactiveModeChanged(self, source): + """Handle plot interactive mode changed: + + If changed from elsewhere, disable drawing tool + """ + if source is not self: + self.pencilAction.setChecked(False) + self.rectAction.setChecked(False) + self.polygonAction.setChecked(False) + self._releaseDrawingMode() + self._updateDrawingModeWidgets() + + def _releaseDrawingMode(self): + """Release the drawing mode if is was used""" + if self._drawingMode is None: + return + self.plot.sigPlotSignal.disconnect(self._plotDrawEvent) + self._drawingMode = None + + def _activeRectMode(self): + """Handle rect action mode triggering""" + self._releaseDrawingMode() + self._drawingMode = 'rectangle' + self.plot.sigPlotSignal.connect(self._plotDrawEvent) + color = self.getCurrentMaskColor() + self.plot.setInteractiveMode( + 'draw', shape='rectangle', source=self, color=color) + self._updateDrawingModeWidgets() + + def _activeEllipseMode(self): + """Handle circle action mode triggering""" + self._releaseDrawingMode() + self._drawingMode = 'ellipse' + self.plot.sigPlotSignal.connect(self._plotDrawEvent) + color = self.getCurrentMaskColor() + self.plot.setInteractiveMode( + 'draw', shape='ellipse', source=self, color=color) + self._updateDrawingModeWidgets() + + def _activePolygonMode(self): + """Handle polygon action mode triggering""" + self._releaseDrawingMode() + self._drawingMode = 'polygon' + self.plot.sigPlotSignal.connect(self._plotDrawEvent) + color = self.getCurrentMaskColor() + self.plot.setInteractiveMode('draw', shape='polygon', source=self, color=color) + self._updateDrawingModeWidgets() + + def _getPencilWidth(self): + """Returns the width of the pencil to use in data coordinates` + + :rtype: float + """ + return self.pencilSpinBox.value() + + def _activePencilMode(self): + """Handle pencil action mode triggering""" + self._releaseDrawingMode() + self._drawingMode = 'pencil' + self.plot.sigPlotSignal.connect(self._plotDrawEvent) + color = self.getCurrentMaskColor() + width = self._getPencilWidth() + self.plot.setInteractiveMode( + 'draw', shape='pencil', source=self, color=color, width=width) + self._updateDrawingModeWidgets() + + def _updateDrawingModeWidgets(self): + self.maskStateWidget.setVisible(self._drawingMode is not None) + self.pencilSetting.setVisible(self._drawingMode == 'pencil') + + # Handle plot drawing events + + def _isMasking(self): + """Returns true if the tool is used for masking, else it is used for + unmasking. + + :rtype: bool""" + # First draw event, use current modifiers for all draw sequence + doMask = (self.maskStateGroup.checkedId() == 1) + if qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier: + doMask = not doMask + return doMask + + # Handle threshold UI events + + def _thresholdActionGroupTriggered(self, triggeredAction): + """Threshold action group listener.""" + if triggeredAction is self.belowThresholdAction: + self.minLineLabel.setVisible(True) + self.maxLineLabel.setVisible(False) + self.minLineEdit.setVisible(True) + self.maxLineEdit.setVisible(False) + self.applyMaskBtn.setText("Mask below") + elif triggeredAction is self.betweenThresholdAction: + self.minLineLabel.setVisible(True) + self.maxLineLabel.setVisible(True) + self.minLineEdit.setVisible(True) + self.maxLineEdit.setVisible(True) + self.applyMaskBtn.setText("Mask between") + elif triggeredAction is self.aboveThresholdAction: + self.minLineLabel.setVisible(False) + self.maxLineLabel.setVisible(True) + self.minLineEdit.setVisible(False) + self.maxLineEdit.setVisible(True) + self.applyMaskBtn.setText("Mask above") + self.applyMaskBtn.setToolTip(triggeredAction.toolTip()) + + def _maskBtnClicked(self): + if self.belowThresholdAction.isChecked(): + if self.minLineEdit.text(): + self._mask.updateBelowThreshold(self.levelSpinBox.value(), + self.minLineEdit.value()) + self._mask.commit() + + elif self.betweenThresholdAction.isChecked(): + if self.minLineEdit.text() and self.maxLineEdit.text(): + min_ = self.minLineEdit.value() + max_ = self.maxLineEdit.value() + self._mask.updateBetweenThresholds(self.levelSpinBox.value(), + min_, max_) + self._mask.commit() + + elif self.aboveThresholdAction.isChecked(): + if self.maxLineEdit.text(): + max_ = float(self.maxLineEdit.value()) + self._mask.updateAboveThreshold(self.levelSpinBox.value(), + max_) + self._mask.commit() + + def _maskNotFiniteBtnClicked(self): + """Handle not finite mask button clicked: mask NaNs and inf""" + self._mask.updateNotFinite( + self.levelSpinBox.value()) + self._mask.commit() + + +class BaseMaskToolsDockWidget(qt.QDockWidget): + """Base class for :class:`MaskToolsWidget` and + :class:`ScatterMaskToolsWidget`. + + For integration in a :class:`PlotWindow`. + + :param parent: See :class:`QDockWidget` + :paran str name: The title of this widget + """ + + sigMaskChanged = qt.Signal() + + def __init__(self, parent=None, name='Mask', widget=None): + super(BaseMaskToolsDockWidget, self).__init__(parent) + self.setWindowTitle(name) + + if not isinstance(widget, BaseMaskToolsWidget): + raise TypeError("BaseMaskToolsDockWidget requires a MaskToolsWidget") + self.setWidget(widget) + self.widget().sigMaskChanged.connect(self._emitSigMaskChanged) + + self.layout().setContentsMargins(0, 0, 0, 0) + self.dockLocationChanged.connect(self._dockLocationChanged) + self.topLevelChanged.connect(self._topLevelChanged) + + def _emitSigMaskChanged(self): + """Notify mask changes""" + # must be connected to self.widget().sigMaskChanged in child class + self.sigMaskChanged.emit() + + def getSelectionMask(self, copy=True): + """Get the current mask as a 2D array. + + :param bool copy: True (default) to get a copy of the mask. + If False, the returned array MUST not be modified. + :return: The array of the mask with dimension of the 'active' image. + If there is no active image, an empty array is returned. + :rtype: 2D numpy.ndarray of uint8 + """ + return self.widget().getSelectionMask(copy=copy) + + def setSelectionMask(self, mask, copy=True): + """Set the mask to a new array. + + :param numpy.ndarray mask: The array to use for the mask. + :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous. + Array of other types are converted. + :param bool copy: True (the default) to copy the array, + False to use it as is if possible. + :return: None if failed, shape of mask as 2-tuple if successful. + The mask can be cropped or padded to fit active image, + the returned shape is that of the active image. + """ + return self.widget().setSelectionMask(mask, copy=copy) + + def resetSelectionMask(self): + """Reset the mask to an array of zeros with the shape of the + current data.""" + self.widget().resetSelectionMask() + + def toggleViewAction(self): + """Returns a checkable action that shows or closes this widget. + + See :class:`QMainWindow`. + """ + action = super(BaseMaskToolsDockWidget, self).toggleViewAction() + action.setIcon(icons.getQIcon('image-mask')) + action.setToolTip("Display/hide mask tools") + return action + + def _dockLocationChanged(self, area): + if area in (qt.Qt.LeftDockWidgetArea, qt.Qt.RightDockWidgetArea): + direction = qt.QBoxLayout.TopToBottom + else: + direction = qt.QBoxLayout.LeftToRight + self.widget().setDirection(direction) + + def _topLevelChanged(self, topLevel): + if topLevel: + self.widget().setDirection(qt.QBoxLayout.LeftToRight) + self.resize(self.widget().minimumSize()) + self.adjustSize() + + def showEvent(self, event): + """Make sure this widget is raised when it is shown + (when it is first created as a tab in PlotWindow or when it is shown + again after hiding). + """ + self.raise_() diff --git a/src/silx/gui/plot/__init__.py b/src/silx/gui/plot/__init__.py new file mode 100644 index 0000000..3a141b3 --- /dev/null +++ b/src/silx/gui/plot/__init__.py @@ -0,0 +1,71 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a set of Qt widgets for plotting curves and images. + +The plotting API is inherited from the `PyMca <http://pymca.sourceforge.net/>`_ +plot API and is mostly compatible with it. + +Those widgets supports interaction (e.g., zoom, pan, selections). + +List of Qt widgets: + +.. currentmodule:: silx.gui.plot + +- :mod:`.PlotWidget`: A widget displaying a single plot. +- :mod:`.PlotWindow`: A :mod:`.PlotWidget` with a configurable set of tools. +- :class:`.Plot1D`: A widget with tools for curves. +- :class:`.Plot2D`: A widget with tools for images. +- :class:`.ScatterView`: A widget with tools for scatter plot. +- :class:`.ImageView`: A widget with tools for images and a side histogram. +- :class:`.StackView`: A widget with tools for a stack of images. + +By default, those widget are using matplotlib_. +They can optionally use a faster OpenGL-based rendering (beta feature), +which is enabled by setting the ``backend`` argument to ``'gl'`` +when creating the widgets (See :class:`.PlotWidget`). + +.. note:: + + This package depends on matplotlib_. + The OpenGL backend further depends on + `PyOpenGL <http://pyopengl.sourceforge.net/>`_ and OpenGL >= 2.1. + +.. _matplotlib: http://matplotlib.org/ +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/05/2017" + + +from .PlotWidget import PlotWidget # noqa +from .PlotWindow import PlotWindow, Plot1D, Plot2D # noqa +from .items.axis import TickMode +from .ImageView import ImageView # noqa +from .StackView import StackView # noqa +from .ScatterView import ScatterView # noqa + +__all__ = ['ImageView', 'PlotWidget', 'PlotWindow', 'Plot1D', 'Plot2D', + 'StackView', 'ScatterView', 'TickMode'] diff --git a/src/silx/gui/plot/_utils/__init__.py b/src/silx/gui/plot/_utils/__init__.py new file mode 100644 index 0000000..ed87b18 --- /dev/null +++ b/src/silx/gui/plot/_utils/__init__.py @@ -0,0 +1,92 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Miscellaneous utility functions for the Plot""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/03/2017" + + +import numpy + +from .panzoom import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX +from .panzoom import applyZoomToPlot, applyPan, checkAxisLimits + + +def addMarginsToLimits(margins, isXLog, isYLog, + xMin, xMax, yMin, yMax, y2Min=None, y2Max=None): + """Returns updated limits by extending them with margins. + + :param margins: The ratio of the margins to add or None for no margins. + :type margins: A 4-tuple of floats as + (xMinMargin, xMaxMargin, yMinMargin, yMaxMargin) + + :return: The updated limits + :rtype: tuple of 4 or 6 floats: Either (xMin, xMax, yMin, yMax) or + (xMin, xMax, yMin, yMax, y2Min, y2Max) if y2Min and y2Max + are provided. + """ + if margins is not None: + xMinMargin, xMaxMargin, yMinMargin, yMaxMargin = margins + + if not isXLog: + xRange = xMax - xMin + xMin -= xMinMargin * xRange + xMax += xMaxMargin * xRange + + elif xMin > 0. and xMax > 0.: # Log scale + # Do not apply margins if limits < 0 + xMinLog, xMaxLog = numpy.log10(xMin), numpy.log10(xMax) + xRangeLog = xMaxLog - xMinLog + xMin = pow(10., xMinLog - xMinMargin * xRangeLog) + xMax = pow(10., xMaxLog + xMaxMargin * xRangeLog) + + if not isYLog: + yRange = yMax - yMin + yMin -= yMinMargin * yRange + yMax += yMaxMargin * yRange + elif yMin > 0. and yMax > 0.: # Log scale + # Do not apply margins if limits < 0 + yMinLog, yMaxLog = numpy.log10(yMin), numpy.log10(yMax) + yRangeLog = yMaxLog - yMinLog + yMin = pow(10., yMinLog - yMinMargin * yRangeLog) + yMax = pow(10., yMaxLog + yMaxMargin * yRangeLog) + + if y2Min is not None and y2Max is not None: + if not isYLog: + yRange = y2Max - y2Min + y2Min -= yMinMargin * yRange + y2Max += yMaxMargin * yRange + elif y2Min > 0. and y2Max > 0.: # Log scale + # Do not apply margins if limits < 0 + yMinLog, yMaxLog = numpy.log10(y2Min), numpy.log10(y2Max) + yRangeLog = yMaxLog - yMinLog + y2Min = pow(10., yMinLog - yMinMargin * yRangeLog) + y2Max = pow(10., yMaxLog + yMaxMargin * yRangeLog) + + if y2Min is None or y2Max is None: + return xMin, xMax, yMin, yMax + else: + return xMin, xMax, yMin, yMax, y2Min, y2Max diff --git a/src/silx/gui/plot/_utils/delaunay.py b/src/silx/gui/plot/_utils/delaunay.py new file mode 100644 index 0000000..49ad05f --- /dev/null +++ b/src/silx/gui/plot/_utils/delaunay.py @@ -0,0 +1,62 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Wrapper over Delaunay implementation""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/05/2019" + + +import logging +import sys + +import numpy + + +_logger = logging.getLogger(__name__) + + +def delaunay(x, y): + """Returns Delaunay instance for x, y points + + :param numpy.ndarray x: + :param numpy.ndarray y: + :rtype: Union[None,scipy.spatial.Delaunay] + """ + # Lazy-loading of Delaunay + try: + from scipy.spatial import Delaunay as _Delaunay + except ImportError: # Fallback using local Delaunay + from silx.third_party.scipy_spatial import Delaunay as _Delaunay + + points = numpy.array((x, y)).T + try: + delaunay = _Delaunay(points) + except (RuntimeError, ValueError): + _logger.error("Delaunay tesselation failed: %s", + sys.exc_info()[1]) + delaunay = None + + return delaunay diff --git a/src/silx/gui/plot/_utils/dtime_ticklayout.py b/src/silx/gui/plot/_utils/dtime_ticklayout.py new file mode 100644 index 0000000..ebf775b --- /dev/null +++ b/src/silx/gui/plot/_utils/dtime_ticklayout.py @@ -0,0 +1,442 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module implements date-time labels layout on graph axes.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["P. Kenter"] +__license__ = "MIT" +__date__ = "04/04/2018" + + +import datetime as dt +import enum +import logging +import math +import time + +import dateutil.tz + +from dateutil.relativedelta import relativedelta + +from .ticklayout import niceNumGeneric + +_logger = logging.getLogger(__name__) + + +MICROSECONDS_PER_SECOND = 1000000 +SECONDS_PER_MINUTE = 60 +SECONDS_PER_HOUR = 60 * SECONDS_PER_MINUTE +SECONDS_PER_DAY = 24 * SECONDS_PER_HOUR +SECONDS_PER_YEAR = 365.25 * SECONDS_PER_DAY +SECONDS_PER_MONTH_AVERAGE = SECONDS_PER_YEAR / 12 # Seconds per average month + + +# No dt.timezone in Python 2.7 so we use dateutil.tz.tzutc +_EPOCH = dt.datetime(1970, 1, 1, tzinfo=dateutil.tz.tzutc()) + +def timestamp(dtObj): + """ Returns POSIX timestamp of a datetime objects. + + If the dtObj object has a timestamp() method (python 3.3), this is + used. Otherwise (e.g. python 2.7) it is calculated here. + + The POSIX timestamp is a floating point value of the number of seconds + since the start of an epoch (typically 1970-01-01). For details see: + https://docs.python.org/3/library/datetime.html#datetime.datetime.timestamp + + :param datetime.datetime dtObj: date-time representation. + :return: POSIX timestamp + :rtype: float + """ + if hasattr(dtObj, "timestamp"): + return dtObj.timestamp() + else: + # Back ported from Python 3.5 + if dtObj.tzinfo is None: + return time.mktime((dtObj.year, dtObj.month, dtObj.day, + dtObj.hour, dtObj.minute, dtObj.second, + -1, -1, -1)) + dtObj.microsecond / 1e6 + else: + return (dtObj - _EPOCH).total_seconds() + + +@enum.unique +class DtUnit(enum.Enum): + YEARS = 0 + MONTHS = 1 + DAYS = 2 + HOURS = 3 + MINUTES = 4 + SECONDS = 5 + MICRO_SECONDS = 6 # a fraction of a second + + +def getDateElement(dateTime, unit): + """ Picks the date element with the unit from the dateTime + + E.g. getDateElement(datetime(1970, 5, 6), DtUnit.Day) will return 6 + + :param datetime dateTime: date/time to pick from + :param DtUnit unit: The unit describing the date element. + """ + if unit == DtUnit.YEARS: + return dateTime.year + elif unit == DtUnit.MONTHS: + return dateTime.month + elif unit == DtUnit.DAYS: + return dateTime.day + elif unit == DtUnit.HOURS: + return dateTime.hour + elif unit == DtUnit.MINUTES: + return dateTime.minute + elif unit == DtUnit.SECONDS: + return dateTime.second + elif unit == DtUnit.MICRO_SECONDS: + return dateTime.microsecond + else: + raise ValueError("Unexpected DtUnit: {}".format(unit)) + + +def setDateElement(dateTime, value, unit): + """ Returns a copy of dateTime with the tickStep unit set to value + + :param datetime.datetime: date time object + :param int value: value to set + :param DtUnit unit: unit + :return: datetime.datetime + """ + intValue = int(value) + _logger.debug("setDateElement({}, {} (int={}), {})" + .format(dateTime, value, intValue, unit)) + + year = dateTime.year + month = dateTime.month + day = dateTime.day + hour = dateTime.hour + minute = dateTime.minute + second = dateTime.second + microsecond = dateTime.microsecond + + if unit == DtUnit.YEARS: + year = intValue + elif unit == DtUnit.MONTHS: + month = intValue + elif unit == DtUnit.DAYS: + day = intValue + elif unit == DtUnit.HOURS: + hour = intValue + elif unit == DtUnit.MINUTES: + minute = intValue + elif unit == DtUnit.SECONDS: + second = intValue + elif unit == DtUnit.MICRO_SECONDS: + microsecond = intValue + else: + raise ValueError("Unexpected DtUnit: {}".format(unit)) + + _logger.debug("creating date time {}" + .format((year, month, day, hour, minute, second, microsecond))) + + return dt.datetime(year, month, day, hour, minute, second, microsecond, + tzinfo=dateTime.tzinfo) + + + +def roundToElement(dateTime, unit): + """ Returns a copy of dateTime rounded to given unit + + :param datetime.datetime: date time object + :param DtUnit unit: unit + :return: datetime.datetime + """ + year = dateTime.year + month = dateTime.month + day = dateTime.day + hour = dateTime.hour + minute = dateTime.minute + second = dateTime.second + microsecond = dateTime.microsecond + + if unit.value < DtUnit.YEARS.value: + pass # Never round years + if unit.value < DtUnit.MONTHS.value: + month = 1 + if unit.value < DtUnit.DAYS.value: + day = 1 + if unit.value < DtUnit.HOURS.value: + hour = 0 + if unit.value < DtUnit.MINUTES.value: + minute = 0 + if unit.value < DtUnit.SECONDS.value: + second = 0 + if unit.value < DtUnit.MICRO_SECONDS.value: + microsecond = 0 + + result = dt.datetime(year, month, day, hour, minute, second, microsecond, + tzinfo=dateTime.tzinfo) + + return result + + +def addValueToDate(dateTime, value, unit): + """ Adds a value with unit to a dateTime. + + Uses dateutil.relativedelta.relativedelta from the standard library to do + the actual math. This function doesn't allow for fractional month or years, + so month and year are truncated to integers before adding. + + :param datetime dateTime: date time + :param float value: value to be added + :param DtUnit unit: of the value + :return: + """ + #logger.debug("addValueToDate({}, {}, {})".format(dateTime, value, unit)) + + if unit == DtUnit.YEARS: + intValue = int(value) # floats not implemented in relativeDelta(years) + return dateTime + relativedelta(years=intValue) + elif unit == DtUnit.MONTHS: + intValue = int(value) # floats not implemented in relativeDelta(mohths) + return dateTime + relativedelta(months=intValue) + elif unit == DtUnit.DAYS: + return dateTime + relativedelta(days=value) + elif unit == DtUnit.HOURS: + return dateTime + relativedelta(hours=value) + elif unit == DtUnit.MINUTES: + return dateTime + relativedelta(minutes=value) + elif unit == DtUnit.SECONDS: + return dateTime + relativedelta(seconds=value) + elif unit == DtUnit.MICRO_SECONDS: + return dateTime + relativedelta(microseconds=value) + else: + raise ValueError("Unexpected DtUnit: {}".format(unit)) + + +def bestUnit(durationInSeconds): + """ Gets the best tick spacing given a duration in seconds. + + :param durationInSeconds: time span duration in seconds + :return: DtUnit enumeration. + """ + + # Based on; https://stackoverflow.com/a/2144398/ + # If the duration is longer than two years the tick spacing will be in + # years. Else, if the duration is longer than two months, the spacing will + # be in months, Etcetera. + # + # This factor differs per unit. As a baseline it is 2, but for instance, + # for Months this needs to be higher (3>), This because it is impossible to + # have partial months so the tick spacing is always at least 1 month. A + # duration of two months would result in two ticks, which is too few. + # months would then results + + if durationInSeconds > SECONDS_PER_YEAR * 3: + return (durationInSeconds / SECONDS_PER_YEAR, DtUnit.YEARS) + elif durationInSeconds > SECONDS_PER_MONTH_AVERAGE * 3: + return (durationInSeconds / SECONDS_PER_MONTH_AVERAGE, DtUnit.MONTHS) + elif durationInSeconds > SECONDS_PER_DAY * 2: + return (durationInSeconds / SECONDS_PER_DAY, DtUnit.DAYS) + elif durationInSeconds > SECONDS_PER_HOUR * 2: + return (durationInSeconds / SECONDS_PER_HOUR, DtUnit.HOURS) + elif durationInSeconds > SECONDS_PER_MINUTE * 2: + return (durationInSeconds / SECONDS_PER_MINUTE, DtUnit.MINUTES) + elif durationInSeconds > 1 * 2: + return (durationInSeconds, DtUnit.SECONDS) + else: + return (durationInSeconds * MICROSECONDS_PER_SECOND, + DtUnit.MICRO_SECONDS) + + +NICE_DATE_VALUES = { + DtUnit.YEARS: [1, 2, 5, 10], + DtUnit.MONTHS: [1, 2, 3, 4, 6, 12], + DtUnit.DAYS: [1, 2, 3, 7, 14, 28], + DtUnit.HOURS: [1, 2, 3, 4, 6, 12], + DtUnit.MINUTES: [1, 2, 3, 5, 10, 15, 30], + DtUnit.SECONDS: [1, 2, 3, 5, 10, 15, 30], + DtUnit.MICRO_SECONDS : [1.0, 2.0, 5.0, 10.0], # floats for microsec +} + + +def bestFormatString(spacing, unit): + """ Finds the best format string given the spacing and DtUnit. + + If the spacing is a fractional number < 1 the format string will take this + into account + + :param spacing: spacing between ticks + :param DtUnit unit: + :return: Format string for use in strftime + :rtype: str + """ + isSmall = spacing < 1 + + if unit == DtUnit.YEARS: + return "%Y-m" if isSmall else "%Y" + elif unit == DtUnit.MONTHS: + return "%Y-%m-%d" if isSmall else "%Y-%m" + elif unit == DtUnit.DAYS: + return "%H:%M" if isSmall else "%Y-%m-%d" + elif unit == DtUnit.HOURS: + return "%H:%M" if isSmall else "%H:%M" + elif unit == DtUnit.MINUTES: + return "%H:%M:%S" if isSmall else "%H:%M" + elif unit == DtUnit.SECONDS: + return "%S.%f" if isSmall else "%H:%M:%S" + elif unit == DtUnit.MICRO_SECONDS: + return "%S.%f" + else: + raise ValueError("Unexpected DtUnit: {}".format(unit)) + + +def niceDateTimeElement(value, unit, isRound=False): + """ Uses the Nice Numbers algorithm to determine a nice value. + + The fractions are optimized for the unit of the date element. + """ + + niceValues = NICE_DATE_VALUES[unit] + elemValue = niceNumGeneric(value, niceValues, isRound=isRound) + + if unit == DtUnit.YEARS or unit == DtUnit.MONTHS: + elemValue = max(1, int(elemValue)) + + return elemValue + + +def findStartDate(dMin, dMax, nTicks): + """ Rounds a date down to the nearest nice number of ticks + """ + assert dMax >= dMin, \ + "dMin ({}) should come before dMax ({})".format(dMin, dMax) + + if dMin == dMax: + # Fallback when range is smaller than microsecond resolution + return dMin, 1, DtUnit.MICRO_SECONDS + + delta = dMax - dMin + lengthSec = delta.total_seconds() + _logger.debug("findStartDate: {}, {} (duration = {} sec, {} days)" + .format(dMin, dMax, lengthSec, lengthSec / SECONDS_PER_DAY)) + + length, unit = bestUnit(lengthSec) + niceLength = niceDateTimeElement(length, unit) + + _logger.debug("Length: {:8.3f} {} (nice = {})" + .format(length, unit.name, niceLength)) + + niceSpacing = niceDateTimeElement(niceLength / nTicks, unit, isRound=True) + + _logger.debug("Spacing: {:8.3f} {} (nice = {})" + .format(niceLength / nTicks, unit.name, niceSpacing)) + + dVal = getDateElement(dMin, unit) + + if unit == DtUnit.MONTHS: # TODO: better rounding? + niceVal = math.floor((dVal-1) / niceSpacing) * niceSpacing + 1 + elif unit == DtUnit.DAYS: + niceVal = math.floor((dVal-1) / niceSpacing) * niceSpacing + 1 + else: + niceVal = math.floor(dVal / niceSpacing) * niceSpacing + + _logger.debug("StartValue: dVal = {}, niceVal: {} ({})" + .format(dVal, niceVal, unit.name)) + + startDate = roundToElement(dMin, unit) + startDate = setDateElement(startDate, niceVal, unit) + + return startDate, niceSpacing, unit + + +def dateRange(dMin, dMax, step, unit, includeFirstBeyond = False): + """ Generates a range of dates + + :param datetime dMin: start date + :param datetime dMax: end date + :param int step: the step size + :param DtUnit unit: the unit of the step size + :param bool includeFirstBeyond: if True the first date later than dMax will + be included in the range. If False (the default), the last generated + datetime will always be smaller than dMax. + :return: + """ + if (unit == DtUnit.YEARS or unit == DtUnit.MONTHS or + unit == DtUnit.MICRO_SECONDS): + # No support for fractional month or year and resolution is microsecond + # In those cases, make sure the step is at least 1 + step = max(1, step) + else: + assert step > 0, "tickstep is 0" + + dateTime = dMin + while dateTime < dMax: + yield dateTime + dateTime = addValueToDate(dateTime, step, unit) + + if includeFirstBeyond: + yield dateTime + + + +def calcTicks(dMin, dMax, nTicks): + """Returns tick positions. + + :param datetime.datetime dMin: The min value on the axis + :param datetime.datetime dMax: The max value on the axis + :param int nTicks: The target number of ticks. The actual number of found + ticks may differ. + :returns: (list of datetimes, DtUnit) tuple + """ + _logger.debug("Calc calcTicks({}, {}, nTicks={})" + .format(dMin, dMax, nTicks)) + + startDate, niceSpacing, unit = findStartDate(dMin, dMax, nTicks) + + result = [] + for d in dateRange(startDate, dMax, niceSpacing, unit, + includeFirstBeyond=True): + result.append(d) + + assert result[0] <= dMin, \ + "First nice date ({}) should be <= dMin {}".format(result[0], dMin) + + assert result[-1] >= dMax, \ + "Last nice date ({}) should be >= dMax {}".format(result[-1], dMax) + + return result, niceSpacing, unit + + +def calcTicksAdaptive(dMin, dMax, axisLength, tickDensity): + """ Calls calcTicks with a variable number of ticks, depending on axisLength + """ + # At least 2 ticks + nticks = max(2, int(round(tickDensity * axisLength))) + return calcTicks(dMin, dMax, nticks) + + + + + diff --git a/src/silx/gui/plot/_utils/panzoom.py b/src/silx/gui/plot/_utils/panzoom.py new file mode 100644 index 0000000..77efd10 --- /dev/null +++ b/src/silx/gui/plot/_utils/panzoom.py @@ -0,0 +1,325 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Functions to apply pan and zoom on a Plot""" + +__authors__ = ["T. Vincent", "V. Valls"] +__license__ = "MIT" +__date__ = "08/08/2017" + + +import logging +import math +import numpy + + +_logger = logging.getLogger(__name__) + + +# Float 32 info ############################################################### +# Using min/max value below limits of float32 +# so operation with such value (e.g., max - min) do not overflow + +FLOAT32_SAFE_MIN = -1e37 +FLOAT32_MINPOS = numpy.finfo(numpy.float32).tiny +FLOAT32_SAFE_MAX = 1e37 +# TODO double support + + +def checkAxisLimits(vmin, vmax, isLog: bool=False, name: str=""): + """Makes sure axis range is not empty and within supported range. + + :param float vmin: Min axis value + :param float vmax: Max axis value + :return: (min, max) making sure min < max + :rtype: 2-tuple of float + """ + min_ = FLOAT32_MINPOS if isLog else FLOAT32_SAFE_MIN + vmax = numpy.clip(vmax, min_, FLOAT32_SAFE_MAX) + vmin = numpy.clip(vmin, min_, FLOAT32_SAFE_MAX) + + if vmax < vmin: + _logger.debug('%s axis: max < min, inverting limits.', name) + vmin, vmax = vmax, vmin + elif vmax == vmin: + _logger.debug('%s axis: max == min, expanding limits.', name) + if vmin == 0.: + vmin, vmax = -0.1, 0.1 + elif vmin < 0: + vmax *= 0.9 + vmin = max(vmin * 1.1, FLOAT32_SAFE_MIN) # Clip to range + else: # vmin > 0 + vmax = min(vmin * 1.1, FLOAT32_SAFE_MAX) # Clip to range + vmin *= 0.9 + + return vmin, vmax + + +def scale1DRange(min_, max_, center, scale, isLog): + """Scale a 1D range given a scale factor and an center point. + + Keeps the values in a smaller range than float32. + + :param float min_: The current min value of the range. + :param float max_: The current max value of the range. + :param float center: The center of the zoom (i.e., invariant point). + :param float scale: The scale to use for zoom + :param bool isLog: Whether using log scale or not. + :return: The zoomed range. + :rtype: tuple of 2 floats: (min, max) + """ + if isLog: + # Min and center can be < 0 when + # autoscale is off and switch to log scale + # max_ < 0 should not happen + min_ = numpy.log10(min_) if min_ > 0. else FLOAT32_MINPOS + center = numpy.log10(center) if center > 0. else FLOAT32_MINPOS + max_ = numpy.log10(max_) if max_ > 0. else FLOAT32_MINPOS + + if min_ == max_: + return min_, max_ + + offset = (center - min_) / (max_ - min_) + range_ = (max_ - min_) / scale + newMin = center - offset * range_ + newMax = center + (1. - offset) * range_ + + if isLog: + # No overflow as exponent is log10 of a float32 + newMin = pow(10., newMin) + newMax = pow(10., newMax) + newMin = numpy.clip(newMin, FLOAT32_MINPOS, FLOAT32_SAFE_MAX) + newMax = numpy.clip(newMax, FLOAT32_MINPOS, FLOAT32_SAFE_MAX) + else: + newMin = numpy.clip(newMin, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX) + newMax = numpy.clip(newMax, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX) + return newMin, newMax + + +def applyZoomToPlot(plot, scaleF, center=None): + """Zoom in/out plot given a scale and a center point. + + :param plot: The plot on which to apply zoom. + :param float scaleF: Scale factor of zoom. + :param center: (x, y) coords in pixel coordinates of the zoom center. + :type center: 2-tuple of float + """ + xMin, xMax = plot.getXAxis().getLimits() + yMin, yMax = plot.getYAxis().getLimits() + + if center is None: + left, top, width, height = plot.getPlotBoundsInPixels() + cx, cy = left + width // 2, top + height // 2 + else: + cx, cy = center + + dataCenterPos = plot.pixelToData(cx, cy) + assert dataCenterPos is not None + + xMin, xMax = scale1DRange(xMin, xMax, dataCenterPos[0], scaleF, + plot.getXAxis()._isLogarithmic()) + + yMin, yMax = scale1DRange(yMin, yMax, dataCenterPos[1], scaleF, + plot.getYAxis()._isLogarithmic()) + + dataPos = plot.pixelToData(cx, cy, axis="right") + assert dataPos is not None + y2Center = dataPos[1] + y2Min, y2Max = plot.getYAxis(axis="right").getLimits() + y2Min, y2Max = scale1DRange(y2Min, y2Max, y2Center, scaleF, + plot.getYAxis()._isLogarithmic()) + + plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + + +def applyPan(min_, max_, panFactor, isLog10): + """Returns a new range with applied panning. + + Moves the range according to panFactor. + If isLog10 is True, converts to log10 before moving. + + :param float min_: Min value of the data range to pan. + :param float max_: Max value of the data range to pan. + Must be >= min. + :param float panFactor: Signed proportion of the range to use for pan. + :param bool isLog10: True if log10 scale, False if linear scale. + :return: New min and max value with pan applied. + :rtype: 2-tuple of float. + """ + if isLog10 and min_ > 0.: + # Negative range and log scale can happen with matplotlib + logMin, logMax = math.log10(min_), math.log10(max_) + logOffset = panFactor * (logMax - logMin) + newMin = pow(10., logMin + logOffset) + newMax = pow(10., logMax + logOffset) + + # Takes care of out-of-range values + if newMin > 0. and newMax < float('inf'): + min_, max_ = newMin, newMax + + else: + offset = panFactor * (max_ - min_) + newMin, newMax = min_ + offset, max_ + offset + + # Takes care of out-of-range values + if newMin > - float('inf') and newMax < float('inf'): + min_, max_ = newMin, newMax + return min_, max_ + + +class _Unset(object): + """To be able to have distinction between None and unset""" + pass + + +class ViewConstraints(object): + """ + Store constraints applied on the view box and compute the resulting view box. + """ + + def __init__(self): + self._min = [None, None] + self._max = [None, None] + self._minRange = [None, None] + self._maxRange = [None, None] + + def update(self, xMin=_Unset, xMax=_Unset, + yMin=_Unset, yMax=_Unset, + minXRange=_Unset, maxXRange=_Unset, + minYRange=_Unset, maxYRange=_Unset): + """ + Update the constraints managed by the object + + The constraints are the same as the ones provided by PyQtGraph. + + :param float xMin: Minimum allowed x-axis value. + (default do not change the stat, None remove the constraint) + :param float xMax: Maximum allowed x-axis value. + (default do not change the stat, None remove the constraint) + :param float yMin: Minimum allowed y-axis value. + (default do not change the stat, None remove the constraint) + :param float yMax: Maximum allowed y-axis value. + (default do not change the stat, None remove the constraint) + :param float minXRange: Minimum allowed left-to-right span across the + view (default do not change the stat, None remove the constraint) + :param float maxXRange: Maximum allowed left-to-right span across the + view (default do not change the stat, None remove the constraint) + :param float minYRange: Minimum allowed top-to-bottom span across the + view (default do not change the stat, None remove the constraint) + :param float maxYRange: Maximum allowed top-to-bottom span across the + view (default do not change the stat, None remove the constraint) + :return: True if the constraints was changed + """ + updated = False + + minRange = [minXRange, minYRange] + maxRange = [maxXRange, maxYRange] + minPos = [xMin, yMin] + maxPos = [xMax, yMax] + + for axis in range(2): + + value = minPos[axis] + if value is not _Unset and value != self._min[axis]: + self._min[axis] = value + updated = True + + value = maxPos[axis] + if value is not _Unset and value != self._max[axis]: + self._max[axis] = value + updated = True + + value = minRange[axis] + if value is not _Unset and value != self._minRange[axis]: + self._minRange[axis] = value + updated = True + + value = maxRange[axis] + if value is not _Unset and value != self._maxRange[axis]: + self._maxRange[axis] = value + updated = True + + # Sanity checks + + for axis in range(2): + if self._maxRange[axis] is not None and self._min[axis] is not None and self._max[axis] is not None: + # max range cannot be larger than bounds + diff = self._max[axis] - self._min[axis] + self._maxRange[axis] = min(self._maxRange[axis], diff) + updated = True + + return updated + + def normalize(self, xMin, xMax, yMin, yMax, allow_scaling=True): + """Normalize a view range defined by x and y corners using predefined + containts. + + :param float xMin: Min position of the x-axis + :param float xMax: Max position of the x-axis + :param float yMin: Min position of the y-axis + :param float yMax: Max position of the y-axis + :param bool allow_scaling: Allow or not to apply scaling for the + normalization. Used according to the interaction mode. + :return: A normalized tuple of (xMin, xMax, yMin, yMax) + """ + viewRange = [[xMin, xMax], [yMin, yMax]] + + for axis in range(2): + # clamp xRange and yRange + if allow_scaling: + diff = viewRange[axis][1] - viewRange[axis][0] + delta = None + if self._maxRange[axis] is not None and diff > self._maxRange[axis]: + delta = self._maxRange[axis] - diff + elif self._minRange[axis] is not None and diff < self._minRange[axis]: + delta = self._minRange[axis] - diff + if delta is not None: + viewRange[axis][0] -= delta * 0.5 + viewRange[axis][1] += delta * 0.5 + + # clamp min and max positions + outMin = self._min[axis] is not None and viewRange[axis][0] < self._min[axis] + outMax = self._max[axis] is not None and viewRange[axis][1] > self._max[axis] + + if outMin and outMax: + if allow_scaling: + # we can clamp both sides + viewRange[axis][0] = self._min[axis] + viewRange[axis][1] = self._max[axis] + else: + # center the result + delta = viewRange[axis][1] - viewRange[axis][0] + mid = self._min[axis] + self._max[axis] - self._min[axis] + viewRange[axis][0] = mid - delta + viewRange[axis][1] = mid + delta + elif outMin: + delta = self._min[axis] - viewRange[axis][0] + viewRange[axis][0] += delta + viewRange[axis][1] += delta + elif outMax: + delta = self._max[axis] - viewRange[axis][1] + viewRange[axis][0] += delta + viewRange[axis][1] += delta + + return viewRange[0][0], viewRange[0][1], viewRange[1][0], viewRange[1][1] diff --git a/src/silx/gui/plot/_utils/setup.py b/src/silx/gui/plot/_utils/setup.py new file mode 100644 index 0000000..0271745 --- /dev/null +++ b/src/silx/gui/plot/_utils/setup.py @@ -0,0 +1,42 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/03/2017" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('_utils', parent_package, top_path) + config.add_subpackage('test') + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + + setup(configuration=configuration) diff --git a/src/silx/gui/plot/_utils/test/__init__.py b/src/silx/gui/plot/_utils/test/__init__.py new file mode 100644 index 0000000..3ad225d --- /dev/null +++ b/src/silx/gui/plot/_utils/test/__init__.py @@ -0,0 +1,24 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ diff --git a/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py new file mode 100644 index 0000000..8d35acf --- /dev/null +++ b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py @@ -0,0 +1,79 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["P. Kenter"] +__license__ = "MIT" +__date__ = "06/04/2018" + + +import datetime as dt +import unittest + + +from silx.gui.plot._utils.dtime_ticklayout import ( + calcTicks, DtUnit, SECONDS_PER_YEAR) + + +class TestTickLayout(unittest.TestCase): + """Test ticks layout algorithms""" + + def testSmallMonthlySpacing(self): + """ Tests a range that did result in a spacing of less than 1 month. + It is impossible to add fractional month so the unit must be in days + """ + from dateutil import parser + d1 = parser.parse("2017-01-03 13:15:06.000044") + d2 = parser.parse("2017-03-08 09:16:16.307584") + _ticks, _units, spacing = calcTicks(d1, d2, nTicks=4) + + self.assertEqual(spacing, DtUnit.DAYS) + + + def testNoCrash(self): + """ Creates many combinations of and number-of-ticks and end-dates; + tests that it doesn't give an exception and returns a reasonable number + of ticks. + """ + d1 = dt.datetime(2017, 1, 3, 13, 15, 6, 44) + + value = 100e-6 # Start at 100 micro sec range. + + while value <= 200 * SECONDS_PER_YEAR: + + d2 = d1 + dt.timedelta(microseconds=value*1e6) # end date range + + for numTicks in range(2, 12): + ticks, _, _ = calcTicks(d1, d2, numTicks) + + margin = 2.5 + self.assertTrue( + numTicks/margin <= len(ticks) <= numTicks*margin, + "Condition {} <= {} <= {} failed for # ticks={} and d2={}:" + .format(numTicks/margin, len(ticks), numTicks * margin, + numTicks, d2)) + + value = value * 1.5 # let date period grow exponentially diff --git a/src/silx/gui/plot/_utils/test/test_ticklayout.py b/src/silx/gui/plot/_utils/test/test_ticklayout.py new file mode 100644 index 0000000..884b71b --- /dev/null +++ b/src/silx/gui/plot/_utils/test/test_ticklayout.py @@ -0,0 +1,81 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "17/01/2018" + + +import unittest +import numpy + +from silx.utils.testutils import ParametricTestCase + +from silx.gui.plot._utils import ticklayout + + +class TestTickLayout(ParametricTestCase): + """Test ticks layout algorithms""" + + def testTicks(self): + """Test of :func:`ticks`""" + tests = { # (vmin, vmax): ref_ticks + (1., 1.): (1.,), + (0.5, 10.5): (2.0, 4.0, 6.0, 8.0, 10.0), + (0.001, 0.005): (0.001, 0.002, 0.003, 0.004, 0.005) + } + + for (vmin, vmax), ref_ticks in tests.items(): + with self.subTest(vmin=vmin, vmax=vmax): + ticks, labels = ticklayout.ticks(vmin, vmax) + self.assertTrue(numpy.allclose(ticks, ref_ticks)) + + def testNiceNumbers(self): + """Minimalistic tests of :func:`niceNumbers`""" + tests = { # (vmin, vmax): ref_ticks + (0.5, 10.5): (0.0, 12.0, 2.0, 0), + (10000., 10000.5): (10000.0, 10000.5, 0.1, 1), + (0.001, 0.005): (0.001, 0.005, 0.001, 3) + } + + for (vmin, vmax), ref_ticks in tests.items(): + with self.subTest(vmin=vmin, vmax=vmax): + ticks = ticklayout.niceNumbers(vmin, vmax) + self.assertEqual(ticks, ref_ticks) + + def testNiceNumbersLog(self): + """Minimalistic tests of :func:`niceNumbersForLog10`""" + tests = { # (log10(min), log10(max): ref_ticks + (0., 3.): (0, 3, 1, 0), + (-3., 3): (-3, 3, 1, 0), + (-32., 0.): (-36, 0, 6, 0) + } + + for (vmin, vmax), ref_ticks in tests.items(): + with self.subTest(vmin=vmin, vmax=vmax): + ticks = ticklayout.niceNumbersForLog10(vmin, vmax) + self.assertEqual(ticks, ref_ticks) diff --git a/src/silx/gui/plot/_utils/ticklayout.py b/src/silx/gui/plot/_utils/ticklayout.py new file mode 100644 index 0000000..c9fd3e6 --- /dev/null +++ b/src/silx/gui/plot/_utils/ticklayout.py @@ -0,0 +1,267 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module implements labels layout on graph axes.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/10/2016" + + +import math + + +# utils ####################################################################### + +def numberOfDigits(tickSpacing): + """Returns the number of digits to display for text label. + + :param float tickSpacing: Step between ticks in data space. + :return: Number of digits to show for labels. + :rtype: int + """ + nfrac = int(-math.floor(math.log10(tickSpacing))) + if nfrac < 0: + nfrac = 0 + return nfrac + + +# Nice Numbers ################################################################ + +# This is the original niceNum implementation. For the date time ticks a more +# generic implementation was needed. +# +# def _niceNum(value, isRound=False): +# expvalue = math.floor(math.log10(value)) +# frac = value/pow(10., expvalue) +# if isRound: +# if frac < 1.5: +# nicefrac = 1. +# elif frac < 3.: # In niceNumGeneric this is (2+5)/2 = 3.5 +# nicefrac = 2. +# elif frac < 7.: +# nicefrac = 5. # In niceNumGeneric this is (5+10)/2 = 7.5 +# else: +# nicefrac = 10. +# else: +# if frac <= 1.: +# nicefrac = 1. +# elif frac <= 2.: +# nicefrac = 2. +# elif frac <= 5.: +# nicefrac = 5. +# else: +# nicefrac = 10. +# return nicefrac * pow(10., expvalue) + + +def niceNumGeneric(value, niceFractions=None, isRound=False): + """ A more generic implementation of the _niceNum function + + Allows the user to specify the fractions instead of using a hardcoded + list of [1, 2, 5, 10.0]. + """ + if value == 0: + return value + + if niceFractions is None: # Use default values + niceFractions = 1., 2., 5., 10. + roundFractions = (1.5, 3., 7., 10.) if isRound else niceFractions + + else: + roundFractions = list(niceFractions) + if isRound: + # Take the average with the next element. The last remains the same. + for i in range(len(roundFractions) - 1): + roundFractions[i] = (niceFractions[i] + niceFractions[i+1]) / 2 + + highest = niceFractions[-1] + value = float(value) + + expvalue = math.floor(math.log(value, highest)) + frac = value / pow(highest, expvalue) + + for niceFrac, roundFrac in zip(niceFractions, roundFractions): + if frac <= roundFrac: + return niceFrac * pow(highest, expvalue) + + # should not come here + assert False, "should not come here" + + +def niceNumbers(vMin, vMax, nTicks=5): + """Returns tick positions. + + This function implements graph labels layout using nice numbers + by Paul Heckbert from "Graphics Gems", Academic Press, 1990. + See `C code <http://tog.acm.org/resources/GraphicsGems/gems/Label.c>`_. + + :param float vMin: The min value on the axis + :param float vMax: The max value on the axis + :param int nTicks: The number of ticks to position + :returns: min, max, increment value of tick positions and + number of fractional digit to show + :rtype: tuple + """ + vrange = niceNumGeneric(vMax - vMin, isRound=False) + spacing = niceNumGeneric(vrange / nTicks, isRound=True) + graphmin = math.floor(vMin / spacing) * spacing + graphmax = math.ceil(vMax / spacing) * spacing + nfrac = numberOfDigits(spacing) + return graphmin, graphmax, spacing, nfrac + + +def _frange(start, stop, step): + """range for float (including stop).""" + assert step >= 0. + while start <= stop: + yield start + start += step + + +def ticks(vMin, vMax, nbTicks=5): + """Returns tick positions and labels using nice numbers algorithm. + + This enforces ticks to be within [vMin, vMax] range. + It returns at least 1 tick (when vMin == vMax). + + :param float vMin: The min value on the axis + :param float vMax: The max value on the axis + :param int nbTicks: The number of ticks to position + :returns: tick positions and corresponding text labels + :rtype: 2-tuple: list of float, list of string + """ + assert vMin <= vMax + if vMin == vMax: + positions = [vMin] + nfrac = 0 + + else: + start, end, step, nfrac = niceNumbers(vMin, vMax, nbTicks) + positions = [t for t in _frange(start, end, step) if vMin <= t <= vMax] + + # Makes sure there is at least 2 ticks + if len(positions) < 2: + positions = [vMin, vMax] + nfrac = numberOfDigits(vMax - vMin) + + # Generate labels + format_ = '%g' if nfrac == 0 else '%.{}f'.format(nfrac) + labels = [format_ % tick for tick in positions] + return positions, labels + + +def niceNumbersAdaptative(vMin, vMax, axisLength, tickDensity): + """Returns tick positions using :func:`niceNumbers` and a + density of ticks. + + axisLength and tickDensity are based on the same unit (e.g., pixel). + + :param float vMin: The min value on the axis + :param float vMax: The max value on the axis + :param float axisLength: The length of the axis. + :param float tickDensity: The density of ticks along the axis. + :returns: min, max, increment value of tick positions and + number of fractional digit to show + :rtype: tuple + """ + # At least 2 ticks + nticks = max(2, int(round(tickDensity * axisLength))) + tickmin, tickmax, step, nfrac = niceNumbers(vMin, vMax, nticks) + + return tickmin, tickmax, step, nfrac + + +# Nice Numbers for log scale ################################################## + +def niceNumbersForLog10(minLog, maxLog, nTicks=5): + """Return tick positions for logarithmic scale + + :param float minLog: log10 of the min value on the axis + :param float maxLog: log10 of the max value on the axis + :param int nTicks: The number of ticks to position + :returns: log10 of min, max, increment value of tick positions and + number of fractional digit to show + :rtype: tuple of int + """ + graphminlog = math.floor(minLog) + graphmaxlog = math.ceil(maxLog) + rangelog = graphmaxlog - graphminlog + + if rangelog <= nTicks: + spacing = 1. + else: + spacing = math.floor(rangelog / nTicks) + + graphminlog = math.floor(graphminlog / spacing) * spacing + graphmaxlog = math.ceil(graphmaxlog / spacing) * spacing + + nfrac = numberOfDigits(spacing) + + return int(graphminlog), int(graphmaxlog), int(spacing), nfrac + + +def niceNumbersAdaptativeForLog10(vMin, vMax, axisLength, tickDensity): + """Returns tick positions using :func:`niceNumbers` and a + density of ticks. + + axisLength and tickDensity are based on the same unit (e.g., pixel). + + :param float vMin: The min value on the axis + :param float vMax: The max value on the axis + :param float axisLength: The length of the axis. + :param float tickDensity: The density of ticks along the axis. + :returns: log10 of min, max, increment value of tick positions and + number of fractional digit to show + :rtype: tuple + """ + # At least 2 ticks + nticks = max(2, int(round(tickDensity * axisLength))) + tickmin, tickmax, step, nfrac = niceNumbersForLog10(vMin, vMax, nticks) + + return tickmin, tickmax, step, nfrac + + +def computeLogSubTicks(ticks, lowBound, highBound): + """Return the sub ticks for the log scale for all given ticks if subtick + is in [lowBound, highBound] + + :param ticks: log10 of the ticks + :param lowBound: the lower boundary of ticks + :param highBound: the higher boundary of ticks + :return: all the sub ticks contained in ticks (log10) + """ + if len(ticks) < 1: + return [] + + res = [] + for logPos in ticks: + dataOrigPos = logPos + for index in range(2, 10): + dataPos = dataOrigPos * index + if lowBound <= dataPos <= highBound: + res.append(dataPos) + return res diff --git a/src/silx/gui/plot/actions/PlotAction.py b/src/silx/gui/plot/actions/PlotAction.py new file mode 100644 index 0000000..2983775 --- /dev/null +++ b/src/silx/gui/plot/actions/PlotAction.py @@ -0,0 +1,78 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +The class :class:`.PlotAction` help the creation of a qt.QAction associated +with a :class:`.PlotWidget`. +""" + +from __future__ import division + + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "03/01/2018" + + +import weakref +from silx.gui import icons +from silx.gui import qt + + +class PlotAction(qt.QAction): + """Base class for QAction that operates on a PlotWidget. + + :param plot: :class:`.PlotWidget` instance on which to operate. + :param icon: QIcon or str name of icon to use + :param str text: The name of this action to be used for menu label + :param str tooltip: The text of the tooltip + :param triggered: The callback to connect to the action's triggered + signal or None for no callback. + :param bool checkable: True for checkable action, False otherwise (default) + :param parent: See :class:`QAction`. + """ + + def __init__(self, plot, icon, text, tooltip=None, + triggered=None, checkable=False, parent=None): + assert plot is not None + self._plotRef = weakref.ref(plot) + + if not isinstance(icon, qt.QIcon): + # Try with icon as a string and load corresponding icon + icon = icons.getQIcon(icon) + + super(PlotAction, self).__init__(icon, text, parent) + + if tooltip is not None: + self.setToolTip(tooltip) + + self.setCheckable(checkable) + + if triggered is not None: + self.triggered[bool].connect(triggered) + + @property + def plot(self): + """The :class:`.PlotWidget` this action group is controlling.""" + return self._plotRef() diff --git a/src/silx/gui/plot/actions/PlotToolAction.py b/src/silx/gui/plot/actions/PlotToolAction.py new file mode 100644 index 0000000..fbb0b0f --- /dev/null +++ b/src/silx/gui/plot/actions/PlotToolAction.py @@ -0,0 +1,150 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +The class :class:`.PlotToolAction` help the creation of a qt.QAction associating +a tool window with a :class:`.PlotWidget`. +""" + +from __future__ import division + + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "10/10/2018" + + +import weakref + +from .PlotAction import PlotAction +from silx.gui import qt + + +class PlotToolAction(PlotAction): + """Base class for QAction that maintain a tool window operating on a + PlotWidget.""" + + def __init__(self, plot, icon, text, tooltip=None, + triggered=None, checkable=False, parent=None): + PlotAction.__init__(self, + plot=plot, + icon=icon, + text=text, + tooltip=tooltip, + triggered=self._triggered, + parent=parent, + checkable=True) + self._previousGeometry = None + self._toolWindow = None + + def _triggered(self, checked): + """Update the plot of the histogram visibility status + + :param bool checked: status of the action button + """ + self._setToolWindowVisible(checked) + + def _setToolWindowVisible(self, visible): + """Set the tool window visible or hidden.""" + tool = self._getToolWindow() + if tool.isVisible() == visible: + # Nothing to do + return + + if visible: + self._connectPlot(tool) + tool.show() + if self._previousGeometry is not None: + # Restore the geometry + tool.setGeometry(self._previousGeometry) + else: + self._disconnectPlot(tool) + # Save the geometry + self._previousGeometry = tool.geometry() + tool.hide() + + def _connectPlot(self, window): + """Called if the tool is visible and have to be updated according to + event of the plot. + + :param qt.QWidget window: The tool window + """ + pass + + def _disconnectPlot(self, window): + """Called if the tool is not visible and dont have anymore to be updated + according to event of the plot. + + :param qt.QWidget window: The tool window + """ + pass + + def _isWindowInUse(self): + """Returns true if the tool window is currently in use.""" + if not self.isChecked(): + return False + return self._toolWindow is not None + + def _ownerVisibilityChanged(self, isVisible): + """Called when the visibility of the parent of the tool window changes + + :param bool isVisible: True if the parent became visible + """ + if self._isWindowInUse(): + self._setToolWindowVisible(isVisible) + + def eventFilter(self, qobject, event): + """Observe when the close event is emitted then + simply uncheck the action button + + :param qobject: the object observe + :param event: the event received by qobject + """ + if event.type() == qt.QEvent.Close: + if self._toolWindow is not None: + window = self._toolWindow() + self._previousGeometry = window.geometry() + window.hide() + self.setChecked(False) + + return PlotAction.eventFilter(self, qobject, event) + + def _getToolWindow(self): + """Returns the window containing the tool. + + It uses lazy loading to create this tool.. + """ + if self._toolWindow is None: + window = self._createToolWindow() + if self._previousGeometry is not None: + window.setGeometry(self._previousGeometry) + window.installEventFilter(self) + plot = self.plot + plot.sigVisibilityChanged.connect(self._ownerVisibilityChanged) + self._toolWindow = weakref.ref(window) + return self._toolWindow() + + def _createToolWindow(self): + """Create the tool window managing the plot.""" + raise NotImplementedError() diff --git a/src/silx/gui/plot/actions/__init__.py b/src/silx/gui/plot/actions/__init__.py new file mode 100644 index 0000000..930c728 --- /dev/null +++ b/src/silx/gui/plot/actions/__init__.py @@ -0,0 +1,42 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a set of QAction to use with +:class:`~silx.gui.plot.PlotWidget` + +Those actions are useful to add menu items or toolbar items +that interact with a :class:`~silx.gui.plot.PlotWidget`. + +It provides a base class used to define new plot actions: +:class:`~silx.gui.plot.actions.PlotAction`. +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "16/08/2017" + +from .PlotAction import PlotAction +from . import control +from . import mode +from . import io diff --git a/src/silx/gui/plot/actions/control.py b/src/silx/gui/plot/actions/control.py new file mode 100755 index 0000000..439985e --- /dev/null +++ b/src/silx/gui/plot/actions/control.py @@ -0,0 +1,694 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +:mod:`silx.gui.plot.actions.control` provides a set of QAction relative to control +of a :class:`.PlotWidget`. + +The following QAction are available: + +- :class:`ColormapAction` +- :class:`CrosshairAction` +- :class:`CurveStyleAction` +- :class:`GridAction` +- :class:`KeepAspectRatioAction` +- :class:`PanWithArrowKeysAction` +- :class:`ResetZoomAction` +- :class:`ShowAxisAction` +- :class:`XAxisLogarithmicAction` +- :class:`XAxisAutoScaleAction` +- :class:`YAxisInvertedAction` +- :class:`YAxisLogarithmicAction` +- :class:`YAxisAutoScaleAction` +- :class:`ZoomBackAction` +- :class:`ZoomInAction` +- :class:`ZoomOutAction` +""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "27/11/2020" + +from . import PlotAction +import logging +from silx.gui.plot import items +from silx.gui.plot._utils import applyZoomToPlot as _applyZoomToPlot +from silx.gui import qt +from silx.gui import icons + +_logger = logging.getLogger(__name__) + + +class ResetZoomAction(PlotAction): + """QAction controlling reset zoom on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(ResetZoomAction, self).__init__( + plot, icon='zoom-original', text='Reset Zoom', + tooltip='Auto-scale the graph', + triggered=self._actionTriggered, + checkable=False, parent=parent) + self._autoscaleChanged(True) + plot.getXAxis().sigAutoScaleChanged.connect(self._autoscaleChanged) + plot.getYAxis().sigAutoScaleChanged.connect(self._autoscaleChanged) + + def _autoscaleChanged(self, enabled): + xAxis = self.plot.getXAxis() + yAxis = self.plot.getYAxis() + self.setEnabled(xAxis.isAutoScale() or yAxis.isAutoScale()) + + if xAxis.isAutoScale() and yAxis.isAutoScale(): + tooltip = 'Auto-scale the graph' + elif xAxis.isAutoScale(): # And not Y axis + tooltip = 'Auto-scale the x-axis of the graph only' + elif yAxis.isAutoScale(): # And not X axis + tooltip = 'Auto-scale the y-axis of the graph only' + else: # no axis in autoscale + tooltip = 'Auto-scale the graph' + self.setToolTip(tooltip) + + def _actionTriggered(self, checked=False): + self.plot.resetZoom() + + +class ZoomBackAction(PlotAction): + """QAction performing a zoom-back in :class:`.PlotWidget` limits history. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(ZoomBackAction, self).__init__( + plot, icon='zoom-back', text='Zoom Back', + tooltip='Zoom back the plot', + triggered=self._actionTriggered, + checkable=False, parent=parent) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + def _actionTriggered(self, checked=False): + self.plot.getLimitsHistory().pop() + + +class ZoomInAction(PlotAction): + """QAction performing a zoom-in on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(ZoomInAction, self).__init__( + plot, icon='zoom-in', text='Zoom In', + tooltip='Zoom in the plot', + triggered=self._actionTriggered, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.ZoomIn) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + def _actionTriggered(self, checked=False): + _applyZoomToPlot(self.plot, 1.1) + + +class ZoomOutAction(PlotAction): + """QAction performing a zoom-out on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(ZoomOutAction, self).__init__( + plot, icon='zoom-out', text='Zoom Out', + tooltip='Zoom out the plot', + triggered=self._actionTriggered, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.ZoomOut) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + def _actionTriggered(self, checked=False): + _applyZoomToPlot(self.plot, 1. / 1.1) + + +class XAxisAutoScaleAction(PlotAction): + """QAction controlling X axis autoscale on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(XAxisAutoScaleAction, self).__init__( + plot, icon='plot-xauto', text='X Autoscale', + tooltip='Enable x-axis auto-scale when checked.\n' + 'If unchecked, x-axis does not change when reseting zoom.', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.getXAxis().isAutoScale()) + plot.getXAxis().sigAutoScaleChanged.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.getXAxis().setAutoScale(checked) + if checked: + self.plot.resetZoom() + + +class YAxisAutoScaleAction(PlotAction): + """QAction controlling Y axis autoscale on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(YAxisAutoScaleAction, self).__init__( + plot, icon='plot-yauto', text='Y Autoscale', + tooltip='Enable y-axis auto-scale when checked.\n' + 'If unchecked, y-axis does not change when reseting zoom.', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.getYAxis().isAutoScale()) + plot.getYAxis().sigAutoScaleChanged.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.getYAxis().setAutoScale(checked) + if checked: + self.plot.resetZoom() + + +class XAxisLogarithmicAction(PlotAction): + """QAction controlling X axis log scale on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(XAxisLogarithmicAction, self).__init__( + plot, icon='plot-xlog', text='X Log. scale', + tooltip='Logarithmic x-axis when checked', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.axis = plot.getXAxis() + self.setChecked(self.axis.getScale() == self.axis.LOGARITHMIC) + self.axis.sigScaleChanged.connect(self._setCheckedIfLogScale) + + def _setCheckedIfLogScale(self, scale): + self.setChecked(scale == self.axis.LOGARITHMIC) + + def _actionTriggered(self, checked=False): + scale = self.axis.LOGARITHMIC if checked else self.axis.LINEAR + self.axis.setScale(scale) + + +class YAxisLogarithmicAction(PlotAction): + """QAction controlling Y axis log scale on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(YAxisLogarithmicAction, self).__init__( + plot, icon='plot-ylog', text='Y Log. scale', + tooltip='Logarithmic y-axis when checked', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.axis = plot.getYAxis() + self.setChecked(self.axis.getScale() == self.axis.LOGARITHMIC) + self.axis.sigScaleChanged.connect(self._setCheckedIfLogScale) + + def _setCheckedIfLogScale(self, scale): + self.setChecked(scale == self.axis.LOGARITHMIC) + + def _actionTriggered(self, checked=False): + scale = self.axis.LOGARITHMIC if checked else self.axis.LINEAR + self.axis.setScale(scale) + + +class GridAction(PlotAction): + """QAction controlling grid mode on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param str gridMode: The grid mode to use in 'both', 'major'. + See :meth:`.PlotWidget.setGraphGrid` + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, gridMode='both', parent=None): + assert gridMode in ('both', 'major') + self._gridMode = gridMode + + super(GridAction, self).__init__( + plot, icon='plot-grid', text='Grid', + tooltip='Toggle grid (on/off)', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.getGraphGrid() is not None) + plot.sigSetGraphGrid.connect(self._gridChanged) + + def _gridChanged(self, which): + """Slot listening for PlotWidget grid mode change.""" + self.setChecked(which != 'None') + + def _actionTriggered(self, checked=False): + self.plot.setGraphGrid(self._gridMode if checked else None) + + +class CurveStyleAction(PlotAction): + """QAction controlling curve style on a :class:`.PlotWidget`. + + It changes the default line and markers style which updates all + curves on the plot. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(CurveStyleAction, self).__init__( + plot, icon='plot-toggle-points', text='Curve style', + tooltip='Change curve line and markers style', + triggered=self._actionTriggered, + checkable=False, parent=parent) + + def _actionTriggered(self, checked=False): + currentState = (self.plot.isDefaultPlotLines(), + self.plot.isDefaultPlotPoints()) + + if currentState == (False, False): + newState = True, False + else: + # line only, line and symbol, symbol only + states = (True, False), (True, True), (False, True) + newState = states[(states.index(currentState) + 1) % 3] + + self.plot.setDefaultPlotLines(newState[0]) + self.plot.setDefaultPlotPoints(newState[1]) + + +class ColormapAction(PlotAction): + """QAction opening a ColormapDialog to update the colormap. + + Both the active image colormap and the default colormap are updated. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + self._dialog = None # To store an instance of ColormapDialog + super(ColormapAction, self).__init__( + plot, icon='colormap', text='Colormap', + tooltip="Change colormap", + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.plot.sigActiveImageChanged.connect(self._updateColormap) + self.plot.sigActiveScatterChanged.connect(self._updateColormap) + + def setColorDialog(self, colorDialog): + """Set a specific color dialog instead of using the default dialog.""" + assert(colorDialog is not None) + assert(self._dialog is None) + self._dialog = colorDialog + self._dialog.visibleChanged.connect(self._dialogVisibleChanged) + self.setChecked(self._dialog.isVisible()) + + @staticmethod + def _createDialog(parent): + """Create the dialog if not already existing + + :parent QWidget parent: Parent of the new colormap + :rtype: ColormapDialog + """ + from silx.gui.dialog.ColormapDialog import ColormapDialog + dialog = ColormapDialog(parent=parent) + dialog.setModal(False) + return dialog + + def _actionTriggered(self, checked=False): + """Create a cmap dialog and update active image and default cmap.""" + if self._dialog is None: + self._dialog = self._createDialog(self.plot) + self._dialog.visibleChanged.connect(self._dialogVisibleChanged) + + # Run the dialog listening to colormap change + if checked is True: + self._updateColormap() + self._dialog.show() + else: + self._dialog.hide() + + def _dialogVisibleChanged(self, isVisible): + self.setChecked(isVisible) + + def _updateColormap(self): + if self._dialog is None: + return + image = self.plot.getActiveImage() + + if isinstance(image, items.ColormapMixIn): + # Set dialog from active image + colormap = image.getColormap() + # Set histogram and range if any + self._dialog.setItem(image) + + else: + # No active image or active image is RGBA, + # Check for active scatter plot + scatter = self.plot._getActiveItem(kind='scatter') + if scatter is not None: + colormap = scatter.getColormap() + self._dialog.setItem(scatter) + + else: + # No active data image nor scatter, + # set dialog from default info + colormap = self.plot.getDefaultColormap() + # Reset histogram and range if any + self._dialog.setData(None) + + self._dialog.setColormap(colormap) + + +class ColorBarAction(PlotAction): + """QAction opening the ColorBarWidget of the specified plot. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + self._dialog = None # To store an instance of ColorBar + super(ColorBarAction, self).__init__( + plot, icon='colorbar', text='Colorbar', + tooltip="Show/Hide the colorbar", + triggered=self._actionTriggered, + checkable=True, parent=parent) + colorBarWidget = self.plot.getColorBarWidget() + old = self.blockSignals(True) + self.setChecked(colorBarWidget.isVisibleTo(self.plot)) + self.blockSignals(old) + colorBarWidget.sigVisibleChanged.connect(self._widgetVisibleChanged) + + def _widgetVisibleChanged(self, isVisible): + """Callback when the colorbar `visible` property change.""" + if self.isChecked() == isVisible: + return + self.setChecked(isVisible) + + def _actionTriggered(self, checked=False): + """Create a cmap dialog and update active image and default cmap.""" + colorBarWidget = self.plot.getColorBarWidget() + if not colorBarWidget.isHidden() == checked: + return + self.plot.getColorBarWidget().setVisible(checked) + + +class KeepAspectRatioAction(PlotAction): + """QAction controlling aspect ratio on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + # Uses two images for checked/unchecked states + self._states = { + False: (icons.getQIcon('shape-circle-solid'), + "Keep data aspect ratio"), + True: (icons.getQIcon('shape-ellipse-solid'), + "Do no keep data aspect ratio") + } + + icon, tooltip = self._states[plot.isKeepDataAspectRatio()] + super(KeepAspectRatioAction, self).__init__( + plot, + icon=icon, + text='Toggle keep aspect ratio', + tooltip=tooltip, + triggered=self._actionTriggered, + checkable=False, + parent=parent) + plot.sigSetKeepDataAspectRatio.connect( + self._keepDataAspectRatioChanged) + + def _keepDataAspectRatioChanged(self, aspectRatio): + """Handle Plot set keep aspect ratio signal""" + icon, tooltip = self._states[aspectRatio] + self.setIcon(icon) + self.setToolTip(tooltip) + + def _actionTriggered(self, checked=False): + # This will trigger _keepDataAspectRatioChanged + self.plot.setKeepDataAspectRatio(not self.plot.isKeepDataAspectRatio()) + + +class YAxisInvertedAction(PlotAction): + """QAction controlling Y orientation on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + # Uses two images for checked/unchecked states + self._states = { + False: (icons.getQIcon('plot-ydown'), + "Orient Y axis downward"), + True: (icons.getQIcon('plot-yup'), + "Orient Y axis upward"), + } + + icon, tooltip = self._states[plot.getYAxis().isInverted()] + super(YAxisInvertedAction, self).__init__( + plot, + icon=icon, + text='Invert Y Axis', + tooltip=tooltip, + triggered=self._actionTriggered, + checkable=False, + parent=parent) + plot.getYAxis().sigInvertedChanged.connect(self._yAxisInvertedChanged) + + def _yAxisInvertedChanged(self, inverted): + """Handle Plot set y axis inverted signal""" + icon, tooltip = self._states[inverted] + self.setIcon(icon) + self.setToolTip(tooltip) + + def _actionTriggered(self, checked=False): + # This will trigger _yAxisInvertedChanged + yAxis = self.plot.getYAxis() + yAxis.setInverted(not yAxis.isInverted()) + + +class CrosshairAction(PlotAction): + """QAction toggling crosshair cursor on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param str color: Color to use to draw the crosshair + :param int linewidth: Width of the crosshair cursor + :param str linestyle: Style of line. See :meth:`.Plot.setGraphCursor` + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, color='black', linewidth=1, linestyle='-', + parent=None): + self.color = color + """Color used to draw the crosshair (str).""" + + self.linewidth = linewidth + """Width of the crosshair cursor (int).""" + + self.linestyle = linestyle + """Style of line of the cursor (str).""" + + super(CrosshairAction, self).__init__( + plot, icon='crosshair', text='Crosshair Cursor', + tooltip='Enable crosshair cursor when checked', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.getGraphCursor() is not None) + plot.sigSetGraphCursor.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.setGraphCursor(checked, + color=self.color, + linestyle=self.linestyle, + linewidth=self.linewidth) + + +class PanWithArrowKeysAction(PlotAction): + """QAction toggling pan with arrow keys on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + + super(PanWithArrowKeysAction, self).__init__( + plot, icon='arrow-keys', text='Pan with arrow keys', + tooltip='Enable pan with arrow keys when checked', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.isPanWithArrowKeys()) + plot.sigSetPanWithArrowKeys.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.setPanWithArrowKeys(checked) + + +class ShowAxisAction(PlotAction): + """QAction controlling axis visibility on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + tooltip = 'Show plot axis when checked, otherwise hide them' + PlotAction.__init__(self, + plot, + icon='axis', + text='show axis', + tooltip=tooltip, + triggered=self._actionTriggered, + checkable=True, + parent=parent) + self.setChecked(self.plot.isAxesDisplayed()) + plot._sigAxesVisibilityChanged.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.setAxesDisplayed(checked) + + +class ClosePolygonInteractionAction(PlotAction): + """QAction controlling closure of a polygon in draw interaction mode + if the :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + tooltip = 'Close the current polygon drawn' + PlotAction.__init__(self, + plot, + icon='add-shape-polygon', + text='Close the polygon', + tooltip=tooltip, + triggered=self._actionTriggered, + checkable=True, + parent=parent) + self.plot.sigInteractiveModeChanged.connect(self._modeChanged) + self._modeChanged(None) + + def _modeChanged(self, source): + mode = self.plot.getInteractiveMode() + enabled = "shape" in mode and mode["shape"] == "polygon" + self.setEnabled(enabled) + + def _actionTriggered(self, checked=False): + self.plot._eventHandler.validate() + + +class OpenGLAction(PlotAction): + """QAction controlling rendering of a :class:`.PlotWidget`. + + For now it can enable or not the OpenGL backend. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + # Uses two images for checked/unchecked states + self._states = { + "opengl": (icons.getQIcon('backend-opengl'), + "OpenGL rendering (fast)\nClick to disable OpenGL"), + "matplotlib": (icons.getQIcon('backend-opengl'), + "Matplotlib rendering (safe)\nClick to enable OpenGL"), + "unknown": (icons.getQIcon('backend-opengl'), + "Custom rendering") + } + + name = self._getBackendName(plot) + self.__state = name + icon, tooltip = self._states[name] + super(OpenGLAction, self).__init__( + plot, + icon=icon, + text='Enable/disable OpenGL rendering', + tooltip=tooltip, + triggered=self._actionTriggered, + checkable=True, + parent=parent) + + def _backendUpdated(self): + name = self._getBackendName(self.plot) + self.__state = name + icon, tooltip = self._states[name] + self.setIcon(icon) + self.setToolTip(tooltip) + self.setChecked(name == "opengl") + + def _getBackendName(self, plot): + backend = plot.getBackend() + name = type(backend).__name__.lower() + if "opengl" in name: + return "opengl" + elif "matplotlib" in name: + return "matplotlib" + else: + return "unknown" + + def _actionTriggered(self, checked=False): + plot = self.plot + name = self._getBackendName(self.plot) + if self.__state != name: + # THere is no event to know the backend was updated + # So here we check if there is a mismatch between the displayed state + # and the real state of the widget + self._backendUpdated() + return + if name != "opengl": + from silx.gui.utils import glutils + result = glutils.isOpenGLAvailable() + if not result: + qt.QMessageBox.critical(plot, "OpenGL rendering not available", result.error) + # Uncheck if needed + self._backendUpdated() + return + plot.setBackend("opengl") + else: + plot.setBackend("matplotlib") + self._backendUpdated() diff --git a/src/silx/gui/plot/actions/fit.py b/src/silx/gui/plot/actions/fit.py new file mode 100644 index 0000000..e130b24 --- /dev/null +++ b/src/silx/gui/plot/actions/fit.py @@ -0,0 +1,485 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +:mod:`silx.gui.plot.actions.fit` module provides actions relative to fit. + +The following QAction are available: + +- :class:`.FitAction` + +.. autoclass:`.FitAction` +""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "10/10/2018" + +import logging +import sys +import weakref +import numpy + +from .PlotToolAction import PlotToolAction +from .. import items +from ....utils.deprecation import deprecated +from silx.gui import qt +from silx.gui.plot.ItemsSelectionDialog import ItemsSelectionDialog + +_logger = logging.getLogger(__name__) + + +def _getUniqueCurveOrHistogram(plot): + """Returns unique :class:`Curve` or :class:`Histogram` in a `PlotWidget`. + + If there is an active curve, returns it, else return curve or histogram + only if alone in the plot. + + :param PlotWidget plot: + :rtype: Union[None,~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram] + """ + curve = plot.getActiveCurve() + if curve is not None: + return curve + + visibleItems = [item for item in plot.getItems() if item.isVisible()] + histograms = [item for item in visibleItems + if isinstance(item, items.Histogram)] + curves = [item for item in visibleItems + if isinstance(item, items.Curve)] + + if len(histograms) == 1 and len(curves) == 0: + return histograms[0] + elif len(curves) == 1 and len(histograms) == 0: + return curves[0] + else: + return None + + +class _FitItemSelector(qt.QObject): + """ + :class:`PlotWidget` observer that emits signal when fit selection changes. + + Track active curve or unique curve or histogram. + """ + + sigCurrentItemChanged = qt.Signal(object) + """Signal emitted when the item to fit has changed""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__plotWidgetRef = None + self.__currentItem = None + + def getCurrentItem(self): + """Return currently selected item + + :rtype: Union[Item,None] + """ + return self.__currentItem + + def getPlotWidget(self): + """Return currently attached :class:`PlotWidget` + + :rtype: Union[PlotWidget,None] + """ + return None if self.__plotWidgetRef is None else self.__plotWidgetRef() + + def setPlotWidget(self, plotWidget): + """Set the :class:`PlotWidget` for which to track changes + + :param Union[PlotWidget,None] plotWidget: + The :class:`PlotWidget` to observe + """ + # disconnect from previous plot + previousPlotWidget = self.getPlotWidget() + if previousPlotWidget is not None: + previousPlotWidget.sigItemAdded.disconnect( + self.__plotWidgetUpdated) + previousPlotWidget.sigItemRemoved.disconnect( + self.__plotWidgetUpdated) + previousPlotWidget.sigActiveCurveChanged.disconnect( + self.__plotWidgetUpdated) + + if plotWidget is None: + self.__plotWidgetRef = None + self.__setCurrentItem(None) + return + self.__plotWidgetRef = weakref.ref(plotWidget, self.__plotDeleted) + + # connect to new plot + plotWidget.sigItemAdded.connect(self.__plotWidgetUpdated) + plotWidget.sigItemRemoved.connect(self.__plotWidgetUpdated) + plotWidget.sigActiveCurveChanged.connect(self.__plotWidgetUpdated) + self.__plotWidgetUpdated() + + def __plotDeleted(self): + """Handle deletion of PlotWidget""" + self.__setCurrentItem(None) + + def __plotWidgetUpdated(self, *args, **kwargs): + """Handle updates of PlotWidget content""" + plotWidget = self.getPlotWidget() + if plotWidget is None: + return + self.__setCurrentItem(_getUniqueCurveOrHistogram(plotWidget)) + + def __setCurrentItem(self, item): + """Handle change of current item""" + if sys.is_finalizing(): + return + + previousItem = self.getCurrentItem() + if item != previousItem: + if previousItem is not None: + previousItem.sigItemChanged.disconnect(self.__itemUpdated) + + self.__currentItem = item + + if self.__currentItem is not None: + self.__currentItem.sigItemChanged.connect(self.__itemUpdated) + self.sigCurrentItemChanged.emit(self.__currentItem) + + def __itemUpdated(self, event): + """Handle change on current item""" + if event == items.ItemChangedType.DATA: + self.sigCurrentItemChanged.emit(self.__currentItem) + + +class FitAction(PlotToolAction): + """QAction to open a :class:`FitWidget` and set its data to the + active curve if any, or to the first curve. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + self.__item = None + self.__activeCurveSynchroEnabled = False + self.__range = 0, 1 + self.__rangeAutoUpdate = False + self.__x, self.__y = None, None # Data to fit + self.__curveParams = {} # Store curve parameters to use for fit result + self.__legend = None + + super(FitAction, self).__init__( + plot, icon='math-fit', text='Fit curve', + tooltip='Open a fit dialog', + parent=parent) + + self.__fitItemSelector = _FitItemSelector() + self.__fitItemSelector.sigCurrentItemChanged.connect( + self._setFittedItem) + + + @property + @deprecated(replacement='getXRange()[0]', since_version='0.13.0') + def xmin(self): + return self.getXRange()[0] + + @property + @deprecated(replacement='getXRange()[1]', since_version='0.13.0') + def xmax(self): + return self.getXRange()[1] + + @property + @deprecated(replacement='getXData()', since_version='0.13.0') + def x(self): + return self.getXData() + + @property + @deprecated(replacement='getYData()', since_version='0.13.0') + def y(self): + return self.getYData() + + @property + @deprecated(since_version='0.13.0') + def xlabel(self): + return self.__curveParams.get('xlabel', None) + + @property + @deprecated(since_version='0.13.0') + def ylabel(self): + return self.__curveParams.get('ylabel', None) + + @property + @deprecated(since_version='0.13.0') + def legend(self): + return self.__legend + + def _createToolWindow(self): + # import done here rather than at module level to avoid circular import + # FitWidget -> BackgroundWidget -> PlotWindow -> actions -> fit -> FitWidget + from ...fit.FitWidget import FitWidget + + window = FitWidget(parent=self.plot) + window.setWindowFlags(qt.Qt.Dialog) + window.sigFitWidgetSignal.connect(self.handle_signal) + return window + + def _connectPlot(self, window): + if self.isXRangeUpdatedOnZoom(): + self.__setAutoXRangeEnabled(True) + else: + plot = self.plot + if plot is None: + _logger.error("No associated PlotWidget") + return + self._setXRange(*plot.getXAxis().getLimits()) + + if self.isFittedItemUpdatedFromActiveCurve(): + self.__setFittedItemAutoUpdateEnabled(True) + else: + # Wait for the next iteration, else the plot is not yet initialized + # No curve available + qt.QTimer.singleShot(10, self._initFit) + + def _disconnectPlot(self, window): + if self.isXRangeUpdatedOnZoom(): + self.__setAutoXRangeEnabled(False) + + if self.isFittedItemUpdatedFromActiveCurve(): + self.__setFittedItemAutoUpdateEnabled(False) + + def _initFit(self): + plot = self.plot + if plot is None: + _logger.error("No associated PlotWidget") + return + + item = _getUniqueCurveOrHistogram(plot) + if item is None: + # ambiguous case, we need to ask which plot item to fit + isd = ItemsSelectionDialog(parent=plot, plot=plot) + isd.setWindowTitle("Select item to be fitted") + isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection) + isd.setAvailableKinds(["curve", "histogram"]) + isd.selectAllKinds() + + if not isd.exec(): # Cancel + self._getToolWindow().setVisible(False) + else: + selectedItems = isd.getSelectedItems() + item = selectedItems[0] if len(selectedItems) == 1 else None + + self._setXRange(*plot.getXAxis().getLimits()) + self._setFittedItem(item) + + def __updateFitWidget(self): + """Update the data/range used by the FitWidget""" + fitWidget = self._getToolWindow() + + item = self._getFittedItem() + xdata = self.getXData(copy=False) + ydata = self.getYData(copy=False) + if item is None or xdata is None or ydata is None: + fitWidget.setData(y=None) + fitWidget.setWindowTitle("No curve selected") + + else: + xmin, xmax = self.getXRange() + fitWidget.setData( + xdata, ydata, xmin=xmin, xmax=xmax) + fitWidget.setWindowTitle( + "Fitting " + item.getName() + + " on x range %f-%f" % (xmin, xmax)) + + # X Range management + + def getXRange(self): + """Returns the range on the X axis on which to perform the fit.""" + return self.__range + + def _setXRange(self, xmin, xmax): + """Set the range on which the fit is done. + + :param float xmin: + :param float xmax: + """ + range_ = float(xmin), float(xmax) + if self.__range != range_: + self.__range = range_ + self.__updateFitWidget() + + def __setAutoXRangeEnabled(self, enabled): + """Implement the change of update mode of the X range. + + :param bool enabled: + """ + plot = self.plot + if plot is None: + _logger.error("No associated PlotWidget") + return + + if enabled: + self._setXRange(*plot.getXAxis().getLimits()) + plot.getXAxis().sigLimitsChanged.connect(self._setXRange) + else: + plot.getXAxis().sigLimitsChanged.disconnect(self._setXRange) + + def setXRangeUpdatedOnZoom(self, enabled): + """Set whether or not to update the X range on zoom change. + + :param bool enabled: + """ + if enabled != self.__rangeAutoUpdate: + self.__rangeAutoUpdate = enabled + if self._getToolWindow().isVisible(): + self.__setAutoXRangeEnabled(enabled) + + def isXRangeUpdatedOnZoom(self): + """Returns the current mode of fitted data X range update. + + :rtype: bool + """ + return self.__rangeAutoUpdate + + # Fitted item update + + def getXData(self, copy=True): + """Returns the X data used for the fit or None if undefined. + + :param bool copy: + True to get a copy of the data, False to get the internal data. + :rtype: Union[numpy.ndarray,None] + """ + return None if self.__x is None else numpy.array(self.__x, copy=copy) + + def getYData(self, copy=True): + """Returns the Y data used for the fit or None if undefined. + + :param bool copy: + True to get a copy of the data, False to get the internal data. + :rtype: Union[numpy.ndarray,None] + """ + return None if self.__y is None else numpy.array(self.__y, copy=copy) + + def _getFittedItem(self): + """Returns the current item used for the fit + + :rtype: Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None] + """ + return self.__item + + def _setFittedItem(self, item): + """Set the curve to use for fitting. + + :param Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None] item: + """ + plot = self.plot + if plot is None: + _logger.error("No associated PlotWidget") + + if plot is None or item is None: + self.__item = None + self.__curveParams = {} + self.__updateFitWidget() + return + + axis = item.getYAxis() if isinstance(item, items.YAxisMixIn) else 'left' + self.__curveParams = { + 'yaxis': axis, + 'xlabel': plot.getXAxis().getLabel(), + 'ylabel': plot.getYAxis(axis).getLabel(), + } + self.__legend = item.getName() + + if isinstance(item, items.Histogram): + bin_edges = item.getBinEdgesData(copy=False) + # take the middle coordinate between adjacent bin edges + self.__x = (bin_edges[1:] + bin_edges[:-1]) / 2 + self.__y = item.getValueData(copy=False) + # else take the active curve, or else the unique curve + elif isinstance(item, items.Curve): + self.__x = item.getXData(copy=False) + self.__y = item.getYData(copy=False) + + self.__item = item + self.__updateFitWidget() + + def __setFittedItemAutoUpdateEnabled(self, enabled): + """Implement the change of fitted item update mode + + :param bool enabled: + """ + plot = self.plot + if plot is None: + _logger.error("No associated PlotWidget") + return + + self.__fitItemSelector.setPlotWidget(self.plot if enabled else None) + + def setFittedItemUpdatedFromActiveCurve(self, enabled): + """Toggle fitted data synchronization with plot active curve. + + :param bool enabled: + """ + enabled = bool(enabled) + if enabled != self.__activeCurveSynchroEnabled: + self.__activeCurveSynchroEnabled = enabled + if self._getToolWindow().isVisible(): + self.__setFittedItemAutoUpdateEnabled(enabled) + + def isFittedItemUpdatedFromActiveCurve(self): + """Returns True if fitted data is synchronized with plot. + + :rtype: bool + """ + return self.__activeCurveSynchroEnabled + + # Handle fit completed + + def handle_signal(self, ddict): + xdata = self.getXData(copy=False) + if xdata is None: + _logger.error("No reference data to display fit result for") + return + + xmin, xmax = self.getXRange() + x_fit = xdata[xmin <= xdata] + x_fit = x_fit[x_fit <= xmax] + fit_legend = "Fit <%s>" % self.__legend + fit_curve = self.plot.getCurve(fit_legend) + + if ddict["event"] == "FitFinished": + fit_widget = self._getToolWindow() + if fit_widget is None: + return + y_fit = fit_widget.fitmanager.gendata() + if fit_curve is None: + self.plot.addCurve(x_fit, y_fit, + fit_legend, + resetzoom=False, + **self.__curveParams) + else: + fit_curve.setData(x_fit, y_fit) + fit_curve.setVisible(True) + fit_curve.setYAxis(self.__curveParams.get('yaxis', 'left')) + + if ddict["event"] in ["FitStarted", "FitFailed"]: + if fit_curve is not None: + fit_curve.setVisible(False) diff --git a/src/silx/gui/plot/actions/histogram.py b/src/silx/gui/plot/actions/histogram.py new file mode 100644 index 0000000..be9f5a7 --- /dev/null +++ b/src/silx/gui/plot/actions/histogram.py @@ -0,0 +1,542 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +:mod:`silx.gui.plot.actions.histogram` provides actions relative to histograms +for :class:`.PlotWidget`. + +The following QAction are available: + +- :class:`PixelIntensitiesHistoAction` +""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] +__date__ = "01/12/2020" +__license__ = "MIT" + +from typing import Optional, Tuple +import numpy +import logging +import weakref + +from .PlotToolAction import PlotToolAction + +from silx.math.histogram import Histogramnd +from silx.math.combo import min_max +from silx.gui import qt +from silx.gui.plot import items +from silx.gui.widgets.ElidedLabel import ElidedLabel +from silx.gui.widgets.RangeSlider import RangeSlider +from silx.utils.deprecation import deprecated + +_logger = logging.getLogger(__name__) + + +class _ElidedLabel(ElidedLabel): + """QLabel with a default size larger than what is displayed.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse) + + def sizeHint(self): + hint = super().sizeHint() + nbchar = max(len(self.getText()), 12) + width = self.fontMetrics().boundingRect('#' * nbchar).width() + return qt.QSize(max(hint.width(), width), hint.height()) + + +class _StatWidget(qt.QWidget): + """Widget displaying a name and a value + + :param parent: + :param name: + """ + + def __init__(self, parent=None, name: str=''): + super().__init__(parent) + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + + keyWidget = qt.QLabel(parent=self) + keyWidget.setText("<b>" + name.capitalize() + ":<b>") + layout.addWidget(keyWidget) + self.__valueWidget = _ElidedLabel(parent=self) + self.__valueWidget.setText("-") + self.__valueWidget.setTextInteractionFlags( + qt.Qt.TextSelectableByMouse | qt.Qt.TextSelectableByKeyboard) + layout.addWidget(self.__valueWidget) + + def setValue(self, value: Optional[float]): + """Set the displayed value + + :param value: + """ + self.__valueWidget.setText( + "-" if value is None else "{:.5g}".format(value)) + + +class _IntEdit(qt.QLineEdit): + """QLineEdit for integers with a default value and update on validation. + + :param QWidget parent: + """ + + sigValueChanged = qt.Signal(int) + """Signal emitted when the value has changed (on editing finished)""" + + def __init__(self, parent=None): + super().__init__(parent) + self.__value = None + self.setAlignment(qt.Qt.AlignRight) + validator = qt.QIntValidator() + self.setValidator(validator) + validator.bottomChanged.connect(self.__updateSize) + validator.topChanged.connect(self.__updateSize) + self.__updateSize() + + self.textEdited.connect(self.__textEdited) + + def __updateSize(self, *args): + """Update widget's maximum size according to bounds""" + bottom, top = self.getRange() + nbchar = max(len(str(bottom)), len(str(top))) + font = self.font() + font.setStyle(qt.QFont.StyleItalic) + fontMetrics = qt.QFontMetrics(font) + self.setMaximumWidth( + fontMetrics.boundingRect('0' * (nbchar + 1)).width() + ) + self.setMaxLength(nbchar) + + def __textEdited(self, _): + if self.font().style() != qt.QFont.StyleItalic: + font = self.font() + font.setStyle(qt.QFont.StyleItalic) + self.setFont(font) + + # Use events rather than editingFinished to also trigger with empty text + + def focusOutEvent(self, event): + self.__commitValue() + return super().focusOutEvent(event) + + def keyPressEvent(self, event): + if event.key() in (qt.Qt.Key_Enter, qt.Qt.Key_Return): + self.__commitValue() + return super().keyPressEvent(event) + + def __commitValue(self): + """Update the value returned by :meth:`getValue`""" + value = self.getCurrentValue() + if value is None: + value = self.getDefaultValue() + if value is None: + return # No value, keep previous one + + if self.font().style() != qt.QFont.StyleNormal: + font = self.font() + font.setStyle(qt.QFont.StyleNormal) + self.setFont(font) + + if value != self.__value: + self.__value = value + self.sigValueChanged.emit(value) + + def getValue(self) -> Optional[int]: + """Return current value (None if never set).""" + return self.__value + + def setRange(self, bottom: int, top: int): + """Set the range of valid values""" + self.validator().setRange(bottom, top) + + def getRange(self) -> Tuple[int, int]: + """Returns the current range of valid values + + :returns: (bottom, top) + """ + return self.validator().bottom(), self.validator().top() + + def __validate(self, value: int, extend_range: bool): + """Ensure value is in range + + :param int value: + :param bool extend_range: + True to extend range if needed. + False to clip value if needed. + """ + if extend_range: + bottom, top = self.getRange() + self.setRange(min(value, bottom), max(value, top)) + return numpy.clip(value, *self.getRange()) + + def setDefaultValue(self, value: int, extend_range: bool=False): + """Set default value when QLineEdit is empty + + :param int value: + :param bool extend_range: + True to extend range if needed. + False to clip value if needed + """ + self.setPlaceholderText(str(self.__validate(value, extend_range))) + if self.getCurrentValue() is None: + self.__commitValue() + + def getDefaultValue(self) -> Optional[int]: + """Return the default value or the bottom one if not set""" + try: + return int(self.placeholderText()) + except ValueError: + return None + + def setCurrentValue(self, value: int, extend_range: bool=False): + """Set the currently displayed value + + :param int value: + :param bool extend_range: + True to extend range if needed. + False to clip value if needed + """ + self.setText(str(self.__validate(value, extend_range))) + self.__commitValue() + + def getCurrentValue(self) -> Optional[int]: + """Returns the displayed value or None if not correct""" + try: + return int(self.text()) + except ValueError: + return None + + +class HistogramWidget(qt.QWidget): + """Widget displaying a histogram and some statistic indicators""" + + _SUPPORTED_ITEM_CLASS = items.ImageBase, items.Scatter + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setWindowTitle('Histogram') + + self.__itemRef = None # weakref on the item to track + + layout = qt.QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # Plot + # Lazy import to avoid circular dependencies + from silx.gui.plot.PlotWindow import Plot1D + self.__plot = Plot1D(self) + layout.addWidget(self.__plot) + + self.__plot.setDataMargins(0.1, 0.1, 0.1, 0.1) + self.__plot.getXAxis().setLabel("Value") + self.__plot.getYAxis().setLabel("Count") + posInfo = self.__plot.getPositionInfoWidget() + posInfo.setSnappingMode(posInfo.SNAPPING_CURVE) + + # Histogram controls + controlsWidget = qt.QWidget(self) + layout.addWidget(controlsWidget) + controlsLayout = qt.QHBoxLayout(controlsWidget) + controlsLayout.setContentsMargins(4, 4, 4, 4) + + controlsLayout.addWidget(qt.QLabel("<b>Histogram:<b>")) + controlsLayout.addWidget(qt.QLabel("N. bins:")) + self.__nbinsLineEdit = _IntEdit(self) + self.__nbinsLineEdit.setRange(2, 9999) + self.__nbinsLineEdit.sigValueChanged.connect( + self.__updateHistogramFromControls) + controlsLayout.addWidget(self.__nbinsLineEdit) + self.__rangeLabel = qt.QLabel("Range:") + controlsLayout.addWidget(self.__rangeLabel) + self.__rangeSlider = RangeSlider(parent=self) + self.__rangeSlider.sigValueChanged.connect( + self.__updateHistogramFromControls) + self.__rangeSlider.sigValueChanged.connect(self.__rangeChanged) + controlsLayout.addWidget(self.__rangeSlider) + controlsLayout.addStretch(1) + + # Stats display + statsWidget = qt.QWidget(self) + layout.addWidget(statsWidget) + statsLayout = qt.QHBoxLayout(statsWidget) + statsLayout.setContentsMargins(4, 4, 4, 4) + + self.__statsWidgets = dict( + (name, _StatWidget(parent=statsWidget, name=name)) + for name in ("min", "max", "mean", "std", "sum")) + + for widget in self.__statsWidgets.values(): + statsLayout.addWidget(widget) + statsLayout.addStretch(1) + + def getPlotWidget(self): + """Returns :class:`PlotWidget` use to display the histogram""" + return self.__plot + + def resetZoom(self): + """Reset PlotWidget zoom""" + self.getPlotWidget().resetZoom() + + def reset(self): + """Clear displayed information""" + self.getPlotWidget().clear() + self.setStatistics() + + def getItem(self) -> Optional[items.Item]: + """Returns item used to display histogram and statistics.""" + return None if self.__itemRef is None else self.__itemRef() + + def setItem(self, item: Optional[items.Item]): + """Set item from which to display histogram and statistics. + + :param item: + """ + previous = self.getItem() + if previous is not None: + previous.sigItemChanged.disconnect(self.__itemChanged) + + self.__itemRef = None if item is None else weakref.ref(item) + if item is not None: + if isinstance(item, self._SUPPORTED_ITEM_CLASS): + # Only listen signal for supported items + item.sigItemChanged.connect(self.__itemChanged) + self._updateFromItem() + + def __itemChanged(self, event): + """Handle update of the item""" + if event in (items.ItemChangedType.DATA, items.ItemChangedType.MASK): + self._updateFromItem() + + def __updateHistogramFromControls(self, *args): + """Handle udates coming from histogram control widgets""" + + hist = self.getHistogram(copy=False) + if hist is not None: + count, edges = hist + if (len(count) == self.__nbinsLineEdit.getValue() and + (edges[0], edges[-1]) == self.__rangeSlider.getValues()): + return # Nothing has changed + + self._updateFromItem() + + def __rangeChanged(self, first, second): + """Handle change of histogram range from the range slider""" + tooltip = "Histogram range:\n[%g, %g]" % (first, second) + self.__rangeSlider.setToolTip(tooltip) + self.__rangeLabel.setToolTip(tooltip) + + def _updateFromItem(self): + """Update histogram and stats from the item""" + item = self.getItem() + + if item is None: + self.reset() + return + + if not isinstance(item, self._SUPPORTED_ITEM_CLASS): + _logger.error("Unsupported item", item) + self.reset() + return + + # Compute histogram and stats + array = item.getValueData(copy=False) + + if array.size == 0: + self.reset() + return + + xmin, xmax = min_max(array, min_positive=False, finite=True) + if xmin is None or xmax is None: # All not finite data + self.reset() + return + guessed_nbins = min(1024, int(numpy.sqrt(array.size))) + + # bad hack: get 256 bins in the case we have a B&W + if numpy.issubdtype(array.dtype, numpy.integer): + if guessed_nbins > xmax - xmin: + guessed_nbins = xmax - xmin + guessed_nbins = max(2, guessed_nbins) + + # Set default nbins + self.__nbinsLineEdit.setDefaultValue(guessed_nbins, extend_range=True) + # Set slider range: do not keep the range value, but the relative pos. + previousPositions = self.__rangeSlider.getPositions() + if xmin == xmax: # Enlarge range is none + if xmin == 0: + range_ = -0.01, 0.01 + else: + range_ = sorted((xmin * .99, xmin * 1.01)) + else: + range_ = xmin, xmax + + self.__rangeSlider.setRange(*range_) + self.__rangeSlider.setPositions(*previousPositions) + + histogram = Histogramnd( + array.ravel().astype(numpy.float32), + n_bins=max(2, self.__nbinsLineEdit.getValue()), + histo_range=self.__rangeSlider.getValues(), + ) + if len(histogram.edges) != 1: + _logger.error("Error while computing the histogram") + self.reset() + return + + self.setHistogram(histogram.histo, histogram.edges[0]) + self.resetZoom() + self.setStatistics( + min_=xmin, + max_=xmax, + mean=numpy.nanmean(array), + std=numpy.nanstd(array), + sum_=numpy.nansum(array)) + + def setHistogram(self, histogram, edges): + """Set displayed histogram + + :param histogram: Bin values (N) + :param edges: Bin edges (N+1) + """ + # Only useful if setHistogram is called directly + # TODO + #nbins = len(histogram) + #if nbins != self.__nbinsLineEdit.getDefaultValue(): + # self.__nbinsLineEdit.setValue(nbins, extend_range=True) + #self.__rangeSlider.setValues(edges[0], edges[-1]) + + self.getPlotWidget().addHistogram( + histogram=histogram, + edges=edges, + legend='histogram', + fill=True, + color='#66aad7', + resetzoom=False) + + def getHistogram(self, copy: bool=True): + """Returns currently displayed histogram. + + :param copy: True to get a copy, + False to get internal representation (Do not modify!) + :return: (histogram, edges) or None + """ + for item in self.getPlotWidget().getItems(): + if item.getName() == 'histogram': + return (item.getValueData(copy=copy), + item.getBinEdgesData(copy=copy)) + else: + return None + + def setStatistics(self, + min_: Optional[float] = None, + max_: Optional[float] = None, + mean: Optional[float] = None, + std: Optional[float] = None, + sum_: Optional[float] = None): + """Set displayed statistic indicators.""" + self.__statsWidgets['min'].setValue(min_) + self.__statsWidgets['max'].setValue(max_) + self.__statsWidgets['mean'].setValue(mean) + self.__statsWidgets['std'].setValue(std) + self.__statsWidgets['sum'].setValue(sum_) + + +class PixelIntensitiesHistoAction(PlotToolAction): + """QAction to plot the pixels intensities diagram + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + PlotToolAction.__init__(self, + plot, + icon='pixel-intensities', + text='pixels intensity', + tooltip='Compute image intensity distribution', + parent=parent) + + def _connectPlot(self, window): + plot = self.plot + if plot is not None: + selection = plot.selection() + selection.sigSelectedItemsChanged.connect(self._selectedItemsChanged) + self._updateSelectedItem() + + PlotToolAction._connectPlot(self, window) + + def _disconnectPlot(self, window): + plot = self.plot + if plot is not None: + selection = self.plot.selection() + selection.sigSelectedItemsChanged.disconnect(self._selectedItemsChanged) + + PlotToolAction._disconnectPlot(self, window) + self.getHistogramWidget().setItem(None) + + def _updateSelectedItem(self): + """Synchronises selected item with plot widget.""" + plot = self.plot + if plot is not None: + selected = plot.selection().getSelectedItems() + # Give priority to image over scatter + for klass in (items.ImageBase, items.Scatter): + for item in selected: + if isinstance(item, klass): + # Found a matching item, use it + self.getHistogramWidget().setItem(item) + return + self.getHistogramWidget().setItem(None) + + def _selectedItemsChanged(self): + if self._isWindowInUse(): + self._updateSelectedItem() + + @deprecated(since_version='0.15.0') + def computeIntensityDistribution(self): + self.getHistogramWidget()._updateFromItem() + + def getHistogramWidget(self): + """Returns the widget displaying the histogram""" + return self._getToolWindow() + + @deprecated(since_version='0.15.0', + replacement='getHistogramWidget().getPlotWidget()') + def getHistogramPlotWidget(self): + return self._getToolWindow().getPlotWidget() + + def _createToolWindow(self): + return HistogramWidget(self.plot, qt.Qt.Window) + + def getHistogram(self) -> Optional[numpy.ndarray]: + """Return the last computed histogram + + :return: the histogram displayed in the HistogramWidget + """ + histogram = self.getHistogramWidget().getHistogram() + return None if histogram is None else histogram[0] diff --git a/src/silx/gui/plot/actions/io.py b/src/silx/gui/plot/actions/io.py new file mode 100644 index 0000000..7f4edd3 --- /dev/null +++ b/src/silx/gui/plot/actions/io.py @@ -0,0 +1,819 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +:mod:`silx.gui.plot.actions.io` provides a set of QAction relative of inputs +and outputs for a :class:`.PlotWidget`. + +The following QAction are available: + +- :class:`CopyAction` +- :class:`PrintAction` +- :class:`SaveAction` +""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "25/09/2020" + +from . import PlotAction +from silx.io.utils import save1D, savespec, NEXUS_HDF5_EXT +from silx.io.nxdata import save_NXdata +import logging +import sys +import os.path +from collections import OrderedDict +import traceback +import numpy +from silx.utils.deprecation import deprecated +from silx.gui import qt, printer +from silx.gui.dialog.GroupDialog import GroupDialog +from silx.third_party.EdfFile import EdfFile +from silx.third_party.TiffIO import TiffIO +from ...utils.image import convertArrayToQImage +if sys.version_info[0] == 3: + from io import BytesIO +else: + import cStringIO as _StringIO + BytesIO = _StringIO.StringIO + +_logger = logging.getLogger(__name__) + +_NEXUS_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT]) + + +def selectOutputGroup(h5filename): + """Open a dialog to prompt the user to select a group in + which to output data. + + :param str h5filename: name of an existing HDF5 file + :rtype: str + :return: Name of output group, or None if the dialog was cancelled + """ + dialog = GroupDialog() + dialog.addFile(h5filename) + dialog.setWindowTitle("Select an output group") + if not dialog.exec(): + return None + return dialog.getSelectedDataUrl().data_path() + + +class SaveAction(PlotAction): + """QAction for saving Plot content. + + It opens a Save as... dialog. + + :param plot: :class:`.PlotWidget` instance on which to operate. + :param parent: See :class:`QAction`. + """ + + SNAPSHOT_FILTER_SVG = 'Plot Snapshot as SVG (*.svg)' + SNAPSHOT_FILTER_PNG = 'Plot Snapshot as PNG (*.png)' + + DEFAULT_ALL_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG) + + # Dict of curve filters with CSV-like format + # Using ordered dict to guarantee filters order + # Note: '%.18e' is numpy.savetxt default format + CURVE_FILTERS_TXT = OrderedDict(( + ('Curve as Raw ASCII (*.txt)', + {'fmt': '%.18e', 'delimiter': ' ', 'header': False}), + ('Curve as ";"-separated CSV (*.csv)', + {'fmt': '%.18e', 'delimiter': ';', 'header': True}), + ('Curve as ","-separated CSV (*.csv)', + {'fmt': '%.18e', 'delimiter': ',', 'header': True}), + ('Curve as tab-separated CSV (*.csv)', + {'fmt': '%.18e', 'delimiter': '\t', 'header': True}), + ('Curve as OMNIC CSV (*.csv)', + {'fmt': '%.7E', 'delimiter': ',', 'header': False}), + ('Curve as SpecFile (*.dat)', + {'fmt': '%.10g', 'delimiter': '', 'header': False}) + )) + + CURVE_FILTER_NPY = 'Curve as NumPy binary file (*.npy)' + + CURVE_FILTER_NXDATA = 'Curve as NXdata (%s)' % _NEXUS_HDF5_EXT_STR + + DEFAULT_CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [ + CURVE_FILTER_NPY, CURVE_FILTER_NXDATA] + + DEFAULT_ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)",) + + IMAGE_FILTER_EDF = 'Image data as EDF (*.edf)' + IMAGE_FILTER_TIFF = 'Image data as TIFF (*.tif)' + IMAGE_FILTER_NUMPY = 'Image data as NumPy binary file (*.npy)' + IMAGE_FILTER_ASCII = 'Image data as ASCII (*.dat)' + IMAGE_FILTER_CSV_COMMA = 'Image data as ,-separated CSV (*.csv)' + IMAGE_FILTER_CSV_SEMICOLON = 'Image data as ;-separated CSV (*.csv)' + IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)' + IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)' + IMAGE_FILTER_NXDATA = 'Image as NXdata (%s)' % _NEXUS_HDF5_EXT_STR + + DEFAULT_IMAGE_FILTERS = (IMAGE_FILTER_EDF, + IMAGE_FILTER_TIFF, + IMAGE_FILTER_NUMPY, + IMAGE_FILTER_ASCII, + IMAGE_FILTER_CSV_COMMA, + IMAGE_FILTER_CSV_SEMICOLON, + IMAGE_FILTER_CSV_TAB, + IMAGE_FILTER_RGB_PNG, + IMAGE_FILTER_NXDATA) + + SCATTER_FILTER_NXDATA = 'Scatter as NXdata (%s)' % _NEXUS_HDF5_EXT_STR + DEFAULT_SCATTER_FILTERS = (SCATTER_FILTER_NXDATA,) + + # filters for which we don't want an "overwrite existing file" warning + DEFAULT_APPEND_FILTERS = (CURVE_FILTER_NXDATA, IMAGE_FILTER_NXDATA, + SCATTER_FILTER_NXDATA) + + def __init__(self, plot, parent=None): + self._filters = { + 'all': OrderedDict(), + 'curve': OrderedDict(), + 'curves': OrderedDict(), + 'image': OrderedDict(), + 'scatter': OrderedDict()} + + self._appendFilters = list(self.DEFAULT_APPEND_FILTERS) + + # Initialize filters + for nameFilter in self.DEFAULT_ALL_FILTERS: + self.setFileFilter( + dataKind='all', nameFilter=nameFilter, func=self._saveSnapshot) + + for nameFilter in self.DEFAULT_CURVE_FILTERS: + self.setFileFilter( + dataKind='curve', nameFilter=nameFilter, func=self._saveCurve) + + for nameFilter in self.DEFAULT_ALL_CURVES_FILTERS: + self.setFileFilter( + dataKind='curves', nameFilter=nameFilter, func=self._saveCurves) + + for nameFilter in self.DEFAULT_IMAGE_FILTERS: + self.setFileFilter( + dataKind='image', nameFilter=nameFilter, func=self._saveImage) + + for nameFilter in self.DEFAULT_SCATTER_FILTERS: + self.setFileFilter( + dataKind='scatter', nameFilter=nameFilter, func=self._saveScatter) + + super(SaveAction, self).__init__( + plot, icon='document-save', text='Save as...', + tooltip='Save curve/image/plot snapshot dialog', + triggered=self._actionTriggered, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.Save) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + @staticmethod + def _errorMessage(informativeText='', parent=None): + """Display an error message.""" + # TODO issue with QMessageBox size fixed and too small + msg = qt.QMessageBox(parent) + msg.setIcon(qt.QMessageBox.Critical) + msg.setInformativeText(informativeText + ' ' + str(sys.exc_info()[1])) + msg.setDetailedText(traceback.format_exc()) + msg.exec() + + def _saveSnapshot(self, plot, filename, nameFilter): + """Save a snapshot of the :class:`PlotWindow` widget. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter == self.SNAPSHOT_FILTER_PNG: + fileFormat = 'png' + elif nameFilter == self.SNAPSHOT_FILTER_SVG: + fileFormat = 'svg' + else: # Format not supported + _logger.error( + 'Saving plot snapshot failed: format not supported') + return False + + plot.saveGraph(filename, fileFormat=fileFormat) + return True + + def _getAxesLabels(self, item): + # If curve has no associated label, get the default from the plot + xlabel = item.getXLabel() or self.plot.getXAxis().getLabel() + ylabel = item.getYLabel() or self.plot.getYAxis().getLabel() + return xlabel, ylabel + + def _get1dData(self, item): + "provide xdata, [ydata], xlabel, [ylabel] and manages error bars" + xlabel, ylabel = self._getAxesLabels(item) + x_data = item.getXData(copy=False) + y_data = item.getYData(copy=False) + x_err = item.getXErrorData(copy=False) + y_err = item.getYErrorData(copy=False) + labels = [ylabel] + data = [y_data] + + if x_err is not None: + if numpy.isscalar(x_err): + data.append(numpy.zeros_like(y_data) + x_err) + labels.append(xlabel + "_errors") + elif x_err.ndim == 1: + data.append(x_err) + labels.append(xlabel + "_errors") + elif x_err.ndim == 2: + data.append(x_err[0]) + labels.append(xlabel + "_errors_below") + data.append(x_err[1]) + labels.append(xlabel + "_errors_above") + + if y_err is not None: + if numpy.isscalar(y_err): + data.append(numpy.zeros_like(y_data) + y_err) + labels.append(ylabel + "_errors") + elif y_err.ndim == 1: + data.append(y_err) + labels.append(ylabel + "_errors") + elif y_err.ndim == 2: + data.append(y_err[0]) + labels.append(ylabel + "_errors_below") + data.append(y_err[1]) + labels.append(ylabel + "_errors_above") + return x_data, data, xlabel, labels + + @staticmethod + def _selectWriteableOutputGroup(filename, parent): + if os.path.exists(filename) and os.path.isfile(filename) \ + and os.access(filename, os.W_OK): + entryPath = selectOutputGroup(filename) + if entryPath is None: + _logger.info("Save operation cancelled") + return None + return entryPath + elif not os.path.exists(filename): + # create new entry in new file + return "/entry" + else: + SaveAction._errorMessage('Save failed (file access issue)\n', parent=parent) + return None + + def _saveCurveAsNXdata(self, curve, filename): + entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot) + if entryPath is None: + return False + + xlabel, ylabel = self._getAxesLabels(curve) + + return save_NXdata( + filename, + nxentry_name=entryPath, + signal=curve.getYData(copy=False), + axes=[curve.getXData(copy=False)], + signal_name="y", + axes_names=["x"], + signal_long_name=ylabel, + axes_long_names=[xlabel], + signal_errors=curve.getYErrorData(copy=False), + axes_errors=[curve.getXErrorData(copy=True)], + title=self.plot.getGraphTitle()) + + def _saveCurve(self, plot, filename, nameFilter): + """Save a curve from the plot. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter not in self.DEFAULT_CURVE_FILTERS: + return False + + # Check if a curve is to be saved + curve = plot.getActiveCurve() + # before calling _saveCurve, if there is no selected curve, we + # make sure there is only one curve on the graph + if curve is None: + curves = plot.getAllCurves() + if not curves: + self._errorMessage("No curve to be saved", parent=self.plot) + return False + curve = curves[0] + + if nameFilter in self.CURVE_FILTERS_TXT: + filter_ = self.CURVE_FILTERS_TXT[nameFilter] + fmt = filter_['fmt'] + csvdelim = filter_['delimiter'] + autoheader = filter_['header'] + else: + # .npy or nxdata + fmt, csvdelim, autoheader = ("", "", False) + + if nameFilter == self.CURVE_FILTER_NXDATA: + return self._saveCurveAsNXdata(curve, filename) + + xdata, data, xlabel, labels = self._get1dData(curve) + + try: + save1D(filename, + xdata, data, + xlabel, labels, + fmt=fmt, csvdelim=csvdelim, + autoheader=autoheader) + except IOError: + self._errorMessage('Save failed\n', parent=self.plot) + return False + + return True + + def _saveCurves(self, plot, filename, nameFilter): + """Save all curves from the plot. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter not in self.DEFAULT_ALL_CURVES_FILTERS: + return False + + curves = plot.getAllCurves() + if not curves: + self._errorMessage("No curves to be saved", parent=self.plot) + return False + + curve = curves[0] + scanno = 1 + try: + xdata, data, xlabel, labels = self._get1dData(curve) + + specfile = savespec(filename, + xdata, data, + xlabel, labels, + fmt="%.7g", scan_number=1, mode="w", + write_file_header=True, + close_file=False) + except IOError: + self._errorMessage('Save failed\n', parent=self.plot) + return False + + for curve in curves[1:]: + try: + scanno += 1 + xdata, data, xlabel, labels = self._get1dData(curve) + specfile = savespec(specfile, + xdata, data, + xlabel, labels, + fmt="%.7g", scan_number=scanno, + write_file_header=False, + close_file=False) + except IOError: + self._errorMessage('Save failed\n', parent=self.plot) + return False + specfile.close() + + return True + + def _saveImage(self, plot, filename, nameFilter): + """Save an image from the plot. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter not in self.DEFAULT_IMAGE_FILTERS: + return False + + image = plot.getActiveImage() + if image is None: + qt.QMessageBox.warning( + plot, "No Data", "No image to be saved") + return False + + data = image.getData(copy=False) + + # TODO Use silx.io for writing files + if nameFilter == self.IMAGE_FILTER_EDF: + edfFile = EdfFile(filename, access="w+") + edfFile.WriteImage({}, data, Append=0) + return True + + elif nameFilter == self.IMAGE_FILTER_TIFF: + tiffFile = TiffIO(filename, mode='w') + tiffFile.writeImage(data, software='silx') + return True + + elif nameFilter == self.IMAGE_FILTER_NUMPY: + try: + numpy.save(filename, data) + except IOError: + self._errorMessage('Save failed\n', parent=self.plot) + return False + return True + + elif nameFilter == self.IMAGE_FILTER_NXDATA: + entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot) + if entryPath is None: + return False + xorigin, yorigin = image.getOrigin() + xscale, yscale = image.getScale() + xaxis = xorigin + xscale * numpy.arange(data.shape[1]) + yaxis = yorigin + yscale * numpy.arange(data.shape[0]) + xlabel, ylabel = self._getAxesLabels(image) + interpretation = "image" if len(data.shape) == 2 else "rgba-image" + + return save_NXdata(filename, + nxentry_name=entryPath, + signal=data, + axes=[yaxis, xaxis], + signal_name="image", + axes_names=["y", "x"], + axes_long_names=[ylabel, xlabel], + title=plot.getGraphTitle(), + interpretation=interpretation) + + elif nameFilter in (self.IMAGE_FILTER_ASCII, + self.IMAGE_FILTER_CSV_COMMA, + self.IMAGE_FILTER_CSV_SEMICOLON, + self.IMAGE_FILTER_CSV_TAB): + csvdelim, filetype = { + self.IMAGE_FILTER_ASCII: (' ', 'txt'), + self.IMAGE_FILTER_CSV_COMMA: (',', 'csv'), + self.IMAGE_FILTER_CSV_SEMICOLON: (';', 'csv'), + self.IMAGE_FILTER_CSV_TAB: ('\t', 'csv'), + }[nameFilter] + + height, width = data.shape + rows, cols = numpy.mgrid[0:height, 0:width] + try: + save1D(filename, rows.ravel(), (cols.ravel(), data.ravel()), + filetype=filetype, + xlabel='row', + ylabels=['column', 'value'], + csvdelim=csvdelim, + autoheader=True) + + except IOError: + self._errorMessage('Save failed\n', parent=self.plot) + return False + return True + + elif nameFilter == self.IMAGE_FILTER_RGB_PNG: + # Get displayed image + rgbaImage = image.getRgbaImageData(copy=False) + # Convert RGB QImage + qimage = convertArrayToQImage(rgbaImage[:, :, :3]) + + if qimage.save(filename, 'PNG'): + return True + else: + _logger.error('Failed to save image as %s', filename) + qt.QMessageBox.critical( + self.parent(), + 'Save image as', + 'Failed to save image') + + return False + + def _saveScatter(self, plot, filename, nameFilter): + """Save an image from the plot. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter not in self.DEFAULT_SCATTER_FILTERS: + return False + + if nameFilter == self.SCATTER_FILTER_NXDATA: + entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot) + if entryPath is None: + return False + scatter = plot.getScatter() + + x = scatter.getXData(copy=False) + y = scatter.getYData(copy=False) + z = scatter.getValueData(copy=False) + + xerror = scatter.getXErrorData(copy=False) + if isinstance(xerror, float): + xerror = xerror * numpy.ones(x.shape, dtype=numpy.float32) + + yerror = scatter.getYErrorData(copy=False) + if isinstance(yerror, float): + yerror = yerror * numpy.ones(x.shape, dtype=numpy.float32) + + xlabel = plot.getGraphXLabel() + ylabel = plot.getGraphYLabel() + + return save_NXdata( + filename, + nxentry_name=entryPath, + signal=z, + axes=[x, y], + signal_name="values", + axes_names=["x", "y"], + axes_long_names=[xlabel, ylabel], + axes_errors=[xerror, yerror], + title=plot.getGraphTitle()) + + def setFileFilter(self, dataKind, nameFilter, func, index=None, appendToFile=False): + """Set a name filter to add/replace a file format support + + :param str dataKind: + The kind of data for which the provided filter is valid. + One of: 'all', 'curve', 'curves', 'image', 'scatter' + :param str nameFilter: The name filter in the QFileDialog. + See :meth:`QFileDialog.setNameFilters`. + :param callable func: The function to call to perform saving. + Expected signature is: + bool func(PlotWidget plot, str filename, str nameFilter) + :param bool appendToFile: True to append the data into the selected + file. + :param integer index: Index of the filter in the final list (or None) + """ + assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter') + + if appendToFile: + self._appendFilters.append(nameFilter) + + # first append or replace the new filter to prevent colissions + self._filters[dataKind][nameFilter] = func + if index is None: + # we are already done + return + + # get the current ordered list of keys + keyList = list(self._filters[dataKind].keys()) + + # deal with negative indices + if index < 0: + index = len(keyList) + index + if index < 0: + index = 0 + + if index >= len(keyList): + # nothing to be done, already at the end + txt = 'Requested index %d impossible, already at the end' % index + _logger.info(txt) + return + + # get the new ordered list + oldIndex = keyList.index(nameFilter) + del keyList[oldIndex] + keyList.insert(index, nameFilter) + + # build the new filters + newFilters = OrderedDict() + for key in keyList: + newFilters[key] = self._filters[dataKind][key] + + # and update the filters + self._filters[dataKind] = newFilters + return + + def getFileFilters(self, dataKind): + """Returns the nameFilter and associated function for a kind of data. + + :param str dataKind: + The kind of data for which the provided filter is valid. + On of: 'all', 'curve', 'curves', 'image', 'scatter' + :return: {nameFilter: function} associations. + :rtype: collections.OrderedDict + """ + assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter') + + return self._filters[dataKind].copy() + + def _actionTriggered(self, checked=False): + """Handle save action.""" + # Set-up filters + filters = OrderedDict() + + # Add image filters if there is an active image + if self.plot.getActiveImage() is not None: + filters.update(self._filters['image'].items()) + + # Add curve filters if there is a curve to save + if (self.plot.getActiveCurve() is not None or + len(self.plot.getAllCurves()) == 1): + filters.update(self._filters['curve'].items()) + if len(self.plot.getAllCurves()) >= 1: + filters.update(self._filters['curves'].items()) + + # Add scatter filters if there is a scatter + # todo: CSV + if self.plot.getScatter() is not None: + filters.update(self._filters['scatter'].items()) + + filters.update(self._filters['all'].items()) + + # Create and run File dialog + dialog = qt.QFileDialog(self.plot) + dialog.setOption(dialog.DontUseNativeDialog) + dialog.setWindowTitle("Output File Selection") + dialog.setModal(1) + dialog.setNameFilters(list(filters.keys())) + + dialog.setFileMode(dialog.AnyFile) + dialog.setAcceptMode(dialog.AcceptSave) + + def onFilterSelection(filt_): + # disable overwrite confirmation for NXdata types, + # because we append the data to existing files + if filt_ in self._appendFilters: + dialog.setOption(dialog.DontConfirmOverwrite) + else: + dialog.setOption(dialog.DontConfirmOverwrite, False) + + dialog.filterSelected.connect(onFilterSelection) + + if not dialog.exec(): + return False + + nameFilter = dialog.selectedNameFilter() + filename = dialog.selectedFiles()[0] + dialog.close() + + if '(' in nameFilter and ')' == nameFilter.strip()[-1]: + # Check for correct file extension + # Extract file extensions as .something + extensions = [ext[ext.find('.'):] for ext in + nameFilter[nameFilter.find('(') + 1:-1].split()] + for ext in extensions: + if (len(filename) > len(ext) and + filename[-len(ext):].lower() == ext.lower()): + break + else: # filename has no extension supported in nameFilter, add one + if len(extensions) >= 1: + filename += extensions[0] + + # Handle save + func = filters.get(nameFilter, None) + if func is not None: + return func(self.plot, filename, nameFilter) + else: + _logger.error('Unsupported file filter: %s', nameFilter) + return False + + +def _plotAsPNG(plot): + """Save a :class:`Plot` as PNG and return the payload. + + :param plot: The :class:`Plot` to save + """ + pngFile = BytesIO() + plot.saveGraph(pngFile, fileFormat='png') + pngFile.flush() + pngFile.seek(0) + data = pngFile.read() + pngFile.close() + return data + + +class PrintAction(PlotAction): + """QAction for printing the plot. + + It opens a Print dialog. + + Current implementation print a bitmap of the plot area and not vector + graphics, so printing quality is not great. + + :param plot: :class:`.PlotWidget` instance on which to operate. + :param parent: See :class:`QAction`. + """ + + def __init__(self, plot, parent=None): + super(PrintAction, self).__init__( + plot, icon='document-print', text='Print...', + tooltip='Open print dialog', + triggered=self.printPlot, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.Print) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + def getPrinter(self): + """The QPrinter instance used by the PrintAction. + + :rtype: QPrinter + """ + return printer.getDefaultPrinter() + + @property + @deprecated(replacement="getPrinter()", since_version="0.8.0") + def printer(self): + return self.getPrinter() + + def printPlotAsWidget(self): + """Open the print dialog and print the plot. + + Use :meth:`QWidget.render` to print the plot + + :return: True if successful + """ + dialog = qt.QPrintDialog(self.getPrinter(), self.plot) + dialog.setWindowTitle('Print Plot') + if not dialog.exec(): + return False + + # Print a snapshot of the plot widget at the top of the page + widget = self.plot.centralWidget() + + painter = qt.QPainter() + if not painter.begin(self.getPrinter()): + return False + + pageRect = self.getPrinter().pageRect(qt.QPrinter.DevicePixel) + xScale = pageRect.width() / widget.width() + yScale = pageRect.height() / widget.height() + scale = min(xScale, yScale) + + painter.translate(pageRect.width() / 2., 0.) + painter.scale(scale, scale) + painter.translate(-widget.width() / 2., 0.) + widget.render(painter) + painter.end() + + return True + + def printPlot(self): + """Open the print dialog and print the plot. + + Use :meth:`Plot.saveGraph` to print the plot. + + :return: True if successful + """ + # Init printer and start printer dialog + dialog = qt.QPrintDialog(self.getPrinter(), self.plot) + dialog.setWindowTitle('Print Plot') + if not dialog.exec(): + return False + + # Save Plot as PNG and make a pixmap from it with default dpi + pngData = _plotAsPNG(self.plot) + + pixmap = qt.QPixmap() + pixmap.loadFromData(pngData, 'png') + + pageRect = self.getPrinter().pageRect(qt.QPrinter.DevicePixel) + xScale = pageRect.width() / pixmap.width() + yScale = pageRect.height() / pixmap.height() + scale = min(xScale, yScale) + + # Draw pixmap with painter + painter = qt.QPainter() + if not painter.begin(self.getPrinter()): + return False + + painter.drawPixmap(0, 0, + pixmap.width() * scale, + pixmap.height() * scale, + pixmap) + painter.end() + + return True + + +class CopyAction(PlotAction): + """QAction to copy :class:`.PlotWidget` content to clipboard. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(CopyAction, self).__init__( + plot, icon='edit-copy', text='Copy plot', + tooltip='Copy a snapshot of the plot into the clipboard', + triggered=self.copyPlot, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.Copy) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + def copyPlot(self): + """Copy plot content to the clipboard as a bitmap.""" + # Save Plot as PNG and make a QImage from it with default dpi + pngData = _plotAsPNG(self.plot) + image = qt.QImage.fromData(pngData, 'png') + qt.QApplication.clipboard().setImage(image) diff --git a/src/silx/gui/plot/actions/medfilt.py b/src/silx/gui/plot/actions/medfilt.py new file mode 100644 index 0000000..f86a377 --- /dev/null +++ b/src/silx/gui/plot/actions/medfilt.py @@ -0,0 +1,147 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +:mod:`silx.gui.plot.actions.medfilt` provides a set of QAction to apply filter +on data contained in a :class:`.PlotWidget`. + +The following QAction are available: + +- :class:`MedianFilterAction` +- :class:`MedianFilter1DAction` +- :class:`MedianFilter2DAction` + +""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] +__license__ = "MIT" + +__date__ = "10/10/2018" + +from .PlotToolAction import PlotToolAction +from silx.gui.widgets.MedianFilterDialog import MedianFilterDialog +from silx.math.medianfilter import medfilt2d +import logging + +_logger = logging.getLogger(__name__) + + +class MedianFilterAction(PlotToolAction): + """QAction to plot the pixels intensities diagram + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + PlotToolAction.__init__(self, + plot, + icon='median-filter', + text='median filter', + tooltip='Apply a median filter on the image', + parent=parent) + self._originalImage = None + self._legend = None + self._filteredImage = None + + def _createToolWindow(self): + popup = MedianFilterDialog(parent=self.plot) + popup.sigFilterOptChanged.connect(self._updateFilter) + return popup + + def _connectPlot(self, window): + PlotToolAction._connectPlot(self, window) + self.plot.sigActiveImageChanged.connect(self._updateActiveImage) + self._updateActiveImage() + + def _disconnectPlot(self, window): + PlotToolAction._disconnectPlot(self, window) + self.plot.sigActiveImageChanged.disconnect(self._updateActiveImage) + + def _updateActiveImage(self): + """Set _activeImageLegend and _originalImage from the active image""" + self._activeImageLegend = self.plot.getActiveImage(just_legend=True) + if self._activeImageLegend is None: + self._originalImage = None + self._legend = None + else: + self._originalImage = self.plot.getImage(self._activeImageLegend).getData(copy=False) + self._legend = self.plot.getImage(self._activeImageLegend).getName() + + def _updateFilter(self, kernelWidth, conditional=False): + if self._originalImage is None: + return + + self.plot.sigActiveImageChanged.disconnect(self._updateActiveImage) + filteredImage = self._computeFilteredImage(kernelWidth, conditional) + self.plot.addImage(data=filteredImage, + legend=self._legend, + replace=True) + self.plot.sigActiveImageChanged.connect(self._updateActiveImage) + + def _computeFilteredImage(self, kernelWidth, conditional): + raise NotImplementedError('MedianFilterAction is a an abstract class') + + def getFilteredImage(self): + """ + :return: the image with the median filter apply on""" + return self._filteredImage + + +class MedianFilter1DAction(MedianFilterAction): + """Define the MedianFilterAction for 1D + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + def __init__(self, plot, parent=None): + MedianFilterAction.__init__(self, + plot, + parent=parent) + + def _computeFilteredImage(self, kernelWidth, conditional): + assert(self.plot is not None) + return medfilt2d(self._originalImage, + (kernelWidth, 1), + conditional) + + +class MedianFilter2DAction(MedianFilterAction): + """Define the MedianFilterAction for 2D + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + def __init__(self, plot, parent=None): + MedianFilterAction.__init__(self, + plot, + parent=parent) + + def _computeFilteredImage(self, kernelWidth, conditional): + assert(self.plot is not None) + return medfilt2d(self._originalImage, + (kernelWidth, kernelWidth), + conditional) diff --git a/src/silx/gui/plot/actions/mode.py b/src/silx/gui/plot/actions/mode.py new file mode 100644 index 0000000..ee05256 --- /dev/null +++ b/src/silx/gui/plot/actions/mode.py @@ -0,0 +1,104 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +:mod:`silx.gui.plot.actions.mode` provides a set of QAction relative to mouse +mode of a :class:`.PlotWidget`. + +The following QAction are available: + +- :class:`ZoomModeAction` +- :class:`PanModeAction` +""" + +from __future__ import division + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "16/08/2017" + +from . import PlotAction +import logging + +_logger = logging.getLogger(__name__) + + +class ZoomModeAction(PlotAction): + """QAction controlling the zoom mode of a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(ZoomModeAction, self).__init__( + plot, icon='zoom', text='Zoom mode', + tooltip='Zoom in or out', + triggered=self._actionTriggered, + checkable=True, parent=parent) + # Listen to mode change + self.plot.sigInteractiveModeChanged.connect(self._modeChanged) + # Init the state + self._modeChanged(None) + + def _modeChanged(self, source): + modeDict = self.plot.getInteractiveMode() + old = self.blockSignals(True) + self.setChecked(modeDict["mode"] == "zoom") + self.blockSignals(old) + + def _actionTriggered(self, checked=False): + plot = self.plot + if plot is not None: + plot.setInteractiveMode('zoom', source=self) + + +class PanModeAction(PlotAction): + """QAction controlling the pan mode of a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(PanModeAction, self).__init__( + plot, icon='pan', text='Pan mode', + tooltip='Pan the view', + triggered=self._actionTriggered, + checkable=True, parent=parent) + # Listen to mode change + self.plot.sigInteractiveModeChanged.connect(self._modeChanged) + # Init the state + self._modeChanged(None) + + def _modeChanged(self, source): + modeDict = self.plot.getInteractiveMode() + old = self.blockSignals(True) + self.setChecked(modeDict["mode"] == "pan") + self.blockSignals(old) + + def _actionTriggered(self, checked=False): + plot = self.plot + if plot is not None: + plot.setInteractiveMode('pan', source=self) diff --git a/src/silx/gui/plot/backends/BackendBase.py b/src/silx/gui/plot/backends/BackendBase.py new file mode 100755 index 0000000..1e86807 --- /dev/null +++ b/src/silx/gui/plot/backends/BackendBase.py @@ -0,0 +1,568 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""Base class for Plot backends. + +It documents the Plot backend API. + +This API is a simplified version of PyMca PlotBackend API. +""" + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "21/12/2018" + +import weakref +from ... import qt + + +# Names for setCursor +CURSOR_DEFAULT = 'default' +CURSOR_POINTING = 'pointing' +CURSOR_SIZE_HOR = 'size horizontal' +CURSOR_SIZE_VER = 'size vertical' +CURSOR_SIZE_ALL = 'size all' + + +class BackendBase(object): + """Class defining the API a backend of the Plot should provide.""" + + def __init__(self, plot, parent=None): + """Init. + + :param Plot plot: The Plot this backend is attached to + :param parent: The parent widget of the plot widget. + """ + self.__xLimits = 1., 100. + self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)} + self.__yAxisInverted = False + self.__keepDataAspectRatio = False + self.__xAxisTimeSeries = False + self._xAxisTimeZone = None + # Store a weakref to get access to the plot state. + self._setPlot(plot) + + @property + def _plot(self): + """The plot this backend is attached to.""" + if self._plotRef is None: + raise RuntimeError('This backend is not attached to a Plot') + + plot = self._plotRef() + if plot is None: + raise RuntimeError('This backend is no more attached to a Plot') + return plot + + def _setPlot(self, plot): + """Allow to set plot after init. + + Use with caution, basically **immediately** after init. + """ + self._plotRef = weakref.ref(plot) + + # Add methods + + def addCurve(self, x, y, + color, symbol, linewidth, linestyle, + yaxis, + xerror, yerror, + fill, alpha, symbolsize, baseline): + """Add a 1D curve given by x an y to the graph. + + :param numpy.ndarray x: The data corresponding to the x axis + :param numpy.ndarray y: The data corresponding to the y axis + :param color: color(s) to be used + :type color: string ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in colors.py + :param str symbol: Symbol to be drawn at each (x, y) position:: + + - ' ' or '' no symbol + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + + :param float linewidth: The width of the curve in pixels + :param str linestyle: Type of line:: + + - ' ' or '' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + + :param str yaxis: The Y axis this curve belongs to in: 'left', 'right' + :param xerror: Values with the uncertainties on the x values + :type xerror: numpy.ndarray or None + :param yerror: Values with the uncertainties on the y values + :type yerror: numpy.ndarray or None + :param bool fill: True to fill the curve, False otherwise + :param float alpha: Curve opacity, as a float in [0., 1.] + :param float symbolsize: Size of the symbol (if any) drawn + at each (x, y) position. + :returns: The handle used by the backend to univocally access the curve + """ + return object() + + def addImage(self, data, + origin, scale, + colormap, alpha): + """Add an image to the plot. + + :param numpy.ndarray data: (nrows, ncolumns) data or + (nrows, ncolumns, RGBA) ubyte array + :param origin: (origin X, origin Y) of the data. + Default: (0., 0.) + :type origin: 2-tuple of float + :param scale: (scale X, scale Y) of the data. + Default: (1., 1.) + :type scale: 2-tuple of float + :param ~silx.gui.colors.Colormap colormap: Colormap object to use. + Ignored if data is RGB(A). + :param float alpha: Opacity of the image, as a float in range [0, 1]. + :returns: The handle used by the backend to univocally access the image + """ + return object() + + def addTriangles(self, x, y, triangles, + color, alpha): + """Add a set of triangles. + + :param numpy.ndarray x: The data corresponding to the x axis + :param numpy.ndarray y: The data corresponding to the y axis + :param numpy.ndarray triangles: The indices to make triangles + as a (Ntriangle, 3) array + :param numpy.ndarray color: color(s) as (npoints, 4) array + :param float alpha: Opacity as a float in [0., 1.] + :returns: The triangles' unique identifier used by the backend + """ + return object() + + def addShape(self, x, y, shape, color, fill, overlay, + linestyle, linewidth, linebgcolor): + """Add an item (i.e. a shape) to the plot. + + :param numpy.ndarray x: The X coords of the points of the shape + :param numpy.ndarray y: The Y coords of the points of the shape + :param str shape: Type of item to be drawn in + hline, polygon, rectangle, vline, polylines + :param str color: Color of the item + :param bool fill: True to fill the shape + :param bool overlay: True if item is an overlay, False otherwise + :param str linestyle: Style of the line. + Only relevant for line markers where X or Y is None. + Value in: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + :param float linewidth: Width of the line. + Only relevant for line markers where X or Y is None. + :param str linebgcolor: Background color of the line, e.g., 'blue', 'b', + '#FF0000'. It is used to draw dotted line using a second color. + :returns: The handle used by the backend to univocally access the item + """ + return object() + + def addMarker(self, x, y, text, color, + symbol, linestyle, linewidth, constraint, yaxis): + """Add a point, vertical line or horizontal line marker to the plot. + + :param float x: Horizontal position of the marker in graph coordinates. + If None, the marker is a horizontal line. + :param float y: Vertical position of the marker in graph coordinates. + If None, the marker is a vertical line. + :param str text: Text associated to the marker (or None for no text) + :param str color: Color to be used for instance 'blue', 'b', '#FF0000' + :param str symbol: Symbol representing the marker. + Only relevant for point markers where X and Y are not None. + Value in: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + :param str linestyle: Style of the line. + Only relevant for line markers where X or Y is None. + Value in: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + :param float linewidth: Width of the line. + Only relevant for line markers where X or Y is None. + :param constraint: A function filtering marker displacement by + dragging operations or None for no filter. + This function is called each time a marker is + moved. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + :param str yaxis: The Y axis this marker belongs to in: 'left', 'right' + :return: Handle used by the backend to univocally access the marker + """ + return object() + + # Remove methods + + def remove(self, item): + """Remove an existing item from the plot. + + :param item: A backend specific item handle returned by a add* method + """ + pass + + # Interaction methods + + def setGraphCursorShape(self, cursor): + """Set the cursor shape. + + To override in interactive backends. + + :param str cursor: Name of the cursor shape or None + """ + pass + + def setGraphCursor(self, flag, color, linewidth, linestyle): + """Toggle the display of a crosshair cursor and set its attributes. + + To override in interactive backends. + + :param bool flag: Toggle the display of a crosshair cursor. + :param color: The color to use for the crosshair. + :type color: A string (either a predefined color name in colors.py + or "#RRGGBB")) or a 4 columns unsigned byte array. + :param int linewidth: The width of the lines of the crosshair. + :param linestyle: Type of line:: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + + :type linestyle: None or one of the predefined styles. + """ + pass + + def getItemsFromBackToFront(self, condition=None): + """Returns the list of plot items order as rendered by the backend. + + This is the order used for rendering. + By default, it takes into account overlays, z value and order of addition of items, + but backends can override it. + + :param callable condition: + Callable taking an item as input and returning False for items to skip. + If None (default), no item is skipped. + :rtype: List[~silx.gui.plot.items.Item] + """ + # Sort items: Overlays first, then others + # and in each category ordered by z and then by order of addition + # as content keeps this order. + content = self._plot.getItems() + if condition is not None: + content = [item for item in content if condition(item)] + + return sorted( + content, + key=lambda i: ((1 if i.isOverlay() else 0), i.getZValue())) + + def pickItem(self, x, y, item): + """Return picked indices if any, or None. + + :param float x: The x pixel coord where to pick. + :param float y: The y pixel coord where to pick. + :param item: A backend item created with add* methods. + :return: None if item was not picked, else returns + picked indices information. + :rtype: Union[None,List] + """ + return None + + # Update curve + + def setCurveColor(self, curve, color): + """Set the color of a curve. + + :param curve: The curve handle + :param str color: The color to use. + """ + pass + + # Misc. + + def getWidgetHandle(self): + """Return the widget this backend is drawing to.""" + return None + + def postRedisplay(self): + """Trigger backend update and repaint.""" + self.replot() + + def replot(self): + """Redraw the plot.""" + with self._plot._paintContext(): + pass + + def saveGraph(self, fileName, fileFormat, dpi): + """Save the graph to a file (or a StringIO) + + At least "png", "svg" are supported. + + :param fileName: Destination + :type fileName: String or StringIO or BytesIO + :param str fileFormat: String specifying the format + :param int dpi: The resolution to use or None. + """ + pass + + # Graph labels + + def setGraphTitle(self, title): + """Set the main title of the plot. + + :param str title: Title associated to the plot + """ + pass + + def setGraphXLabel(self, label): + """Set the X axis label. + + :param str label: label associated to the plot bottom X axis + """ + pass + + def setGraphYLabel(self, label, axis): + """Set the left Y axis label. + + :param str label: label associated to the plot left Y axis + :param str axis: The axis for which to get the limits: left or right + """ + pass + + # Graph limits + + def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None): + """Set the limits of the X and Y axes at once. + + :param float xmin: minimum bottom axis value + :param float xmax: maximum bottom axis value + :param float ymin: minimum left axis value + :param float ymax: maximum left axis value + :param float y2min: minimum right axis value + :param float y2max: maximum right axis value + """ + self.__xLimits = xmin, xmax + self.__yLimits['left'] = ymin, ymax + if y2min is not None and y2max is not None: + self.__yLimits['right'] = y2min, y2max + + def getGraphXLimits(self): + """Get the graph X (bottom) limits. + + :return: Minimum and maximum values of the X axis + """ + return self.__xLimits + + def setGraphXLimits(self, xmin, xmax): + """Set the limits of X axis. + + :param float xmin: minimum bottom axis value + :param float xmax: maximum bottom axis value + """ + self.__xLimits = xmin, xmax + + def getGraphYLimits(self, axis): + """Get the graph Y (left) limits. + + :param str axis: The axis for which to get the limits: left or right + :return: Minimum and maximum values of the Y axis + """ + return self.__yLimits[axis] + + def setGraphYLimits(self, ymin, ymax, axis): + """Set the limits of the Y axis. + + :param float ymin: minimum left axis value + :param float ymax: maximum left axis value + :param str axis: The axis for which to get the limits: left or right + """ + self.__yLimits[axis] = ymin, ymax + + # Graph axes + + + def getXAxisTimeZone(self): + """Returns tzinfo that is used if the X-Axis plots date-times. + + None means the datetimes are interpreted as local time. + + :rtype: datetime.tzinfo of None. + """ + return self._xAxisTimeZone + + def setXAxisTimeZone(self, tz): + """Sets tzinfo that is used if the X-Axis plots date-times. + + Use None to let the datetimes be interpreted as local time. + + :rtype: datetime.tzinfo of None. + """ + self._xAxisTimeZone = tz + + def isXAxisTimeSeries(self): + """Return True if the X-axis scale shows datetime objects. + + :rtype: bool + """ + return self.__xAxisTimeSeries + + def setXAxisTimeSeries(self, isTimeSeries): + """Set whether the X-axis is a time series + + :param bool flag: True to switch to time series, False for regular axis. + """ + self.__xAxisTimeSeries = bool(isTimeSeries) + + def setXAxisLogarithmic(self, flag): + """Set the X axis scale between linear and log. + + :param bool flag: If True, the bottom axis will use a log scale + """ + pass + + def setYAxisLogarithmic(self, flag): + """Set the Y axis scale between linear and log. + + :param bool flag: If True, the left axis will use a log scale + """ + pass + + def setYAxisInverted(self, flag): + """Invert the Y axis. + + :param bool flag: If True, put the vertical axis origin on the top + """ + self.__yAxisInverted = bool(flag) + + def isYAxisInverted(self): + """Return True if left Y axis is inverted, False otherwise.""" + return self.__yAxisInverted + + def isKeepDataAspectRatio(self): + """Returns whether the plot is keeping data aspect ratio or not.""" + return self.__keepDataAspectRatio + + def setKeepDataAspectRatio(self, flag): + """Set whether to keep data aspect ratio or not. + + :param flag: True to respect data aspect ratio + :type flag: Boolean, default True + """ + self.__keepDataAspectRatio = bool(flag) + + def setGraphGrid(self, which): + """Set grid. + + :param which: None to disable grid, 'major' for major grid, + 'both' for major and minor grid + """ + pass + + # Data <-> Pixel coordinates conversion + + def dataToPixel(self, x, y, axis): + """Convert a position in data space to a position in pixels + in the widget. + + :param float x: The X coordinate in data space. + :param float y: The Y coordinate in data space. + :param str axis: The Y axis to use for the conversion + ('left' or 'right'). + :returns: The corresponding position in pixels or + None if the data position is not in the displayed area. + :rtype: A tuple of 2 floats: (xPixel, yPixel) or None. + """ + raise NotImplementedError() + + def pixelToData(self, x, y, axis): + """Convert a position in pixels in the widget to a position in + the data space. + + :param float x: The X coordinate in pixels. + :param float y: The Y coordinate in pixels. + :param str axis: The Y axis to use for the conversion + ('left' or 'right'). + :returns: The corresponding position in data space or + None if the pixel position is not in the plot area. + :rtype: A tuple of 2 floats: (xData, yData) or None. + """ + raise NotImplementedError() + + def getPlotBoundsInPixels(self): + """Plot area bounds in widget coordinates in pixels. + + :return: bounds as a 4-tuple of int: (left, top, width, height) + """ + raise NotImplementedError() + + def setAxesMargins(self, left: float, top: float, right: float, bottom: float): + """Set the size of plot margins as ratios. + + Values are expected in [0., 1.] + + :param float left: + :param float top: + :param float right: + :param float bottom: + """ + pass + + def setForegroundColors(self, foregroundColor, gridColor): + """Set foreground and grid colors used to display this widget. + + :param List[float] foregroundColor: RGBA foreground color of the widget + :param List[float] gridColor: RGBA grid color of the data view + """ + pass + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + """Set background colors used to display this widget. + + :param List[float] backgroundColor: RGBA background color of the widget + :param Union[Tuple[float],None] dataBackgroundColor: + RGBA background color of the data view + """ + pass diff --git a/src/silx/gui/plot/backends/BackendMatplotlib.py b/src/silx/gui/plot/backends/BackendMatplotlib.py new file mode 100755 index 0000000..7fe4ec0 --- /dev/null +++ b/src/silx/gui/plot/backends/BackendMatplotlib.py @@ -0,0 +1,1557 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Matplotlib Plot backend.""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"] +__license__ = "MIT" +__date__ = "21/12/2018" + + +import logging +import datetime as dt +from typing import Tuple +import numpy + +from pkg_resources import parse_version as _parse_version + + +_logger = logging.getLogger(__name__) + + +from ... import qt + +# First of all init matplotlib and set its backend +from ...utils.matplotlib import FigureCanvasQTAgg +import matplotlib +from matplotlib.container import Container +from matplotlib.figure import Figure +from matplotlib.patches import Rectangle, Polygon +from matplotlib.image import AxesImage +from matplotlib.backend_bases import MouseEvent +from matplotlib.lines import Line2D +from matplotlib.text import Text +from matplotlib.collections import PathCollection, LineCollection +from matplotlib.ticker import Formatter, ScalarFormatter, Locator +from matplotlib.tri import Triangulation +from matplotlib.collections import TriMesh +from matplotlib import path as mpath + +from . import BackendBase +from .. import items +from .._utils import FLOAT32_MINPOS +from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp + +_PATCH_LINESTYLE = { + "-": 'solid', + "--": 'dashed', + '-.': 'dashdot', + ':': 'dotted', + '': "solid", + None: "solid", +} +"""Patches do not uses the same matplotlib syntax""" + +_MARKER_PATHS = {} +"""Store cached extra marker paths""" + +_SPECIAL_MARKERS = { + 'tickleft': 0, + 'tickright': 1, + 'tickup': 2, + 'tickdown': 3, + 'caretleft': 4, + 'caretright': 5, + 'caretup': 6, + 'caretdown': 7, +} + + +def normalize_linestyle(linestyle): + """Normalize known old-style linestyle, else return the provided value.""" + return _PATCH_LINESTYLE.get(linestyle, linestyle) + +def get_path_from_symbol(symbol): + """Get the path representation of a symbol, else None if + it is not provided. + + :param str symbol: Symbol description used by silx + :rtype: Union[None,matplotlib.path.Path] + """ + if symbol == u'\u2665': + path = _MARKER_PATHS.get(symbol, None) + if path is not None: + return path + vertices = numpy.array([ + [0,-99], + [31,-73], [47,-55], [55,-46], + [63,-37], [94,-2], [94,33], + [94,69], [71,89], [47,89], + [24,89], [8,74], [0,58], + [-8,74], [-24,89], [-47,89], + [-71,89], [-94,69], [-94,33], + [-94,-2], [-63,-37], [-55,-46], + [-47,-55], [-31,-73], [0,-99], + [0,-99]]) + codes = [mpath.Path.CURVE4] * len(vertices) + codes[0] = mpath.Path.MOVETO + codes[-1] = mpath.Path.CLOSEPOLY + path = mpath.Path(vertices, codes) + _MARKER_PATHS[symbol] = path + return path + return None + +class NiceDateLocator(Locator): + """ + Matplotlib Locator that uses Nice Numbers algorithm (adapted to dates) + to find the tick locations. This results in the same number behaviour + as when using the silx Open GL backend. + + Expects the data to be posix timestampes (i.e. seconds since 1970) + """ + def __init__(self, numTicks=5, tz=None): + """ + :param numTicks: target number of ticks + :param datetime.tzinfo tz: optional time zone. None is local time. + """ + super(NiceDateLocator, self).__init__() + self.numTicks = numTicks + + self._spacing = None + self._unit = None + self.tz = tz + + @property + def spacing(self): + """ The current spacing. Will be updated when new tick value are made""" + return self._spacing + + @property + def unit(self): + """ The current DtUnit. Will be updated when new tick value are made""" + return self._unit + + def __call__(self): + """Return the locations of the ticks""" + vmin, vmax = self.axis.get_view_interval() + return self.tick_values(vmin, vmax) + + def tick_values(self, vmin, vmax): + """ Calculates tick values + """ + if vmax < vmin: + vmin, vmax = vmax, vmin + + # vmin and vmax should be timestamps (i.e. seconds since 1 Jan 1970) + dtMin = dt.datetime.fromtimestamp(vmin, tz=self.tz) + dtMax = dt.datetime.fromtimestamp(vmax, tz=self.tz) + dtTicks, self._spacing, self._unit = \ + calcTicks(dtMin, dtMax, self.numTicks) + + # Convert datetime back to time stamps. + ticks = [timestamp(dtTick) for dtTick in dtTicks] + return ticks + + +class NiceAutoDateFormatter(Formatter): + """ + Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the + best possible formats given the locators current spacing an date unit. + """ + + def __init__(self, locator, tz=None): + """ + :param niceDateLocator: a NiceDateLocator object + :param datetime.tzinfo tz: optional time zone. None is local time. + """ + super(NiceAutoDateFormatter, self).__init__() + self.locator = locator + self.tz = tz + + @property + def formatString(self): + if self.locator.spacing is None or self.locator.unit is None: + # Locator has no spacing or units yet. Return elaborate fmtString + return "Y-%m-%d %H:%M:%S" + else: + return bestFormatString(self.locator.spacing, self.locator.unit) + + def __call__(self, x, pos=None): + """Return the format for tick val *x* at position *pos* + Expects x to be a POSIX timestamp (seconds since 1 Jan 1970) + """ + dateTime = dt.datetime.fromtimestamp(x, tz=self.tz) + tickStr = dateTime.strftime(self.formatString) + return tickStr + + +class _PickableContainer(Container): + """Artists container with a :meth:`contains` method""" + + def __init__(self, *args, **kwargs): + Container.__init__(self, *args, **kwargs) + self.__zorder = None + + @property + def axes(self): + """Mimin Artist.axes""" + for child in self.get_children(): + if hasattr(child, 'axes'): + return child.axes + return None + + def draw(self, *args, **kwargs): + """artist-like draw to broadcast draw to children""" + for child in self.get_children(): + child.draw(*args, **kwargs) + + def get_zorder(self): + """Mimic Artist.get_zorder""" + return self.__zorder + + def set_zorder(self, z): + """Mimic Artist.set_zorder to broadcast to children""" + if z != self.__zorder: + self.__zorder = z + for child in self.get_children(): + child.set_zorder(z) + + def contains(self, mouseevent): + """Mimic Artist.contains, and call it on all children. + + :param mouseevent: + :return: Picking status and associated information as a dict + :rtype: (bool,dict) + """ + # Goes through children from front to back and return first picked one. + for child in reversed(self.get_children()): + picked, info = child.contains(mouseevent) + if picked: + return picked, info + return False, {} + + +class _TextWithOffset(Text): + """Text object which can be displayed at a specific position + of the plot, but with a pixel offset""" + + def __init__(self, *args, **kwargs): + Text.__init__(self, *args, **kwargs) + self.pixel_offset = (0, 0) + self.__cache = None + + def draw(self, renderer): + self.__cache = None + return Text.draw(self, renderer) + + def __get_xy(self): + if self.__cache is not None: + return self.__cache + + align = self.get_horizontalalignment() + if align == "left": + xoffset = self.pixel_offset[0] + elif align == "right": + xoffset = -self.pixel_offset[0] + else: + xoffset = 0 + + align = self.get_verticalalignment() + if align == "top": + yoffset = -self.pixel_offset[1] + elif align == "bottom": + yoffset = self.pixel_offset[1] + else: + yoffset = 0 + + trans = self.get_transform() + x = super(_TextWithOffset, self).convert_xunits(self._x) + y = super(_TextWithOffset, self).convert_xunits(self._y) + pos = x, y + + try: + invtrans = trans.inverted() + except numpy.linalg.LinAlgError: + # Cannot inverse transform, fallback: pos without offset + self.__cache = None + return pos + + proj = trans.transform_point(pos) + proj = proj + numpy.array((xoffset, yoffset)) + pos = invtrans.transform_point(proj) + self.__cache = pos + return pos + + def convert_xunits(self, x): + """Return the pixel position of the annotated point.""" + return self.__get_xy()[0] + + def convert_yunits(self, y): + """Return the pixel position of the annotated point.""" + return self.__get_xy()[1] + + +class _MarkerContainer(_PickableContainer): + """Marker artists container supporting draw/remove and text position update + + :param artists: + Iterable with either one Line2D or a Line2D and a Text. + The use of an iterable if enforced by Container being + a subclass of tuple that defines a specific __new__. + :param x: X coordinate of the marker (None for horizontal lines) + :param y: Y coordinate of the marker (None for vertical lines) + """ + + def __init__(self, artists, symbol, x, y, yAxis): + self.line = artists[0] + self.text = artists[1] if len(artists) > 1 else None + self.symbol = symbol + self.x = x + self.y = y + self.yAxis = yAxis + + _PickableContainer.__init__(self, artists) + + def draw(self, *args, **kwargs): + """artist-like draw to broadcast draw to line and text""" + self.line.draw(*args, **kwargs) + if self.text is not None: + self.text.draw(*args, **kwargs) + + def updateMarkerText(self, xmin, xmax, ymin, ymax, yinverted): + """Update marker text position and visibility according to plot limits + + :param xmin: X axis lower limit + :param xmax: X axis upper limit + :param ymin: Y axis lower limit + :param ymax: Y axis upper limit + :param yinverted: True if the y axis is inverted + """ + if self.text is not None: + visible = ((self.x is None or xmin <= self.x <= xmax) and + (self.y is None or ymin <= self.y <= ymax)) + self.text.set_visible(visible) + + if self.x is not None and self.y is not None: + if self.symbol is None: + valign = 'baseline' + else: + if yinverted: + valign = 'bottom' + else: + valign = 'top' + self.text.set_verticalalignment(valign) + + elif self.y is None: # vertical line + # Always display it on top + center = (ymax + ymin) * 0.5 + pos = (ymax - ymin) * 0.5 * 0.99 + if yinverted: + pos = -pos + self.text.set_y(center + pos) + + elif self.x is None: # Horizontal line + delta = abs(xmax - xmin) + if xmin > xmax: + xmax = xmin + xmax -= 0.005 * delta + self.text.set_x(xmax) + + def contains(self, mouseevent): + """Mimic Artist.contains, and call it on the line Artist. + + :param mouseevent: + :return: Picking status and associated information as a dict + :rtype: (bool,dict) + """ + return self.line.contains(mouseevent) + + +class _DoubleColoredLinePatch(matplotlib.patches.Patch): + """Matplotlib patch to display any patch using double color.""" + + def __init__(self, patch): + super(_DoubleColoredLinePatch, self).__init__() + self.__patch = patch + self.linebgcolor = None + + def __getattr__(self, name): + return getattr(self.__patch, name) + + def draw(self, renderer): + oldLineStype = self.__patch.get_linestyle() + if self.linebgcolor is not None and oldLineStype != "solid": + oldLineColor = self.__patch.get_edgecolor() + oldHatch = self.__patch.get_hatch() + self.__patch.set_linestyle("solid") + self.__patch.set_edgecolor(self.linebgcolor) + self.__patch.set_hatch(None) + self.__patch.draw(renderer) + self.__patch.set_linestyle(oldLineStype) + self.__patch.set_edgecolor(oldLineColor) + self.__patch.set_hatch(oldHatch) + self.__patch.draw(renderer) + + def set_transform(self, transform): + self.__patch.set_transform(transform) + + def get_path(self): + return self.__patch.get_path() + + def contains(self, mouseevent, radius=None): + return self.__patch.contains(mouseevent, radius) + + def contains_point(self, point, radius=None): + return self.__patch.contains_point(point, radius) + + +class Image(AxesImage): + """An AxesImage with a fast path for uint8 RGBA images. + + :param List[float] silx_origin: (ox, oy) Offset of the image. + :param List[float] silx_scale: (sx, sy) Scale of the image. + """ + + def __init__(self, *args, + silx_origin=(0., 0.), + silx_scale=(1., 1.), + **kwargs): + super().__init__(*args, **kwargs) + self.__silx_origin = silx_origin + self.__silx_scale = silx_scale + + def contains(self, mouseevent): + """Overridden to fill 'ind' with row and column""" + inside, info = super().contains(mouseevent) + if inside: + x, y = mouseevent.xdata, mouseevent.ydata + ox, oy = self.__silx_origin + sx, sy = self.__silx_scale + height, width = self.get_size() + column = numpy.clip(int((x - ox) / sx), 0, width - 1) + row = numpy.clip(int((y - oy) / sy), 0, height - 1) + info['ind'] = (row,), (column,) + return inside, info + + def set_data(self, A): + """Overridden to add a fast path for RGBA unit8 images""" + A = numpy.array(A, copy=False) + if A.ndim != 3 or A.shape[2] != 4 or A.dtype != numpy.uint8: + super(Image, self).set_data(A) + else: + # Call AxesImage.set_data with small data to set attributes + super(Image, self).set_data(numpy.zeros((2, 2, 4), dtype=A.dtype)) + self._A = A # Override stored data + + +class BackendMatplotlib(BackendBase.BackendBase): + """Base class for Matplotlib backend without a FigureCanvas. + + For interactive on screen plot, see :class:`BackendMatplotlibQt`. + + See :class:`BackendBase.BackendBase` for public API documentation. + """ + + def __init__(self, plot, parent=None): + super(BackendMatplotlib, self).__init__(plot, parent) + + # matplotlib is handling keep aspect ratio at draw time + # When keep aspect ratio is on, and one changes the limits and + # ask them *before* next draw has been performed he will get the + # limits without applying keep aspect ratio. + # This attribute is used to ensure consistent values returned + # when getting the limits at the expense of a replot + self._dirtyLimits = True + self._axesDisplayed = True + self._matplotlibVersion = _parse_version(matplotlib.__version__) + + self.fig = Figure() + self.fig.set_facecolor("w") + + self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left") + self.ax2 = self.ax.twinx() + self.ax2.set_label("right") + # Make sure background of Axes is displayed + self.ax2.patch.set_visible(False) + self.ax.patch.set_visible(True) + + # Set axis zorder=0.5 so grid is displayed at 0.5 + self.ax.set_axisbelow(True) + + # disable the use of offsets + try: + axes = [ + self.ax.get_yaxis().get_major_formatter(), + self.ax.get_xaxis().get_major_formatter(), + self.ax2.get_yaxis().get_major_formatter(), + self.ax2.get_xaxis().get_major_formatter(), + ] + for axis in axes: + axis.set_useOffset(False) + axis.set_scientific(False) + except: + _logger.warning('Cannot disabled axes offsets in %s ' + % matplotlib.__version__) + + self.ax2.set_autoscaley_on(True) + + # this works but the figure color is left + if self._matplotlibVersion < _parse_version('2'): + self.ax.set_axis_bgcolor('none') + else: + self.ax.set_facecolor('none') + self.fig.sca(self.ax) + + self._background = None + + self._colormaps = {} + + self._graphCursor = tuple() + + self._enableAxis('right', False) + self._isXAxisTimeSeries = False + + def getItemsFromBackToFront(self, condition=None): + """Order as BackendBase + take into account matplotlib Axes structure""" + def axesOrder(item): + if item.isOverlay(): + return 2 + elif isinstance(item, items.YAxisMixIn) and item.getYAxis() == 'right': + return 1 + else: + return 0 + + return sorted( + BackendBase.BackendBase.getItemsFromBackToFront( + self, condition=condition), + key=axesOrder) + + def _overlayItems(self): + """Generator of backend renderer for overlay items""" + for item in self._plot.getItems(): + if (item.isOverlay() and + item.isVisible() and + item._backendRenderer is not None): + yield item._backendRenderer + + def _hasOverlays(self): + """Returns whether there is an overlay layer or not. + + The overlay layers contains overlay items and the crosshair. + + :rtype: bool + """ + if self._graphCursor: + return True # There is the crosshair + + for item in self._overlayItems(): + return True # There is at least one overlay item + return False + + # Add methods + + def _getMarkerFromSymbol(self, symbol): + """Returns a marker that can be displayed by matplotlib. + + :param str symbol: A symbol description used by silx + :rtype: Union[str,int,matplotlib.path.Path] + """ + path = get_path_from_symbol(symbol) + if path is not None: + return path + num = _SPECIAL_MARKERS.get(symbol, None) + if num is not None: + return num + # This symbol must be supported by matplotlib + return symbol + + def addCurve(self, x, y, + color, symbol, linewidth, linestyle, + yaxis, + xerror, yerror, + fill, alpha, symbolsize, baseline): + for parameter in (x, y, color, symbol, linewidth, linestyle, + yaxis, fill, alpha, symbolsize): + assert parameter is not None + assert yaxis in ('left', 'right') + + if (len(color) == 4 and + type(color[3]) in [type(1), numpy.uint8, numpy.int8]): + color = numpy.array(color, dtype=numpy.float64) / 255. + + if yaxis == "right": + axes = self.ax2 + self._enableAxis("right", True) + else: + axes = self.ax + + pickradius = 3 + + artists = [] # All the artists composing the curve + + # First add errorbars if any so they are behind the curve + if xerror is not None or yerror is not None: + if hasattr(color, 'dtype') and len(color) == len(x): + errorbarColor = 'k' + else: + errorbarColor = color + + # Nx1 error array deprecated in matplotlib >=3.1 (removed in 3.3) + if (isinstance(xerror, numpy.ndarray) and xerror.ndim == 2 and + xerror.shape[1] == 1): + xerror = numpy.ravel(xerror) + if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and + yerror.shape[1] == 1): + yerror = numpy.ravel(yerror) + + errorbars = axes.errorbar(x, y, + xerr=xerror, yerr=yerror, + linestyle=' ', color=errorbarColor) + artists += list(errorbars.get_children()) + + if hasattr(color, 'dtype') and len(color) == len(x): + # scatter plot + if color.dtype not in [numpy.float32, numpy.float64]: + actualColor = color / 255. + else: + actualColor = color + + if linestyle not in ["", " ", None]: + # scatter plot with an actual line ... + # we need to assign a color ... + curveList = axes.plot(x, y, + linestyle=linestyle, + color=actualColor[0], + linewidth=linewidth, + picker=True, + pickradius=pickradius, + marker=None) + artists += list(curveList) + + marker = self._getMarkerFromSymbol(symbol) + scatter = axes.scatter(x, y, + color=actualColor, + marker=marker, + picker=True, + pickradius=pickradius, + s=symbolsize**2) + artists.append(scatter) + + if fill: + if baseline is None: + _baseline = FLOAT32_MINPOS + else: + _baseline = baseline + artists.append(axes.fill_between( + x, _baseline, y, facecolor=actualColor[0], linestyle='')) + + else: # Curve + curveList = axes.plot(x, y, + linestyle=linestyle, + color=color, + linewidth=linewidth, + marker=symbol, + picker=True, + pickradius=pickradius, + markersize=symbolsize) + artists += list(curveList) + + if fill: + if baseline is None: + _baseline = FLOAT32_MINPOS + else: + _baseline = baseline + artists.append( + axes.fill_between(x, _baseline, y, facecolor=color)) + + for artist in artists: + if alpha < 1: + artist.set_alpha(alpha) + + return _PickableContainer(artists) + + def addImage(self, data, origin, scale, colormap, alpha): + # Non-uniform image + # http://wiki.scipy.org/Cookbook/Histograms + # Non-linear axes + # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib + for parameter in (data, origin, scale): + assert parameter is not None + + origin = float(origin[0]), float(origin[1]) + scale = float(scale[0]), float(scale[1]) + height, width = data.shape[0:2] + + # All image are shown as RGBA image + image = Image(self.ax, + interpolation='nearest', + picker=True, + origin='lower', + silx_origin=origin, + silx_scale=scale) + + if alpha < 1: + image.set_alpha(alpha) + + # Set image extent + xmin = origin[0] + xmax = xmin + scale[0] * width + if scale[0] < 0.: + xmin, xmax = xmax, xmin + + ymin = origin[1] + ymax = ymin + scale[1] * height + if scale[1] < 0.: + ymin, ymax = ymax, ymin + + image.set_extent((xmin, xmax, ymin, ymax)) + + # Set image data + if scale[0] < 0. or scale[1] < 0.: + # For negative scale, step by -1 + xstep = 1 if scale[0] >= 0. else -1 + ystep = 1 if scale[1] >= 0. else -1 + data = data[::ystep, ::xstep] + + if data.ndim == 2: # Data image, convert to RGBA image + data = colormap.applyToData(data) + elif data.dtype == numpy.uint16: + # Normalize uint16 data to have a similar behavior as opengl backend + data = data.astype(numpy.float32) + data /= 65535 + + image.set_data(data) + self.ax.add_artist(image) + return image + + def addTriangles(self, x, y, triangles, color, alpha): + for parameter in (x, y, triangles, color, alpha): + assert parameter is not None + + color = numpy.array(color, copy=False) + assert color.ndim == 2 and len(color) == len(x) + + if color.dtype not in [numpy.float32, numpy.float64]: + color = color.astype(numpy.float32) / 255. + + collection = TriMesh( + Triangulation(x, y, triangles), + alpha=alpha, + pickradius=0) # 0 enables picking on filled triangle + collection.set_color(color) + self.ax.add_collection(collection) + + return collection + + def addShape(self, x, y, shape, color, fill, overlay, + linestyle, linewidth, linebgcolor): + if (linebgcolor is not None and + shape not in ('rectangle', 'polygon', 'polylines')): + _logger.warning( + 'linebgcolor not implemented for %s with matplotlib backend', + shape) + xView = numpy.array(x, copy=False) + yView = numpy.array(y, copy=False) + + linestyle = normalize_linestyle(linestyle) + + if shape == "line": + item = self.ax.plot(x, y, color=color, + linestyle=linestyle, linewidth=linewidth, + marker=None)[0] + + elif shape == "hline": + if hasattr(y, "__len__"): + y = y[-1] + item = self.ax.axhline(y, color=color, + linestyle=linestyle, linewidth=linewidth) + + elif shape == "vline": + if hasattr(x, "__len__"): + x = x[-1] + item = self.ax.axvline(x, color=color, + linestyle=linestyle, linewidth=linewidth) + + elif shape == 'rectangle': + xMin = numpy.nanmin(xView) + xMax = numpy.nanmax(xView) + yMin = numpy.nanmin(yView) + yMax = numpy.nanmax(yView) + w = xMax - xMin + h = yMax - yMin + item = Rectangle(xy=(xMin, yMin), + width=w, + height=h, + fill=False, + color=color, + linestyle=linestyle, + linewidth=linewidth) + if fill: + item.set_hatch('.') + + if linestyle != "solid" and linebgcolor is not None: + item = _DoubleColoredLinePatch(item) + item.linebgcolor = linebgcolor + + self.ax.add_patch(item) + + elif shape in ('polygon', 'polylines'): + points = numpy.array((xView, yView)).T + if shape == 'polygon': + closed = True + else: # shape == 'polylines' + closed = numpy.all(numpy.equal(points[0], points[-1])) + item = Polygon(points, + closed=closed, + fill=False, + color=color, + linestyle=linestyle, + linewidth=linewidth) + if fill and shape == 'polygon': + item.set_hatch('/') + + if linestyle != "solid" and linebgcolor is not None: + item = _DoubleColoredLinePatch(item) + item.linebgcolor = linebgcolor + + self.ax.add_patch(item) + + else: + raise NotImplementedError("Unsupported item shape %s" % shape) + + if overlay: + item.set_animated(True) + + return item + + def addMarker(self, x, y, text, color, + symbol, linestyle, linewidth, constraint, yaxis): + textArtist = None + + xmin, xmax = self.getGraphXLimits() + ymin, ymax = self.getGraphYLimits(axis=yaxis) + + if yaxis == 'left': + ax = self.ax + elif yaxis == 'right': + ax = self.ax2 + else: + assert(False) + + marker = self._getMarkerFromSymbol(symbol) + if x is not None and y is not None: + line = ax.plot(x, y, + linestyle=" ", + color=color, + marker=marker, + markersize=10.)[-1] + + if text is not None: + textArtist = _TextWithOffset(x, y, text, + color=color, + horizontalalignment='left') + if symbol is not None: + textArtist.pixel_offset = 10, 3 + elif x is not None: + line = ax.axvline(x, + color=color, + linewidth=linewidth, + linestyle=linestyle) + if text is not None: + # Y position will be updated in updateMarkerText call + textArtist = _TextWithOffset(x, 1., text, + color=color, + horizontalalignment='left', + verticalalignment='top') + textArtist.pixel_offset = 5, 3 + elif y is not None: + line = ax.axhline(y, + color=color, + linewidth=linewidth, + linestyle=linestyle) + + if text is not None: + # X position will be updated in updateMarkerText call + textArtist = _TextWithOffset(1., y, text, + color=color, + horizontalalignment='right', + verticalalignment='top') + textArtist.pixel_offset = 5, 3 + else: + raise RuntimeError('A marker must at least have one coordinate') + + line.set_picker(True) + line.set_pickradius(5) + + # All markers are overlays + line.set_animated(True) + if textArtist is not None: + ax.add_artist(textArtist) + textArtist.set_animated(True) + + artists = [line] if textArtist is None else [line, textArtist] + container = _MarkerContainer(artists, symbol, x, y, yaxis) + container.updateMarkerText(xmin, xmax, ymin, ymax, self.isYAxisInverted()) + + return container + + def _updateMarkers(self): + xmin, xmax = self.ax.get_xbound() + ymin1, ymax1 = self.ax.get_ybound() + ymin2, ymax2 = self.ax2.get_ybound() + yinverted = self.isYAxisInverted() + for item in self._overlayItems(): + if isinstance(item, _MarkerContainer): + if item.yAxis == 'left': + item.updateMarkerText(xmin, xmax, ymin1, ymax1, yinverted) + else: + item.updateMarkerText(xmin, xmax, ymin2, ymax2, yinverted) + + # Remove methods + + def remove(self, item): + try: + item.remove() + except ValueError: + pass # Already removed e.g., in set[X|Y]AxisLogarithmic + + # Interaction methods + + def setGraphCursor(self, flag, color, linewidth, linestyle): + if flag: + lineh = self.ax.axhline( + self.ax.get_ybound()[0], visible=False, color=color, + linewidth=linewidth, linestyle=linestyle) + lineh.set_animated(True) + + linev = self.ax.axvline( + self.ax.get_xbound()[0], visible=False, color=color, + linewidth=linewidth, linestyle=linestyle) + linev.set_animated(True) + + self._graphCursor = lineh, linev + else: + if self._graphCursor: + lineh, linev = self._graphCursor + lineh.remove() + linev.remove() + self._graphCursor = tuple() + + # Active curve + + def setCurveColor(self, curve, color): + # Store Line2D and PathCollection + for artist in curve.get_children(): + if isinstance(artist, (Line2D, LineCollection)): + artist.set_color(color) + elif isinstance(artist, PathCollection): + artist.set_facecolors(color) + artist.set_edgecolors(color) + else: + _logger.warning( + 'setActiveCurve ignoring artist %s', str(artist)) + + # Misc. + + def getWidgetHandle(self): + return self.fig.canvas + + def _enableAxis(self, axis, flag=True): + """Show/hide Y axis + + :param str axis: Axis name: 'left' or 'right' + :param bool flag: Default, True + """ + assert axis in ('right', 'left') + axes = self.ax2 if axis == 'right' else self.ax + axes.get_yaxis().set_visible(flag) + + def replot(self): + """Do not perform rendering. + + Override in subclass to actually draw something. + """ + with self._plot._paintContext(): + self._replot() + + def _replot(self): + """Call from subclass :meth:`replot` to handle updates""" + # TODO images, markers? scatter plot? move in remove? + # Right Y axis only support curve for now + # Hide right Y axis if no line is present + self._dirtyLimits = False + if not self.ax2.lines: + self._enableAxis('right', False) + + def _drawOverlays(self): + """Draw overlays if any.""" + def condition(item): + return (item.isVisible() and + item._backendRenderer is not None and + item.isOverlay()) + + for item in self.getItemsFromBackToFront(condition=condition): + if (isinstance(item, items.YAxisMixIn) and + item.getYAxis() == 'right'): + axes = self.ax2 + else: + axes = self.ax + axes.draw_artist(item._backendRenderer) + + for item in self._graphCursor: + self.ax.draw_artist(item) + + def updateZOrder(self): + """Reorder all items with z order from 0 to 1""" + items = self.getItemsFromBackToFront( + lambda item: item.isVisible() and item._backendRenderer is not None) + count = len(items) + for index, item in enumerate(items): + if item.getZValue() < 0.5: + # Make sure matplotlib z order is below the grid (with z=0.5) + zorder = 0.5 * index / count + else: # Make sure matplotlib z order is above the grid (> 0.5) + zorder = 1. + index / count + if zorder != item._backendRenderer.get_zorder(): + item._backendRenderer.set_zorder(zorder) + + def saveGraph(self, fileName, fileFormat, dpi): + self.updateZOrder() + + # fileName can be also a StringIO or file instance + if dpi is not None: + self.fig.savefig(fileName, format=fileFormat, dpi=dpi) + else: + self.fig.savefig(fileName, format=fileFormat) + self._plot._setDirtyPlot() + + # Graph labels + + def setGraphTitle(self, title): + self.ax.set_title(title) + + def setGraphXLabel(self, label): + self.ax.set_xlabel(label) + + def setGraphYLabel(self, label, axis): + axes = self.ax if axis == 'left' else self.ax2 + axes.set_ylabel(label) + + # Graph limits + + def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None): + # Let matplotlib taking care of keep aspect ratio if any + self._dirtyLimits = True + self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax)) + + if y2min is not None and y2max is not None: + if not self.isYAxisInverted(): + self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max)) + else: + self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max)) + + if not self.isYAxisInverted(): + self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax)) + else: + self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax)) + + self._updateMarkers() + + def getGraphXLimits(self): + if self._dirtyLimits and self.isKeepDataAspectRatio(): + self.ax.apply_aspect() + self.ax2.apply_aspect() + self._dirtyLimits = False + return self.ax.get_xbound() + + def setGraphXLimits(self, xmin, xmax): + self._dirtyLimits = True + self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax)) + self._updateMarkers() + + def getGraphYLimits(self, axis): + assert axis in ('left', 'right') + ax = self.ax2 if axis == 'right' else self.ax + + if not ax.get_visible(): + return None + + if self._dirtyLimits and self.isKeepDataAspectRatio(): + self.ax.apply_aspect() + self.ax2.apply_aspect() + self._dirtyLimits = False + + return ax.get_ybound() + + def setGraphYLimits(self, ymin, ymax, axis): + ax = self.ax2 if axis == 'right' else self.ax + if ymax < ymin: + ymin, ymax = ymax, ymin + self._dirtyLimits = True + + if self.isKeepDataAspectRatio(): + # matplotlib keeps limits of shared axis when keeping aspect ratio + # So x limits are kept when changing y limits.... + # Change x limits first by taking into account aspect ratio + # and then change y limits.. so matplotlib does not need + # to make change (to y) to keep aspect ratio + xmin, xmax = ax.get_xbound() + curYMin, curYMax = ax.get_ybound() + + newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin) + xcenter = 0.5 * (xmin + xmax) + ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange) + + if not self.isYAxisInverted(): + ax.set_ylim(ymin, ymax) + else: + ax.set_ylim(ymax, ymin) + + self._updateMarkers() + + # Graph axes + + def setXAxisTimeZone(self, tz): + super(BackendMatplotlib, self).setXAxisTimeZone(tz) + + # Make new formatter and locator with the time zone. + self.setXAxisTimeSeries(self.isXAxisTimeSeries()) + + def isXAxisTimeSeries(self): + return self._isXAxisTimeSeries + + def setXAxisTimeSeries(self, isTimeSeries): + self._isXAxisTimeSeries = isTimeSeries + if self._isXAxisTimeSeries: + # We can't use a matplotlib.dates.DateFormatter because it expects + # the data to be in datetimes. Silx works internally with + # timestamps (floats). + locator = NiceDateLocator(tz=self.getXAxisTimeZone()) + self.ax.xaxis.set_major_locator(locator) + self.ax.xaxis.set_major_formatter( + NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone())) + else: + try: + scalarFormatter = ScalarFormatter(useOffset=False) + except: + _logger.warning('Cannot disabled axes offsets in %s ' % + matplotlib.__version__) + scalarFormatter = ScalarFormatter() + self.ax.xaxis.set_major_formatter(scalarFormatter) + + def setXAxisLogarithmic(self, flag): + # Workaround for matplotlib 2.1.0 when one tries to set an axis + # to log scale with both limits <= 0 + # In this case a draw with positive limits is needed first + if flag and self._matplotlibVersion >= _parse_version('2.1.0'): + xlim = self.ax.get_xlim() + if xlim[0] <= 0 and xlim[1] <= 0: + self.ax.set_xlim(1, 10) + self.draw() + + self.ax2.set_xscale('log' if flag else 'linear') + self.ax.set_xscale('log' if flag else 'linear') + + def setYAxisLogarithmic(self, flag): + # Workaround for matplotlib 2.0 issue with negative bounds + # before switching to log scale + if flag and self._matplotlibVersion >= _parse_version('2.0.0'): + redraw = False + for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)): + ylim = axis.get_ylim() + if ylim[0] <= 0 or ylim[1] <= 0: + dataRange = self._plot.getDataRange()[dataRangeIndex] + if dataRange is None: + dataRange = 1, 100 # Fallback + axis.set_ylim(*dataRange) + redraw = True + if redraw: + self.draw() + + self.ax2.set_yscale('log' if flag else 'linear') + self.ax.set_yscale('log' if flag else 'linear') + + def setYAxisInverted(self, flag): + if self.ax.yaxis_inverted() != bool(flag): + self.ax.invert_yaxis() + self._updateMarkers() + + def isYAxisInverted(self): + return self.ax.yaxis_inverted() + + def isKeepDataAspectRatio(self): + return self.ax.get_aspect() in (1.0, 'equal') + + def setKeepDataAspectRatio(self, flag): + self.ax.set_aspect(1.0 if flag else 'auto') + self.ax2.set_aspect(1.0 if flag else 'auto') + + def setGraphGrid(self, which): + self.ax.grid(False, which='both') # Disable all grid first + if which is not None: + self.ax.grid(True, which=which) + + # Data <-> Pixel coordinates conversion + + def _getDevicePixelRatio(self) -> float: + """Compatibility wrapper for devicePixelRatioF""" + return 1. + + def _mplToQtPosition(self, x: float, y: float) -> Tuple[float, float]: + """Convert matplotlib "display" space coord to Qt widget logical pixel + """ + ratio = self._getDevicePixelRatio() + # Convert from matplotlib origin (bottom) to Qt origin (top) + # and apply device pixel ratio + return x / ratio, (self.fig.get_window_extent().height - y) / ratio + + def _qtToMplPosition(self, x: float, y: float) -> Tuple[float, float]: + """Convert Qt widget logical pixel to matplotlib "display" space coord + """ + ratio = self._getDevicePixelRatio() + # Apply device pixel ration and + # convert from Qt origin (top) to matplotlib origin (bottom) + return x * ratio, self.fig.get_window_extent().height - (y * ratio) + + def dataToPixel(self, x, y, axis): + ax = self.ax2 if axis == "right" else self.ax + displayPos = ax.transData.transform_point((x, y)).transpose() + return self._mplToQtPosition(*displayPos) + + def pixelToData(self, x, y, axis): + ax = self.ax2 if axis == "right" else self.ax + displayPos = self._qtToMplPosition(x, y) + return tuple(ax.transData.inverted().transform_point(displayPos)) + + def getPlotBoundsInPixels(self): + bbox = self.ax.get_window_extent() + # Warning this is not returning int... + ratio = self._getDevicePixelRatio() + return tuple(int(value / ratio) for value in ( + bbox.xmin, + self.fig.get_window_extent().height - bbox.ymax, + bbox.width, + bbox.height)) + + def setAxesMargins(self, left: float, top: float, right: float, bottom: float): + width, height = 1. - left - right, 1. - top - bottom + position = left, bottom, width, height + + # Toggle display of axes and viewbox rect + isFrameOn = position != (0., 0., 1., 1.) + self.ax.set_frame_on(isFrameOn) + self.ax2.set_frame_on(isFrameOn) + + self.ax.set_position(position) + self.ax2.set_position(position) + + self._synchronizeBackgroundColors() + self._synchronizeForegroundColors() + self._plot._setDirtyPlot() + + def _synchronizeBackgroundColors(self): + backgroundColor = self._plot.getBackgroundColor().getRgbF() + + dataBackgroundColor = self._plot.getDataBackgroundColor() + if dataBackgroundColor.isValid(): + dataBackgroundColor = dataBackgroundColor.getRgbF() + else: + dataBackgroundColor = backgroundColor + + if self.ax.get_frame_on(): + self.fig.patch.set_facecolor(backgroundColor) + if self._matplotlibVersion < _parse_version('2'): + self.ax.set_axis_bgcolor(dataBackgroundColor) + else: + self.ax.set_facecolor(dataBackgroundColor) + else: + self.fig.patch.set_facecolor(dataBackgroundColor) + + def _synchronizeForegroundColors(self): + foregroundColor = self._plot.getForegroundColor().getRgbF() + + gridColor = self._plot.getGridColor() + if gridColor.isValid(): + gridColor = gridColor.getRgbF() + else: + gridColor = foregroundColor + + for axes in (self.ax, self.ax2): + if axes.get_frame_on(): + axes.spines['bottom'].set_color(foregroundColor) + axes.spines['top'].set_color(foregroundColor) + axes.spines['right'].set_color(foregroundColor) + axes.spines['left'].set_color(foregroundColor) + axes.tick_params(axis='x', colors=foregroundColor) + axes.tick_params(axis='y', colors=foregroundColor) + axes.yaxis.label.set_color(foregroundColor) + axes.xaxis.label.set_color(foregroundColor) + axes.title.set_color(foregroundColor) + + for line in axes.get_xgridlines(): + line.set_color(gridColor) + + for line in axes.get_ygridlines(): + line.set_color(gridColor) + # axes.grid().set_markeredgecolor(gridColor) + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + self._synchronizeBackgroundColors() + + def setForegroundColors(self, foregroundColor, gridColor): + self._synchronizeForegroundColors() + + +class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): + """QWidget matplotlib backend using a QtAgg canvas. + + It adds fast overlay drawing and mouse event management. + """ + + _sigPostRedisplay = qt.Signal() + """Signal handling automatic asynchronous replot""" + + def __init__(self, plot, parent=None): + BackendMatplotlib.__init__(self, plot, parent) + FigureCanvasQTAgg.__init__(self, self.fig) + self.setParent(parent) + + self._limitsBeforeResize = None + + FigureCanvasQTAgg.setSizePolicy( + self, qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding) + FigureCanvasQTAgg.updateGeometry(self) + + # Make postRedisplay asynchronous using Qt signal + self._sigPostRedisplay.connect( + self.__deferredReplot, qt.Qt.QueuedConnection) + + self._picked = None + + self.mpl_connect('button_press_event', self._onMousePress) + self.mpl_connect('button_release_event', self._onMouseRelease) + self.mpl_connect('motion_notify_event', self._onMouseMove) + self.mpl_connect('scroll_event', self._onMouseWheel) + + def postRedisplay(self): + self._sigPostRedisplay.emit() + + def __deferredReplot(self): + # Since this is deferred, makes sure it is still needed + plot = self._plotRef() + if (plot is not None and + plot._getDirtyPlot() and + plot.getBackend() is self): + self.replot() + + def _getDevicePixelRatio(self) -> float: + """Compatibility wrapper for devicePixelRatioF""" + if hasattr(self, 'devicePixelRatioF'): + ratio = self.devicePixelRatioF() + else: # Qt < 5.6 compatibility + ratio = float(self.devicePixelRatio()) + # Safety net: avoid returning 0 + return ratio if ratio != 0. else 1. + + # Mouse event forwarding + + _MPL_TO_PLOT_BUTTONS = {1: 'left', 2: 'middle', 3: 'right'} + + def _onMousePress(self, event): + button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None) + if button is not None: + x, y = self._mplToQtPosition(event.x, event.y) + self._plot.onMousePress(int(x), int(y), button) + + def _onMouseMove(self, event): + x, y = self._mplToQtPosition(event.x, event.y) + if self._graphCursor: + position = self._plot.pixelToData( + x, y, axis='left', check=True) + lineh, linev = self._graphCursor + if position is not None: + linev.set_visible(True) + linev.set_xdata((position[0], position[0])) + lineh.set_visible(True) + lineh.set_ydata((position[1], position[1])) + self._plot._setDirtyPlot(overlayOnly=True) + elif lineh.get_visible(): + lineh.set_visible(False) + linev.set_visible(False) + self._plot._setDirtyPlot(overlayOnly=True) + # onMouseMove must trigger replot if dirty flag is raised + + self._plot.onMouseMove(int(x), int(y)) + + def _onMouseRelease(self, event): + button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None) + if button is not None: + x, y = self._mplToQtPosition(event.x, event.y) + self._plot.onMouseRelease(int(x), int(y), button) + + def _onMouseWheel(self, event): + x, y = self._mplToQtPosition(event.x, event.y) + self._plot.onMouseWheel(int(x), int(y), event.step) + + def leaveEvent(self, event): + """QWidget event handler""" + try: + plot = self._plot + except RuntimeError: + pass + else: + plot.onMouseLeaveWidget() + + # picking + + def pickItem(self, x, y, item): + xDisplay, yDisplay = self._qtToMplPosition(x, y) + mouseEvent = MouseEvent( + 'button_press_event', self, int(xDisplay), int(yDisplay)) + # Override axes and data position with the axes + mouseEvent.inaxes = item.axes + mouseEvent.xdata, mouseEvent.ydata = self.pixelToData( + x, y, axis='left' if item.axes is self.ax else 'right') + picked, info = item.contains(mouseEvent) + + if not picked: + return None + + elif isinstance(item, TriMesh): + # Convert selected triangle to data point indices + triangulation = item._triangulation + indices = triangulation.get_masked_triangles()[info['ind'][0]] + + # Sort picked triangle points by distance to mouse + # from furthest to closest to put closest point last + # This is to be somewhat consistent with last scatter point + # being the top one. + xdata, ydata = self.pixelToData(x, y, axis='left') + dists = ((triangulation.x[indices] - xdata) ** 2 + + (triangulation.y[indices] - ydata) ** 2) + return indices[numpy.flip(numpy.argsort(dists), axis=0)] + + else: # Returns indices if any + return info.get('ind', ()) + + # replot control + + def resizeEvent(self, event): + # Store current limits + self._limitsBeforeResize = ( + self.ax.get_xbound(), self.ax.get_ybound(), self.ax2.get_ybound()) + + FigureCanvasQTAgg.resizeEvent(self, event) + if self.isKeepDataAspectRatio() or self._hasOverlays(): + # This is needed with matplotlib 1.5.x and 2.0.x + self._plot._setDirtyPlot() + + def draw(self): + """Overload draw + + It performs a full redraw (including overlays) of the plot. + It also resets background and emit limits changed signal. + + This is directly called by matplotlib for widget resize. + """ + self.updateZOrder() + + # Starting with mpl 2.1.0, toggling autoscale raises a ValueError + # in some situations. See #1081, #1136, #1163, + if self._matplotlibVersion >= _parse_version("2.0.0"): + try: + FigureCanvasQTAgg.draw(self) + except ValueError as err: + _logger.debug( + "ValueError caught while calling FigureCanvasQTAgg.draw: " + "'%s'", err) + else: + FigureCanvasQTAgg.draw(self) + + if self._hasOverlays(): + # Save background + self._background = self.copy_from_bbox(self.fig.bbox) + else: + self._background = None # Reset background + + # Check if limits changed due to a resize of the widget + if self._limitsBeforeResize is not None: + xLimits, yLimits, yRightLimits = self._limitsBeforeResize + self._limitsBeforeResize = None + + if (xLimits != self.ax.get_xbound() or + yLimits != self.ax.get_ybound()): + self._updateMarkers() + + if xLimits != self.ax.get_xbound(): + self._plot.getXAxis()._emitLimitsChanged() + if yLimits != self.ax.get_ybound(): + self._plot.getYAxis(axis='left')._emitLimitsChanged() + if yRightLimits != self.ax2.get_ybound(): + self._plot.getYAxis(axis='right')._emitLimitsChanged() + + self._drawOverlays() + + def replot(self): + with self._plot._paintContext(): + BackendMatplotlib._replot(self) + + dirtyFlag = self._plot._getDirtyPlot() + + if dirtyFlag == 'overlay': + # Only redraw overlays using fast rendering path + if self._background is None: + self._background = self.copy_from_bbox(self.fig.bbox) + self.restore_region(self._background) + self._drawOverlays() + self.blit(self.fig.bbox) + + elif dirtyFlag: # Need full redraw + self.draw() + + # Workaround issue of rendering overlays with some matplotlib versions + if (_parse_version('1.5') <= self._matplotlibVersion < _parse_version('2.1') and + not hasattr(self, '_firstReplot')): + self._firstReplot = False + if self._hasOverlays(): + qt.QTimer.singleShot(0, self.draw) # Request async draw + + # cursor + + _QT_CURSORS = { + BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor, + BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor, + BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor, + BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor, + BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor, + } + + def setGraphCursorShape(self, cursor): + if cursor is None: + FigureCanvasQTAgg.unsetCursor(self) + else: + cursor = self._QT_CURSORS[cursor] + FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor)) diff --git a/src/silx/gui/plot/backends/BackendOpenGL.py b/src/silx/gui/plot/backends/BackendOpenGL.py new file mode 100755 index 0000000..f1a12af --- /dev/null +++ b/src/silx/gui/plot/backends/BackendOpenGL.py @@ -0,0 +1,1420 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""OpenGL Plot backend.""" + +from __future__ import division + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/12/2018" + +import logging +import weakref + +import numpy + +from .. import items +from .._utils import FLOAT32_MINPOS +from . import BackendBase +from ... import colors +from ... import qt + +from ..._glutils import gl +from ... import _glutils as glu +from . import glutils +from .glutils.PlotImageFile import saveImageToFile + +_logger = logging.getLogger(__name__) + + +# TODO idea: BackendQtMixIn class to share code between mpl and gl +# TODO check if OpenGL is available +# TODO make an off-screen mesa backend + +# Content ##################################################################### + +class _ShapeItem(dict): + def __init__(self, x, y, shape, color, fill, overlay, + linestyle, linewidth, linebgcolor): + super(_ShapeItem, self).__init__() + + if shape not in ('polygon', 'rectangle', 'line', + 'vline', 'hline', 'polylines'): + raise NotImplementedError("Unsupported shape {0}".format(shape)) + + x = numpy.array(x, copy=False) + y = numpy.array(y, copy=False) + + if shape == 'rectangle': + xMin, xMax = x + x = numpy.array((xMin, xMin, xMax, xMax)) + yMin, yMax = y + y = numpy.array((yMin, yMax, yMax, yMin)) + + # Ignore fill for polylines to mimic matplotlib + fill = fill if shape != 'polylines' else False + + self.update({ + 'shape': shape, + 'color': colors.rgba(color), + 'fill': 'hatch' if fill else None, + 'x': x, + 'y': y, + 'linestyle': linestyle, + 'linewidth': linewidth, + 'linebgcolor': linebgcolor, + }) + + +class _MarkerItem(dict): + def __init__(self, x, y, text, color, + symbol, linestyle, linewidth, constraint, yaxis): + super(_MarkerItem, self).__init__() + + if symbol is None: + symbol = '+' + + # Apply constraint to provided position + isConstraint = (constraint is not None and + x is not None and y is not None) + if isConstraint: + x, y = constraint(x, y) + + self.update({ + 'x': x, + 'y': y, + 'text': text, + 'color': colors.rgba(color), + 'constraint': constraint if isConstraint else None, + 'symbol': symbol, + 'linestyle': linestyle, + 'linewidth': linewidth, + 'yaxis': yaxis, + }) + + +# shaders ##################################################################### + +_baseVertShd = """ + attribute vec2 position; + uniform mat4 matrix; + uniform bvec2 isLog; + + const float oneOverLog10 = 0.43429448190325176; + + void main(void) { + vec2 posTransformed = position; + if (isLog.x) { + posTransformed.x = oneOverLog10 * log(position.x); + } + if (isLog.y) { + posTransformed.y = oneOverLog10 * log(position.y); + } + gl_Position = matrix * vec4(posTransformed, 0.0, 1.0); + } + """ + +_baseFragShd = """ + uniform vec4 color; + uniform int hatchStep; + uniform float tickLen; + + void main(void) { + if (tickLen != 0.) { + if (mod((gl_FragCoord.x + gl_FragCoord.y) / tickLen, 2.) < 1.) { + gl_FragColor = color; + } else { + discard; + } + } else if (hatchStep == 0 || + mod(gl_FragCoord.x - gl_FragCoord.y, float(hatchStep)) == 0.) { + gl_FragColor = color; + } else { + discard; + } + } + """ + +_texVertShd = """ + attribute vec2 position; + attribute vec2 texCoords; + uniform mat4 matrix; + + varying vec2 coords; + + void main(void) { + gl_Position = matrix * vec4(position, 0.0, 1.0); + coords = texCoords; + } + """ + +_texFragShd = """ + uniform sampler2D tex; + + varying vec2 coords; + + void main(void) { + gl_FragColor = texture2D(tex, coords); + gl_FragColor.a = 1.0; + } + """ + +# BackendOpenGL ############################################################### + + +class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): + """OpenGL-based Plot backend. + + WARNINGS: + Unless stated otherwise, this API is NOT thread-safe and MUST be + called from the main thread. + When numpy arrays are passed as arguments to the API (through + :func:`addCurve` and :func:`addImage`), they are copied only if + required. + So, the caller should not modify these arrays afterwards. + """ + + def __init__(self, plot, parent=None, f=qt.Qt.WindowFlags()): + glu.OpenGLWidget.__init__(self, parent, + alphaBufferSize=8, + depthBufferSize=0, + stencilBufferSize=0, + version=(2, 1), + f=f) + BackendBase.BackendBase.__init__(self, plot, parent) + + self._backgroundColor = 1., 1., 1., 1. + self._dataBackgroundColor = 1., 1., 1., 1. + + self.matScreenProj = glutils.mat4Identity() + + self._progBase = glu.Program( + _baseVertShd, _baseFragShd, attrib0='position') + self._progTex = glu.Program( + _texVertShd, _texFragShd, attrib0='position') + self._plotFBOs = weakref.WeakKeyDictionary() + + self._keepDataAspectRatio = False + + self._crosshairCursor = None + self._mousePosInPixels = None + + self._glGarbageCollector = [] + + self._plotFrame = glutils.GLPlotFrame2D( + foregroundColor=(0., 0., 0., 1.), + gridColor=(.7, .7, .7, 1.), + marginRatios=(.15, .1, .1, .15)) + self._plotFrame.size = ( # Init size with size int + int(self.getDevicePixelRatio() * 640), + int(self.getDevicePixelRatio() * 480)) + + self.setAutoFillBackground(False) + self.setMouseTracking(True) + + # QWidget + + _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'} + + def sizeHint(self): + return qt.QSize(8 * 80, 6 * 80) # Mimic MatplotlibBackend + + def mousePressEvent(self, event): + if event.button() not in self._MOUSE_BTNS: + return super(BackendOpenGL, self).mousePressEvent(event) + self._plot.onMousePress( + event.x(), event.y(), self._MOUSE_BTNS[event.button()]) + event.accept() + + def mouseMoveEvent(self, event): + qtPos = event.x(), event.y() + + previousMousePosInPixels = self._mousePosInPixels + if qtPos == self._mouseInPlotArea(*qtPos): + devicePixelRatio = self.getDevicePixelRatio() + devicePos = qtPos[0] * devicePixelRatio, qtPos[1] * devicePixelRatio + self._mousePosInPixels = devicePos # Mouse in plot area + else: + self._mousePosInPixels = None # Mouse outside plot area + + if (self._crosshairCursor is not None and + previousMousePosInPixels != self._mousePosInPixels): + # Avoid replot when cursor remains outside plot area + self._plot._setDirtyPlot(overlayOnly=True) + + self._plot.onMouseMove(*qtPos) + event.accept() + + def mouseReleaseEvent(self, event): + if event.button() not in self._MOUSE_BTNS: + return super(BackendOpenGL, self).mouseReleaseEvent(event) + self._plot.onMouseRelease( + event.x(), event.y(), self._MOUSE_BTNS[event.button()]) + event.accept() + + def wheelEvent(self, event): + delta = event.angleDelta().y() + angleInDegrees = delta / 8. + if qt.BINDING == "PySide6": + x, y = event.position().x(), event.position().y() + else: + x, y = event.x(), event.y() + self._plot.onMouseWheel(x, y, angleInDegrees) + event.accept() + + def leaveEvent(self, _): + self._plot.onMouseLeaveWidget() + + # OpenGLWidget API + + def initializeGL(self): + gl.testGL() + + gl.glClearStencil(0) + + gl.glEnable(gl.GL_BLEND) + # gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA) + gl.glBlendFuncSeparate(gl.GL_SRC_ALPHA, + gl.GL_ONE_MINUS_SRC_ALPHA, + gl.GL_ONE, + gl.GL_ONE) + + # For lines + gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST) + + # For points + gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2 + gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2 + # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) + + def _paintDirectGL(self): + self._renderPlotAreaGL() + self._plotFrame.render() + self._renderOverlayGL() + + def _paintFBOGL(self): + context = glu.Context.getCurrent() + plotFBOTex = self._plotFBOs.get(context) + if (self._plot._getDirtyPlot() or self._plotFrame.isDirty or + plotFBOTex is None): + self._plotVertices = ( + # Vertex coordinates + numpy.array(((-1., -1.), (1., -1.), (-1., 1.), (1., 1.)), + dtype=numpy.float32), + # Texture coordinates + numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)), + dtype=numpy.float32)) + if plotFBOTex is None or \ + plotFBOTex.shape[1] != self._plotFrame.size[0] or \ + plotFBOTex.shape[0] != self._plotFrame.size[1]: + if plotFBOTex is not None: + plotFBOTex.discard() + plotFBOTex = glu.FramebufferTexture( + gl.GL_RGBA, + shape=(self._plotFrame.size[1], + self._plotFrame.size[0]), + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=(gl.GL_CLAMP_TO_EDGE, + gl.GL_CLAMP_TO_EDGE)) + self._plotFBOs[context] = plotFBOTex + + with plotFBOTex: + gl.glClearColor(*self._backgroundColor) + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) + self._renderPlotAreaGL() + self._plotFrame.render() + + # Render plot in screen coords + gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) + + self._progTex.use() + texUnit = 0 + + gl.glUniform1i(self._progTex.uniforms['tex'], texUnit) + gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE, + glutils.mat4Identity().astype(numpy.float32)) + + gl.glEnableVertexAttribArray(self._progTex.attributes['position']) + gl.glVertexAttribPointer(self._progTex.attributes['position'], + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, + self._plotVertices[0]) + + gl.glEnableVertexAttribArray(self._progTex.attributes['texCoords']) + gl.glVertexAttribPointer(self._progTex.attributes['texCoords'], + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, + self._plotVertices[1]) + + with plotFBOTex.texture: + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices[0])) + + self._renderOverlayGL() + + def paintGL(self): + plot = self._plotRef() + if plot is None: + return + + with plot._paintContext(): + with glu.Context.current(self.context()): + # Release OpenGL resources + for item in self._glGarbageCollector: + item.discard() + self._glGarbageCollector = [] + + gl.glClearColor(*self._backgroundColor) + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) + + # Check if window is large enough + if self._plotFrame.plotSize <= (2, 2): + return + + # Sync plot frame with window + self._plotFrame.devicePixelRatio = self.getDevicePixelRatio() + # self._paintDirectGL() + self._paintFBOGL() + + def _renderItems(self, overlay=False): + """Render items according to :class:`PlotWidget` order + + Note: Scissor test should already be set. + + :param bool overlay: + False (the default) to render item that are not overlays. + True to render items that are overlays. + """ + # Values that are often used + plotWidth, plotHeight = self._plotFrame.plotSize + isXLog = self._plotFrame.xAxis.isLog + isYLog = self._plotFrame.yAxis.isLog + isYInverted = self._plotFrame.isYAxisInverted + + # Used by marker rendering + labels = [] + pixelOffset = 3 + + context = glutils.RenderContext( + isXLog=isXLog, isYLog=isYLog, dpi=self.getDotsPerInch()) + + for plotItem in self.getItemsFromBackToFront( + condition=lambda i: i.isVisible() and i.isOverlay() == overlay): + if plotItem._backendRenderer is None: + continue + + item = plotItem._backendRenderer + + if isinstance(item, glutils.GLPlotItem): # Render data items + gl.glViewport(self._plotFrame.margins.left, + self._plotFrame.margins.bottom, + plotWidth, plotHeight) + # Set matrix + if item.yaxis == 'right': + context.matrix = self._plotFrame.transformedDataY2ProjMat + else: + context.matrix = self._plotFrame.transformedDataProjMat + item.render(context) + + elif isinstance(item, _ShapeItem): # Render shape items + gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) + + if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or + (isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)): + # Ignore items <= 0. on log axes + continue + + if item['shape'] == 'hline': + width = self._plotFrame.size[0] + _, yPixel = self._plotFrame.dataToPixel( + 0.5 * sum(self._plotFrame.dataRanges[0]), + item['y'], + axis='left') + subShapes = [numpy.array(((0., yPixel), (width, yPixel)), + dtype=numpy.float32)] + + elif item['shape'] == 'vline': + xPixel, _ = self._plotFrame.dataToPixel( + item['x'], + 0.5 * sum(self._plotFrame.dataRanges[1]), + axis='left') + height = self._plotFrame.size[1] + subShapes = [numpy.array(((xPixel, 0), (xPixel, height)), + dtype=numpy.float32)] + + else: + # Split sub-shapes at not finite values + splits = numpy.nonzero(numpy.logical_not(numpy.logical_and( + numpy.isfinite(item['x']), numpy.isfinite(item['y']))))[0] + splits = numpy.concatenate(([-1], splits, [len(item['x'])])) + subShapes = [] + for begin, end in zip(splits[:-1] + 1, splits[1:]): + if end > begin: + subShapes.append(numpy.array([ + self._plotFrame.dataToPixel(x, y, axis='left') + for (x, y) in zip(item['x'][begin:end], item['y'][begin:end])])) + + for points in subShapes: # Draw each sub-shape + # Draw the fill + if (item['fill'] is not None and + item['shape'] not in ('hline', 'vline')): + self._progBase.use() + gl.glUniformMatrix4fv( + self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, + self.matScreenProj.astype(numpy.float32)) + gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) + gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) + + shape2D = glutils.FilledShape2D( + points, style=item['fill'], color=item['color']) + shape2D.render( + posAttrib=self._progBase.attributes['position'], + colorUnif=self._progBase.uniforms['color'], + hatchStepUnif=self._progBase.uniforms['hatchStep']) + + # Draw the stroke + if item['linestyle'] not in ('', ' ', None): + if item['shape'] != 'polylines': + # close the polyline + points = numpy.append(points, + numpy.atleast_2d(points[0]), axis=0) + + lines = glutils.GLLines2D( + points[:, 0], points[:, 1], + style=item['linestyle'], + color=item['color'], + dash2ndColor=item['linebgcolor'], + width=item['linewidth']) + context.matrix = self.matScreenProj + lines.render(context) + + elif isinstance(item, _MarkerItem): + gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) + + xCoord, yCoord, yAxis = item['x'], item['y'], item['yaxis'] + + if ((isXLog and xCoord is not None and xCoord <= 0) or + (isYLog and yCoord is not None and yCoord <= 0)): + # Do not render markers with negative coords on log axis + continue + + color = item['color'] + intensity = color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114 + bgColor = (1., 1., 1., 0.5) if intensity <= 0.5 else (0., 0., 0., 0.5) + if xCoord is None or yCoord is None: + if xCoord is None: # Horizontal line in data space + pixelPos = self._plotFrame.dataToPixel( + 0.5 * sum(self._plotFrame.dataRanges[0]), + yCoord, + axis=yAxis) + + if item['text'] is not None: + x = self._plotFrame.size[0] - \ + self._plotFrame.margins.right - pixelOffset + y = pixelPos[1] - pixelOffset + label = glutils.Text2D( + item['text'], x, y, + color=item['color'], + bgColor=bgColor, + align=glutils.RIGHT, + valign=glutils.BOTTOM, + devicePixelRatio=self.getDevicePixelRatio()) + labels.append(label) + + width = self._plotFrame.size[0] + lines = glutils.GLLines2D( + (0, width), (pixelPos[1], pixelPos[1]), + style=item['linestyle'], + color=item['color'], + width=item['linewidth']) + context.matrix = self.matScreenProj + lines.render(context) + + else: # yCoord is None: vertical line in data space + yRange = self._plotFrame.dataRanges[1 if yAxis == 'left' else 2] + pixelPos = self._plotFrame.dataToPixel( + xCoord, 0.5 * sum(yRange), axis=yAxis) + + if item['text'] is not None: + x = pixelPos[0] + pixelOffset + y = self._plotFrame.margins.top + pixelOffset + label = glutils.Text2D( + item['text'], x, y, + color=item['color'], + bgColor=bgColor, + align=glutils.LEFT, + valign=glutils.TOP, + devicePixelRatio=self.getDevicePixelRatio()) + labels.append(label) + + height = self._plotFrame.size[1] + lines = glutils.GLLines2D( + (pixelPos[0], pixelPos[0]), (0, height), + style=item['linestyle'], + color=item['color'], + width=item['linewidth']) + context.matrix = self.matScreenProj + lines.render(context) + + else: + xmin, xmax = self._plot.getXAxis().getLimits() + ymin, ymax = self._plot.getYAxis(axis=yAxis).getLimits() + if not xmin < xCoord < xmax or not ymin < yCoord < ymax: + # Do not render markers outside visible plot area + continue + pixelPos = self._plotFrame.dataToPixel( + xCoord, yCoord, axis=yAxis) + + if isYInverted: + valign = glutils.BOTTOM + vPixelOffset = -pixelOffset + else: + valign = glutils.TOP + vPixelOffset = pixelOffset + + if item['text'] is not None: + x = pixelPos[0] + pixelOffset + y = pixelPos[1] + vPixelOffset + label = glutils.Text2D( + item['text'], x, y, + color=item['color'], + bgColor=bgColor, + align=glutils.LEFT, + valign=valign, + devicePixelRatio=self.getDevicePixelRatio()) + labels.append(label) + + # For now simple implementation: using a curve for each marker + # Should pack all markers to a single set of points + markerCurve = glutils.GLPlotCurve2D( + numpy.array((pixelPos[0],), dtype=numpy.float64), + numpy.array((pixelPos[1],), dtype=numpy.float64), + marker=item['symbol'], + markerColor=item['color'], + markerSize=11) + + context = glutils.RenderContext( + matrix=self.matScreenProj, + isXLog=False, + isYLog=False, + dpi=self.getDotsPerInch()) + markerCurve.render(context) + markerCurve.discard() + + else: + _logger.error('Unsupported item: %s', str(item)) + continue + + # Render marker labels + gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) + for label in labels: + label.render(self.matScreenProj) + + def _renderOverlayGL(self): + """Render overlay layer: overlay items and crosshair.""" + plotWidth, plotHeight = self._plotFrame.plotSize + + # Scissor to plot area + gl.glScissor(self._plotFrame.margins.left, + self._plotFrame.margins.bottom, + plotWidth, plotHeight) + gl.glEnable(gl.GL_SCISSOR_TEST) + + self._renderItems(overlay=True) + + # Render crosshair cursor + if self._crosshairCursor is not None and self._mousePosInPixels is not None: + self._progBase.use() + gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) + gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) + posAttrib = self._progBase.attributes['position'] + matrixUnif = self._progBase.uniforms['matrix'] + colorUnif = self._progBase.uniforms['color'] + hatchStepUnif = self._progBase.uniforms['hatchStep'] + + gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) + + gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE, + self.matScreenProj.astype(numpy.float32)) + + color, lineWidth = self._crosshairCursor + gl.glUniform4f(colorUnif, *color) + gl.glUniform1i(hatchStepUnif, 0) + + xPixel, yPixel = self._mousePosInPixels + xPixel, yPixel = xPixel + 0.5, yPixel + 0.5 + vertices = numpy.array(((0., yPixel), + (self._plotFrame.size[0], yPixel), + (xPixel, 0.), + (xPixel, self._plotFrame.size[1])), + dtype=numpy.float32) + + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, vertices) + gl.glLineWidth(lineWidth) + gl.glDrawArrays(gl.GL_LINES, 0, len(vertices)) + + gl.glDisable(gl.GL_SCISSOR_TEST) + + def _renderPlotAreaGL(self): + """Render base layer of plot area. + + It renders the background, grid and items except overlays + """ + plotWidth, plotHeight = self._plotFrame.plotSize + + gl.glScissor(self._plotFrame.margins.left, + self._plotFrame.margins.bottom, + plotWidth, plotHeight) + gl.glEnable(gl.GL_SCISSOR_TEST) + + if self._dataBackgroundColor != self._backgroundColor: + gl.glClearColor(*self._dataBackgroundColor) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + + self._plotFrame.renderGrid() + + # Matrix + trBounds = self._plotFrame.transformedDataRanges + if trBounds.x[0] != trBounds.x[1] and trBounds.y[0] != trBounds.y[1]: + # Do rendering of items + self._renderItems(overlay=False) + + gl.glDisable(gl.GL_SCISSOR_TEST) + + def resizeGL(self, width, height): + if width == 0 or height == 0: # Do not resize + return + + self._plotFrame.size = ( + int(self.getDevicePixelRatio() * width), + int(self.getDevicePixelRatio() * height)) + + self.matScreenProj = glutils.mat4Ortho( + 0, self._plotFrame.size[0], + self._plotFrame.size[1], 0, + 1, -1) + + # Store current ranges + previousXRange = self.getGraphXLimits() + previousYRange = self.getGraphYLimits(axis='left') + previousYRightRange = self.getGraphYLimits(axis='right') + + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \ + self._plotFrame.dataRanges + self.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + + # If plot range has changed, then emit signal + if previousXRange != self.getGraphXLimits(): + self._plot.getXAxis()._emitLimitsChanged() + if previousYRange != self.getGraphYLimits(axis='left'): + self._plot.getYAxis(axis='left')._emitLimitsChanged() + if previousYRightRange != self.getGraphYLimits(axis='right'): + self._plot.getYAxis(axis='right')._emitLimitsChanged() + + # Add methods + + @staticmethod + def _castArrayTo(v): + """Returns best floating type to cast the array to. + + :param numpy.ndarray v: Array to cast + :rtype: numpy.dtype + :raise ValueError: If dtype is not supported + """ + if numpy.issubdtype(v.dtype, numpy.floating): + return numpy.float32 if v.itemsize <= 4 else numpy.float64 + elif numpy.issubdtype(v.dtype, numpy.integer): + return numpy.float32 if v.itemsize <= 2 else numpy.float64 + else: + raise ValueError('Unsupported data type') + + def addCurve(self, x, y, + color, symbol, linewidth, linestyle, + yaxis, + xerror, yerror, + fill, alpha, symbolsize, baseline): + for parameter in (x, y, color, symbol, linewidth, linestyle, + yaxis, fill, symbolsize): + assert parameter is not None + assert yaxis in ('left', 'right') + + # Convert input data + x = numpy.array(x, copy=False) + y = numpy.array(y, copy=False) + + # Check if float32 is enough + if (self._castArrayTo(x) is numpy.float32 and + self._castArrayTo(y) is numpy.float32): + dtype = numpy.float32 + else: + dtype = numpy.float64 + + x = numpy.array(x, dtype=dtype, copy=False, order='C') + y = numpy.array(y, dtype=dtype, copy=False, order='C') + + # Convert errors to float32 + if xerror is not None: + xerror = numpy.array( + xerror, dtype=numpy.float32, copy=False, order='C') + if yerror is not None: + yerror = numpy.array( + yerror, dtype=numpy.float32, copy=False, order='C') + + # Handle axes log scale: convert data + + if self._plotFrame.xAxis.isLog: + logX = numpy.log10(x) + + if xerror is not None: + # Transform xerror so that + # log10(x) +/- xerror' = log10(x +/- xerror) + if hasattr(xerror, 'shape') and len(xerror.shape) == 2: + xErrorMinus, xErrorPlus = xerror[0], xerror[1] + else: + xErrorMinus, xErrorPlus = xerror, xerror + with numpy.errstate(divide='ignore', invalid='ignore'): + # Ignore divide by zero, invalid value encountered in log10 + xErrorMinus = logX - numpy.log10(x - xErrorMinus) + xErrorPlus = numpy.log10(x + xErrorPlus) - logX + xerror = numpy.array((xErrorMinus, xErrorPlus), + dtype=numpy.float32) + + x = logX + + isYLog = (yaxis == 'left' and self._plotFrame.yAxis.isLog) or ( + yaxis == 'right' and self._plotFrame.y2Axis.isLog) + + if isYLog: + logY = numpy.log10(y) + + if yerror is not None: + # Transform yerror so that + # log10(y) +/- yerror' = log10(y +/- yerror) + if hasattr(yerror, 'shape') and len(yerror.shape) == 2: + yErrorMinus, yErrorPlus = yerror[0], yerror[1] + else: + yErrorMinus, yErrorPlus = yerror, yerror + with numpy.errstate(divide='ignore', invalid='ignore'): + # Ignore divide by zero, invalid value encountered in log10 + yErrorMinus = logY - numpy.log10(y - yErrorMinus) + yErrorPlus = numpy.log10(y + yErrorPlus) - logY + yerror = numpy.array((yErrorMinus, yErrorPlus), + dtype=numpy.float32) + + y = logY + + # TODO check if need more filtering of error (e.g., clip to positive) + + # TODO check and improve this + if (len(color) == 4 and + type(color[3]) in [type(1), numpy.uint8, numpy.int8]): + color = numpy.array(color, dtype=numpy.float32) / 255. + + if isinstance(color, numpy.ndarray) and color.ndim == 2: + colorArray = color + color = None + else: + colorArray = None + color = colors.rgba(color) + + if alpha < 1.: # Apply image transparency + if colorArray is not None and colorArray.shape[1] == 4: + # multiply alpha channel + colorArray[:, 3] = colorArray[:, 3] * alpha + if color is not None: + color = color[0], color[1], color[2], color[3] * alpha + + fillColor = None + if fill is True: + fillColor = color + curve = glutils.GLPlotCurve2D( + x, y, colorArray, + xError=xerror, + yError=yerror, + lineStyle=linestyle, + lineColor=color, + lineWidth=linewidth, + marker=symbol, + markerColor=color, + markerSize=symbolsize, + fillColor=fillColor, + baseline=baseline, + isYLog=isYLog) + curve.yaxis = 'left' if yaxis is None else yaxis + + if yaxis == "right": + self._plotFrame.isY2Axis = True + + return curve + + def addImage(self, data, + origin, scale, + colormap, alpha): + for parameter in (data, origin, scale): + assert parameter is not None + + if data.ndim == 2: + # Ensure array is contiguous and eventually convert its type + dtypes = [dtype for dtype in ( + numpy.float32, numpy.float16, numpy.uint8, numpy.uint16) + if glu.isSupportedGLType(dtype)] + if data.dtype in dtypes: + data = numpy.array(data, copy=False, order='C') + else: + _logger.info( + 'addImage: Convert %s data to float32', str(data.dtype)) + data = numpy.array(data, dtype=numpy.float32, order='C') + + normalization = colormap.getNormalization() + if normalization in glutils.GLPlotColormap.SUPPORTED_NORMALIZATIONS: + # Fast path applying colormap on the GPU + cmapRange = colormap.getColormapRange(data=data) + colormapLut = colormap.getNColors(nbColors=256) + gamma = colormap.getGammaNormalizationParameter() + nanColor = colors.rgba(colormap.getNaNColor()) + + image = glutils.GLPlotColormap( + data, + origin, + scale, + colormapLut, + normalization, + gamma, + cmapRange, + alpha, + nanColor) + + else: # Fallback applying colormap on CPU + rgba = colormap.applyToData(data) + image = glutils.GLPlotRGBAImage(rgba, origin, scale, alpha) + + elif len(data.shape) == 3: + # For RGB, RGBA data + assert data.shape[2] in (3, 4) + + if numpy.issubdtype(data.dtype, numpy.floating): + data = numpy.array(data, dtype=numpy.float32, copy=False) + elif data.dtype in [numpy.uint8, numpy.uint16]: + pass + elif numpy.issubdtype(data.dtype, numpy.integer): + data = numpy.array(data, dtype=numpy.uint8, copy=False) + else: + raise ValueError('Unsupported data type') + + image = glutils.GLPlotRGBAImage(data, origin, scale, alpha) + + else: + raise RuntimeError("Unsupported data shape {0}".format(data.shape)) + + # TODO is this needed? + if self._plotFrame.xAxis.isLog and image.xMin <= 0.: + raise RuntimeError( + 'Cannot add image with X <= 0 with X axis log scale') + if self._plotFrame.yAxis.isLog and image.yMin <= 0.: + raise RuntimeError( + 'Cannot add image with Y <= 0 with Y axis log scale') + + return image + + def addTriangles(self, x, y, triangles, + color, alpha): + # Handle axes log scale: convert data + if self._plotFrame.xAxis.isLog: + x = numpy.log10(x) + if self._plotFrame.yAxis.isLog: + y = numpy.log10(y) + + triangles = glutils.GLPlotTriangles(x, y, color, triangles, alpha) + + return triangles + + def addShape(self, x, y, shape, color, fill, overlay, + linestyle, linewidth, linebgcolor): + x = numpy.array(x, copy=False) + y = numpy.array(y, copy=False) + + # TODO is this needed? + if self._plotFrame.xAxis.isLog and x.min() <= 0.: + raise RuntimeError( + 'Cannot add item with X <= 0 with X axis log scale') + if self._plotFrame.yAxis.isLog and y.min() <= 0.: + raise RuntimeError( + 'Cannot add item with Y <= 0 with Y axis log scale') + + return _ShapeItem(x, y, shape, color, fill, overlay, + linestyle, linewidth, linebgcolor) + + def addMarker(self, x, y, text, color, + symbol, linestyle, linewidth, constraint, yaxis): + return _MarkerItem(x, y, text, color, + symbol, linestyle, linewidth, constraint, yaxis) + + # Remove methods + + def remove(self, item): + if isinstance(item, glutils.GLPlotItem): + if item.yaxis == 'right': + # Check if some curves remains on the right Y axis + y2AxisItems = (item for item in self._plot.getItems() + if isinstance(item, items.YAxisMixIn) and + item.getYAxis() == 'right') + self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None + + if item.isInitialized(): + self._glGarbageCollector.append(item) + + elif isinstance(item, (_MarkerItem, _ShapeItem)): + pass # No-op + + else: + _logger.error('Unsupported item: %s', str(item)) + + # Interaction methods + + _QT_CURSORS = { + BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor, + BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor, + BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor, + BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor, + BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor, + } + + def setGraphCursorShape(self, cursor): + if cursor is None: + super(BackendOpenGL, self).unsetCursor() + else: + cursor = self._QT_CURSORS[cursor] + super(BackendOpenGL, self).setCursor(qt.QCursor(cursor)) + + def setGraphCursor(self, flag, color, linewidth, linestyle): + if linestyle != '-': + _logger.warning( + "BackendOpenGL.setGraphCursor linestyle parameter ignored") + + if flag: + color = colors.rgba(color) + crosshairCursor = color, linewidth + else: + crosshairCursor = None + + if crosshairCursor != self._crosshairCursor: + self._crosshairCursor = crosshairCursor + + _PICK_OFFSET = 3 # Offset in pixel used for picking + + def _mouseInPlotArea(self, x, y): + """Returns closest visible position in the plot. + + This is performed in Qt widget pixel, not device pixel. + + :param float x: X coordinate in Qt widget pixel + :param float y: Y coordinate in Qt widget pixel + :return: (x, y) closest point in the plot. + :rtype: List[float] + """ + left, top, width, height = self.getPlotBoundsInPixels() + return (numpy.clip(x, left, left + width - 1), # TODO -1? + numpy.clip(y, top, top + height - 1)) + + def __pickCurves(self, item, x, y): + """Perform picking on a curve item. + + :param GLPlotCurve2D item: + :param float x: X position of the mouse in widget coordinates + :param float y: Y position of the mouse in widget coordinates + :return: List of indices of picked points or None if not picked + :rtype: Union[List[int],None] + """ + offset = self._PICK_OFFSET + if item.marker is not None: + # Convert markerSize from points to qt pixels + qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio() + size = item.markerSize / 72. * qtDpi + offset = max(size / 2., offset) + if item.lineStyle is not None: + # Convert line width from points to qt pixels + qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio() + lineWidth = item.lineWidth / 72. * qtDpi + offset = max(lineWidth / 2., offset) + + inAreaPos = self._mouseInPlotArea(x - offset, y - offset) + dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1], + axis=item.yaxis, check=True) + if dataPos is None: + return None + xPick0, yPick0 = dataPos + + inAreaPos = self._mouseInPlotArea(x + offset, y + offset) + dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1], + axis=item.yaxis, check=True) + if dataPos is None: + return None + xPick1, yPick1 = dataPos + + if xPick0 < xPick1: + xPickMin, xPickMax = xPick0, xPick1 + else: + xPickMin, xPickMax = xPick1, xPick0 + + if yPick0 < yPick1: + yPickMin, yPickMax = yPick0, yPick1 + else: + yPickMin, yPickMax = yPick1, yPick0 + + # Apply log scale if axis is log + if self._plotFrame.xAxis.isLog: + xPickMin = numpy.log10(xPickMin) + xPickMax = numpy.log10(xPickMax) + + if (item.yaxis == 'left' and self._plotFrame.yAxis.isLog) or ( + item.yaxis == 'right' and self._plotFrame.y2Axis.isLog): + yPickMin = numpy.log10(yPickMin) + yPickMax = numpy.log10(yPickMax) + + return item.pick(xPickMin, yPickMin, + xPickMax, yPickMax) + + def pickItem(self, x, y, item): + # Picking is performed in Qt widget pixels not device pixels + dataPos = self._plot.pixelToData(x, y, axis='left', check=True) + if dataPos is None: + return None # Outside plot area + + if item is None: + _logger.error("No item provided for picking") + return None + + # Pick markers + if isinstance(item, _MarkerItem): + yaxis = item['yaxis'] + pixelPos = self._plot.dataToPixel( + item['x'], item['y'], axis=yaxis, check=False) + if pixelPos is None: + return None # negative coord on a log axis + + if item['x'] is None: # Horizontal line + pt1 = self._plot.pixelToData( + x, y - self._PICK_OFFSET, axis=yaxis, check=False) + pt2 = self._plot.pixelToData( + x, y + self._PICK_OFFSET, axis=yaxis, check=False) + isPicked = (min(pt1[1], pt2[1]) <= item['y'] <= + max(pt1[1], pt2[1])) + + elif item['y'] is None: # Vertical line + pt1 = self._plot.pixelToData( + x - self._PICK_OFFSET, y, axis=yaxis, check=False) + pt2 = self._plot.pixelToData( + x + self._PICK_OFFSET, y, axis=yaxis, check=False) + isPicked = (min(pt1[0], pt2[0]) <= item['x'] <= + max(pt1[0], pt2[0])) + + else: + isPicked = ( + numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET and + numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET) + + return (0,) if isPicked else None + + # Pick image, curve, triangles + elif isinstance(item, glutils.GLPlotItem): + if isinstance(item, glutils.GLPlotCurve2D): + return self.__pickCurves(item, x, y) + else: + return item.pick(*dataPos) # Might be None + + # Update curve + + def setCurveColor(self, curve, color): + pass # TODO + + # Misc. + + def getWidgetHandle(self): + return self + + def postRedisplay(self): + self.update() + + def replot(self): + self.update() # async redraw + + def saveGraph(self, fileName, fileFormat, dpi): + if dpi is not None: + _logger.warning("saveGraph ignores dpi parameter") + + if fileFormat not in ['png', 'ppm', 'svg', 'tiff']: + raise NotImplementedError('Unsupported format: %s' % fileFormat) + + if not self.isValid(): + _logger.error('OpenGL 2.1 not available, cannot save OpenGL image') + width, height = self._plotFrame.size + data = numpy.zeros((height, width, 3), dtype=numpy.uint8) + else: + self.makeCurrent() + + data = numpy.empty( + (self._plotFrame.size[1], self._plotFrame.size[0], 3), + dtype=numpy.uint8, order='C') + + context = self.context() + framebufferTexture = self._plotFBOs.get(context) + if framebufferTexture is None: + # Fallback, supports direct rendering mode: _paintDirectGL + # might have issues as it can read on-screen framebuffer + fboName = self.defaultFramebufferObject() + width, height = self._plotFrame.size + else: + fboName = framebufferTexture.name + height, width = framebufferTexture.shape + + previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fboName) + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + gl.glReadPixels(0, 0, width, height, + gl.GL_RGB, gl.GL_UNSIGNED_BYTE, data) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer) + + # glReadPixels gives bottom to top, + # while images are stored as top to bottom + data = numpy.flipud(data) + + # fileName is either a file-like object or a str + saveImageToFile(data, fileName, fileFormat) + + # Graph labels + + def setGraphTitle(self, title): + self._plotFrame.title = title + + def setGraphXLabel(self, label): + self._plotFrame.xAxis.title = label + + def setGraphYLabel(self, label, axis): + if axis == 'left': + self._plotFrame.yAxis.title = label + else: # right axis + self._plotFrame.y2Axis.title = label + + # Graph limits + + def _setDataRanges(self, xlim=None, ylim=None, y2lim=None): + """Set the visible range of data in the plot frame. + + This clips the ranges to possible values (takes care of float32 + range + positive range for log). + This also takes care of non-orthogonal axes. + + This should be moved to PlotFrame. + """ + # Update axes range with a clipped range if too wide + self._plotFrame.setDataRanges(xlim, ylim, y2lim) + + def _ensureAspectRatio(self, keepDim=None): + """Update plot bounds in order to keep aspect ratio. + + Warning: keepDim on right Y axis is not implemented ! + + :param str keepDim: The dimension to maintain: 'x', 'y' or None. + If None (the default), the dimension with the largest range. + """ + plotWidth, plotHeight = self._plotFrame.plotSize + if plotWidth <= 2 or plotHeight <= 2: + return + + if keepDim is None: + ranges = self._plot.getDataRange() + if (ranges.y is not None and + ranges.x is not None and + (ranges.y[1] - ranges.y[0]) != 0.): + dataRatio = (ranges.x[1] - ranges.x[0]) / float(ranges.y[1] - ranges.y[0]) + plotRatio = plotWidth / float(plotHeight) # Test != 0 before + + keepDim = 'x' if dataRatio > plotRatio else 'y' + else: # Limit case + keepDim = 'x' + + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \ + self._plotFrame.dataRanges + if keepDim == 'y': + dataW = (yMax - yMin) * plotWidth / float(plotHeight) + xCenter = 0.5 * (xMin + xMax) + xMin = xCenter - 0.5 * dataW + xMax = xCenter + 0.5 * dataW + elif keepDim == 'x': + dataH = (xMax - xMin) * plotHeight / float(plotWidth) + yCenter = 0.5 * (yMin + yMax) + yMin = yCenter - 0.5 * dataH + yMax = yCenter + 0.5 * dataH + y2Center = 0.5 * (y2Min + y2Max) + y2Min = y2Center - 0.5 * dataH + y2Max = y2Center + 0.5 * dataH + else: + raise RuntimeError('Unsupported dimension to keep: %s' % keepDim) + + # Update plot frame bounds + self._setDataRanges(xlim=(xMin, xMax), + ylim=(yMin, yMax), + y2lim=(y2Min, y2Max)) + + def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None, + keepDim=None): + # Update axes range with a clipped range if too wide + self._setDataRanges(xlim=xRange, + ylim=yRange, + y2lim=y2Range) + + # Keep data aspect ratio + if self.isKeepDataAspectRatio(): + self._ensureAspectRatio(keepDim) + + def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None): + assert xmin < xmax + assert ymin < ymax + + if y2min is None or y2max is None: + y2Range = None + else: + assert y2min < y2max + y2Range = y2min, y2max + self._setPlotBounds((xmin, xmax), (ymin, ymax), y2Range) + + def getGraphXLimits(self): + return self._plotFrame.dataRanges.x + + def setGraphXLimits(self, xmin, xmax): + assert xmin < xmax + self._setPlotBounds(xRange=(xmin, xmax), keepDim='x') + + def getGraphYLimits(self, axis): + assert axis in ("left", "right") + if axis == "left": + return self._plotFrame.dataRanges.y + else: + return self._plotFrame.dataRanges.y2 + + def setGraphYLimits(self, ymin, ymax, axis): + assert ymin < ymax + assert axis in ("left", "right") + + if axis == "left": + self._setPlotBounds(yRange=(ymin, ymax), keepDim='y') + else: + self._setPlotBounds(y2Range=(ymin, ymax), keepDim='y') + + # Graph axes + + def getXAxisTimeZone(self): + return self._plotFrame.xAxis.timeZone + + def setXAxisTimeZone(self, tz): + self._plotFrame.xAxis.timeZone = tz + + def isXAxisTimeSeries(self): + return self._plotFrame.xAxis.isTimeSeries + + def setXAxisTimeSeries(self, isTimeSeries): + self._plotFrame.xAxis.isTimeSeries = isTimeSeries + + def setXAxisLogarithmic(self, flag): + if flag != self._plotFrame.xAxis.isLog: + if flag and self._keepDataAspectRatio: + _logger.warning( + "KeepDataAspectRatio is ignored with log axes") + + self._plotFrame.xAxis.isLog = flag + + def setYAxisLogarithmic(self, flag): + if (flag != self._plotFrame.yAxis.isLog or + flag != self._plotFrame.y2Axis.isLog): + if flag and self._keepDataAspectRatio: + _logger.warning( + "KeepDataAspectRatio is ignored with log axes") + + self._plotFrame.yAxis.isLog = flag + self._plotFrame.y2Axis.isLog = flag + + def setYAxisInverted(self, flag): + if flag != self._plotFrame.isYAxisInverted: + self._plotFrame.isYAxisInverted = flag + + def isYAxisInverted(self): + return self._plotFrame.isYAxisInverted + + def isKeepDataAspectRatio(self): + if self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog: + return False + else: + return self._keepDataAspectRatio + + def setKeepDataAspectRatio(self, flag): + if flag and (self._plotFrame.xAxis.isLog or + self._plotFrame.yAxis.isLog): + _logger.warning("KeepDataAspectRatio is ignored with log axes") + + self._keepDataAspectRatio = flag + + def setGraphGrid(self, which): + assert which in (None, 'major', 'both') + self._plotFrame.grid = which is not None # TODO True grid support + + # Data <-> Pixel coordinates conversion + + def dataToPixel(self, x, y, axis): + result = self._plotFrame.dataToPixel(x, y, axis) + if result is None: + return None + else: + devicePixelRatio = self.getDevicePixelRatio() + return tuple(value/devicePixelRatio for value in result) + + def pixelToData(self, x, y, axis): + devicePixelRatio = self.getDevicePixelRatio() + return self._plotFrame.pixelToData( + x * devicePixelRatio, y * devicePixelRatio, axis) + + def getPlotBoundsInPixels(self): + devicePixelRatio = self.getDevicePixelRatio() + return tuple(int(value / devicePixelRatio) + for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize) + + def setAxesMargins(self, left: float, top: float, right: float, bottom: float): + self._plotFrame.marginRatios = left, top, right, bottom + + def setForegroundColors(self, foregroundColor, gridColor): + self._plotFrame.foregroundColor = foregroundColor + self._plotFrame.gridColor = gridColor + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + self._backgroundColor = backgroundColor + self._dataBackgroundColor = dataBackgroundColor diff --git a/src/silx/gui/plot/backends/__init__.py b/src/silx/gui/plot/backends/__init__.py new file mode 100644 index 0000000..966d9df --- /dev/null +++ b/src/silx/gui/plot/backends/__init__.py @@ -0,0 +1,29 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package implements the backend of the Plot.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/03/2017" diff --git a/src/silx/gui/plot/backends/glutils/GLPlotCurve.py b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py new file mode 100644 index 0000000..e4667b4 --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py @@ -0,0 +1,1380 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides classes to render 2D lines and scatter plots +""" + +from __future__ import division + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import math +import logging + +import numpy + +from silx.math.combo import min_max + +from ...._glutils import gl +from ...._glutils import Program, vertexBuffer, VertexBufferAttrib +from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate +from .GLPlotImage import GLPlotItem + + +_logger = logging.getLogger(__name__) + + +_MPL_NONES = None, 'None', '', ' ' +"""Possible values for None""" + + +def _notNaNSlices(array, length=1): + """Returns slices of none NaN values in the array. + + :param numpy.ndarray array: 1D array from which to get slices + :param int length: Slices shorter than length gets discarded + :return: Array of (start, end) slice indices + :rtype: numpy.ndarray + """ + isnan = numpy.isnan(numpy.array(array, copy=False).reshape(-1)) + notnan = numpy.logical_not(isnan) + start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1 + if notnan[0]: + start = numpy.append(0, start) + end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1 + if notnan[-1]: + end = numpy.append(end, len(array)) + slices = numpy.transpose((start, end)) + if length > 1: + # discard slices with less than length values + slices = slices[numpy.diff(slices, axis=1).ravel() >= length] + return slices + + +# fill ######################################################################## + +class _Fill2D(object): + """Object rendering curve filling as polygons + + :param numpy.ndarray xData: X coordinates of points + :param numpy.ndarray yData: Y coordinates of points + :param float baseline: Y value of the 'bottom' of the fill. + 0 for linear Y scale, -38 for log Y scale + :param List[float] color: RGBA color as 4 float in [0, 1] + :param List[float] offset: Translation of coordinates (ox, oy) + """ + + _PROGRAM = Program( + vertexShader=""" + #version 120 + + uniform mat4 matrix; + attribute float xPos; + attribute float yPos; + + void main(void) { + gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0); + } + """, + fragmentShader=""" + #version 120 + + uniform vec4 color; + + void main(void) { + gl_FragColor = color; + } + """, + attrib0='xPos') + + def __init__(self, xData=None, yData=None, + baseline=0, + color=(0., 0., 0., 1.), + offset=(0., 0.)): + self.xData = xData + self.yData = yData + self._xFillVboData = None + self._yFillVboData = None + self.color = color + self.offset = offset + + # Offset baseline + self.baseline = baseline - self.offset[1] + + def prepare(self): + """Rendering preparation: build indices and bounding box vertices""" + if (self._xFillVboData is None and + self.xData is not None and self.yData is not None): + + # Get slices of not NaN values longer than 1 element + isnan = numpy.logical_or(numpy.isnan(self.xData), numpy.isnan(self.yData)) + notnan = numpy.logical_not(isnan) + start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1 + if notnan[0]: + start = numpy.append(0, start) + end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1 + if notnan[-1]: + end = numpy.append(end, len(isnan)) + slices = numpy.transpose((start, end)) + # discard slices with less than length values + slices = slices[numpy.diff(slices, axis=1).reshape(-1) >= 2] + + # Number of points: slice + 2 * leading and trailing points + # Twice leading and trailing points to produce degenerated triangles + nbPoints = numpy.sum(numpy.diff(slices, axis=1)) * 2 + 4 * len(slices) + points = numpy.empty((nbPoints, 2), dtype=numpy.float32) + + offset = 0 + # invert baseline for filling + new_y_data = numpy.append(self.yData, self.baseline) + for start, end in slices: + # Duplicate first point for connecting degenerated triangle + points[offset:offset+2] = self.xData[start], new_y_data[start] + + # 2nd point of the polygon is last point + points[offset+2] = self.xData[start], self.baseline[start] + + indices = numpy.append(numpy.arange(start, end), + numpy.arange(len(self.xData) + end-1, len(self.xData) + start-1, -1)) + indices = indices[buildFillMaskIndices(len(indices))] + + points[offset+3:offset+3+len(indices), 0] = self.xData[indices % len(self.xData)] + points[offset+3:offset+3+len(indices), 1] = new_y_data[indices] + + # Duplicate last point for connecting degenerated triangle + points[offset+3+len(indices)] = points[offset+3+len(indices)-1] + + offset += len(indices) + 4 + + self._xFillVboData, self._yFillVboData = vertexBuffer(points.T) + + def render(self, context): + """Perform rendering + + :param RenderContext context: + """ + self.prepare() + + if self._xFillVboData is None: + return # Nothing to display + + self._PROGRAM.use() + + gl.glUniformMatrix4fv( + self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE, + numpy.dot(context.matrix, + mat4Translate(*self.offset)).astype(numpy.float32)) + + gl.glUniform4f(self._PROGRAM.uniforms['color'], *self.color) + + xPosAttrib = self._PROGRAM.attributes['xPos'] + yPosAttrib = self._PROGRAM.attributes['yPos'] + + gl.glEnableVertexAttribArray(xPosAttrib) + self._xFillVboData.setVertexAttrib(xPosAttrib) + + gl.glEnableVertexAttribArray(yPosAttrib) + self._yFillVboData.setVertexAttrib(yPosAttrib) + + # Prepare fill mask + gl.glEnable(gl.GL_STENCIL_TEST) + gl.glStencilMask(1) + gl.glStencilFunc(gl.GL_ALWAYS, 1, 1) + gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT) + gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE) + gl.glDepthMask(gl.GL_FALSE) + + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, self._xFillVboData.size) + + gl.glStencilFunc(gl.GL_EQUAL, 1, 1) + # Reset stencil while drawing + gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO) + gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE) + gl.glDepthMask(gl.GL_TRUE) + + # Draw directly in NDC + gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE, + mat4Identity().astype(numpy.float32)) + + # NDC vertices + gl.glVertexAttribPointer( + xPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0, + numpy.array((-1., -1., 1., 1.), dtype=numpy.float32)) + gl.glVertexAttribPointer( + yPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0, + numpy.array((-1., 1., -1., 1.), dtype=numpy.float32)) + + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4) + + gl.glDisable(gl.GL_STENCIL_TEST) + + def discard(self): + """Release VBOs""" + if self.isInitialized(): + self._xFillVboData.vbo.discard() + + self._xFillVboData = None + self._yFillVboData = None + + def isInitialized(self): + return self._xFillVboData is not None + + +# line ######################################################################## + +SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':' + + +class GLLines2D(object): + """Object rendering curve as a polyline + + :param xVboData: X coordinates VBO + :param yVboData: Y coordinates VBO + :param colorVboData: VBO of colors + :param distVboData: VBO of distance along the polyline + :param str style: Line style in: '-', '--', '-.', ':' + :param List[float] color: RGBA color as 4 float in [0, 1] + :param float width: Line width + :param float dashPeriod: Period of dashes + :param drawMode: OpenGL drawing mode + :param List[float] offset: Translation of coordinates (ox, oy) + """ + + STYLES = SOLID, DASHED, DASHDOT, DOTTED + """Supported line styles""" + + _SOLID_PROGRAM = Program( + vertexShader=""" + #version 120 + + uniform mat4 matrix; + attribute float xPos; + attribute float yPos; + attribute vec4 color; + + varying vec4 vColor; + + void main(void) { + gl_Position = matrix * vec4(xPos, yPos, 0., 1.) ; + vColor = color; + } + """, + fragmentShader=""" + #version 120 + + varying vec4 vColor; + + void main(void) { + gl_FragColor = vColor; + } + """, + attrib0='xPos') + + # Limitation: Dash using an estimate of distance in screen coord + # to avoid computing distance when viewport is resized + # results in inequal dashes when viewport aspect ratio is far from 1 + _DASH_PROGRAM = Program( + vertexShader=""" + #version 120 + + uniform mat4 matrix; + uniform vec2 halfViewportSize; + attribute float xPos; + attribute float yPos; + attribute vec4 color; + attribute float distance; + + varying float vDist; + varying vec4 vColor; + + void main(void) { + gl_Position = matrix * vec4(xPos, yPos, 0., 1.); + //Estimate distance in pixels + vec2 probe = vec2(matrix * vec4(1., 1., 0., 0.)) * + halfViewportSize; + float pixelPerDataEstimate = length(probe)/sqrt(2.); + vDist = distance * pixelPerDataEstimate; + vColor = color; + } + """, + fragmentShader=""" + #version 120 + + /* Dashes: [0, x], [y, z] + Dash period: w */ + uniform vec4 dash; + uniform vec4 dash2ndColor; + + varying float vDist; + varying vec4 vColor; + + void main(void) { + float dist = mod(vDist, dash.w); + if ((dist > dash.x && dist < dash.y) || dist > dash.z) { + if (dash2ndColor.a == 0.) { + discard; // Discard full transparent bg color + } else { + gl_FragColor = dash2ndColor; + } + } else { + gl_FragColor = vColor; + } + } + """, + attrib0='xPos') + + def __init__(self, xVboData=None, yVboData=None, + colorVboData=None, distVboData=None, + style=SOLID, color=(0., 0., 0., 1.), dash2ndColor=None, + width=1, dashPeriod=10., drawMode=None, + offset=(0., 0.)): + if (xVboData is not None and + not isinstance(xVboData, VertexBufferAttrib)): + xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32) + self.xVboData = xVboData + + if (yVboData is not None and + not isinstance(yVboData, VertexBufferAttrib)): + yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32) + self.yVboData = yVboData + + # Compute distances if not given while providing numpy array coordinates + if (isinstance(self.xVboData, numpy.ndarray) and + isinstance(self.yVboData, numpy.ndarray) and + distVboData is None): + distVboData = distancesFromArrays(self.xVboData, self.yVboData) + + if (distVboData is not None and + not isinstance(distVboData, VertexBufferAttrib)): + distVboData = numpy.array( + distVboData, copy=False, dtype=numpy.float32) + self.distVboData = distVboData + + if colorVboData is not None: + assert isinstance(colorVboData, VertexBufferAttrib) + self.colorVboData = colorVboData + self.useColorVboData = colorVboData is not None + + self.color = color + self.dash2ndColor = dash2ndColor + self.width = width + self._style = None + self.style = style + self.dashPeriod = dashPeriod + self.offset = offset + + self._drawMode = drawMode if drawMode is not None else gl.GL_LINE_STRIP + + @property + def style(self): + """Line style (Union[str,None])""" + return self._style + + @style.setter + def style(self, style): + if style in _MPL_NONES: + self._style = None + else: + assert style in self.STYLES + self._style = style + + @classmethod + def init(cls): + """OpenGL context initialization""" + gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST) + + def render(self, context): + """Perform rendering + + :param RenderContext context: + """ + width = self.width / 72. * context.dpi + + style = self.style + if style is None: + return + + elif style == SOLID: + program = self._SOLID_PROGRAM + program.use() + + else: # DASHED, DASHDOT, DOTTED + program = self._DASH_PROGRAM + program.use() + + x, y, viewWidth, viewHeight = gl.glGetFloatv(gl.GL_VIEWPORT) + gl.glUniform2f(program.uniforms['halfViewportSize'], + 0.5 * viewWidth, 0.5 * viewHeight) + + dashPeriod = self.dashPeriod * width + if self.style == DOTTED: + dash = (0.2 * dashPeriod, + 0.5 * dashPeriod, + 0.7 * dashPeriod, + dashPeriod) + elif self.style == DASHDOT: + dash = (0.3 * dashPeriod, + 0.5 * dashPeriod, + 0.6 * dashPeriod, + dashPeriod) + else: + dash = (0.5 * dashPeriod, + dashPeriod, + dashPeriod, + dashPeriod) + + gl.glUniform4f(program.uniforms['dash'], *dash) + + if self.dash2ndColor is None: + # Use fully transparent color which gets discarded in shader + dash2ndColor = (0., 0., 0., 0.) + else: + dash2ndColor = self.dash2ndColor + gl.glUniform4f(program.uniforms['dash2ndColor'], *dash2ndColor) + + distAttrib = program.attributes['distance'] + gl.glEnableVertexAttribArray(distAttrib) + if isinstance(self.distVboData, VertexBufferAttrib): + self.distVboData.setVertexAttrib(distAttrib) + else: + gl.glVertexAttribPointer(distAttrib, + 1, + gl.GL_FLOAT, + False, + 0, + self.distVboData) + + if width != 1: + gl.glEnable(gl.GL_LINE_SMOOTH) + + matrix = numpy.dot(context.matrix, + mat4Translate(*self.offset)).astype(numpy.float32) + gl.glUniformMatrix4fv(program.uniforms['matrix'], + 1, gl.GL_TRUE, matrix) + + colorAttrib = program.attributes['color'] + if self.useColorVboData and self.colorVboData is not None: + gl.glEnableVertexAttribArray(colorAttrib) + self.colorVboData.setVertexAttrib(colorAttrib) + else: + gl.glDisableVertexAttribArray(colorAttrib) + gl.glVertexAttrib4f(colorAttrib, *self.color) + + xPosAttrib = program.attributes['xPos'] + gl.glEnableVertexAttribArray(xPosAttrib) + if isinstance(self.xVboData, VertexBufferAttrib): + self.xVboData.setVertexAttrib(xPosAttrib) + else: + gl.glVertexAttribPointer(xPosAttrib, + 1, + gl.GL_FLOAT, + False, + 0, + self.xVboData) + + yPosAttrib = program.attributes['yPos'] + gl.glEnableVertexAttribArray(yPosAttrib) + if isinstance(self.yVboData, VertexBufferAttrib): + self.yVboData.setVertexAttrib(yPosAttrib) + else: + gl.glVertexAttribPointer(yPosAttrib, + 1, + gl.GL_FLOAT, + False, + 0, + self.yVboData) + + gl.glLineWidth(width) + gl.glDrawArrays(self._drawMode, 0, self.xVboData.size) + + gl.glDisable(gl.GL_LINE_SMOOTH) + + +def distancesFromArrays(xData, yData): + """Returns distances between each points + + :param numpy.ndarray xData: X coordinate of points + :param numpy.ndarray yData: Y coordinate of points + :rtype: numpy.ndarray + """ + # Split array into sub-shapes at not finite points + splits = numpy.nonzero(numpy.logical_not(numpy.logical_and( + numpy.isfinite(xData), numpy.isfinite(yData))))[0] + splits = numpy.concatenate(([-1], splits, [len(xData) - 1])) + + # Compute distance independently for each sub-shapes, + # putting not finite points as last points of sub-shapes + distances = [] + for begin, end in zip(splits[:-1] + 1, splits[1:] + 1): + if begin == end: # Empty shape + continue + elif end - begin == 1: # Single element + distances.append([0]) + else: + deltas = numpy.dstack(( + numpy.ediff1d(xData[begin:end], to_begin=numpy.float32(0.)), + numpy.ediff1d(yData[begin:end], to_begin=numpy.float32(0.))))[0] + distances.append( + numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1)))) + return numpy.concatenate(distances) + + +# points ###################################################################### + +DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK = \ + 'd', 'o', 's', '+', 'x', '.', ',', '*' + +H_LINE, V_LINE, HEART = '_', '|', u'\u2665' + +TICK_LEFT = "tickleft" +TICK_RIGHT = "tickright" +TICK_UP = "tickup" +TICK_DOWN = "tickdown" +CARET_LEFT = "caretleft" +CARET_RIGHT = "caretright" +CARET_UP = "caretup" +CARET_DOWN = "caretdown" + + +class _Points2D(object): + """Object rendering curve markers + + :param xVboData: X coordinates VBO + :param yVboData: Y coordinates VBO + :param colorVboData: VBO of colors + :param str marker: Kind of symbol to use, see :attr:`MARKERS`. + :param List[float] color: RGBA color as 4 float in [0, 1] + :param float size: Marker size + :param List[float] offset: Translation of coordinates (ox, oy) + """ + + MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK, + H_LINE, V_LINE, HEART, TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN, + CARET_LEFT, CARET_RIGHT, CARET_UP, CARET_DOWN) + """List of supported markers""" + + _VERTEX_SHADER = """ + #version 120 + + uniform mat4 matrix; + uniform int transform; + uniform float size; + attribute float xPos; + attribute float yPos; + attribute vec4 color; + + varying vec4 vColor; + + void main(void) { + gl_Position = matrix * vec4(xPos, yPos, 0., 1.); + vColor = color; + gl_PointSize = size; + } + """ + + _FRAGMENT_SHADER_SYMBOLS = { + DIAMOND: """ + float alphaSymbol(vec2 coord, float size) { + vec2 centerCoord = abs(coord - vec2(0.5, 0.5)); + float f = centerCoord.x + centerCoord.y; + return clamp(size * (0.5 - f), 0.0, 1.0); + } + """, + CIRCLE: """ + float alphaSymbol(vec2 coord, float size) { + float radius = 0.5; + float r = distance(coord, vec2(0.5, 0.5)); + return clamp(size * (radius - r), 0.0, 1.0); + } + """, + SQUARE: """ + float alphaSymbol(vec2 coord, float size) { + return 1.0; + } + """, + PLUS: """ + float alphaSymbol(vec2 coord, float size) { + vec2 d = abs(size * (coord - vec2(0.5, 0.5))); + if (min(d.x, d.y) < 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + X_MARKER: """ + float alphaSymbol(vec2 coord, float size) { + vec2 pos = floor(size * coord) + 0.5; + vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size)); + if (min(d_x.x, d_x.y) <= 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + ASTERISK: """ + float alphaSymbol(vec2 coord, float size) { + /* Combining +, x and circle */ + vec2 d_plus = abs(size * (coord - vec2(0.5, 0.5))); + vec2 pos = floor(size * coord) + 0.5; + vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size)); + if (min(d_plus.x, d_plus.y) < 0.5) { + return 1.0; + } else if (min(d_x.x, d_x.y) <= 0.5) { + float r = distance(coord, vec2(0.5, 0.5)); + return clamp(size * (0.5 - r), 0.0, 1.0); + } else { + return 0.0; + } + } + """, + H_LINE: """ + float alphaSymbol(vec2 coord, float size) { + float dy = abs(size * (coord.y - 0.5)); + if (dy < 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + V_LINE: """ + float alphaSymbol(vec2 coord, float size) { + float dx = abs(size * (coord.x - 0.5)); + if (dx < 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + HEART: """ + float alphaSymbol(vec2 coord, float size) { + coord = (coord - 0.5) * 2.; + coord *= 0.75; + coord.y += 0.25; + float a = atan(coord.x,-coord.y)/3.141593; + float r = length(coord); + float h = abs(a); + float d = (13.0*h - 22.0*h*h + 10.0*h*h*h)/(6.0-5.0*h); + float res = clamp(r-d, 0., 1.); + // antialiasing + res = smoothstep(0.1, 0.001, res); + return res; + } + """, + TICK_LEFT: """ + float alphaSymbol(vec2 coord, float size) { + coord = size * (coord - 0.5); + float dy = abs(coord.y); + if (dy < 0.5 && coord.x < 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + TICK_RIGHT: """ + float alphaSymbol(vec2 coord, float size) { + coord = size * (coord - 0.5); + float dy = abs(coord.y); + if (dy < 0.5 && coord.x > -0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + TICK_UP: """ + float alphaSymbol(vec2 coord, float size) { + coord = size * (coord - 0.5); + float dx = abs(coord.x); + if (dx < 0.5 && coord.y < 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + TICK_DOWN: """ + float alphaSymbol(vec2 coord, float size) { + coord = size * (coord - 0.5); + float dx = abs(coord.x); + if (dx < 0.5 && coord.y > -0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + CARET_LEFT: """ + float alphaSymbol(vec2 coord, float size) { + coord = size * (coord - 0.5); + float d = abs(coord.x) - abs(coord.y); + if (d >= -0.1 && coord.x > 0.5) { + return smoothstep(-0.1, 0.1, d); + } else { + return 0.0; + } + } + """, + CARET_RIGHT: """ + float alphaSymbol(vec2 coord, float size) { + coord = size * (coord - 0.5); + float d = abs(coord.x) - abs(coord.y); + if (d >= -0.1 && coord.x < 0.5) { + return smoothstep(-0.1, 0.1, d); + } else { + return 0.0; + } + } + """, + CARET_UP: """ + float alphaSymbol(vec2 coord, float size) { + coord = size * (coord - 0.5); + float d = abs(coord.y) - abs(coord.x); + if (d >= -0.1 && coord.y > 0.5) { + return smoothstep(-0.1, 0.1, d); + } else { + return 0.0; + } + } + """, + CARET_DOWN: """ + float alphaSymbol(vec2 coord, float size) { + coord = size * (coord - 0.5); + float d = abs(coord.y) - abs(coord.x); + if (d >= -0.1 && coord.y < 0.5) { + return smoothstep(-0.1, 0.1, d); + } else { + return 0.0; + } + } + """, + } + + _FRAGMENT_SHADER_TEMPLATE = """ + #version 120 + + uniform float size; + + varying vec4 vColor; + + %s + + void main(void) { + float alpha = alphaSymbol(gl_PointCoord, size); + if (alpha <= 0.0) { + discard; + } else { + gl_FragColor = vec4(vColor.rgb, alpha * clamp(vColor.a, 0.0, 1.0)); + } + } + """ + + _PROGRAMS = {} + + def __init__(self, xVboData=None, yVboData=None, colorVboData=None, + marker=SQUARE, color=(0., 0., 0., 1.), size=7, + offset=(0., 0.)): + self.color = color + self._marker = None + self.marker = marker + self.size = size + self.offset = offset + + self.xVboData = xVboData + self.yVboData = yVboData + self.colorVboData = colorVboData + self.useColorVboData = colorVboData is not None + + @property + def marker(self): + """Symbol used to display markers (str)""" + return self._marker + + @marker.setter + def marker(self, marker): + if marker in _MPL_NONES: + self._marker = None + else: + assert marker in self.MARKERS + self._marker = marker + + @classmethod + def _getProgram(cls, marker): + """On-demand shader program creation.""" + if marker == PIXEL: + marker = SQUARE + elif marker == POINT: + marker = CIRCLE + + if marker not in cls._PROGRAMS: + cls._PROGRAMS[marker] = Program( + vertexShader=cls._VERTEX_SHADER, + fragmentShader=(cls._FRAGMENT_SHADER_TEMPLATE % + cls._FRAGMENT_SHADER_SYMBOLS[marker]), + attrib0='xPos') + + return cls._PROGRAMS[marker] + + @classmethod + def init(cls): + """OpenGL context initialization""" + version = gl.glGetString(gl.GL_VERSION) + majorVersion = int(version[0]) + assert majorVersion >= 2 + gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2 + gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2 + if majorVersion >= 3: # OpenGL 3 + gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) + + def render(self, context): + """Perform rendering + + :param RenderContext context: + """ + if self.marker is None: + return + + program = self._getProgram(self.marker) + program.use() + + matrix = numpy.dot(context.matrix, + mat4Translate(*self.offset)).astype(numpy.float32) + gl.glUniformMatrix4fv(program.uniforms['matrix'], 1, gl.GL_TRUE, matrix) + + if self.marker == PIXEL: + size = 1 + elif self.marker == POINT: + size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point + else: + size = self.size + size = size / 72. * context.dpi + + if self.marker in (PLUS, H_LINE, V_LINE, + TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN): + # Convert to nearest odd number + size = size // 2 * 2 + 1. + + gl.glUniform1f(program.uniforms['size'], size) + # gl.glPointSize(self.size) + + cAttrib = program.attributes['color'] + if self.useColorVboData and self.colorVboData is not None: + gl.glEnableVertexAttribArray(cAttrib) + self.colorVboData.setVertexAttrib(cAttrib) + else: + gl.glDisableVertexAttribArray(cAttrib) + gl.glVertexAttrib4f(cAttrib, *self.color) + + xAttrib = program.attributes['xPos'] + gl.glEnableVertexAttribArray(xAttrib) + self.xVboData.setVertexAttrib(xAttrib) + + yAttrib = program.attributes['yPos'] + gl.glEnableVertexAttribArray(yAttrib) + self.yVboData.setVertexAttrib(yAttrib) + + gl.glDrawArrays(gl.GL_POINTS, 0, self.xVboData.size) + + gl.glUseProgram(0) + + +# error bars ################################################################## + +class _ErrorBars(object): + """Display errors bars. + + This is using its own VBO as opposed to fill/points/lines. + There is no picking on error bars. + + It uses 2 vertices per error bars and uses :class:`GLLines2D` to + render error bars and :class:`_Points2D` to render the ends. + + :param numpy.ndarray xData: X coordinates of the data. + :param numpy.ndarray yData: Y coordinates of the data. + :param xError: The absolute error on the X axis. + :type xError: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for negative errors, + row 1 for positive errors. + :param yError: The absolute error on the Y axis. + :type yError: A float, or a numpy.ndarray of float32. See xError. + :param float xMin: The min X value already computed by GLPlotCurve2D. + :param float yMin: The min Y value already computed by GLPlotCurve2D. + :param List[float] color: RGBA color as 4 float in [0, 1] + :param List[float] offset: Translation of coordinates (ox, oy) + """ + + def __init__(self, xData, yData, xError, yError, + xMin, yMin, + color=(0., 0., 0., 1.), + offset=(0., 0.)): + self._attribs = None + self._xMin, self._yMin = xMin, yMin + self.offset = offset + + if xError is not None or yError is not None: + self._xData = numpy.array( + xData, order='C', dtype=numpy.float32, copy=False) + self._yData = numpy.array( + yData, order='C', dtype=numpy.float32, copy=False) + + # This also works if xError, yError is a float/int + self._xError = numpy.array( + xError, order='C', dtype=numpy.float32, copy=False) + self._yError = numpy.array( + yError, order='C', dtype=numpy.float32, copy=False) + else: + self._xData, self._yData = None, None + self._xError, self._yError = None, None + + self._lines = GLLines2D( + None, None, color=color, drawMode=gl.GL_LINES, offset=offset) + self._xErrPoints = _Points2D( + None, None, color=color, marker=V_LINE, offset=offset) + self._yErrPoints = _Points2D( + None, None, color=color, marker=H_LINE, offset=offset) + + def _buildVertices(self): + """Generates error bars vertices""" + nbLinesPerDataPts = (0 if self._xError is None else 2) + \ + (0 if self._yError is None else 2) + + nbDataPts = len(self._xData) + + # interleave coord+error, coord-error. + # xError vertices first if any, then yError vertices if any. + xCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2, + dtype=numpy.float32) + yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2, + dtype=numpy.float32) + + if self._xError is not None: # errors on the X axis + if len(self._xError.shape) == 2: + xErrorMinus, xErrorPlus = self._xError[0], self._xError[1] + else: + # numpy arrays of len 1 or len(xData) + xErrorMinus, xErrorPlus = self._xError, self._xError + + # Interleave vertices for xError + endXError = 4 * nbDataPts + with numpy.errstate(invalid="ignore"): + xCoords[0:endXError-3:4] = self._xData + xErrorPlus + xCoords[1:endXError-2:4] = self._xData + xCoords[2:endXError-1:4] = self._xData + with numpy.errstate(invalid="ignore"): + xCoords[3:endXError:4] = self._xData - xErrorMinus + + yCoords[0:endXError-3:4] = self._yData + yCoords[1:endXError-2:4] = self._yData + yCoords[2:endXError-1:4] = self._yData + yCoords[3:endXError:4] = self._yData + + else: + endXError = 0 + + if self._yError is not None: # errors on the Y axis + if len(self._yError.shape) == 2: + yErrorMinus, yErrorPlus = self._yError[0], self._yError[1] + else: + # numpy arrays of len 1 or len(yData) + yErrorMinus, yErrorPlus = self._yError, self._yError + + # Interleave vertices for yError + xCoords[endXError::4] = self._xData + xCoords[endXError+1::4] = self._xData + xCoords[endXError+2::4] = self._xData + xCoords[endXError+3::4] = self._xData + + with numpy.errstate(invalid="ignore"): + yCoords[endXError::4] = self._yData + yErrorPlus + yCoords[endXError+1::4] = self._yData + yCoords[endXError+2::4] = self._yData + with numpy.errstate(invalid="ignore"): + yCoords[endXError+3::4] = self._yData - yErrorMinus + + return xCoords, yCoords + + def prepare(self): + """Rendering preparation: build indices and bounding box vertices""" + if self._xData is None: + return + + if self._attribs is None: + xCoords, yCoords = self._buildVertices() + + xAttrib, yAttrib = vertexBuffer((xCoords, yCoords)) + self._attribs = xAttrib, yAttrib + + self._lines.xVboData = xAttrib + self._lines.yVboData = yAttrib + + # Set xError points using the same VBO as lines + self._xErrPoints.xVboData = xAttrib.copy() + self._xErrPoints.xVboData.size //= 2 + self._xErrPoints.yVboData = yAttrib.copy() + self._xErrPoints.yVboData.size //= 2 + + # Set yError points using the same VBO as lines + self._yErrPoints.xVboData = xAttrib.copy() + self._yErrPoints.xVboData.size //= 2 + self._yErrPoints.xVboData.offset += (xAttrib.itemsize * + xAttrib.size // 2) + self._yErrPoints.yVboData = yAttrib.copy() + self._yErrPoints.yVboData.size //= 2 + self._yErrPoints.yVboData.offset += (yAttrib.itemsize * + yAttrib.size // 2) + + def render(self, context): + """Perform rendering + + :param RenderContext context: + """ + self.prepare() + + if self._attribs is not None: + self._lines.render(context) + self._xErrPoints.render(context) + self._yErrPoints.render(context) + + def discard(self): + """Release VBOs""" + if self.isInitialized(): + self._lines.xVboData, self._lines.yVboData = None, None + self._xErrPoints.xVboData, self._xErrPoints.yVboData = None, None + self._yErrPoints.xVboData, self._yErrPoints.yVboData = None, None + self._attribs[0].vbo.discard() + self._attribs = None + + def isInitialized(self): + return self._attribs is not None + + +# curves ###################################################################### + +def _proxyProperty(*componentsAttributes): + """Create a property to access an attribute of attribute(s). + Useful for composition. + Supports multiple components this way: + getter returns the first found, setter sets all + """ + def getter(self): + for compName, attrName in componentsAttributes: + try: + component = getattr(self, compName) + except AttributeError: + pass + else: + return getattr(component, attrName) + + def setter(self, value): + for compName, attrName in componentsAttributes: + component = getattr(self, compName) + setattr(component, attrName, value) + return property(getter, setter) + + +class GLPlotCurve2D(GLPlotItem): + def __init__(self, xData, yData, colorData=None, + xError=None, yError=None, + lineStyle=SOLID, + lineColor=(0., 0., 0., 1.), + lineWidth=1, + lineDashPeriod=20, + marker=SQUARE, + markerColor=(0., 0., 0., 1.), + markerSize=7, + fillColor=None, + baseline=None, + isYLog=False): + super().__init__() + self.colorData = colorData + + # Compute x bounds + if xError is None: + self.xMin, self.xMax = min_max(xData, min_positive=False) + else: + # Takes the error into account + if hasattr(xError, 'shape') and len(xError.shape) == 2: + xErrorMinus, xErrorPlus = xError[0], xError[1] + else: + xErrorMinus, xErrorPlus = xError, xError + self.xMin = numpy.nanmin(xData - xErrorMinus) + self.xMax = numpy.nanmax(xData + xErrorPlus) + + # Compute y bounds + if yError is None: + self.yMin, self.yMax = min_max(yData, min_positive=False) + else: + # Takes the error into account + if hasattr(yError, 'shape') and len(yError.shape) == 2: + yErrorMinus, yErrorPlus = yError[0], yError[1] + else: + yErrorMinus, yErrorPlus = yError, yError + self.yMin = numpy.nanmin(yData - yErrorMinus) + self.yMax = numpy.nanmax(yData + yErrorPlus) + + # Handle data offset + if xData.itemsize > 4 or yData.itemsize > 4: # Use normalization + # offset data, do not offset error as it is relative + self.offset = self.xMin, self.yMin + with numpy.errstate(invalid="ignore"): + self.xData = (xData - self.offset[0]).astype(numpy.float32) + self.yData = (yData - self.offset[1]).astype(numpy.float32) + + else: # float32 + self.offset = 0., 0. + self.xData = xData + self.yData = yData + if fillColor is not None: + def deduce_baseline(baseline): + if baseline is None: + _baseline = 0 + else: + _baseline = baseline + if not isinstance(_baseline, numpy.ndarray): + _baseline = numpy.repeat(_baseline, + len(self.xData)) + if isYLog is True: + with numpy.errstate(divide='ignore', invalid='ignore'): + log_val = numpy.log10(_baseline) + _baseline = numpy.where(_baseline>0.0, log_val, -38) + return _baseline + + _baseline = deduce_baseline(baseline) + + # Use different baseline depending of Y log scale + self.fill = _Fill2D(self.xData, self.yData, + baseline=_baseline, + color=fillColor, + offset=self.offset) + else: + self.fill = None + + self._errorBars = _ErrorBars(self.xData, self.yData, + xError, yError, + self.xMin, self.yMin, + offset=self.offset) + + self.lines = GLLines2D() + self.lines.style = lineStyle + self.lines.color = lineColor + self.lines.width = lineWidth + self.lines.dashPeriod = lineDashPeriod + self.lines.offset = self.offset + + self.points = _Points2D() + self.points.marker = marker + self.points.color = markerColor + self.points.size = markerSize + self.points.offset = self.offset + + xVboData = _proxyProperty(('lines', 'xVboData'), ('points', 'xVboData')) + + yVboData = _proxyProperty(('lines', 'yVboData'), ('points', 'yVboData')) + + colorVboData = _proxyProperty(('lines', 'colorVboData'), + ('points', 'colorVboData')) + + useColorVboData = _proxyProperty(('lines', 'useColorVboData'), + ('points', 'useColorVboData')) + + distVboData = _proxyProperty(('lines', 'distVboData')) + + lineStyle = _proxyProperty(('lines', 'style')) + + lineColor = _proxyProperty(('lines', 'color')) + + lineWidth = _proxyProperty(('lines', 'width')) + + lineDashPeriod = _proxyProperty(('lines', 'dashPeriod')) + + marker = _proxyProperty(('points', 'marker')) + + markerColor = _proxyProperty(('points', 'color')) + + markerSize = _proxyProperty(('points', 'size')) + + @classmethod + def init(cls): + """OpenGL context initialization""" + GLLines2D.init() + _Points2D.init() + + def prepare(self): + """Rendering preparation: build indices and bounding box vertices""" + if self.xVboData is None: + xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None + if self.lineStyle in (DASHED, DASHDOT, DOTTED): + dists = distancesFromArrays(self.xData, self.yData) + if self.colorData is None: + xAttrib, yAttrib, dAttrib = vertexBuffer( + (self.xData, self.yData, dists)) + else: + xAttrib, yAttrib, cAttrib, dAttrib = vertexBuffer( + (self.xData, self.yData, self.colorData, dists)) + elif self.colorData is None: + xAttrib, yAttrib = vertexBuffer((self.xData, self.yData)) + else: + xAttrib, yAttrib, cAttrib = vertexBuffer( + (self.xData, self.yData, self.colorData)) + + self.xVboData = xAttrib + self.yVboData = yAttrib + self.distVboData = dAttrib + + if cAttrib is not None and self.colorData.dtype.kind == 'u': + cAttrib.normalization = True # Normalize uint to [0, 1] + self.colorVboData = cAttrib + self.useColorVboData = cAttrib is not None + + def render(self, context): + """Perform rendering + + :param RenderContext context: Rendering information + """ + self.prepare() + if self.fill is not None: + self.fill.render(context) + self._errorBars.render(context) + self.lines.render(context) + self.points.render(context) + + def discard(self): + """Release VBOs""" + if self.xVboData is not None: + self.xVboData.vbo.discard() + + self.xVboData = None + self.yVboData = None + self.colorVboData = None + self.distVboData = None + + self._errorBars.discard() + if self.fill is not None: + self.fill.discard() + + def isInitialized(self): + return (self.xVboData is not None or + self._errorBars.isInitialized() or + (self.fill is not None and self.fill.isInitialized())) + + def pick(self, xPickMin, yPickMin, xPickMax, yPickMax): + """Perform picking on the curve according to its rendering. + + The picking area is [xPickMin, xPickMax], [yPickMin, yPickMax]. + + In case a segment between 2 points with indices i, i+1 is picked, + only its lower index end point (i.e., i) is added to the result. + In case an end point with index i is picked it is added to the result, + and the segment [i-1, i] is not tested for picking. + + :return: The indices of the picked data + :rtype: Union[List[int],None] + """ + if (self.marker is None and self.lineStyle is None) or \ + self.xMin > xPickMax or xPickMin > self.xMax or \ + self.yMin > yPickMax or yPickMin > self.yMax: + return None + + # offset picking bounds + xPickMin = xPickMin - self.offset[0] + xPickMax = xPickMax - self.offset[0] + yPickMin = yPickMin - self.offset[1] + yPickMax = yPickMax - self.offset[1] + + if self.lineStyle is not None: + # Using Cohen-Sutherland algorithm for line clipping + with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings + codes = ((self.yData > yPickMax) << 3) | \ + ((self.yData < yPickMin) << 2) | \ + ((self.xData > xPickMax) << 1) | \ + (self.xData < xPickMin) + + notNaN = numpy.logical_not(numpy.logical_or( + numpy.isnan(self.xData), numpy.isnan(self.yData))) + + # Add all points that are inside the picking area + indices = numpy.nonzero( + numpy.logical_and(codes == 0, notNaN))[0].tolist() + + # Segment that might cross the area with no end point inside it + segToTestIdx = numpy.nonzero((codes[:-1] != 0) & + (codes[1:] != 0) & + ((codes[:-1] & codes[1:]) == 0))[0] + + TOP, BOTTOM, RIGHT, LEFT = (1 << 3), (1 << 2), (1 << 1), (1 << 0) + + for index in segToTestIdx: + if index not in indices: + x0, y0 = self.xData[index], self.yData[index] + x1, y1 = self.xData[index + 1], self.yData[index + 1] + code1 = codes[index + 1] + + # check for crossing with horizontal bounds + # y0 == y1 is a never event: + # => pt0 and pt1 in same vertical area are not in segToTest + if code1 & TOP: + x = x0 + (x1 - x0) * (yPickMax - y0) / (y1 - y0) + elif code1 & BOTTOM: + x = x0 + (x1 - x0) * (yPickMin - y0) / (y1 - y0) + else: + x = None # No horizontal bounds intersection test + + if x is not None and xPickMin <= x <= xPickMax: + # Intersection + indices.append(index) + + else: + # check for crossing with vertical bounds + # x0 == x1 is a never event (see remark for y) + if code1 & RIGHT: + y = y0 + (y1 - y0) * (xPickMax - x0) / (x1 - x0) + elif code1 & LEFT: + y = y0 + (y1 - y0) * (xPickMin - x0) / (x1 - x0) + else: + y = None # No vertical bounds intersection test + + if y is not None and yPickMin <= y <= yPickMax: + # Intersection + indices.append(index) + + indices.sort() + + else: + with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings + indices = numpy.nonzero((self.xData >= xPickMin) & + (self.xData <= xPickMax) & + (self.yData >= yPickMin) & + (self.yData <= yPickMax))[0].tolist() + + return tuple(indices) if len(indices) > 0 else None diff --git a/src/silx/gui/plot/backends/glutils/GLPlotFrame.py b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py new file mode 100644 index 0000000..1fccb02 --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py @@ -0,0 +1,1210 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This modules provides the rendering of plot titles, axes and grid. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +# TODO +# keep aspect ratio managed here? +# smarter dirty flag handling? + +import datetime as dt +import math +import weakref +import logging +from collections import namedtuple + +import numpy + +from ...._glutils import gl, Program +from ..._utils import checkAxisLimits, FLOAT32_MINPOS +from .GLSupport import mat4Ortho +from .GLText import Text2D, CENTER, BOTTOM, TOP, LEFT, RIGHT, ROTATE_270 +from ..._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10 +from ..._utils.dtime_ticklayout import calcTicksAdaptive, bestFormatString +from ..._utils.dtime_ticklayout import timestamp + +_logger = logging.getLogger(__name__) + + +# PlotAxis #################################################################### + +class PlotAxis(object): + """Represents a 1D axis of the plot. + This class is intended to be used with :class:`GLPlotFrame`. + """ + + def __init__(self, plotFrame, + tickLength=(0., 0.), + foregroundColor=(0., 0., 0., 1.0), + labelAlign=CENTER, labelVAlign=CENTER, + titleAlign=CENTER, titleVAlign=CENTER, + titleRotate=0, titleOffset=(0., 0.)): + self._ticks = None + + self._plotFrameRef = weakref.ref(plotFrame) + + self._isDateTime = False + self._timeZone = None + self._isLog = False + self._dataRange = 1., 100. + self._displayCoords = (0., 0.), (1., 0.) + self._title = '' + + self._tickLength = tickLength + self._foregroundColor = foregroundColor + self._labelAlign = labelAlign + self._labelVAlign = labelVAlign + self._titleAlign = titleAlign + self._titleVAlign = titleVAlign + self._titleRotate = titleRotate + self._titleOffset = titleOffset + + @property + def dataRange(self): + """The range of the data represented on the axis as a tuple + of 2 floats: (min, max).""" + return self._dataRange + + @dataRange.setter + def dataRange(self, dataRange): + assert len(dataRange) == 2 + assert dataRange[0] <= dataRange[1] + dataRange = float(dataRange[0]), float(dataRange[1]) + + if dataRange != self._dataRange: + self._dataRange = dataRange + self._dirtyTicks() + + @property + def isLog(self): + """Whether the axis is using a log10 scale or not as a bool.""" + return self._isLog + + @isLog.setter + def isLog(self, isLog): + isLog = bool(isLog) + if isLog != self._isLog: + self._isLog = isLog + self._dirtyTicks() + + @property + def timeZone(self): + """Returnss datetime.tzinfo that is used if this axis plots date times.""" + return self._timeZone + + @timeZone.setter + def timeZone(self, tz): + """Sets dateetime.tzinfo that is used if this axis plots date times.""" + self._timeZone = tz + self._dirtyTicks() + + @property + def isTimeSeries(self): + """Whether the axis is showing floats as datetime objects""" + return self._isDateTime + + @isTimeSeries.setter + def isTimeSeries(self, isTimeSeries): + isTimeSeries = bool(isTimeSeries) + if isTimeSeries != self._isDateTime: + self._isDateTime = isTimeSeries + self._dirtyTicks() + + @property + def displayCoords(self): + """The coordinates of the start and end points of the axis + in display space (i.e., in pixels) as a tuple of 2 tuples of + 2 floats: ((x0, y0), (x1, y1)). + """ + return self._displayCoords + + @displayCoords.setter + def displayCoords(self, displayCoords): + assert len(displayCoords) == 2 + assert len(displayCoords[0]) == 2 + assert len(displayCoords[1]) == 2 + displayCoords = tuple(displayCoords[0]), tuple(displayCoords[1]) + if displayCoords != self._displayCoords: + self._displayCoords = displayCoords + self._dirtyTicks() + + @property + def devicePixelRatio(self): + """Returns the ratio between qt pixels and device pixels.""" + plotFrame = self._plotFrameRef() + return plotFrame.devicePixelRatio if plotFrame is not None else 1. + + @property + def title(self): + """The text label associated with this axis as a str in latin-1.""" + return self._title + + @title.setter + def title(self, title): + if title != self._title: + self._title = title + self._dirtyPlotFrame() + + @property + def titleOffset(self): + """Title offset in pixels (x: int, y: int)""" + return self._titleOffset + + @titleOffset.setter + def titleOffset(self, offset): + if offset != self._titleOffset: + self._titleOffset = offset + self._dirtyTicks() + + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "foregroundColor must have length 4, got {}".format(len(self._foregroundColor)) + if self._foregroundColor != color: + self._foregroundColor = color + self._dirtyTicks() + + @property + def ticks(self): + """Ticks as tuples: ((x, y) in display, dataPos, textLabel).""" + if self._ticks is None: + self._ticks = tuple(self._ticksGenerator()) + return self._ticks + + def getVerticesAndLabels(self): + """Create the list of vertices for axis and associated text labels. + + :returns: A tuple: List of 2D line vertices, List of Text2D labels. + """ + vertices = list(self.displayCoords) # Add start and end points + labels = [] + tickLabelsSize = [0., 0.] + + xTickLength, yTickLength = self._tickLength + xTickLength *= self.devicePixelRatio + yTickLength *= self.devicePixelRatio + for (xPixel, yPixel), dataPos, text in self.ticks: + if text is None: + tickScale = 0.5 + else: + tickScale = 1. + + label = Text2D(text=text, + color=self._foregroundColor, + x=xPixel - xTickLength, + y=yPixel - yTickLength, + align=self._labelAlign, + valign=self._labelVAlign, + devicePixelRatio=self.devicePixelRatio) + + width, height = label.size + if width > tickLabelsSize[0]: + tickLabelsSize[0] = width + if height > tickLabelsSize[1]: + tickLabelsSize[1] = height + + labels.append(label) + + vertices.append((xPixel, yPixel)) + vertices.append((xPixel + tickScale * xTickLength, + yPixel + tickScale * yTickLength)) + + (x0, y0), (x1, y1) = self.displayCoords + xAxisCenter = 0.5 * (x0 + x1) + yAxisCenter = 0.5 * (y0 + y1) + + xOffset, yOffset = self.titleOffset + + # Adaptative title positioning: + # tickNorm = math.sqrt(xTickLength ** 2 + yTickLength ** 2) + # xOffset = -tickLabelsSize[0] * xTickLength / tickNorm + # xOffset -= 3 * xTickLength + # yOffset = -tickLabelsSize[1] * yTickLength / tickNorm + # yOffset -= 3 * yTickLength + + axisTitle = Text2D(text=self.title, + color=self._foregroundColor, + x=xAxisCenter + xOffset, + y=yAxisCenter + yOffset, + align=self._titleAlign, + valign=self._titleVAlign, + rotate=self._titleRotate, + devicePixelRatio=self.devicePixelRatio) + labels.append(axisTitle) + + return vertices, labels + + def _dirtyPlotFrame(self): + """Dirty parent GLPlotFrame""" + plotFrame = self._plotFrameRef() + if plotFrame is not None: + plotFrame._dirty() + + def _dirtyTicks(self): + """Mark ticks as dirty and notify listener (i.e., background).""" + self._ticks = None + self._dirtyPlotFrame() + + @staticmethod + def _frange(start, stop, step): + """range for float (including stop).""" + while start <= stop: + yield start + start += step + + def _ticksGenerator(self): + """Generator of ticks as tuples: + ((x, y) in display, dataPos, textLabel). + """ + dataMin, dataMax = self.dataRange + if self.isLog and dataMin <= 0.: + _logger.warning( + 'Getting ticks while isLog=True and dataRange[0]<=0.') + dataMin = 1. + if dataMax < dataMin: + dataMax = 1. + + if dataMin != dataMax: # data range is not null + (x0, y0), (x1, y1) = self.displayCoords + + if self.isLog: + + if self.isTimeSeries: + _logger.warning("Time series not implemented for log-scale") + + logMin, logMax = math.log10(dataMin), math.log10(dataMax) + tickMin, tickMax, step, _ = niceNumbersForLog10(logMin, logMax) + + xScale = (x1 - x0) / (logMax - logMin) + yScale = (y1 - y0) / (logMax - logMin) + + for logPos in self._frange(tickMin, tickMax, step): + if logMin <= logPos <= logMax: + dataPos = 10 ** logPos + xPixel = x0 + (logPos - logMin) * xScale + yPixel = y0 + (logPos - logMin) * yScale + text = '1e%+03d' % logPos + yield ((xPixel, yPixel), dataPos, text) + + if step == 1: + ticks = list(self._frange(tickMin, tickMax, step))[:-1] + for logPos in ticks: + dataOrigPos = 10 ** logPos + for index in range(2, 10): + dataPos = dataOrigPos * index + if dataMin <= dataPos <= dataMax: + logSubPos = math.log10(dataPos) + xPixel = x0 + (logSubPos - logMin) * xScale + yPixel = y0 + (logSubPos - logMin) * yScale + yield ((xPixel, yPixel), dataPos, None) + + else: + xScale = (x1 - x0) / (dataMax - dataMin) + yScale = (y1 - y0) / (dataMax - dataMin) + + nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio + + # Density of 1.3 label per 92 pixels + # i.e., 1.3 label per inch on a 92 dpi screen + tickDensity = 1.3 / 92 + + if not self.isTimeSeries: + tickMin, tickMax, step, nbFrac = niceNumbersAdaptative( + dataMin, dataMax, nbPixels, tickDensity) + + for dataPos in self._frange(tickMin, tickMax, step): + if dataMin <= dataPos <= dataMax: + xPixel = x0 + (dataPos - dataMin) * xScale + yPixel = y0 + (dataPos - dataMin) * yScale + + if nbFrac == 0: + text = '%g' % dataPos + else: + text = ('%.' + str(nbFrac) + 'f') % dataPos + yield ((xPixel, yPixel), dataPos, text) + else: + # Time series + dtMin = dt.datetime.fromtimestamp(dataMin, tz=self.timeZone) + dtMax = dt.datetime.fromtimestamp(dataMax, tz=self.timeZone) + + tickDateTimes, spacing, unit = calcTicksAdaptive( + dtMin, dtMax, nbPixels, tickDensity) + + for tickDateTime in tickDateTimes: + if dtMin <= tickDateTime <= dtMax: + + dataPos = timestamp(tickDateTime) + xPixel = x0 + (dataPos - dataMin) * xScale + yPixel = y0 + (dataPos - dataMin) * yScale + + fmtStr = bestFormatString(spacing, unit) + text = tickDateTime.strftime(fmtStr) + + yield ((xPixel, yPixel), dataPos, text) + + +# GLPlotFrame ################################################################# + +class GLPlotFrame(object): + """Base class for rendering a 2D frame surrounded by axes.""" + + _TICK_LENGTH_IN_PIXELS = 5 + _LINE_WIDTH = 1 + + _SHADERS = { + 'vertex': """ + attribute vec2 position; + uniform mat4 matrix; + + void main(void) { + gl_Position = matrix * vec4(position, 0.0, 1.0); + } + """, + 'fragment': """ + uniform vec4 color; + uniform float tickFactor; /* = 1./tickLength or 0. for solid line */ + + void main(void) { + if (mod(tickFactor * (gl_FragCoord.x + gl_FragCoord.y), 2.) < 1.) { + gl_FragColor = color; + } else { + discard; + } + } + """ + } + + _Margins = namedtuple('Margins', ('left', 'right', 'top', 'bottom')) + + # Margins used when plot frame is not displayed + _NoDisplayMargins = _Margins(0, 0, 0, 0) + + def __init__(self, marginRatios, foregroundColor, gridColor): + """ + :param List[float] marginRatios: + The ratios of margins around plot area for axis and labels. + (left, top, right, bottom) as float in [0., 1.] + :param foregroundColor: color used for the frame and labels. + :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0 + :param gridColor: color used for grid lines. + :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0 + """ + self._renderResources = None + + self.__marginRatios = marginRatios + self.__marginsCache = None + + self._foregroundColor = foregroundColor + self._gridColor = gridColor + + self.axes = [] # List of PlotAxis to be updated by subclasses + + self._grid = False + self._size = 0., 0. + self._title = '' + + self._devicePixelRatio = 1. + + @property + def isDirty(self): + """True if it need to refresh graphic rendering, False otherwise.""" + return self._renderResources is None + + GRID_NONE = 0 + GRID_MAIN_TICKS = 1 + GRID_SUB_TICKS = 2 + GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS) + + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "foregroundColor must have length 4, got {}".format(len(self._foregroundColor)) + if self._foregroundColor != color: + self._foregroundColor = color + for axis in self.axes: + axis.foregroundColor = color + self._dirty() + + @property + def gridColor(self): + """Color used for frame and labels""" + return self._gridColor + + @gridColor.setter + def gridColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "gridColor must have length 4, got {}".format(len(self._gridColor)) + if self._gridColor != color: + self._gridColor = color + self._dirty() + + @property + def marginRatios(self): + """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1]. + """ + return self.__marginRatios + + @marginRatios.setter + def marginRatios(self, ratios): + ratios = tuple(float(v) for v in ratios) + assert len(ratios) == 4 + for value in ratios: + assert 0. <= value <= 1. + assert ratios[0] + ratios[2] < 1. + assert ratios[1] + ratios[3] < 1. + + if self.__marginRatios != ratios: + self.__marginRatios = ratios + self.__marginsCache = None # Clear cached margins + self._dirty() + + @property + def margins(self): + """Margins in pixels around the plot.""" + if self.__marginsCache is None: + width, height = self.size + left, top, right, bottom = self.marginRatios + self.__marginsCache = self._Margins( + left=int(left*width), + right=int(right*width), + top=int(top*height), + bottom=int(bottom*height)) + return self.__marginsCache + + @property + def devicePixelRatio(self): + return self._devicePixelRatio + + @devicePixelRatio.setter + def devicePixelRatio(self, ratio): + if ratio != self._devicePixelRatio: + self._devicePixelRatio = ratio + self._dirty() + + @property + def grid(self): + """Grid display mode: + - 0: No grid. + - 1: Grid on main ticks. + - 2: Grid on sub-ticks for log scale axes. + - 3: Grid on main and sub ticks.""" + return self._grid + + @grid.setter + def grid(self, grid): + assert grid in (self.GRID_NONE, self.GRID_MAIN_TICKS, + self.GRID_SUB_TICKS, self.GRID_ALL_TICKS) + if grid != self._grid: + self._grid = grid + self._dirty() + + @property + def size(self): + """Size in device pixels of the plot area including margins.""" + return self._size + + @size.setter + def size(self, size): + assert len(size) == 2 + size = tuple(size) + if size != self._size: + self._size = size + self.__marginsCache = None # Clear cached margins + self._dirty() + + @property + def plotOrigin(self): + """Plot area origin (left, top) in widget coordinates in pixels.""" + return self.margins.left, self.margins.top + + @property + def plotSize(self): + """Plot area size (width, height) in pixels.""" + w, h = self.size + w -= self.margins.left + self.margins.right + h -= self.margins.top + self.margins.bottom + return w, h + + @property + def title(self): + """Main title as a str in latin-1.""" + return self._title + + @title.setter + def title(self, title): + if title != self._title: + self._title = title + self._dirty() + + # In-place update + # if self._renderResources is not None: + # self._renderResources[-1][-1].text = title + + def _dirty(self): + # When Text2D require discard we need to handle it + self._renderResources = None + + def _buildGridVertices(self): + if self._grid == self.GRID_NONE: + return [] + + elif self._grid == self.GRID_MAIN_TICKS: + def test(text): + return text is not None + elif self._grid == self.GRID_SUB_TICKS: + def test(text): + return text is None + elif self._grid == self.GRID_ALL_TICKS: + def test(_): + return True + else: + logging.warning('Wrong grid mode: %d' % self._grid) + return [] + + return self._buildGridVerticesWithTest(test) + + def _buildGridVerticesWithTest(self, test): + """Override in subclass to generate grid vertices""" + return [] + + def _buildVerticesAndLabels(self): + # To fill with copy of axes lists + vertices = [] + labels = [] + + for axis in self.axes: + axisVertices, axisLabels = axis.getVerticesAndLabels() + vertices += axisVertices + labels += axisLabels + + vertices = numpy.array(vertices, dtype=numpy.float32) + + # Add main title + xTitle = (self.size[0] + self.margins.left - + self.margins.right) // 2 + yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS + labels.append(Text2D(text=self.title, + color=self._foregroundColor, + x=xTitle, + y=yTitle, + align=CENTER, + valign=BOTTOM, + devicePixelRatio=self.devicePixelRatio)) + + # grid + gridVertices = numpy.array(self._buildGridVertices(), + dtype=numpy.float32) + + self._renderResources = (vertices, gridVertices, labels) + + _program = Program( + _SHADERS['vertex'], _SHADERS['fragment'], attrib0='position') + + def render(self): + if self.margins == self._NoDisplayMargins: + return + + if self._renderResources is None: + self._buildVerticesAndLabels() + vertices, gridVertices, labels = self._renderResources + + width, height = self.size + matProj = mat4Ortho(0, width, height, 0, 1, -1) + + gl.glViewport(0, 0, width, height) + + prog = self._program + prog.use() + + gl.glLineWidth(self._LINE_WIDTH) + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + matProj.astype(numpy.float32)) + gl.glUniform4f(prog.uniforms['color'], *self._foregroundColor) + gl.glUniform1f(prog.uniforms['tickFactor'], 0.) + + gl.glEnableVertexAttribArray(prog.attributes['position']) + gl.glVertexAttribPointer(prog.attributes['position'], + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, vertices) + + gl.glDrawArrays(gl.GL_LINES, 0, len(vertices)) + + for label in labels: + label.render(matProj) + + def renderGrid(self): + if self._grid == self.GRID_NONE: + return + + if self._renderResources is None: + self._buildVerticesAndLabels() + vertices, gridVertices, labels = self._renderResources + + width, height = self.size + matProj = mat4Ortho(0, width, height, 0, 1, -1) + + gl.glViewport(0, 0, width, height) + + prog = self._program + prog.use() + + gl.glLineWidth(self._LINE_WIDTH) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + matProj.astype(numpy.float32)) + gl.glUniform4f(prog.uniforms['color'], *self._gridColor) + gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen + + gl.glEnableVertexAttribArray(prog.attributes['position']) + gl.glVertexAttribPointer(prog.attributes['position'], + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, gridVertices) + + gl.glDrawArrays(gl.GL_LINES, 0, len(gridVertices)) + + +# GLPlotFrame2D ############################################################### + +class GLPlotFrame2D(GLPlotFrame): + def __init__(self, marginRatios, foregroundColor, gridColor): + """ + :param List[float] marginRatios: + The ratios of margins around plot area for axis and labels. + (left, top, right, bottom) as float in [0., 1.] + :param foregroundColor: color used for the frame and labels. + :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0 + :param gridColor: color used for grid lines. + :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0 + + """ + super(GLPlotFrame2D, self).__init__(marginRatios, foregroundColor, gridColor) + self.axes.append(PlotAxis(self, + tickLength=(0., -5.), + foregroundColor=self._foregroundColor, + labelAlign=CENTER, labelVAlign=TOP, + titleAlign=CENTER, titleVAlign=TOP, + titleRotate=0)) + + self._x2AxisCoords = () + + self.axes.append(PlotAxis(self, + tickLength=(5., 0.), + foregroundColor=self._foregroundColor, + labelAlign=RIGHT, labelVAlign=CENTER, + titleAlign=CENTER, titleVAlign=BOTTOM, + titleRotate=ROTATE_270)) + + self._y2Axis = PlotAxis(self, + tickLength=(-5., 0.), + foregroundColor=self._foregroundColor, + labelAlign=LEFT, labelVAlign=CENTER, + titleAlign=CENTER, titleVAlign=TOP, + titleRotate=ROTATE_270) + + self._isYAxisInverted = False + + self._dataRanges = { + 'x': (1., 100.), 'y': (1., 100.), 'y2': (1., 100.)} + + self._baseVectors = (1., 0.), (0., 1.) + + self._transformedDataRanges = None + self._transformedDataProjMat = None + self._transformedDataY2ProjMat = None + + def _dirty(self): + super(GLPlotFrame2D, self)._dirty() + self._transformedDataRanges = None + self._transformedDataProjMat = None + self._transformedDataY2ProjMat = None + + @property + def isDirty(self): + """True if it need to refresh graphic rendering, False otherwise.""" + return (super(GLPlotFrame2D, self).isDirty or + self._transformedDataRanges is None or + self._transformedDataProjMat is None or + self._transformedDataY2ProjMat is None) + + @property + def xAxis(self): + return self.axes[0] + + @property + def yAxis(self): + return self.axes[1] + + @property + def y2Axis(self): + return self._y2Axis + + @property + def isY2Axis(self): + """Whether to display the left Y axis or not.""" + return len(self.axes) == 3 + + @isY2Axis.setter + def isY2Axis(self, isY2Axis): + if isY2Axis != self.isY2Axis: + if isY2Axis: + self.axes.append(self._y2Axis) + else: + self.axes = self.axes[:2] + + self._dirty() + + @property + def isYAxisInverted(self): + """Whether Y axes are inverted or not as a bool.""" + return self._isYAxisInverted + + @isYAxisInverted.setter + def isYAxisInverted(self, value): + value = bool(value) + if value != self._isYAxisInverted: + self._isYAxisInverted = value + self._dirty() + + DEFAULT_BASE_VECTORS = (1., 0.), (0., 1.) + """Values of baseVectors for orthogonal axes.""" + + @property + def baseVectors(self): + """Coordinates of the X and Y axes in the orthogonal plot coords. + + Raises ValueError if corresponding matrix is singular. + + 2 tuples of 2 floats: (xx, xy), (yx, yy) + """ + return self._baseVectors + + @baseVectors.setter + def baseVectors(self, baseVectors): + self._dirty() + + (xx, xy), (yx, yy) = baseVectors + vectors = (float(xx), float(xy)), (float(yx), float(yy)) + + det = (vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1]) + if det == 0.: + raise ValueError("Singular matrix for base vectors: " + + str(vectors)) + + if vectors != self._baseVectors: + self._baseVectors = vectors + self._dirty() + + def _updateTitleOffset(self): + """Update axes title offset according to margins""" + margins = self.margins + self.xAxis.titleOffset = 0, margins.bottom // 2 + self.yAxis.titleOffset = -3 * margins.left // 4, 0 + self.y2Axis.titleOffset = 3 * margins.right // 4, 0 + + # Override size and marginRatios setters to update titleOffsets + @GLPlotFrame.size.setter + def size(self, size): + GLPlotFrame.size.fset(self, size) + self._updateTitleOffset() + + @GLPlotFrame.marginRatios.setter + def marginRatios(self, ratios): + GLPlotFrame.marginRatios.fset(self, ratios) + self._updateTitleOffset() + + @property + def dataRanges(self): + """Ranges of data visible in the plot on x, y and y2 axes. + + This is different to the axes range when axes are not orthogonal. + + Type: ((xMin, xMax), (yMin, yMax), (y2Min, y2Max)) + """ + return self._DataRanges(self._dataRanges['x'], + self._dataRanges['y'], + self._dataRanges['y2']) + + def setDataRanges(self, x=None, y=None, y2=None): + """Set data range over each axes. + + The provided ranges are clipped to possible values + (i.e., 32 float range + positive range for log scale). + + :param x: (min, max) data range over X axis + :param y: (min, max) data range over Y axis + :param y2: (min, max) data range over Y2 axis + """ + if x is not None: + self._dataRanges['x'] = checkAxisLimits( + x[0], x[1], self.xAxis.isLog, name='x') + + if y is not None: + self._dataRanges['y'] = checkAxisLimits( + y[0], y[1], self.yAxis.isLog, name='y') + + if y2 is not None: + self._dataRanges['y2'] = checkAxisLimits( + y2[0], y2[1], self.y2Axis.isLog, name='y2') + + self.xAxis.dataRange = self._dataRanges['x'] + self.yAxis.dataRange = self._dataRanges['y'] + self.y2Axis.dataRange = self._dataRanges['y2'] + + _DataRanges = namedtuple('dataRanges', ('x', 'y', 'y2')) + + @property + def transformedDataRanges(self): + """Bounds of the displayed area in transformed data coordinates + (i.e., log scale applied if any as well as skew) + + 3-tuple of 2-tuple (min, max) for each axis: x, y, y2. + """ + if self._transformedDataRanges is None: + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self.dataRanges + + if self.xAxis.isLog: + try: + xMin = math.log10(xMin) + except ValueError: + _logger.info('xMin: warning log10(%f)', xMin) + xMin = 0. + try: + xMax = math.log10(xMax) + except ValueError: + _logger.info('xMax: warning log10(%f)', xMax) + xMax = 0. + + if self.yAxis.isLog: + try: + yMin = math.log10(yMin) + except ValueError: + _logger.info('yMin: warning log10(%f)', yMin) + yMin = 0. + try: + yMax = math.log10(yMax) + except ValueError: + _logger.info('yMax: warning log10(%f)', yMax) + yMax = 0. + + try: + y2Min = math.log10(y2Min) + except ValueError: + _logger.info('yMin: warning log10(%f)', y2Min) + y2Min = 0. + try: + y2Max = math.log10(y2Max) + except ValueError: + _logger.info('yMax: warning log10(%f)', y2Max) + y2Max = 0. + + self._transformedDataRanges = self._DataRanges( + (xMin, xMax), (yMin, yMax), (y2Min, y2Max)) + + return self._transformedDataRanges + + @property + def transformedDataProjMat(self): + """Orthographic projection matrix for rendering transformed data + + :type: numpy.matrix + """ + if self._transformedDataProjMat is None: + xMin, xMax = self.transformedDataRanges.x + yMin, yMax = self.transformedDataRanges.y + + if self.isYAxisInverted: + mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1) + else: + mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1) + self._transformedDataProjMat = mat + + return self._transformedDataProjMat + + @property + def transformedDataY2ProjMat(self): + """Orthographic projection matrix for rendering transformed data + for the 2nd Y axis + + :type: numpy.matrix + """ + if self._transformedDataY2ProjMat is None: + xMin, xMax = self.transformedDataRanges.x + y2Min, y2Max = self.transformedDataRanges.y2 + + if self.isYAxisInverted: + mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1) + else: + mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1) + self._transformedDataY2ProjMat = mat + + return self._transformedDataY2ProjMat + + def dataToPixel(self, x, y, axis='left'): + """Convert data coordinate to widget pixel coordinate. + """ + assert axis in ('left', 'right') + + trBounds = self.transformedDataRanges + + if self.xAxis.isLog: + if x < FLOAT32_MINPOS: + return None + xDataTr = math.log10(x) + else: + xDataTr = x + + if self.yAxis.isLog: + if y < FLOAT32_MINPOS: + return None + yDataTr = math.log10(y) + else: + yDataTr = y + + # Non-orthogonal axes + if self.baseVectors != self.DEFAULT_BASE_VECTORS: + (xx, xy), (yx, yy) = self.baseVectors + skew_mat = numpy.array(((xx, yx), (xy, yy))) + + coords = numpy.dot(skew_mat, numpy.array((xDataTr, yDataTr))) + xDataTr, yDataTr = coords + + plotWidth, plotHeight = self.plotSize + + xPixel = int(self.margins.left + + plotWidth * (xDataTr - trBounds.x[0]) / + (trBounds.x[1] - trBounds.x[0])) + + usedAxis = trBounds.y if axis == "left" else trBounds.y2 + yOffset = (plotHeight * (yDataTr - usedAxis[0]) / + (usedAxis[1] - usedAxis[0])) + + if self.isYAxisInverted: + yPixel = int(self.margins.top + yOffset) + else: + yPixel = int(self.size[1] - self.margins.bottom - yOffset) + + return xPixel, yPixel + + def pixelToData(self, x, y, axis="left"): + """Convert pixel position to data coordinates. + + :param float x: X coord + :param float y: Y coord + :param str axis: Y axis to use in ('left', 'right') + :return: (x, y) position in data coords + """ + assert axis in ("left", "right") + + plotWidth, plotHeight = self.plotSize + + trBounds = self.transformedDataRanges + + xData = (x - self.margins.left + 0.5) / float(plotWidth) + xData = trBounds.x[0] + xData * (trBounds.x[1] - trBounds.x[0]) + + usedAxis = trBounds.y if axis == "left" else trBounds.y2 + if self.isYAxisInverted: + yData = (y - self.margins.top + 0.5) / float(plotHeight) + yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0]) + else: + yData = self.size[1] - self.margins.bottom - y - 0.5 + yData /= float(plotHeight) + yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0]) + + # non-orthogonal axis + if self.baseVectors != self.DEFAULT_BASE_VECTORS: + (xx, xy), (yx, yy) = self.baseVectors + skew_mat = numpy.array(((xx, yx), (xy, yy))) + skew_mat = numpy.linalg.inv(skew_mat) + + coords = numpy.dot(skew_mat, numpy.array((xData, yData))) + xData, yData = coords + + if self.xAxis.isLog: + xData = pow(10, xData) + if self.yAxis.isLog: + yData = pow(10, yData) + + return xData, yData + + def _buildGridVerticesWithTest(self, test): + vertices = [] + + if self.baseVectors == self.DEFAULT_BASE_VECTORS: + for axis in self.axes: + for (xPixel, yPixel), data, text in axis.ticks: + if test(text): + vertices.append((xPixel, yPixel)) + if axis == self.xAxis: + vertices.append((xPixel, self.margins.top)) + elif axis == self.yAxis: + vertices.append((self.size[0] - self.margins.right, + yPixel)) + else: # axis == self.y2Axis + vertices.append((self.margins.left, yPixel)) + + else: + # Get plot corners in data coords + plotLeft, plotTop = self.plotOrigin + plotWidth, plotHeight = self.plotSize + + corners = [(plotLeft, plotTop), + (plotLeft, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop)] + + for axis in self.axes: + if axis == self.xAxis: + cornersInData = numpy.array([ + self.pixelToData(x, y) for (x, y) in corners]) + borders = ((cornersInData[0], cornersInData[3]), # top + (cornersInData[1], cornersInData[0]), # left + (cornersInData[3], cornersInData[2])) # right + + for (xPixel, yPixel), data, text in axis.ticks: + if test(text): + for (x0, y0), (x1, y1) in borders: + if min(x0, x1) <= data < max(x0, x1): + yIntersect = (data - x0) * \ + (y1 - y0) / (x1 - x0) + y0 + + pixelPos = self.dataToPixel( + data, yIntersect) + if pixelPos is not None: + vertices.append((xPixel, yPixel)) + vertices.append(pixelPos) + break # Stop at first intersection + + else: # y or y2 axes + if axis == self.yAxis: + axis_name = 'left' + cornersInData = numpy.array([ + self.pixelToData(x, y) for (x, y) in corners]) + borders = ( + (cornersInData[3], cornersInData[2]), # right + (cornersInData[0], cornersInData[3]), # top + (cornersInData[2], cornersInData[1])) # bottom + + else: # axis == self.y2Axis + axis_name = 'right' + corners = numpy.array([self.pixelToData( + x, y, axis='right') for (x, y) in corners]) + borders = ( + (cornersInData[1], cornersInData[0]), # left + (cornersInData[0], cornersInData[3]), # top + (cornersInData[2], cornersInData[1])) # bottom + + for (xPixel, yPixel), data, text in axis.ticks: + if test(text): + for (x0, y0), (x1, y1) in borders: + if min(y0, y1) <= data < max(y0, y1): + xIntersect = (data - y0) * \ + (x1 - x0) / (y1 - y0) + x0 + + pixelPos = self.dataToPixel( + xIntersect, data, axis=axis_name) + if pixelPos is not None: + vertices.append((xPixel, yPixel)) + vertices.append(pixelPos) + break # Stop at first intersection + + return vertices + + def _buildVerticesAndLabels(self): + width, height = self.size + + xCoords = (self.margins.left - 0.5, + width - self.margins.right + 0.5) + yCoords = (height - self.margins.bottom + 0.5, + self.margins.top - 0.5) + + self.axes[0].displayCoords = ((xCoords[0], yCoords[0]), + (xCoords[1], yCoords[0])) + + self._x2AxisCoords = ((xCoords[0], yCoords[1]), + (xCoords[1], yCoords[1])) + + if self.isYAxisInverted: + # Y axes are inverted, axes coordinates are inverted + yCoords = yCoords[1], yCoords[0] + + self.axes[1].displayCoords = ((xCoords[0], yCoords[0]), + (xCoords[0], yCoords[1])) + + self._y2Axis.displayCoords = ((xCoords[1], yCoords[0]), + (xCoords[1], yCoords[1])) + + super(GLPlotFrame2D, self)._buildVerticesAndLabels() + + vertices, gridVertices, labels = self._renderResources + + # Adds vertices for borders without axis + extraVertices = [] + extraVertices += self._x2AxisCoords + if not self.isY2Axis: + extraVertices += self._y2Axis.displayCoords + + extraVertices = numpy.array( + extraVertices, copy=False, dtype=numpy.float32) + vertices = numpy.append(vertices, extraVertices, axis=0) + + self._renderResources = (vertices, gridVertices, labels) + + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "foregroundColor must have length 4, got {}".format(len(self._foregroundColor)) + if self._foregroundColor != color: + self._y2Axis.foregroundColor = color + GLPlotFrame.foregroundColor.fset(self, color) # call parent property diff --git a/src/silx/gui/plot/backends/glutils/GLPlotImage.py b/src/silx/gui/plot/backends/glutils/GLPlotImage.py new file mode 100644 index 0000000..3ad94b9 --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/GLPlotImage.py @@ -0,0 +1,756 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides a class to render 2D array as a colormap or RGB(A) image +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import math +import numpy + +from silx.math.combo import min_max + +from ...._glutils import gl, Program, Texture +from ..._utils import FLOAT32_MINPOS +from .GLSupport import mat4Translate, mat4Scale +from .GLTexture import Image +from .GLPlotItem import GLPlotItem + + +class _GLPlotData2D(GLPlotItem): + def __init__(self, data, origin, scale): + super().__init__() + self.data = data + assert len(origin) == 2 + self.origin = tuple(origin) + assert len(scale) == 2 + self.scale = tuple(scale) + + def pick(self, x, y): + if self.xMin <= x <= self.xMax and self.yMin <= y <= self.yMax: + ox, oy = self.origin + sx, sy = self.scale + col = int((x - ox) / sx) + row = int((y - oy) / sy) + return (row,), (col,) + else: + return None + + @property + def xMin(self): + ox, sx = self.origin[0], self.scale[0] + return ox if sx >= 0. else ox + sx * self.data.shape[1] + + @property + def yMin(self): + oy, sy = self.origin[1], self.scale[1] + return oy if sy >= 0. else oy + sy * self.data.shape[0] + + @property + def xMax(self): + ox, sx = self.origin[0], self.scale[0] + return ox + sx * self.data.shape[1] if sx >= 0. else ox + + @property + def yMax(self): + oy, sy = self.origin[1], self.scale[1] + return oy + sy * self.data.shape[0] if sy >= 0. else oy + + +class GLPlotColormap(_GLPlotData2D): + + _SHADERS = { + 'linear': { + 'vertex': """ + #version 120 + + uniform mat4 matrix; + attribute vec2 texCoords; + attribute vec2 position; + + varying vec2 coords; + + void main(void) { + coords = texCoords; + gl_Position = matrix * vec4(position, 0.0, 1.0); + } + """, + 'fragTransform': """ + vec2 textureCoords(void) { + return coords; + } + """}, + + 'log': { + 'vertex': """ + #version 120 + + attribute vec2 position; + uniform mat4 matrix; + uniform mat4 matOffset; + uniform bvec2 isLog; + + varying vec2 coords; + + const float oneOverLog10 = 0.43429448190325176; + + void main(void) { + vec4 dataPos = matOffset * vec4(position, 0.0, 1.0); + if (isLog.x) { + dataPos.x = oneOverLog10 * log(dataPos.x); + } + if (isLog.y) { + dataPos.y = oneOverLog10 * log(dataPos.y); + } + coords = dataPos.xy; + gl_Position = matrix * dataPos; + } + """, + 'fragTransform': """ + uniform bvec2 isLog; + uniform vec2 bounds_oneOverRange; + uniform vec2 bounds_originOverRange; + + vec2 textureCoords(void) { + vec2 pos = coords; + if (isLog.x) { + pos.x = pow(10., coords.x); + } + if (isLog.y) { + pos.y = pow(10., coords.y); + } + return pos * bounds_oneOverRange - bounds_originOverRange; + // TODO texture coords in range different from [0, 1] + } + """}, + + 'fragment': """ + #version 120 + + /* isnan declaration for compatibility with GLSL 1.20 */ + bool isnan(float value) { + return (value != value); + } + + uniform sampler2D data; + uniform sampler2D cmap_texture; + uniform int cmap_normalization; + uniform float cmap_parameter; + uniform float cmap_min; + uniform float cmap_oneOverRange; + uniform float alpha; + uniform vec4 nancolor; + + varying vec2 coords; + + %s + + const float oneOverLog10 = 0.43429448190325176; + + void main(void) { + float data = texture2D(data, textureCoords()).r; + float value = data; + if (cmap_normalization == 1) { /*Logarithm mapping*/ + if (value > 0.) { + value = clamp(cmap_oneOverRange * + (oneOverLog10 * log(value) - cmap_min), + 0., 1.); + } else { + value = 0.; + } + } else if (cmap_normalization == 2) { /*Square root mapping*/ + if (value >= 0.) { + value = clamp(cmap_oneOverRange * (sqrt(value) - cmap_min), + 0., 1.); + } else { + value = 0.; + } + } else if (cmap_normalization == 3) { /*Gamma correction mapping*/ + value = pow( + clamp(cmap_oneOverRange * (value - cmap_min), 0., 1.), + cmap_parameter); + } else if (cmap_normalization == 4) { /* arcsinh mapping */ + /* asinh = log(x + sqrt(x*x + 1) for compatibility with GLSL 1.20 */ + value = clamp(cmap_oneOverRange * (log(value + sqrt(value*value + 1.0)) - cmap_min), 0., 1.); + } else { /*Linear mapping and fallback*/ + value = clamp(cmap_oneOverRange * (value - cmap_min), 0., 1.); + } + + if (isnan(data)) { + gl_FragColor = nancolor; + } else { + gl_FragColor = texture2D(cmap_texture, vec2(value, 0.5)); + } + gl_FragColor.a *= alpha; + } + """ + } + + _DATA_TEX_UNIT = 0 + _CMAP_TEX_UNIT = 1 + + _INTERNAL_FORMATS = { + numpy.dtype(numpy.float32): gl.GL_R32F, + numpy.dtype(numpy.float16): gl.GL_R16F, + # Use normalized integer for unsigned int formats + numpy.dtype(numpy.uint16): gl.GL_R16, + numpy.dtype(numpy.uint8): gl.GL_R8, + } + + _linearProgram = Program(_SHADERS['linear']['vertex'], + _SHADERS['fragment'] % + _SHADERS['linear']['fragTransform'], + attrib0='position') + + _logProgram = Program(_SHADERS['log']['vertex'], + _SHADERS['fragment'] % + _SHADERS['log']['fragTransform'], + attrib0='position') + + SUPPORTED_NORMALIZATIONS = 'linear', 'log', 'sqrt', 'gamma', 'arcsinh' + + def __init__(self, data, origin, scale, + colormap, normalization='linear', gamma=0., cmapRange=None, + alpha=1.0, nancolor=(1., 1., 1., 0.)): + """Create a 2D colormap + + :param data: The 2D scalar data array to display + :type data: numpy.ndarray with 2 dimensions (dtype=numpy.float32) + :param origin: (x, y) coordinates of the origin of the data array + :type origin: 2-tuple of floats. + :param scale: (sx, sy) scale factors of the data array. + This is the size of a data pixel in plot data space. + :type scale: 2-tuple of floats. + :param str colormap: Name of the colormap to use + TODO: Accept a 1D scalar array as the colormap + :param str normalization: The colormap normalization. + One of: 'linear', 'log', 'sqrt', 'gamma' + ;param float gamma: The gamma parameter (for 'gamma' normalization) + :param cmapRange: The range of colormap or None for autoscale colormap + For logarithmic colormap, the range is in the untransformed data + TODO: check consistency with matplotlib + :type cmapRange: (float, float) or None + :param float alpha: Opacity from 0 (transparent) to 1 (opaque) + :param nancolor: RGBA color for Not-A-Number values + :type nancolor: 4-tuple of float in [0., 1.] + """ + assert data.dtype in self._INTERNAL_FORMATS + assert normalization in self.SUPPORTED_NORMALIZATIONS + + super(GLPlotColormap, self).__init__(data, origin, scale) + self.colormap = numpy.array(colormap, copy=False) + self.normalization = normalization + self.gamma = gamma + self._cmapRange = (1., 10.) # Colormap range + self.cmapRange = cmapRange # Update _cmapRange + self._alpha = numpy.clip(alpha, 0., 1.) + self._nancolor = numpy.clip(nancolor, 0., 1.) + + self._cmap_texture = None + self._texture = None + self._textureIsDirty = False + + def discard(self): + if self._cmap_texture is not None: + self._cmap_texture.discard() + self._cmap_texture = None + + if self._texture is not None: + self._texture.discard() + self._texture = None + self._textureIsDirty = False + + def isInitialized(self): + return (self._cmap_texture is not None or + self._texture is not None) + + @property + def cmapRange(self): + if self.normalization == 'log': + assert self._cmapRange[0] > 0. and self._cmapRange[1] > 0. + elif self.normalization == 'sqrt': + assert self._cmapRange[0] >= 0. and self._cmapRange[1] >= 0. + return self._cmapRange + + @cmapRange.setter + def cmapRange(self, cmapRange): + assert len(cmapRange) == 2 + assert cmapRange[0] <= cmapRange[1] + self._cmapRange = float(cmapRange[0]), float(cmapRange[1]) + + @property + def alpha(self): + return self._alpha + + def updateData(self, data): + assert data.dtype in self._INTERNAL_FORMATS + oldData = self.data + self.data = data + + if self._texture is not None: + if (self.data.shape != oldData.shape or + self.data.dtype != oldData.dtype): + self.discard() + else: + self._textureIsDirty = True + + def prepare(self): + if self._cmap_texture is None: + # TODO share cmap texture accross Images + # put all cmaps in one texture + colormap = numpy.empty((16, 256, self.colormap.shape[1]), + dtype=self.colormap.dtype) + colormap[:] = self.colormap + format_ = gl.GL_RGBA if colormap.shape[-1] == 4 else gl.GL_RGB + self._cmap_texture = Texture(internalFormat=format_, + data=colormap, + format_=format_, + texUnit=self._CMAP_TEX_UNIT, + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=(gl.GL_CLAMP_TO_EDGE, + gl.GL_CLAMP_TO_EDGE)) + self._cmap_texture.prepare() + + if self._texture is None: + internalFormat = self._INTERNAL_FORMATS[self.data.dtype] + + self._texture = Image(internalFormat, + self.data, + format_=gl.GL_RED, + texUnit=self._DATA_TEX_UNIT) + elif self._textureIsDirty: + self._textureIsDirty = True + self._texture.updateAll(format_=gl.GL_RED, data=self.data) + + def _setCMap(self, prog): + dataMin, dataMax = self.cmapRange # If log, it is stricly positive + param = 0. + + if self.data.dtype in (numpy.uint16, numpy.uint8): + # Using unsigned int as normalized integer in OpenGL + # So normalize range + maxInt = float(numpy.iinfo(self.data.dtype).max) + dataMin, dataMax = dataMin / maxInt, dataMax / maxInt + + if self.normalization == 'log': + dataMin = math.log10(dataMin) + dataMax = math.log10(dataMax) + normID = 1 + elif self.normalization == 'sqrt': + dataMin = math.sqrt(dataMin) + dataMax = math.sqrt(dataMax) + normID = 2 + elif self.normalization == 'gamma': + # Keep dataMin, dataMax as is + param = self.gamma + normID = 3 + elif self.normalization == 'arcsinh': + dataMin = numpy.arcsinh(dataMin) + dataMax = numpy.arcsinh(dataMax) + normID = 4 + else: # Linear and fallback + normID = 0 + + gl.glUniform1i(prog.uniforms['cmap_texture'], + self._cmap_texture.texUnit) + gl.glUniform1i(prog.uniforms['cmap_normalization'], normID) + gl.glUniform1f(prog.uniforms['cmap_parameter'], param) + gl.glUniform1f(prog.uniforms['cmap_min'], dataMin) + if dataMax > dataMin: + oneOverRange = 1. / (dataMax - dataMin) + else: + oneOverRange = 0. # Fall-back + gl.glUniform1f(prog.uniforms['cmap_oneOverRange'], oneOverRange) + + gl.glUniform4f(prog.uniforms['nancolor'], *self._nancolor) + + self._cmap_texture.bind() + + def _renderLinear(self, context): + """Perform rendering when both axes have linear scales + + :param RenderContext context: Rendering information + """ + self.prepare() + + prog = self._linearProgram + prog.use() + + gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT) + + mat = numpy.dot(numpy.dot(context.matrix, + mat4Translate(*self.origin)), + mat4Scale(*self.scale)) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + mat.astype(numpy.float32)) + + gl.glUniform1f(prog.uniforms['alpha'], self.alpha) + + self._setCMap(prog) + + self._texture.render(prog.attributes['position'], + prog.attributes['texCoords'], + self._DATA_TEX_UNIT) + + def _renderLog10(self, context): + """Perform rendering when one axis has log scale + + :param RenderContext context: Rendering information + """ + xMin, yMin = self.xMin, self.yMin + if ((context.isXLog and xMin < FLOAT32_MINPOS) or + (context.isYLog and yMin < FLOAT32_MINPOS)): + # Do not render images that are partly or totally <= 0 + return + + self.prepare() + + prog = self._logProgram + prog.use() + + ox, oy = self.origin + + gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT) + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + context.matrix.astype(numpy.float32)) + mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale)) + gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, + mat.astype(numpy.float32)) + + gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog) + + ex = ox + self.scale[0] * self.data.shape[1] + ey = oy + self.scale[1] * self.data.shape[0] + + xOneOverRange = 1. / (ex - ox) + yOneOverRange = 1. / (ey - oy) + gl.glUniform2f(prog.uniforms['bounds_originOverRange'], + ox * xOneOverRange, oy * yOneOverRange) + gl.glUniform2f(prog.uniforms['bounds_oneOverRange'], + xOneOverRange, yOneOverRange) + + gl.glUniform1f(prog.uniforms['alpha'], self.alpha) + + self._setCMap(prog) + + try: + tiles = self._texture.tiles + except AttributeError: + raise RuntimeError("No texture, discard has already been called") + if len(tiles) > 1: + raise NotImplementedError( + "Image over multiple textures not supported with log scale") + + texture, vertices, info = tiles[0] + + texture.bind(self._DATA_TEX_UNIT) + + posAttrib = prog.attributes['position'] + stride = vertices.shape[-1] * vertices.itemsize + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, vertices) + + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices)) + + def render(self, context): + """Perform rendering + + :param RenderContext context: Rendering information + """ + if any((context.isXLog, context.isYLog)): + self._renderLog10(context) + else: + self._renderLinear(context) + + # Unbind colormap texture + gl.glActiveTexture(gl.GL_TEXTURE0 + self._cmap_texture.texUnit) + gl.glBindTexture(self._cmap_texture.target, 0) + + +# image ####################################################################### + +class GLPlotRGBAImage(_GLPlotData2D): + + _SHADERS = { + 'linear': { + 'vertex': """ + #version 120 + + attribute vec2 position; + attribute vec2 texCoords; + uniform mat4 matrix; + + varying vec2 coords; + + void main(void) { + gl_Position = matrix * vec4(position, 0.0, 1.0); + coords = texCoords; + } + """, + 'fragment': """ + #version 120 + + uniform sampler2D tex; + uniform float alpha; + + varying vec2 coords; + + void main(void) { + gl_FragColor = texture2D(tex, coords); + gl_FragColor.a *= alpha; + } + """}, + + 'log': { + 'vertex': """ + #version 120 + + attribute vec2 position; + uniform mat4 matrix; + uniform mat4 matOffset; + uniform bvec2 isLog; + + varying vec2 coords; + + const float oneOverLog10 = 0.43429448190325176; + + void main(void) { + vec4 dataPos = matOffset * vec4(position, 0.0, 1.0); + if (isLog.x) { + dataPos.x = oneOverLog10 * log(dataPos.x); + } + if (isLog.y) { + dataPos.y = oneOverLog10 * log(dataPos.y); + } + coords = dataPos.xy; + gl_Position = matrix * dataPos; + } + """, + 'fragment': """ + #version 120 + + uniform sampler2D tex; + uniform bvec2 isLog; + uniform vec2 bounds_oneOverRange; + uniform vec2 bounds_originOverRange; + uniform float alpha; + + varying vec2 coords; + + vec2 textureCoords(void) { + vec2 pos = coords; + if (isLog.x) { + pos.x = pow(10., coords.x); + } + if (isLog.y) { + pos.y = pow(10., coords.y); + } + return pos * bounds_oneOverRange - bounds_originOverRange; + // TODO texture coords in range different from [0, 1] + } + + void main(void) { + gl_FragColor = texture2D(tex, textureCoords()); + gl_FragColor.a *= alpha; + } + """} + } + + _DATA_TEX_UNIT = 0 + + _SUPPORTED_DTYPES = (numpy.dtype(numpy.float32), + numpy.dtype(numpy.uint8), + numpy.dtype(numpy.uint16)) + + _linearProgram = Program(_SHADERS['linear']['vertex'], + _SHADERS['linear']['fragment'], + attrib0='position') + + _logProgram = Program(_SHADERS['log']['vertex'], + _SHADERS['log']['fragment'], + attrib0='position') + + def __init__(self, data, origin, scale, alpha): + """Create a 2D RGB(A) image from data + + :param data: The 2D image data array to display + :type data: numpy.ndarray with 3 dimensions + (dtype=numpy.uint8 or numpy.float32) + :param origin: (x, y) coordinates of the origin of the data array + :type origin: 2-tuple of floats. + :param scale: (sx, sy) scale factors of the data array. + This is the size of a data pixel in plot data space. + :type scale: 2-tuple of floats. + :param float alpha: Opacity from 0 (transparent) to 1 (opaque) + """ + assert data.dtype in self._SUPPORTED_DTYPES + super(GLPlotRGBAImage, self).__init__(data, origin, scale) + self._texture = None + self._textureIsDirty = False + self._alpha = numpy.clip(alpha, 0., 1.) + + @property + def alpha(self): + return self._alpha + + def discard(self): + if self.isInitialized(): + self._texture.discard() + self._texture = None + self._textureIsDirty = False + + def isInitialized(self): + return self._texture is not None + + def updateData(self, data): + assert data.dtype in self._SUPPORTED_DTYPES + oldData = self.data + self.data = data + + if self._texture is not None: + if self.data.shape != oldData.shape: + self.discard() + else: + self._textureIsDirty = True + + def prepare(self): + if self._texture is None: + formatName = 'GL_RGBA' if self.data.shape[2] == 4 else 'GL_RGB' + format_ = getattr(gl, formatName) + + if self.data.dtype == numpy.uint16: + formatName += '16' # Use sized internal format for uint16 + internalFormat = getattr(gl, formatName) + + self._texture = Image(internalFormat, + self.data, + format_=format_, + texUnit=self._DATA_TEX_UNIT) + elif self._textureIsDirty: + self._textureIsDirty = False + + # We should check that internal format is the same + format_ = gl.GL_RGBA if self.data.shape[2] == 4 else gl.GL_RGB + self._texture.updateAll(format_=format_, data=self.data) + + def _renderLinear(self, context): + """Perform rendering with both axes having linear scales + + :param RenderContext context: Rendering information + """ + self.prepare() + + prog = self._linearProgram + prog.use() + + gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT) + + mat = numpy.dot(numpy.dot(context.matrix, mat4Translate(*self.origin)), + mat4Scale(*self.scale)) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + mat.astype(numpy.float32)) + + gl.glUniform1f(prog.uniforms['alpha'], self.alpha) + + self._texture.render(prog.attributes['position'], + prog.attributes['texCoords'], + self._DATA_TEX_UNIT) + + def _renderLog(self, context): + """Perform rendering with axes having log scale + + :param RenderContext context: Rendering information + """ + self.prepare() + + prog = self._logProgram + prog.use() + + ox, oy = self.origin + + gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT) + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + context.matrix.astype(numpy.float32)) + mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale)) + gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, + mat.astype(numpy.float32)) + + gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog) + + gl.glUniform1f(prog.uniforms['alpha'], self.alpha) + + ex = ox + self.scale[0] * self.data.shape[1] + ey = oy + self.scale[1] * self.data.shape[0] + + xOneOverRange = 1. / (ex - ox) + yOneOverRange = 1. / (ey - oy) + gl.glUniform2f(prog.uniforms['bounds_originOverRange'], + ox * xOneOverRange, oy * yOneOverRange) + gl.glUniform2f(prog.uniforms['bounds_oneOverRange'], + xOneOverRange, yOneOverRange) + + try: + tiles = self._texture.tiles + except AttributeError: + raise RuntimeError("No texture, discard has already been called") + if len(tiles) > 1: + raise NotImplementedError( + "Image over multiple textures not supported with log scale") + + texture, vertices, info = tiles[0] + + texture.bind(self._DATA_TEX_UNIT) + + posAttrib = prog.attributes['position'] + stride = vertices.shape[-1] * vertices.itemsize + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, vertices) + + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices)) + + def render(self, context): + """Perform rendering + + :param RenderContext context: Rendering information + """ + if any((context.isXLog, context.isYLog)): + self._renderLog(context) + else: + self._renderLinear(context) diff --git a/src/silx/gui/plot/backends/glutils/GLPlotItem.py b/src/silx/gui/plot/backends/glutils/GLPlotItem.py new file mode 100644 index 0000000..ae13091 --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/GLPlotItem.py @@ -0,0 +1,99 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2020-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides a base class for PlotWidget OpenGL backend primitives +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/07/2020" + + +class RenderContext: + """Context with which to perform OpenGL rendering. + + :param numpy.ndarray matrix: 4x4 transform matrix to use for rendering + :param bool isXLog: Whether X axis is log scale or not + :param bool isYLog: Whether Y axis is log scale or not + :param float dpi: Number of device pixels per inch + """ + + def __init__(self, matrix=None, isXLog=False, isYLog=False, dpi=96.): + self.matrix = matrix + """Current transformation matrix""" + + self.__isXLog = isXLog + self.__isYLog = isYLog + self.__dpi = dpi + + @property + def isXLog(self): + """True if X axis is using log scale""" + return self.__isXLog + + @property + def isYLog(self): + """True if Y axis is using log scale""" + return self.__isYLog + + @property + def dpi(self): + """Number of device pixels per inch""" + return self.__dpi + + +class GLPlotItem: + """Base class for primitives used in the PlotWidget OpenGL backend""" + + def __init__(self): + self.yaxis = 'left' + "YAxis this item is attached to (either 'left' or 'right')" + + def pick(self, x, y): + """Perform picking at given position. + + :param float x: X coordinate in plot data frame of reference + :param float y: Y coordinate in plot data frame of reference + :returns: + Result of picking as a list of indices or None if nothing picked + :rtype: Union[List[int],None] + """ + return None + + def render(self, context): + """Performs OpenGL rendering of the item. + + :param RenderContext context: Rendering context information + """ + pass + + def discard(self): + """Discards OpenGL resources this item has created.""" + pass + + def isInitialized(self) -> bool: + """Returns True if resources where initialized and requires `discard`. + """ + return True diff --git a/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py new file mode 100644 index 0000000..fbe9e02 --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py @@ -0,0 +1,197 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides a class to render a set of 2D triangles +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import ctypes + +import numpy + +from .....math.combo import min_max +from .... import _glutils as glutils +from ...._glutils import gl +from .GLPlotItem import GLPlotItem + + +class GLPlotTriangles(GLPlotItem): + """Handle rendering of a set of colored triangles""" + + _PROGRAM = glutils.Program( + vertexShader=""" + #version 120 + + uniform mat4 matrix; + attribute float xPos; + attribute float yPos; + attribute vec4 color; + + varying vec4 vColor; + + void main(void) { + gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0); + vColor = color; + } + """, + fragmentShader=""" + #version 120 + + uniform float alpha; + varying vec4 vColor; + + void main(void) { + gl_FragColor = vColor; + gl_FragColor.a *= alpha; + } + """, + attrib0='xPos') + + def __init__(self, x, y, color, triangles, alpha=1.): + """ + + :param numpy.ndarray x: X coordinates of triangle corners + :param numpy.ndarray y: Y coordinates of triangle corners + :param numpy.ndarray color: color for each point + :param numpy.ndarray triangles: (N, 3) array of indices of triangles + :param float alpha: Opacity in [0, 1] + """ + super().__init__() + # Check and convert input data + x = numpy.ravel(numpy.array(x, dtype=numpy.float32)) + y = numpy.ravel(numpy.array(y, dtype=numpy.float32)) + color = numpy.array(color, copy=False) + # Cast to uint32 + triangles = numpy.array(triangles, copy=False, dtype=numpy.uint32) + + assert x.size == y.size + assert x.size == len(color) + assert color.ndim == 2 and color.shape[1] in (3, 4) + if numpy.issubdtype(color.dtype, numpy.floating): + color = numpy.array(color, dtype=numpy.float32, copy=False) + elif numpy.issubdtype(color.dtype, numpy.integer): + color = numpy.array(color, dtype=numpy.uint8, copy=False) + else: + raise ValueError('Unsupported color type') + assert triangles.ndim == 2 and triangles.shape[1] == 3 + + self.__x_y_color = x, y, color + self.xMin, self.xMax = min_max(x, finite=True) + self.yMin, self.yMax = min_max(y, finite=True) + self.__triangles = triangles + self.__alpha = numpy.clip(float(alpha), 0., 1.) + self.__vbos = None + self.__indicesVbo = None + self.__picking_triangles = None + + def pick(self, x, y): + """Perform picking + + :param float x: X coordinates in plot data frame + :param float y: Y coordinates in plot data frame + :return: List of picked data point indices + :rtype: Union[List[int],None] + """ + if (x < self.xMin or x > self.xMax or + y < self.yMin or y > self.yMax): + return None + + xPts, yPts = self.__x_y_color[:2] + if self.__picking_triangles is None: + self.__picking_triangles = numpy.zeros( + self.__triangles.shape + (3,), dtype=numpy.float32) + self.__picking_triangles[:, :, 0] = xPts[self.__triangles] + self.__picking_triangles[:, :, 1] = yPts[self.__triangles] + + segment = numpy.array(((x, y, -1), (x, y, 1)), dtype=numpy.float32) + # Picked triangle indices + indices = glutils.segmentTrianglesIntersection( + segment, self.__picking_triangles)[0] + # Point indices + indices = numpy.unique(numpy.ravel(self.__triangles[indices])) + + # Sorted from furthest to closest point + dists = (xPts[indices] - x) ** 2 + (yPts[indices] - y) ** 2 + indices = indices[numpy.flip(numpy.argsort(dists), axis=0)] + + return tuple(indices) if len(indices) > 0 else None + + def discard(self): + """Release resources on the GPU""" + if self.isInitialized(): + self.__vbos[0].vbo.discard() + self.__vbos = None + self.__indicesVbo.discard() + self.__indicesVbo = None + + def isInitialized(self): + return self.__vbos is not None + + def prepare(self): + """Allocate resources on the GPU""" + if self.__vbos is None: + self.__vbos = glutils.vertexBuffer(self.__x_y_color) + # Normalization is need for color + self.__vbos[-1].normalization = True + + if self.__indicesVbo is None: + self.__indicesVbo = glutils.VertexBuffer( + numpy.ravel(self.__triangles), + usage=gl.GL_STATIC_DRAW, + target=gl.GL_ELEMENT_ARRAY_BUFFER) + + def render(self, context): + """Perform rendering + + :param RenderContext context: Rendering information + """ + self.prepare() + + if self.__vbos is None or self.__indicesVbo is None: + return # Nothing to display + + self._PROGRAM.use() + + gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], + 1, + gl.GL_TRUE, + context.matrix.astype(numpy.float32)) + + gl.glUniform1f(self._PROGRAM.uniforms['alpha'], self.__alpha) + + for index, name in enumerate(('xPos', 'yPos', 'color')): + attr = self._PROGRAM.attributes[name] + gl.glEnableVertexAttribArray(attr) + self.__vbos[index].setVertexAttrib(attr) + + with self.__indicesVbo: + gl.glDrawElements(gl.GL_TRIANGLES, + self.__triangles.size, + glutils.numpyToGLType(self.__triangles.dtype), + ctypes.c_void_p(0)) diff --git a/src/silx/gui/plot/backends/glutils/GLSupport.py b/src/silx/gui/plot/backends/glutils/GLSupport.py new file mode 100644 index 0000000..da6dffa --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/GLSupport.py @@ -0,0 +1,158 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides convenient classes and functions for OpenGL rendering. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import numpy + +from ...._glutils import gl + + +def buildFillMaskIndices(nIndices, dtype=None): + """Returns triangle strip indices for rendering a filled polygon mask + + :param int nIndices: Number of points + :param Union[numpy.dtype,None] dtype: + If specified the dtype of the returned indices array + :return: 1D array of indices constructing a triangle strip + :rtype: numpy.ndarray + """ + if dtype is None: + if nIndices <= numpy.iinfo(numpy.uint16).max + 1: + dtype = numpy.uint16 + else: + dtype = numpy.uint32 + + lastIndex = nIndices - 1 + splitIndex = lastIndex // 2 + 1 + indices = numpy.empty(nIndices, dtype=dtype) + indices[::2] = numpy.arange(0, splitIndex, step=1, dtype=dtype) + indices[1::2] = numpy.arange(lastIndex, splitIndex - 1, step=-1, + dtype=dtype) + return indices + + +class FilledShape2D(object): + _NO_HATCH = 0 + _HATCH_STEP = 20 + + def __init__(self, points, style='solid', color=(0., 0., 0., 1.)): + self.vertices = numpy.array(points, dtype=numpy.float32, copy=False) + self._indices = buildFillMaskIndices(len(self.vertices)) + + tVertex = numpy.transpose(self.vertices) + xMin, xMax = min(tVertex[0]), max(tVertex[0]) + yMin, yMax = min(tVertex[1]), max(tVertex[1]) + self.bboxVertices = numpy.array(((xMin, yMin), (xMin, yMax), + (xMax, yMin), (xMax, yMax)), + dtype=numpy.float32) + self._xMin, self._xMax = xMin, xMax + self._yMin, self._yMax = yMin, yMax + + self.style = style + self.color = color + + def render(self, posAttrib, colorUnif, hatchStepUnif): + assert self.style in ('hatch', 'solid') + gl.glUniform4f(colorUnif, *self.color) + step = self._HATCH_STEP if self.style == 'hatch' else self._NO_HATCH + gl.glUniform1i(hatchStepUnif, step) + + # Prepare fill mask + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, self.vertices) + + gl.glEnable(gl.GL_STENCIL_TEST) + gl.glStencilMask(1) + gl.glStencilFunc(gl.GL_ALWAYS, 1, 1) + gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT) + gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE) + gl.glDepthMask(gl.GL_FALSE) + + gl.glDrawElements(gl.GL_TRIANGLE_STRIP, len(self._indices), + gl.GL_UNSIGNED_SHORT, self._indices) + + gl.glStencilFunc(gl.GL_EQUAL, 1, 1) + # Reset stencil while drawing + gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO) + gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE) + gl.glDepthMask(gl.GL_TRUE) + + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, self.bboxVertices) + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self.bboxVertices)) + + gl.glDisable(gl.GL_STENCIL_TEST) + + +# matrix ###################################################################### + +def mat4Ortho(left, right, bottom, top, near, far): + """Orthographic projection matrix (row-major)""" + return numpy.array(( + (2./(right - left), 0., 0., -(right+left)/float(right-left)), + (0., 2./(top - bottom), 0., -(top+bottom)/float(top-bottom)), + (0., 0., -2./(far-near), -(far+near)/float(far-near)), + (0., 0., 0., 1.)), dtype=numpy.float64) + + +def mat4Translate(x=0., y=0., z=0.): + """Translation matrix (row-major)""" + return numpy.array(( + (1., 0., 0., x), + (0., 1., 0., y), + (0., 0., 1., z), + (0., 0., 0., 1.)), dtype=numpy.float64) + + +def mat4Scale(sx=1., sy=1., sz=1.): + """Scale matrix (row-major)""" + return numpy.array(( + (sx, 0., 0., 0.), + (0., sy, 0., 0.), + (0., 0., sz, 0.), + (0., 0., 0., 1.)), dtype=numpy.float64) + + +def mat4Identity(): + """Identity matrix""" + return numpy.array(( + (1., 0., 0., 0.), + (0., 1., 0., 0.), + (0., 0., 1., 0.), + (0., 0., 0., 1.)), dtype=numpy.float64) diff --git a/src/silx/gui/plot/backends/glutils/GLText.py b/src/silx/gui/plot/backends/glutils/GLText.py new file mode 100644 index 0000000..d6ae6fa --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/GLText.py @@ -0,0 +1,287 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides minimalistic text support for OpenGL. +It provides Latin-1 (ISO8859-1) characters for one monospace font at one size. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +from collections import OrderedDict +import weakref + +import numpy + +from ...._glutils import font, gl, Context, Program, Texture +from .GLSupport import mat4Translate + + +# TODO: Font should be configurable by the main program: using mpl.rcParams? + + +class _Cache(object): + """LRU (Least Recent Used) cache. + + :param int maxsize: Maximum number of (key, value) pairs in the cache + :param callable callback: + Called when a (key, value) pair is removed from the cache. + It must take 2 arguments: key and value. + """ + + def __init__(self, maxsize=128, callback=None): + self._maxsize = int(maxsize) + self._callback = callback + self._cache = OrderedDict() + + def __contains__(self, item): + return item in self._cache + + def __getitem__(self, key): + if key in self._cache: + # Remove/add key from ordered dict to store last access info + value = self._cache.pop(key) + self._cache[key] = value + return value + else: + raise KeyError + + def __setitem__(self, key, value): + """Add a key, value pair to the cache. + + :param key: The key to set + :param value: The corresponding value + """ + if key not in self._cache and len(self._cache) >= self._maxsize: + removedKey, removedValue = self._cache.popitem(last=False) + if self._callback is not None: + self._callback(removedKey, removedValue) + self._cache[key] = value + + +# Text2D ###################################################################### + +LEFT, CENTER, RIGHT = 'left', 'center', 'right' +TOP, BASELINE, BOTTOM = 'top', 'baseline', 'bottom' +ROTATE_90, ROTATE_180, ROTATE_270 = 90, 180, 270 + + +class Text2D(object): + + _SHADERS = { + 'vertex': """ + #version 120 + + attribute vec2 position; + attribute vec2 texCoords; + uniform mat4 matrix; + + varying vec2 vCoords; + + void main(void) { + gl_Position = matrix * vec4(position, 0.0, 1.0); + vCoords = texCoords; + } + """, + 'fragment': """ + #version 120 + + uniform sampler2D texText; + uniform vec4 color; + uniform vec4 bgColor; + + varying vec2 vCoords; + + void main(void) { + gl_FragColor = mix(bgColor, color, texture2D(texText, vCoords).r); + } + """ + } + + _TEX_COORDS = numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)), + dtype=numpy.float32).ravel() + + _program = Program(_SHADERS['vertex'], + _SHADERS['fragment'], + attrib0='position') + + # Discard texture objects when removed from the cache + _textures = weakref.WeakKeyDictionary() + """Cache already created textures""" + + _sizes = _Cache() + """Cache already computed sizes""" + + def __init__(self, text, x=0, y=0, + color=(0., 0., 0., 1.), + bgColor=None, + align=LEFT, valign=BASELINE, + rotate=0, + devicePixelRatio= 1.): + self.devicePixelRatio = devicePixelRatio + self._vertices = None + self._text = text + self.x = x + self.y = y + self.color = color + self.bgColor = bgColor + + if align not in (LEFT, CENTER, RIGHT): + raise ValueError( + "Horizontal alignment not supported: {0}".format(align)) + self._align = align + + if valign not in (TOP, CENTER, BASELINE, BOTTOM): + raise ValueError( + "Vertical alignment not supported: {0}".format(valign)) + self._valign = valign + + self._rotate = numpy.radians(rotate) + + def _getTexture(self, text, devicePixelRatio): + # Retrieve/initialize texture cache for current context + textureKey = text, devicePixelRatio + + context = Context.getCurrent() + if context not in self._textures: + self._textures[context] = _Cache( + callback=lambda key, value: value[0].discard()) + textures = self._textures[context] + + if textureKey not in textures: + image, offset = font.rasterText( + text, + font.getDefaultFontFamily(), + devicePixelRatio=self.devicePixelRatio) + if textureKey not in self._sizes: + self._sizes[textureKey] = image.shape[1], image.shape[0] + + texture = Texture( + gl.GL_RED, + data=image, + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=(gl.GL_CLAMP_TO_EDGE, + gl.GL_CLAMP_TO_EDGE)) + texture.prepare() + textures[textureKey] = texture, offset + + return textures[textureKey] + + @property + def text(self): + return self._text + + @property + def size(self): + textureKey = self.text, self.devicePixelRatio + if textureKey not in self._sizes: + image, offset = font.rasterText( + self.text, + font.getDefaultFontFamily(), + devicePixelRatio=self.devicePixelRatio) + self._sizes[textureKey] = image.shape[1], image.shape[0] + return self._sizes[textureKey] + + def getVertices(self, offset, shape): + height, width = shape + + if self._align == LEFT: + xOrig = 0 + elif self._align == RIGHT: + xOrig = - width + else: # CENTER + xOrig = - width // 2 + + if self._valign == BASELINE: + yOrig = - offset + elif self._valign == TOP: + yOrig = 0 + elif self._valign == BOTTOM: + yOrig = - height + else: # CENTER + yOrig = - height // 2 + + vertices = numpy.array(( + (xOrig, yOrig), + (xOrig + width, yOrig), + (xOrig, yOrig + height), + (xOrig + width, yOrig + height)), dtype=numpy.float32) + + cos, sin = numpy.cos(self._rotate), numpy.sin(self._rotate) + vertices = numpy.ascontiguousarray(numpy.transpose(numpy.array(( + cos * vertices[:, 0] - sin * vertices[:, 1], + sin * vertices[:, 0] + cos * vertices[:, 1]), + dtype=numpy.float32))) + + return vertices + + def render(self, matrix): + if not self.text: + return + + prog = self._program + prog.use() + + texUnit = 0 + texture, offset = self._getTexture(self.text, self.devicePixelRatio) + + gl.glUniform1i(prog.uniforms['texText'], texUnit) + + mat = numpy.dot(matrix, mat4Translate(int(self.x), int(self.y))) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + mat.astype(numpy.float32)) + + gl.glUniform4f(prog.uniforms['color'], *self.color) + if self.bgColor is not None: + bgColor = self.bgColor + else: + bgColor = self.color[0], self.color[1], self.color[2], 0. + gl.glUniform4f(prog.uniforms['bgColor'], *bgColor) + + vertices = self.getVertices(offset, texture.shape) + + posAttrib = prog.attributes['position'] + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, + vertices) + + texAttrib = prog.attributes['texCoords'] + gl.glEnableVertexAttribArray(texAttrib) + gl.glVertexAttribPointer(texAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, + self._TEX_COORDS) + + with texture: + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4) diff --git a/src/silx/gui/plot/backends/glutils/GLTexture.py b/src/silx/gui/plot/backends/glutils/GLTexture.py new file mode 100644 index 0000000..37fbdd0 --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/GLTexture.py @@ -0,0 +1,241 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""This module provides classes wrapping OpenGL texture.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +from ctypes import c_void_p +import logging + +import numpy + +from ...._glutils import gl, Texture, numpyToGLType + + +_logger = logging.getLogger(__name__) + + +def _checkTexture2D(internalFormat, shape, + format_=None, type_=gl.GL_FLOAT, border=0): + """Check if texture size with provided parameters is supported + + :rtype: bool + """ + height, width = shape + gl.glTexImage2D(gl.GL_PROXY_TEXTURE_2D, 0, internalFormat, + width, height, border, + format_ or internalFormat, + type_, c_void_p(0)) + width = gl.glGetTexLevelParameteriv( + gl.GL_PROXY_TEXTURE_2D, 0, gl.GL_TEXTURE_WIDTH) + return bool(width) + + +MIN_TEXTURE_SIZE = 64 + + +def _getMaxSquareTexture2DSize(internalFormat=gl.GL_RGBA, + format_=None, + type_=gl.GL_FLOAT, + border=0): + """Returns a supported size for a corresponding square texture + + :returns: GL_MAX_TEXTURE_SIZE or a smaller supported size (not optimal) + :rtype: int + """ + # Is this useful? + maxTexSize = gl.glGetIntegerv(gl.GL_MAX_TEXTURE_SIZE) + while maxTexSize > MIN_TEXTURE_SIZE and \ + not _checkTexture2D(internalFormat, (maxTexSize, maxTexSize), + format_, type_, border): + maxTexSize //= 2 + return max(MIN_TEXTURE_SIZE, maxTexSize) + + +class Image(object): + """Image of any size eventually using multiple textures or larger texture + """ + + _WRAP = (gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE) + _MIN_FILTER = gl.GL_NEAREST + _MAG_FILTER = gl.GL_NEAREST + + def __init__(self, internalFormat, data, format_=None, texUnit=0): + self.internalFormat = internalFormat + self.height, self.width = data.shape[0:2] + type_ = numpyToGLType(data.dtype) + + if _checkTexture2D(internalFormat, data.shape[0:2], format_, type_): + texture = Texture(internalFormat, + data, + format_, + texUnit=texUnit, + minFilter=self._MIN_FILTER, + magFilter=self._MAG_FILTER, + wrap=self._WRAP) + texture.prepare() + vertices = numpy.array(( + (0., 0., 0., 0.), + (self.width, 0., 1., 0.), + (0., self.height, 0., 1.), + (self.width, self.height, 1., 1.)), dtype=numpy.float32) + self.tiles = ((texture, vertices, + {'xOrigData': 0, 'yOrigData': 0, + 'wData': self.width, 'hData': self.height}),) + + else: + # Handle dimension too large: make tiles + maxTexSize = _getMaxSquareTexture2DSize(internalFormat, + format_, type_) + + nCols = (self.width+maxTexSize-1) // maxTexSize + colWidths = [self.width // nCols] * nCols + colWidths[-1] += self.width % nCols + + nRows = (self.height+maxTexSize-1) // maxTexSize + rowHeights = [self.height//nRows] * nRows + rowHeights[-1] += self.height % nRows + + tiles = [] + yOrig = 0 + for hData in rowHeights: + xOrig = 0 + for wData in colWidths: + if (hData < MIN_TEXTURE_SIZE or wData < MIN_TEXTURE_SIZE) \ + and not _checkTexture2D(internalFormat, + (hData, wData), + format_, + type_): + # Ensure texture size is at least MIN_TEXTURE_SIZE + tH = max(hData, MIN_TEXTURE_SIZE) + tW = max(wData, MIN_TEXTURE_SIZE) + + uMax, vMax = float(wData)/tW, float(hData)/tH + + # TODO issue with type_ and alignment + texture = Texture(internalFormat, + data=None, + format_=format_, + shape=(tH, tW), + texUnit=texUnit, + minFilter=self._MIN_FILTER, + magFilter=self._MAG_FILTER, + wrap=self._WRAP) + # TODO handle unpack + texture.update(format_, + data[yOrig:yOrig+hData, + xOrig:xOrig+wData]) + # texture.update(format_, type_, data, + # width=wData, height=hData, + # unpackRowLength=width, + # unpackSkipPixels=xOrig, + # unpackSkipRows=yOrig) + else: + uMax, vMax = 1, 1 + # TODO issue with type_ and unpacking tiles + # TODO idea to handle unpack: use array strides + # As it is now, it will make a copy + texture = Texture(internalFormat, + data[yOrig:yOrig+hData, + xOrig:xOrig+wData], + format_, + texUnit=texUnit, + minFilter=self._MIN_FILTER, + magFilter=self._MAG_FILTER, + wrap=self._WRAP) + # TODO + # unpackRowLength=width, + # unpackSkipPixels=xOrig, + # unpackSkipRows=yOrig) + vertices = numpy.array(( + (xOrig, yOrig, 0., 0.), + (xOrig + wData, yOrig, uMax, 0.), + (xOrig, yOrig + hData, 0., vMax), + (xOrig + wData, yOrig + hData, uMax, vMax)), + dtype=numpy.float32) + texture.prepare() + tiles.append((texture, vertices, + {'xOrigData': xOrig, 'yOrigData': yOrig, + 'wData': wData, 'hData': hData})) + xOrig += wData + yOrig += hData + self.tiles = tuple(tiles) + + def discard(self): + for texture, vertices, _ in self.tiles: + texture.discard() + del self.tiles + + def updateAll(self, format_, data, texUnit=0): + if not hasattr(self, 'tiles'): + raise RuntimeError("No texture, discard has already been called") + + assert data.shape[:2] == (self.height, self.width) + if len(self.tiles) == 1: + self.tiles[0][0].update(format_, data, texUnit=texUnit) + else: + for texture, _, info in self.tiles: + yOrig, xOrig = info['yOrigData'], info['xOrigData'] + height, width = info['hData'], info['wData'] + texture.update(format_, + data[yOrig:yOrig+height, xOrig:xOrig+width], + texUnit=texUnit) + texture.prepare() + # TODO check + # width=info['wData'], height=info['hData'], + # texUnit=texUnit, unpackAlign=unpackAlign, + # unpackRowLength=self.width, + # unpackSkipPixels=info['xOrigData'], + # unpackSkipRows=info['yOrigData']) + + def render(self, posAttrib, texAttrib, texUnit=0): + try: + tiles = self.tiles + except AttributeError: + raise RuntimeError("No texture, discard has already been called") + + for texture, vertices, _ in tiles: + texture.bind(texUnit) + + stride = vertices.shape[-1] * vertices.itemsize + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, vertices) + + texCoordsPtr = c_void_p(vertices.ctypes.data + + 2 * vertices.itemsize) + gl.glEnableVertexAttribArray(texAttrib) + gl.glVertexAttribPointer(texAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, texCoordsPtr) + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices)) diff --git a/src/silx/gui/plot/backends/glutils/PlotImageFile.py b/src/silx/gui/plot/backends/glutils/PlotImageFile.py new file mode 100644 index 0000000..5fb6853 --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/PlotImageFile.py @@ -0,0 +1,153 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""Function to save an image to a file.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import base64 +import struct +import sys +import zlib + + +# Image writer ################################################################ + +def convertRGBDataToPNG(data): + """Convert a RGB bitmap to PNG. + + It only supports RGB bitmap with one byte per channel stored as a 3D array. + See `Definitive Guide <http://www.libpng.org/pub/png/book/>`_ and + `Specification <http://www.libpng.org/pub/png/spec/1.2/>`_ for details. + + :param data: A 3D array (h, w, rgb) storing an RGB image + :type data: numpy.ndarray of unsigned bytes + :returns: The PNG encoded data + :rtype: bytes + """ + height, width = data.shape[0], data.shape[1] + depth = 8 # 8 bit per channel + colorType = 2 # 'truecolor' = RGB + interlace = 0 # No + + IHDRdata = struct.pack(">ccccIIBBBBB", b'I', b'H', b'D', b'R', + width, height, depth, colorType, + 0, 0, interlace) + + # Add filter 'None' before each scanline + preparedData = b'\x00' + b'\x00'.join(line.tobytes() for line in data) + compressedData = zlib.compress(preparedData, 8) + + IDATdata = struct.pack("cccc", b'I', b'D', b'A', b'T') + IDATdata += compressedData + + return b''.join([ + b'\x89PNG\r\n\x1a\n', # PNG signature + # IHDR chunk: Image Header + struct.pack(">I", 13), # length + IHDRdata, + struct.pack(">I", zlib.crc32(IHDRdata) & 0xffffffff), # CRC + # IDAT chunk: Payload + struct.pack(">I", len(compressedData)), + IDATdata, + struct.pack(">I", zlib.crc32(IDATdata) & 0xffffffff), # CRC + b'\x00\x00\x00\x00IEND\xaeB`\x82' # IEND chunk: footer + ]) + + +def saveImageToFile(data, fileNameOrObj, fileFormat): + """Save a RGB image to a file. + + :param data: A 3D array (h, w, 3) storing an RGB image. + :type data: numpy.ndarray with of unsigned bytes. + :param fileNameOrObj: Filename or object to use to write the image. + :type fileNameOrObj: A str or a 'file-like' object with a 'write' method. + :param str fileFormat: The type of the file in: 'png', 'ppm', 'svg', 'tiff'. + """ + assert len(data.shape) == 3 + assert data.shape[2] == 3 + assert fileFormat in ('png', 'ppm', 'svg', 'tiff') + + if not hasattr(fileNameOrObj, 'write'): + if sys.version_info < (3, ): + fileObj = open(fileNameOrObj, "wb") + else: + if fileFormat in ('png', 'ppm', 'tiff'): + # Open in binary mode + fileObj = open(fileNameOrObj, 'wb') + else: + fileObj = open(fileNameOrObj, 'w', newline='') + else: # Use as a file-like object + fileObj = fileNameOrObj + + if fileFormat == 'svg': + height, width = data.shape[:2] + base64Data = base64.b64encode(convertRGBDataToPNG(data)) + + fileObj.write( + '<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n') + fileObj.write('<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"\n') + fileObj.write( + ' "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">\n') + fileObj.write('<svg xmlns:xlink="http://www.w3.org/1999/xlink"\n') + fileObj.write(' xmlns="http://www.w3.org/2000/svg"\n') + fileObj.write(' version="1.1"\n') + fileObj.write(' width="%d"\n' % width) + fileObj.write(' height="%d">\n' % height) + fileObj.write(' <image xlink:href="data:image/png;base64,') + fileObj.write(base64Data.decode('ascii')) + fileObj.write('"\n') + fileObj.write(' x="0"\n') + fileObj.write(' y="0"\n') + fileObj.write(' width="%d"\n' % width) + fileObj.write(' height="%d"\n' % height) + fileObj.write(' id="image" />\n') + fileObj.write('</svg>') + + elif fileFormat == 'ppm': + height, width = data.shape[:2] + + fileObj.write(b'P6\n') + fileObj.write(b'%d %d\n' % (width, height)) + fileObj.write(b'255\n') + fileObj.write(data.tobytes()) + + elif fileFormat == 'png': + fileObj.write(convertRGBDataToPNG(data)) + + elif fileFormat == 'tiff': + if fileObj == fileNameOrObj: + raise NotImplementedError( + 'Save TIFF to a file-like object not implemented') + + from silx.third_party.TiffIO import TiffIO + + tif = TiffIO(fileNameOrObj, mode='wb+') + tif.writeImage(data, info={'Title': 'OpenGL Plot Snapshot'}) + + if fileObj != fileNameOrObj: + fileObj.close() diff --git a/src/silx/gui/plot/backends/glutils/__init__.py b/src/silx/gui/plot/backends/glutils/__init__.py new file mode 100644 index 0000000..f87d7c1 --- /dev/null +++ b/src/silx/gui/plot/backends/glutils/__init__.py @@ -0,0 +1,46 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""This module provides convenient classes for the OpenGL rendering backend. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import logging + + +_logger = logging.getLogger(__name__) + + +from .GLPlotCurve import * # noqa +from .GLPlotFrame import * # noqa +from .GLPlotImage import * # noqa +from .GLPlotItem import GLPlotItem, RenderContext # noqa +from .GLPlotTriangles import GLPlotTriangles # noqa +from .GLSupport import * # noqa +from .GLText import * # noqa +from .GLTexture import * # noqa diff --git a/src/silx/gui/plot/items/__init__.py b/src/silx/gui/plot/items/__init__.py new file mode 100644 index 0000000..0fe29c2 --- /dev/null +++ b/src/silx/gui/plot/items/__init__.py @@ -0,0 +1,53 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides classes that describes :class:`.PlotWidget` content. + +Instances of those classes are returned by :class:`.PlotWidget` methods that give +access to its content such as :meth:`.PlotWidget.getCurve`, :meth:`.PlotWidget.getImage`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "22/06/2017" + +from .core import (Item, DataItem, # noqa + LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa + SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa + AlphaMixIn, LineMixIn, ScatterVisualizationMixIn, # noqa + ComplexMixIn, ItemChangedType, PointsBase) # noqa +from .complex import ImageComplexData # noqa +from .curve import Curve, CurveStyle # noqa +from .histogram import Histogram # noqa +from .image import ImageBase, ImageData, ImageDataBase, ImageRgba, ImageStack, MaskImageData # noqa +from .image_aggregated import ImageDataAggregated # noqa +from .shape import Shape, BoundingRect, XAxisExtent, YAxisExtent # noqa +from .scatter import Scatter # noqa +from .marker import MarkerBase, Marker, XMarker, YMarker # noqa +from .axis import Axis, XAxis, YAxis, YRightAxis + +DATA_ITEMS = (ImageComplexData, Curve, Histogram, ImageBase, Scatter, + BoundingRect, XAxisExtent, YAxisExtent) +"""Classes of items representing data and to consider to compute data bounds. +""" diff --git a/src/silx/gui/plot/items/_arc_roi.py b/src/silx/gui/plot/items/_arc_roi.py new file mode 100644 index 0000000..23416ec --- /dev/null +++ b/src/silx/gui/plot/items/_arc_roi.py @@ -0,0 +1,878 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides Arc ROI item for the :class:`~silx.gui.plot.PlotWidget`. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "28/06/2018" + +import logging +import numpy + +from ... import utils +from .. import items +from ...colors import rgba +from ....utils.proxy import docstring +from ._roi_base import HandleBasedROI +from ._roi_base import InteractionModeMixIn +from ._roi_base import RoiInteractionMode + + +logger = logging.getLogger(__name__) + + +class _ArcGeometry: + """ + Non-mutable object to store the geometry of the arc ROI. + + The aim is is to switch between consistent state without dealing with + intermediate values. + """ + def __init__(self, center, startPoint, endPoint, radius, + weight, startAngle, endAngle, closed=False): + """Constructor for a consistent arc geometry. + + There is also specific class method to create different kind of arc + geometry. + """ + self.center = center + self.startPoint = startPoint + self.endPoint = endPoint + self.radius = radius + self.weight = weight + self.startAngle = startAngle + self.endAngle = endAngle + self._closed = closed + + @classmethod + def createEmpty(cls): + """Create an arc geometry from an empty shape + """ + zero = numpy.array([0, 0]) + return cls(zero, zero.copy(), zero.copy(), 0, 0, 0, 0) + + @classmethod + def createRect(cls, startPoint, endPoint, weight): + """Create an arc geometry from a definition of a rectangle + """ + return cls(None, startPoint, endPoint, None, weight, None, None, False) + + @classmethod + def createCircle(cls, center, startPoint, endPoint, radius, + weight, startAngle, endAngle): + """Create an arc geometry from a definition of a circle + """ + return cls(center, startPoint, endPoint, radius, + weight, startAngle, endAngle, True) + + def withWeight(self, weight): + """Return a new geometry based on this object, with a specific weight + """ + return _ArcGeometry(self.center, self.startPoint, self.endPoint, + self.radius, weight, + self.startAngle, self.endAngle, self._closed) + + def withRadius(self, radius): + """Return a new geometry based on this object, with a specific radius. + + The weight and the center is conserved. + """ + startPoint = self.center + (self.startPoint - self.center) / self.radius * radius + endPoint = self.center + (self.endPoint - self.center) / self.radius * radius + return _ArcGeometry(self.center, startPoint, endPoint, + radius, self.weight, + self.startAngle, self.endAngle, self._closed) + + def withStartAngle(self, startAngle): + """Return a new geometry based on this object, with a specific start angle + """ + vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)]) + startPoint = self.center + vector * self.radius + + # Never add more than 180 to maintain coherency + deltaAngle = startAngle - self.startAngle + if deltaAngle > numpy.pi: + deltaAngle -= numpy.pi * 2 + elif deltaAngle < -numpy.pi: + deltaAngle += numpy.pi * 2 + + startAngle = self.startAngle + deltaAngle + return _ArcGeometry( + self.center, + startPoint, + self.endPoint, + self.radius, + self.weight, + startAngle, + self.endAngle, + self._closed, + ) + + def withEndAngle(self, endAngle): + """Return a new geometry based on this object, with a specific end angle + """ + vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)]) + endPoint = self.center + vector * self.radius + + # Never add more than 180 to maintain coherency + deltaAngle = endAngle - self.endAngle + if deltaAngle > numpy.pi: + deltaAngle -= numpy.pi * 2 + elif deltaAngle < -numpy.pi: + deltaAngle += numpy.pi * 2 + + endAngle = self.endAngle + deltaAngle + return _ArcGeometry( + self.center, + self.startPoint, + endPoint, + self.radius, + self.weight, + self.startAngle, + endAngle, + self._closed, + ) + + def translated(self, dx, dy): + """Return the translated geometry by dx, dy""" + delta = numpy.array([dx, dy]) + center = None if self.center is None else self.center + delta + startPoint = None if self.startPoint is None else self.startPoint + delta + endPoint = None if self.endPoint is None else self.endPoint + delta + return _ArcGeometry(center, startPoint, endPoint, + self.radius, self.weight, + self.startAngle, self.endAngle, self._closed) + + def getKind(self): + """Returns the kind of shape defined""" + if self.center is None: + return "rect" + elif numpy.isnan(self.startAngle): + return "point" + elif self.isClosed(): + if self.weight <= 0 or self.weight * 0.5 >= self.radius: + return "circle" + else: + return "donut" + else: + if self.weight * 0.5 < self.radius: + return "arc" + else: + return "camembert" + + def isClosed(self): + """Returns True if the geometry is a circle like""" + if self._closed is not None: + return self._closed + delta = numpy.abs(self.endAngle - self.startAngle) + self._closed = numpy.isclose(delta, numpy.pi * 2) + return self._closed + + def __str__(self): + return str((self.center, + self.startPoint, + self.endPoint, + self.radius, + self.weight, + self.startAngle, + self.endAngle, + self._closed)) + + +class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn): + """A ROI identifying an arc of a circle with a width. + + This ROI provides + - 3 handle to control the curvature + - 1 handle to control the weight + - 1 anchor to translate the shape. + """ + + ICON = 'add-shape-arc' + NAME = 'arc ROI' + SHORT_NAME = "arc" + """Metadata for this kind of ROI""" + + _plotShape = "line" + """Plot shape which is used for the first interaction""" + + ThreePointMode = RoiInteractionMode("3 points", "Provides 3 points to define the main radius circle") + PolarMode = RoiInteractionMode("Polar", "Provides anchors to edit the ROI in polar coords") + # FIXME: MoveMode was designed cause there is too much anchors + # FIXME: It would be good replace it by a dnd on the shape + MoveMode = RoiInteractionMode("Translation", "Provides anchors to only move the ROI") + + def __init__(self, parent=None): + HandleBasedROI.__init__(self, parent=parent) + items.LineMixIn.__init__(self) + InteractionModeMixIn.__init__(self) + + self._geometry = _ArcGeometry.createEmpty() + self._handleLabel = self.addLabelHandle() + + self._handleStart = self.addHandle() + self._handleMid = self.addHandle() + self._handleEnd = self.addHandle() + self._handleWeight = self.addHandle() + self._handleWeight._setConstraint(self._arcCurvatureMarkerConstraint) + self._handleMove = self.addTranslateHandle() + + shape = items.Shape("polygon") + shape.setPoints([[0, 0], [0, 0]]) + shape.setColor(rgba(self.getColor())) + shape.setFill(False) + shape.setOverlay(True) + shape.setLineStyle(self.getLineStyle()) + shape.setLineWidth(self.getLineWidth()) + self.__shape = shape + self.addItem(shape) + + self._initInteractionMode(self.ThreePointMode) + self._interactiveModeUpdated(self.ThreePointMode) + + def availableInteractionModes(self): + """Returns the list of available interaction modes + + :rtype: List[RoiInteractionMode] + """ + return [self.ThreePointMode, self.PolarMode, self.MoveMode] + + def _interactiveModeUpdated(self, modeId): + """Set the interaction mode. + + :param RoiInteractionMode modeId: + """ + if modeId is self.ThreePointMode: + self._handleStart.setSymbol("s") + self._handleMid.setSymbol("s") + self._handleEnd.setSymbol("s") + self._handleWeight.setSymbol("d") + self._handleMove.setSymbol("+") + elif modeId is self.PolarMode: + self._handleStart.setSymbol("o") + self._handleMid.setSymbol("o") + self._handleEnd.setSymbol("o") + self._handleWeight.setSymbol("d") + self._handleMove.setSymbol("+") + elif modeId is self.MoveMode: + self._handleStart.setSymbol("") + self._handleMid.setSymbol("+") + self._handleEnd.setSymbol("") + self._handleWeight.setSymbol("") + self._handleMove.setSymbol("+") + else: + assert False + if self._geometry.isClosed(): + if modeId != self.MoveMode: + self._handleStart.setSymbol("x") + self._handleEnd.setSymbol("x") + self._updateHandles() + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.VISIBLE: + self._updateItemProperty(event, self, self.__shape) + super(ArcROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + super(ArcROI, self)._updatedStyle(event, style) + self.__shape.setColor(style.getColor()) + self.__shape.setLineStyle(style.getLineStyle()) + self.__shape.setLineWidth(style.getLineWidth()) + + def setFirstShapePoints(self, points): + """"Initialize the ROI using the points from the first interaction. + + This interaction is constrained by the plot API and only supports few + shapes. + """ + # The first shape is a line + point0 = points[0] + point1 = points[1] + + # Compute a non collinear point for the curvature + center = (point1 + point0) * 0.5 + normal = point1 - center + normal = numpy.array((normal[1], -normal[0])) + defaultCurvature = numpy.pi / 5.0 + weightCoef = 0.20 + mid = center - normal * defaultCurvature + distance = numpy.linalg.norm(point0 - point1) + weight = distance * weightCoef + + geometry = self._createGeometryFromControlPoints(point0, mid, point1, weight) + self._geometry = geometry + self._updateHandles() + + def _updateText(self, text): + self._handleLabel.setText(text) + + def _updateMidHandle(self): + """Keep the same geometry, but update the location of the control + points. + + So calling this function do not trigger sigRegionChanged. + """ + geometry = self._geometry + + if geometry.isClosed(): + start = numpy.array(self._handleStart.getPosition()) + midPos = geometry.center + geometry.center - start + else: + if geometry.center is None: + midPos = geometry.startPoint * 0.5 + geometry.endPoint * 0.5 + else: + midAngle = geometry.startAngle * 0.5 + geometry.endAngle * 0.5 + vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)]) + midPos = geometry.center + geometry.radius * vector + + with utils.blockSignals(self._handleMid): + self._handleMid.setPosition(*midPos) + + def _updateWeightHandle(self): + geometry = self._geometry + if geometry.center is None: + # rectangle + center = (geometry.startPoint + geometry.endPoint) * 0.5 + normal = geometry.endPoint - geometry.startPoint + normal = numpy.array((normal[1], -normal[0])) + distance = numpy.linalg.norm(normal) + if distance != 0: + normal = normal / distance + weightPos = center + normal * geometry.weight * 0.5 + else: + if geometry.isClosed(): + midAngle = geometry.startAngle + numpy.pi * 0.5 + elif geometry.center is not None: + midAngle = (geometry.startAngle + geometry.endAngle) * 0.5 + vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)]) + weightPos = geometry.center + (geometry.radius + geometry.weight * 0.5) * vector + + with utils.blockSignals(self._handleWeight): + self._handleWeight.setPosition(*weightPos) + + def _getWeightFromHandle(self, weightPos): + geometry = self._geometry + if geometry.center is None: + # rectangle + center = (geometry.startPoint + geometry.endPoint) * 0.5 + return numpy.linalg.norm(center - weightPos) * 2 + else: + distance = numpy.linalg.norm(geometry.center - weightPos) + return abs(distance - geometry.radius) * 2 + + def _updateHandles(self): + geometry = self._geometry + with utils.blockSignals(self._handleStart): + self._handleStart.setPosition(*geometry.startPoint) + with utils.blockSignals(self._handleEnd): + self._handleEnd.setPosition(*geometry.endPoint) + + self._updateMidHandle() + self._updateWeightHandle() + self._updateShape() + + def _updateCurvature(self, start, mid, end, updateCurveHandles, checkClosed=False, updateStart=False): + """Update the curvature using 3 control points in the curve + + :param bool updateCurveHandles: If False curve handles are already at + the right location + """ + if checkClosed: + closed = self._isCloseInPixel(start, end) + else: + closed = self._geometry.isClosed() + if closed: + if updateStart: + start = end + else: + end = start + + if updateCurveHandles: + with utils.blockSignals(self._handleStart): + self._handleStart.setPosition(*start) + with utils.blockSignals(self._handleMid): + self._handleMid.setPosition(*mid) + with utils.blockSignals(self._handleEnd): + self._handleEnd.setPosition(*end) + + weight = self._geometry.weight + geometry = self._createGeometryFromControlPoints(start, mid, end, weight, closed=closed) + self._geometry = geometry + + self._updateWeightHandle() + self._updateShape() + + def _updateCloseInAngle(self, geometry, updateStart): + azim = numpy.abs(geometry.endAngle - geometry.startAngle) + if numpy.pi < azim < 3 * numpy.pi: + closed = self._isCloseInPixel(geometry.startPoint, geometry.endPoint) + geometry._closed = closed + if closed: + sign = 1 if geometry.startAngle < geometry.endAngle else -1 + if updateStart: + geometry.startPoint = geometry.endPoint + geometry.startAngle = geometry.endAngle - sign * 2*numpy.pi + else: + geometry.endPoint = geometry.startPoint + geometry.endAngle = geometry.startAngle + sign * 2*numpy.pi + + def handleDragUpdated(self, handle, origin, previous, current): + modeId = self.getInteractionMode() + if handle is self._handleStart: + if modeId is self.ThreePointMode: + mid = numpy.array(self._handleMid.getPosition()) + end = numpy.array(self._handleEnd.getPosition()) + self._updateCurvature( + current, mid, end, checkClosed=True, updateStart=True, + updateCurveHandles=False + ) + elif modeId is self.PolarMode: + v = current - self._geometry.center + startAngle = numpy.angle(complex(v[0], v[1])) + geometry = self._geometry.withStartAngle(startAngle) + self._updateCloseInAngle(geometry, updateStart=True) + self._geometry = geometry + self._updateHandles() + elif handle is self._handleMid: + if modeId is self.ThreePointMode: + if self._geometry.isClosed(): + radius = numpy.linalg.norm(self._geometry.center - current) + self._geometry = self._geometry.withRadius(radius) + self._updateHandles() + else: + start = numpy.array(self._handleStart.getPosition()) + end = numpy.array(self._handleEnd.getPosition()) + self._updateCurvature(start, current, end, updateCurveHandles=False) + elif modeId is self.PolarMode: + radius = numpy.linalg.norm(self._geometry.center - current) + self._geometry = self._geometry.withRadius(radius) + self._updateHandles() + elif modeId is self.MoveMode: + delta = current - previous + self.translate(*delta) + elif handle is self._handleEnd: + if modeId is self.ThreePointMode: + start = numpy.array(self._handleStart.getPosition()) + mid = numpy.array(self._handleMid.getPosition()) + self._updateCurvature( + start, mid, current, checkClosed=True, updateStart=False, + updateCurveHandles=False + ) + elif modeId is self.PolarMode: + v = current - self._geometry.center + endAngle = numpy.angle(complex(v[0], v[1])) + geometry = self._geometry.withEndAngle(endAngle) + self._updateCloseInAngle(geometry, updateStart=False) + self._geometry = geometry + self._updateHandles() + elif handle is self._handleWeight: + weight = self._getWeightFromHandle(current) + self._geometry = self._geometry.withWeight(weight) + self._updateShape() + elif handle is self._handleMove: + delta = current - previous + self.translate(*delta) + + def _isCloseInPixel(self, point1, point2): + manager = self.parent() + if manager is None: + return False + plot = manager.parent() + if plot is None: + return False + point1 = plot.dataToPixel(*point1) + if point1 is None: + return False + point2 = plot.dataToPixel(*point2) + if point2 is None: + return False + return abs(point1[0] - point2[0]) + abs(point1[1] - point2[1]) < 15 + + def _normalizeGeometry(self): + """Keep the same phisical geometry, but with normalized parameters. + """ + geometry = self._geometry + if geometry.weight * 0.5 >= geometry.radius: + radius = (geometry.weight * 0.5 + geometry.radius) * 0.5 + geometry = geometry.withRadius(radius) + geometry = geometry.withWeight(radius * 2) + self._geometry = geometry + return True + return False + + def handleDragFinished(self, handle, origin, current): + modeId = self.getInteractionMode() + if handle in [self._handleStart, self._handleMid, self._handleEnd]: + if modeId is self.ThreePointMode: + self._normalizeGeometry() + self._updateHandles() + + if self._geometry.isClosed(): + if modeId is self.MoveMode: + self._handleStart.setSymbol("") + self._handleEnd.setSymbol("") + else: + self._handleStart.setSymbol("x") + self._handleEnd.setSymbol("x") + else: + if modeId is self.ThreePointMode: + self._handleStart.setSymbol("s") + self._handleEnd.setSymbol("s") + elif modeId is self.PolarMode: + self._handleStart.setSymbol("o") + self._handleEnd.setSymbol("o") + if modeId is self.MoveMode: + self._handleStart.setSymbol("") + self._handleEnd.setSymbol("") + + def _createGeometryFromControlPoints(self, start, mid, end, weight, closed=None): + """Returns the geometry of the object""" + if closed or (closed is None and numpy.allclose(start, end)): + # Special arc: It's a closed circle + center = (start + mid) * 0.5 + radius = numpy.linalg.norm(start - center) + v = start - center + startAngle = numpy.angle(complex(v[0], v[1])) + endAngle = startAngle + numpy.pi * 2.0 + return _ArcGeometry.createCircle( + center, start, end, radius, weight, startAngle, endAngle + ) + + elif numpy.linalg.norm(numpy.cross(mid - start, end - start)) < 1e-5: + # Degenerated arc, it's a rectangle + return _ArcGeometry.createRect(start, end, weight) + else: + center, radius = self._circleEquation(start, mid, end) + v = start - center + startAngle = numpy.angle(complex(v[0], v[1])) + v = mid - center + midAngle = numpy.angle(complex(v[0], v[1])) + v = end - center + endAngle = numpy.angle(complex(v[0], v[1])) + + # Is it clockwise or anticlockwise + relativeMid = (endAngle - midAngle + 2 * numpy.pi) % (2 * numpy.pi) + relativeEnd = (endAngle - startAngle + 2 * numpy.pi) % (2 * numpy.pi) + if relativeMid < relativeEnd: + if endAngle < startAngle: + endAngle += 2 * numpy.pi + else: + if endAngle > startAngle: + endAngle -= 2 * numpy.pi + + return _ArcGeometry(center, start, end, + radius, weight, startAngle, endAngle) + + def _createShapeFromGeometry(self, geometry): + kind = geometry.getKind() + if kind == "rect": + # It is not an arc + # but we can display it as an intermediate shape + normal = geometry.endPoint - geometry.startPoint + normal = numpy.array((normal[1], -normal[0])) + distance = numpy.linalg.norm(normal) + if distance != 0: + normal /= distance + points = numpy.array([ + geometry.startPoint + normal * geometry.weight * 0.5, + geometry.endPoint + normal * geometry.weight * 0.5, + geometry.endPoint - normal * geometry.weight * 0.5, + geometry.startPoint - normal * geometry.weight * 0.5]) + elif kind == "point": + # It is not an arc + # but we can display it as an intermediate shape + # NOTE: At least 2 points are expected + points = numpy.array([geometry.startPoint, geometry.startPoint]) + elif kind == "circle": + outerRadius = geometry.radius + geometry.weight * 0.5 + angles = numpy.linspace(0, 2 * numpy.pi, num=50) + # It's a circle + points = [] + numpy.append(angles, angles[-1]) + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.append(geometry.center + direction * outerRadius) + points = numpy.array(points) + elif kind == "donut": + innerRadius = geometry.radius - geometry.weight * 0.5 + outerRadius = geometry.radius + geometry.weight * 0.5 + angles = numpy.linspace(0, 2 * numpy.pi, num=50) + # It's a donut + points = [] + # NOTE: NaN value allow to create 2 separated circle shapes + # using a single plot item. It's a kind of cheat + points.append(numpy.array([float("nan"), float("nan")])) + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.insert(0, geometry.center + direction * innerRadius) + points.append(geometry.center + direction * outerRadius) + points.append(numpy.array([float("nan"), float("nan")])) + points = numpy.array(points) + else: + innerRadius = geometry.radius - geometry.weight * 0.5 + outerRadius = geometry.radius + geometry.weight * 0.5 + + delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1 + if geometry.startAngle == geometry.endAngle: + # Degenerated, it's a line (single radius) + angle = geometry.startAngle + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points = [] + points.append(geometry.center + direction * innerRadius) + points.append(geometry.center + direction * outerRadius) + return numpy.array(points) + + angles = numpy.arange(geometry.startAngle, geometry.endAngle, delta) + if angles[-1] != geometry.endAngle: + angles = numpy.append(angles, geometry.endAngle) + + if kind == "camembert": + # It's a part of camembert + points = [] + points.append(geometry.center) + points.append(geometry.startPoint) + delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1 + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.append(geometry.center + direction * outerRadius) + points.append(geometry.endPoint) + points.append(geometry.center) + elif kind == "arc": + # It's a part of donut + points = [] + points.append(geometry.startPoint) + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.insert(0, geometry.center + direction * innerRadius) + points.append(geometry.center + direction * outerRadius) + points.insert(0, geometry.endPoint) + points.append(geometry.endPoint) + else: + assert False + + points = numpy.array(points) + + return points + + def _updateShape(self): + geometry = self._geometry + points = self._createShapeFromGeometry(geometry) + self.__shape.setPoints(points) + + index = numpy.nanargmin(points[:, 1]) + pos = points[index] + with utils.blockSignals(self._handleLabel): + self._handleLabel.setPosition(pos[0], pos[1]) + + if geometry.center is None: + movePos = geometry.startPoint * 0.34 + geometry.endPoint * 0.66 + else: + movePos = geometry.center + + with utils.blockSignals(self._handleMove): + self._handleMove.setPosition(*movePos) + + self.sigRegionChanged.emit() + + def getGeometry(self): + """Returns a tuple containing the geometry of this ROI + + It is a symmetric function of :meth:`setGeometry`. + + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + + :rtype: Tuple[numpy.ndarray,float,float,float,float] + :raise ValueError: In case the ROI can't be represented as section of + a circle + """ + geometry = self._geometry + if geometry.center is None: + raise ValueError("This ROI can't be represented as a section of circle") + return geometry.center, self.getInnerRadius(), self.getOuterRadius(), geometry.startAngle, geometry.endAngle + + def isClosed(self): + """Returns true if the arc is a closed shape, like a circle or a donut. + + :rtype: bool + """ + return self._geometry.isClosed() + + def getCenter(self): + """Returns the center of the circle used to draw arcs of this ROI. + + This center is usually outside the the shape itself. + + :rtype: numpy.ndarray + """ + return self._geometry.center + + def getStartAngle(self): + """Returns the angle of the start of the section of this ROI (in radian). + + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + + :rtype: float + """ + return self._geometry.startAngle + + def getEndAngle(self): + """Returns the angle of the end of the section of this ROI (in radian). + + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + + :rtype: float + """ + return self._geometry.endAngle + + def getInnerRadius(self): + """Returns the radius of the smaller arc used to draw this ROI. + + :rtype: float + """ + geometry = self._geometry + radius = geometry.radius - geometry.weight * 0.5 + if radius < 0: + radius = 0 + return radius + + def getOuterRadius(self): + """Returns the radius of the bigger arc used to draw this ROI. + + :rtype: float + """ + geometry = self._geometry + radius = geometry.radius + geometry.weight * 0.5 + return radius + + def setGeometry(self, center, innerRadius, outerRadius, startAngle, endAngle): + """ + Set the geometry of this arc. + + :param numpy.ndarray center: Center of the circle. + :param float innerRadius: Radius of the smaller arc of the section. + :param float outerRadius: Weight of the bigger arc of the section. + It have to be bigger than `innerRadius` + :param float startAngle: Location of the start of the section (in radian) + :param float endAngle: Location of the end of the section (in radian). + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + """ + if innerRadius > outerRadius: + logger.error("inner radius larger than outer radius") + innerRadius, outerRadius = outerRadius, innerRadius + center = numpy.array(center) + radius = (innerRadius + outerRadius) * 0.5 + weight = outerRadius - innerRadius + + vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)]) + startPoint = center + vector * radius + vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)]) + endPoint = center + vector * radius + + geometry = _ArcGeometry(center, startPoint, endPoint, + radius, weight, + startAngle, endAngle, closed=None) + self._geometry = geometry + self._updateHandles() + + @docstring(HandleBasedROI) + def contains(self, position): + # first check distance, fastest + center = self.getCenter() + distance = numpy.sqrt((position[1] - center[1]) ** 2 + ((position[0] - center[0])) ** 2) + is_in_distance = self.getInnerRadius() <= distance <= self.getOuterRadius() + if not is_in_distance: + return False + rel_pos = position[1] - center[1], position[0] - center[0] + angle = numpy.arctan2(*rel_pos) + # angle is inside [-pi, pi] + + # Normalize the start angle between [-pi, pi] + # with a positive angle range + start_angle = self.getStartAngle() + end_angle = self.getEndAngle() + azim_range = end_angle - start_angle + if azim_range < 0: + start_angle = end_angle + azim_range = -azim_range + start_angle = numpy.mod(start_angle + numpy.pi, 2 * numpy.pi) - numpy.pi + + if angle < start_angle: + angle += 2 * numpy.pi + return start_angle <= angle <= start_angle + azim_range + + def translate(self, x, y): + self._geometry = self._geometry.translated(x, y) + self._updateHandles() + + def _arcCurvatureMarkerConstraint(self, x, y): + """Curvature marker remains on perpendicular bisector""" + geometry = self._geometry + if geometry.center is None: + center = (geometry.startPoint + geometry.endPoint) * 0.5 + vector = geometry.startPoint - geometry.endPoint + vector = numpy.array((vector[1], -vector[0])) + vdist = numpy.linalg.norm(vector) + if vdist != 0: + normal = numpy.array((vector[1], -vector[0])) / vdist + else: + normal = numpy.array((0, 0)) + else: + if geometry.isClosed(): + midAngle = geometry.startAngle + numpy.pi * 0.5 + else: + midAngle = (geometry.startAngle + geometry.endAngle) * 0.5 + normal = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)]) + center = geometry.center + dist = numpy.dot(normal, (numpy.array((x, y)) - center)) + dist = numpy.clip(dist, geometry.radius, geometry.radius * 2) + x, y = center + dist * normal + return x, y + + @staticmethod + def _circleEquation(pt1, pt2, pt3): + """Circle equation from 3 (x, y) points + + :return: Position of the center of the circle and the radius + :rtype: Tuple[Tuple[float,float],float] + """ + x, y, z = complex(*pt1), complex(*pt2), complex(*pt3) + w = z - x + w /= y - x + c = (x - y) * (w - abs(w) ** 2) / 2j / w.imag - x + return numpy.array((-c.real, -c.imag)), abs(c + x) + + def __str__(self): + try: + center, innerRadius, outerRadius, startAngle, endAngle = self.getGeometry() + params = center[0], center[1], innerRadius, outerRadius, startAngle, endAngle + params = 'center: %f %f; radius: %f %f; angles: %f %f' % params + except ValueError: + params = "invalid" + return "%s(%s)" % (self.__class__.__name__, params) diff --git a/src/silx/gui/plot/items/_pick.py b/src/silx/gui/plot/items/_pick.py new file mode 100644 index 0000000..8c8e781 --- /dev/null +++ b/src/silx/gui/plot/items/_pick.py @@ -0,0 +1,72 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides classes supporting item picking.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "04/06/2019" + +import numpy + + +class PickingResult(object): + """Class to access picking information in a :class:`PlotWidget`""" + + def __init__(self, item, indices=None): + """Init + + :param item: The picked item + :param numpy.ndarray indices: Array-like of indices of picked data. + Either 1D or 2D with dim0: data dimension and dim1: indices. + No copy is made. + """ + self._item = item + + if indices is None or len(indices) == 0: + self._indices = None + else: + # Indices is set to None if indices array is empty + indices = numpy.array(indices, copy=False, dtype=numpy.int64) + self._indices = None if indices.size == 0 else indices + + def getItem(self): + """Returns the item this results corresponds to.""" + return self._item + + def getIndices(self, copy=True): + """Returns indices of picked data. + + If data is 1D, it returns a numpy.ndarray, otherwise + it returns a tuple with as many numpy.ndarray as there are + dimensions in the data. + + :param bool copy: True (default) to get a copy, + False to return internal arrays + :rtype: Union[None,numpy.ndarray,List[numpy.ndarray]] + """ + if self._indices is None: + return None + indices = numpy.array(self._indices, copy=copy) + return indices if indices.ndim == 1 else tuple(indices) diff --git a/src/silx/gui/plot/items/_roi_base.py b/src/silx/gui/plot/items/_roi_base.py new file mode 100644 index 0000000..3eb6cf4 --- /dev/null +++ b/src/silx/gui/plot/items/_roi_base.py @@ -0,0 +1,835 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides base components to create ROI item for +the :class:`~silx.gui.plot.PlotWidget`. + +.. inheritance-diagram:: + silx.gui.plot.items.roi + :parts: 1 +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import logging +import numpy +import weakref + +from ....utils.weakref import WeakList +from ... import qt +from .. import items +from ..items import core +from ...colors import rgba +import silx.utils.deprecation +from ....utils.proxy import docstring + + +logger = logging.getLogger(__name__) + + +class _RegionOfInterestBase(qt.QObject): + """Base class of 1D and 2D region of interest + + :param QObject parent: See QObject + :param str name: The name of the ROI + """ + + sigAboutToBeRemoved = qt.Signal() + """Signal emitted just before this ROI is removed from its manager.""" + + sigItemChanged = qt.Signal(object) + """Signal emitted when item has changed. + + It provides a flag describing which property of the item has changed. + See :class:`ItemChangedType` for flags description. + """ + + def __init__(self, parent=None): + qt.QObject.__init__(self, parent=parent) + self.__name = '' + + def getName(self): + """Returns the name of the ROI + + :return: name of the region of interest + :rtype: str + """ + return self.__name + + def setName(self, name): + """Set the name of the ROI + + :param str name: name of the region of interest + """ + name = str(name) + if self.__name != name: + self.__name = name + self._updated(items.ItemChangedType.NAME) + + def _updated(self, event=None, checkVisibility=True): + """Implement Item mix-in update method by updating the plot items + + See :class:`~silx.gui.plot.items.Item._updated` + """ + self.sigItemChanged.emit(event) + + def contains(self, position): + """Returns True if the `position` is in this ROI. + + :param tuple[float,float] position: position to check + :return: True if the value / point is consider to be in the region of + interest. + :rtype: bool + """ + return False # Override in subclass to perform actual test + + +class RoiInteractionMode(object): + """Description of an interaction mode. + + An interaction mode provide a specific kind of interaction for a ROI. + A ROI can implement many interaction. + """ + + def __init__(self, label, description=None): + self._label = label + self._description = description + + @property + def label(self): + return self._label + + @property + def description(self): + return self._description + + +class InteractionModeMixIn(object): + """Mix in feature which can be implemented by a ROI object. + + This provides user interaction to switch between different + interaction mode to edit the ROI. + + This ROI modes have to be described using `RoiInteractionMode`, + and taken into account during interation with handles. + """ + + sigInteractionModeChanged = qt.Signal(object) + + def __init__(self): + self.__modeId = None + + def _initInteractionMode(self, modeId): + """Set the mode without updating anything. + + Must be one of the returned :meth:`availableInteractionModes`. + + :param RoiInteractionMode modeId: Mode to use + """ + self.__modeId = modeId + + def availableInteractionModes(self): + """Returns the list of available interaction modes + + Must be implemented when inherited to provide all available modes. + + :rtype: List[RoiInteractionMode] + """ + raise NotImplementedError() + + def setInteractionMode(self, modeId): + """Set the interaction mode. + + :param RoiInteractionMode modeId: Mode to use + """ + self.__modeId = modeId + self._interactiveModeUpdated(modeId) + self.sigInteractionModeChanged.emit(modeId) + + def _interactiveModeUpdated(self, modeId): + """Called directly after an update of the mode. + + The signal `sigInteractionModeChanged` is triggered after this + call. + + Must be implemented when inherited to take care of the change. + """ + raise NotImplementedError() + + def getInteractionMode(self): + """Returns the interaction mode. + + Must be one of the returned :meth:`availableInteractionModes`. + + :rtype: RoiInteractionMode + """ + return self.__modeId + + +class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn): + """Object describing a region of interest in a plot. + + :param QObject parent: + The RegionOfInterestManager that created this object + """ + + _DEFAULT_LINEWIDTH = 1. + """Default line width of the curve""" + + _DEFAULT_LINESTYLE = '-' + """Default line style of the curve""" + + _DEFAULT_HIGHLIGHT_STYLE = items.CurveStyle(linewidth=2) + """Default highlight style of the item""" + + ICON, NAME, SHORT_NAME = None, None, None + """Metadata to describe the ROI in labels, tooltips and widgets + + Should be set by inherited classes to custom the ROI manager widget. + """ + + sigRegionChanged = qt.Signal() + """Signal emitted everytime the shape or position of the ROI changes""" + + sigEditingStarted = qt.Signal() + """Signal emitted when the user start editing the roi""" + + sigEditingFinished = qt.Signal() + """Signal emitted when the region edition is finished. During edition + sigEditionChanged will be emitted several times and + sigRegionEditionFinished only at end""" + + def __init__(self, parent=None): + # Avoid circular dependency + from ..tools import roi as roi_tools + assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager) + _RegionOfInterestBase.__init__(self, parent) + core.HighlightedMixIn.__init__(self) + self._color = rgba('red') + self._editable = False + self._selectable = False + self._focusProxy = None + self._visible = True + self._child = WeakList() + + def _connectToPlot(self, plot): + """Called after connection to a plot""" + for item in self.getItems(): + # This hack is needed to avoid reentrant call from _disconnectFromPlot + # to the ROI manager. It also speed up the item tests in _itemRemoved + item._roiGroup = True + plot.addItem(item) + + def _disconnectFromPlot(self, plot): + """Called before disconnection from a plot""" + for item in self.getItems(): + # The item could be already be removed by the plot + if item.getPlot() is not None: + del item._roiGroup + plot.removeItem(item) + + def _setItemName(self, item): + """Helper to generate a unique id to a plot item""" + legend = "__ROI-%d__%d" % (id(self), id(item)) + item.setName(legend) + + def setParent(self, parent): + """Set the parent of the RegionOfInterest + + :param Union[None,RegionOfInterestManager] parent: The new parent + """ + # Avoid circular dependency + from ..tools import roi as roi_tools + if (parent is not None and not isinstance(parent, roi_tools.RegionOfInterestManager)): + raise ValueError('Unsupported parent') + + previousParent = self.parent() + if previousParent is not None: + previousPlot = previousParent.parent() + if previousPlot is not None: + self._disconnectFromPlot(previousPlot) + super(RegionOfInterest, self).setParent(parent) + if parent is not None: + plot = parent.parent() + if plot is not None: + self._connectToPlot(plot) + + def addItem(self, item): + """Add an item to the set of this ROI children. + + This item will be added and removed to the plot used by the ROI. + + If the ROI is already part of a plot, the item will also be added to + the plot. + + It the item do not have a name already, a unique one is generated to + avoid item collision in the plot. + + :param silx.gui.plot.items.Item item: A plot item + """ + assert item is not None + self._child.append(item) + if item.getName() == '': + self._setItemName(item) + manager = self.parent() + if manager is not None: + plot = manager.parent() + if plot is not None: + item._roiGroup = True + plot.addItem(item) + + def removeItem(self, item): + """Remove an item from this ROI children. + + If the item is part of a plot it will be removed too. + + :param silx.gui.plot.items.Item item: A plot item + """ + assert item is not None + self._child.remove(item) + plot = item.getPlot() + if plot is not None: + del item._roiGroup + plot.removeItem(item) + + def getItems(self): + """Returns the list of PlotWidget items of this RegionOfInterest. + + :rtype: List[~silx.gui.plot.items.Item] + """ + return tuple(self._child) + + @classmethod + def _getShortName(cls): + """Return an human readable kind of ROI + + :rtype: str + """ + if hasattr(cls, "SHORT_NAME"): + name = cls.SHORT_NAME + if name is None: + name = cls.__name__ + return name + + def getColor(self): + """Returns the color of this ROI + + :rtype: QColor + """ + return qt.QColor.fromRgbF(*self._color) + + def setColor(self, color): + """Set the color used for this ROI. + + :param color: The color to use for ROI shape as + either a color name, a QColor, a list of uint8 or float in [0, 1]. + """ + color = rgba(color) + if color != self._color: + self._color = color + self._updated(items.ItemChangedType.COLOR) + + @silx.utils.deprecation.deprecated(reason='API modification', + replacement='getName()', + since_version=0.12) + def getLabel(self): + """Returns the label displayed for this ROI. + + :rtype: str + """ + return self.getName() + + @silx.utils.deprecation.deprecated(reason='API modification', + replacement='setName(name)', + since_version=0.12) + def setLabel(self, label): + """Set the label displayed with this ROI. + + :param str label: The text label to display + """ + self.setName(name=label) + + def isEditable(self): + """Returns whether the ROI is editable by the user or not. + + :rtype: bool + """ + return self._editable + + def setEditable(self, editable): + """Set whether the ROI can be changed interactively. + + :param bool editable: True to allow edition by the user, + False to disable. + """ + editable = bool(editable) + if self._editable != editable: + self._editable = editable + self._updated(items.ItemChangedType.EDITABLE) + + def isSelectable(self): + """Returns whether the ROI is selectable by the user or not. + + :rtype: bool + """ + return self._selectable + + def setSelectable(self, selectable): + """Set whether the ROI can be selected interactively. + + :param bool selectable: True to allow selection by the user, + False to disable. + """ + selectable = bool(selectable) + if self._selectable != selectable: + self._selectable = selectable + self._updated(items.ItemChangedType.SELECTABLE) + + def getFocusProxy(self): + """Returns the ROI which have to be selected when this ROI is selected, + else None if no proxy specified. + + :rtype: RegionOfInterest + """ + proxy = self._focusProxy + if proxy is None: + return None + proxy = proxy() + if proxy is None: + self._focusProxy = None + return proxy + + def setFocusProxy(self, roi): + """Set the real ROI which will be selected when this ROI is selected, + else None to remove the proxy already specified. + + :param RegionOfInterest roi: A ROI + """ + if roi is not None: + self._focusProxy = weakref.ref(roi) + else: + self._focusProxy = None + + def isVisible(self): + """Returns whether the ROI is visible in the plot. + + .. note:: + This does not take into account whether or not the plot + widget itself is visible (unlike :meth:`QWidget.isVisible` which + checks the visibility of all its parent widgets up to the window) + + :rtype: bool + """ + return self._visible + + def setVisible(self, visible): + """Set whether the plot items associated with this ROI are + visible in the plot. + + :param bool visible: True to show the ROI in the plot, False to + hide it. + """ + visible = bool(visible) + if self._visible != visible: + self._visible = visible + self._updated(items.ItemChangedType.VISIBLE) + + @classmethod + def showFirstInteractionShape(cls): + """Returns True if the shape created by the first interaction and + managed by the plot have to be visible. + + :rtype: bool + """ + return False + + @classmethod + def getFirstInteractionShape(cls): + """Returns the shape kind which will be used by the very first + interaction with the plot. + + This interactions are hardcoded inside the plot + + :rtype: str + """ + return cls._plotShape + + def setFirstShapePoints(self, points): + """"Initialize the ROI using the points from the first interaction. + + This interaction is constrained by the plot API and only supports few + shapes. + """ + raise NotImplementedError() + + def creationStarted(self): + """"Called when the ROI creation interaction was started. + """ + pass + + def creationFinalized(self): + """"Called when the ROI creation interaction was finalized. + """ + pass + + def _updateItemProperty(self, event, source, destination): + """Update the item property of a destination from an item source. + + :param items.ItemChangedType event: Property type to update + :param silx.gui.plot.items.Item source: The reference for the data + :param event Union[Item,List[Item]] destination: The item(s) to update + """ + if not isinstance(destination, (list, tuple)): + destination = [destination] + if event == items.ItemChangedType.NAME: + value = source.getName() + for d in destination: + d.setName(value) + elif event == items.ItemChangedType.EDITABLE: + value = source.isEditable() + for d in destination: + d.setEditable(value) + elif event == items.ItemChangedType.SELECTABLE: + value = source.isSelectable() + for d in destination: + d._setSelectable(value) + elif event == items.ItemChangedType.COLOR: + value = rgba(source.getColor()) + for d in destination: + d.setColor(value) + elif event == items.ItemChangedType.LINE_STYLE: + value = self.getLineStyle() + for d in destination: + d.setLineStyle(value) + elif event == items.ItemChangedType.LINE_WIDTH: + value = self.getLineWidth() + for d in destination: + d.setLineWidth(value) + elif event == items.ItemChangedType.SYMBOL: + value = self.getSymbol() + for d in destination: + d.setSymbol(value) + elif event == items.ItemChangedType.SYMBOL_SIZE: + value = self.getSymbolSize() + for d in destination: + d.setSymbolSize(value) + elif event == items.ItemChangedType.VISIBLE: + value = self.isVisible() + for d in destination: + d.setVisible(value) + else: + assert False + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.HIGHLIGHTED: + style = self.getCurrentStyle() + self._updatedStyle(event, style) + else: + styleEvents = [items.ItemChangedType.COLOR, + items.ItemChangedType.LINE_STYLE, + items.ItemChangedType.LINE_WIDTH, + items.ItemChangedType.SYMBOL, + items.ItemChangedType.SYMBOL_SIZE] + if self.isHighlighted(): + styleEvents.append(items.ItemChangedType.HIGHLIGHTED_STYLE) + + if event in styleEvents: + style = self.getCurrentStyle() + self._updatedStyle(event, style) + + super(RegionOfInterest, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + """Called when the current displayed style of the ROI was changed. + + :param event: The event responsible of the change of the style + :param items.CurveStyle style: The current style + """ + pass + + def getCurrentStyle(self): + """Returns the current curve style. + + Curve style depends on curve highlighting + + :rtype: CurveStyle + """ + baseColor = rgba(self.getColor()) + if isinstance(self, core.LineMixIn): + baseLinestyle = self.getLineStyle() + baseLinewidth = self.getLineWidth() + else: + baseLinestyle = self._DEFAULT_LINESTYLE + baseLinewidth = self._DEFAULT_LINEWIDTH + if isinstance(self, core.SymbolMixIn): + baseSymbol = self.getSymbol() + baseSymbolsize = self.getSymbolSize() + else: + baseSymbol = 'o' + baseSymbolsize = 1 + + if self.isHighlighted(): + style = self.getHighlightedStyle() + color = style.getColor() + linestyle = style.getLineStyle() + linewidth = style.getLineWidth() + symbol = style.getSymbol() + symbolsize = style.getSymbolSize() + + return items.CurveStyle( + color=baseColor if color is None else color, + linestyle=baseLinestyle if linestyle is None else linestyle, + linewidth=baseLinewidth if linewidth is None else linewidth, + symbol=baseSymbol if symbol is None else symbol, + symbolsize=baseSymbolsize if symbolsize is None else symbolsize) + else: + return items.CurveStyle(color=baseColor, + linestyle=baseLinestyle, + linewidth=baseLinewidth, + symbol=baseSymbol, + symbolsize=baseSymbolsize) + + def _editingStarted(self): + assert self._editable is True + self.sigEditingStarted.emit() + + def _editingFinished(self): + self.sigEditingFinished.emit() + + +class HandleBasedROI(RegionOfInterest): + """Manage a ROI based on a set of handles""" + + def __init__(self, parent=None): + RegionOfInterest.__init__(self, parent=parent) + self._handles = [] + self._posOrigin = None + self._posPrevious = None + + def addUserHandle(self, item=None): + """ + Add a new free handle to the ROI. + + This handle do nothing. It have to be managed by the ROI + implementing this class. + + :param Union[None,silx.gui.plot.items.Marker] item: The new marker to + add, else None to create a default marker. + :rtype: silx.gui.plot.items.Marker + """ + return self.addHandle(item, role="user") + + def addLabelHandle(self, item=None): + """ + Add a new label handle to the ROI. + + This handle is not draggable nor selectable. + + It is displayed without symbol, but it is always visible anyway + the ROI is editable, in order to display text. + + :param Union[None,silx.gui.plot.items.Marker] item: The new marker to + add, else None to create a default marker. + :rtype: silx.gui.plot.items.Marker + """ + return self.addHandle(item, role="label") + + def addTranslateHandle(self, item=None): + """ + Add a new translate handle to the ROI. + + Dragging translate handles affect the position position of the ROI + but not the shape itself. + + :param Union[None,silx.gui.plot.items.Marker] item: The new marker to + add, else None to create a default marker. + :rtype: silx.gui.plot.items.Marker + """ + return self.addHandle(item, role="translate") + + def addHandle(self, item=None, role="default"): + """ + Add a new handle to the ROI. + + Dragging handles while affect the position or the shape of the + ROI. + + :param Union[None,silx.gui.plot.items.Marker] item: The new marker to + add, else None to create a default marker. + :rtype: silx.gui.plot.items.Marker + """ + if item is None: + item = items.Marker() + color = rgba(self.getColor()) + color = self._computeHandleColor(color) + item.setColor(color) + if role == "default": + item.setSymbol("s") + elif role == "user": + pass + elif role == "translate": + item.setSymbol("+") + elif role == "label": + item.setSymbol("") + + if role == "user": + pass + elif role == "label": + item._setSelectable(False) + item._setDraggable(False) + item.setVisible(True) + else: + self.__updateEditable(item, self.isEditable(), remove=False) + item._setSelectable(False) + + self._handles.append((item, role)) + self.addItem(item) + return item + + def removeHandle(self, handle): + data = [d for d in self._handles if d[0] is handle][0] + self._handles.remove(data) + role = data[1] + if role not in ["user", "label"]: + if self.isEditable(): + self.__updateEditable(handle, False) + self.removeItem(handle) + + def getHandles(self): + """Returns the list of handles of this HandleBasedROI. + + :rtype: List[~silx.gui.plot.items.Marker] + """ + return tuple(data[0] for data in self._handles) + + def _updated(self, event=None, checkVisibility=True): + """Implement Item mix-in update method by updating the plot items + + See :class:`~silx.gui.plot.items.Item._updated` + """ + if event == items.ItemChangedType.NAME: + self._updateText(self.getName()) + elif event == items.ItemChangedType.VISIBLE: + for item, role in self._handles: + visible = self.isVisible() + editionVisible = visible and self.isEditable() + if role not in ["user", "label"]: + item.setVisible(editionVisible) + else: + item.setVisible(visible) + elif event == items.ItemChangedType.EDITABLE: + for item, role in self._handles: + editable = self.isEditable() + if role not in ["user", "label"]: + self.__updateEditable(item, editable) + super(HandleBasedROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + super(HandleBasedROI, self)._updatedStyle(event, style) + + # Update color of shape items in the plot + color = rgba(self.getColor()) + handleColor = self._computeHandleColor(color) + for item, role in self._handles: + if role == 'user': + pass + elif role == 'label': + item.setColor(color) + else: + item.setColor(handleColor) + + def __updateEditable(self, handle, editable, remove=True): + # NOTE: visibility change emit a position update event + handle.setVisible(editable and self.isVisible()) + handle._setDraggable(editable) + if editable: + handle.sigDragStarted.connect(self._handleEditingStarted) + handle.sigItemChanged.connect(self._handleEditingUpdated) + handle.sigDragFinished.connect(self._handleEditingFinished) + else: + if remove: + handle.sigDragStarted.disconnect(self._handleEditingStarted) + handle.sigItemChanged.disconnect(self._handleEditingUpdated) + handle.sigDragFinished.disconnect(self._handleEditingFinished) + + def _handleEditingStarted(self): + super(HandleBasedROI, self)._editingStarted() + handle = self.sender() + self._posOrigin = numpy.array(handle.getPosition()) + self._posPrevious = numpy.array(self._posOrigin) + self.handleDragStarted(handle, self._posOrigin) + + def _handleEditingUpdated(self): + if self._posOrigin is None: + # Avoid to handle events when visibility change + return + handle = self.sender() + current = numpy.array(handle.getPosition()) + self.handleDragUpdated(handle, self._posOrigin, self._posPrevious, current) + self._posPrevious = current + + def _handleEditingFinished(self): + handle = self.sender() + current = numpy.array(handle.getPosition()) + self.handleDragFinished(handle, self._posOrigin, current) + self._posPrevious = None + self._posOrigin = None + super(HandleBasedROI, self)._editingFinished() + + def isHandleBeingDragged(self): + """Returns True if one of the handles is currently being dragged. + + :rtype: bool + """ + return self._posOrigin is not None + + def handleDragStarted(self, handle, origin): + """Called when an handler drag started""" + pass + + def handleDragUpdated(self, handle, origin, previous, current): + """Called when an handle drag position changed""" + pass + + def handleDragFinished(self, handle, origin, current): + """Called when an handle drag finished""" + pass + + def _computeHandleColor(self, color): + """Returns the anchor color from the base ROI color + + :param Union[numpy.array,Tuple,List]: color + :rtype: Union[numpy.array,Tuple,List] + """ + return color[:3] + (0.5,) + + def _updateText(self, text): + """Update the text displayed by this ROI + + :param str text: A text + """ + pass diff --git a/src/silx/gui/plot/items/axis.py b/src/silx/gui/plot/items/axis.py new file mode 100644 index 0000000..c73323e --- /dev/null +++ b/src/silx/gui/plot/items/axis.py @@ -0,0 +1,560 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the class for axes of the :class:`PlotWidget`. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "22/11/2018" + +import datetime as dt +import enum +import logging + +import dateutil.tz +import numpy + +from ... import qt +from .. import _utils + + +_logger = logging.getLogger(__name__) + + +class TickMode(enum.Enum): + """Determines if ticks are regular number or datetimes.""" + DEFAULT = 0 # Ticks are regular numbers + TIME_SERIES = 1 # Ticks are datetime objects + + +class Axis(qt.QObject): + """This class describes and controls a plot axis. + + Note: This is an abstract class. + """ + # States are half-stored on the backend of the plot, and half-stored on this + # object. + # TODO It would be good to store all the states of an axis in this object. + # i.e. vmin and vmax + + LINEAR = "linear" + """Constant defining a linear scale""" + + LOGARITHMIC = "log" + """Constant defining a logarithmic scale""" + + _SCALES = set([LINEAR, LOGARITHMIC]) + + sigInvertedChanged = qt.Signal(bool) + """Signal emitted when axis orientation has changed""" + + sigScaleChanged = qt.Signal(str) + """Signal emitted when axis scale has changed""" + + _sigLogarithmicChanged = qt.Signal(bool) + """Signal emitted when axis scale has changed to or from logarithmic""" + + sigAutoScaleChanged = qt.Signal(bool) + """Signal emitted when axis autoscale has changed""" + + sigLimitsChanged = qt.Signal(float, float) + """Signal emitted when axis limits have changed""" + + def __init__(self, plot): + """Constructor + + :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this + axis + """ + qt.QObject.__init__(self, parent=plot) + self._scale = self.LINEAR + self._isAutoScale = True + # Store default labels provided to setGraph[X|Y]Label + self._defaultLabel = '' + # Store currently displayed labels + # Current label can differ from input one with active curve handling + self._currentLabel = '' + + def _getPlot(self): + """Returns the PlotWidget this Axis belongs to. + + :rtype: PlotWidget + """ + plot = self.parent() + if plot is None: + raise RuntimeError("Axis no longer attached to a PlotWidget") + return plot + + def _getBackend(self): + """Returns the backend + + :rtype: BackendBase + """ + return self._getPlot()._backend + + def getLimits(self): + """Get the limits of this axis. + + :return: Minimum and maximum values of this axis as tuple + """ + return self._internalGetLimits() + + def setLimits(self, vmin, vmax): + """Set this axis limits. + + :param float vmin: minimum axis value + :param float vmax: maximum axis value + """ + vmin, vmax = self._checkLimits(vmin, vmax) + if self.getLimits() == (vmin, vmax): + return + + self._internalSetLimits(vmin, vmax) + self._getPlot()._setDirtyPlot() + + self._emitLimitsChanged() + + def _emitLimitsChanged(self): + """Emit axis sigLimitsChanged and PlotWidget limitsChanged event""" + vmin, vmax = self.getLimits() + self.sigLimitsChanged.emit(vmin, vmax) + self._getPlot()._notifyLimitsChanged(emitSignal=False) + + def _checkLimits(self, vmin, vmax): + """Makes sure axis range is not empty and within supported range. + + :param float vmin: Min axis value + :param float vmax: Max axis value + :return: (min, max) making sure min < max + :rtype: 2-tuple of float + """ + return _utils.checkAxisLimits( + vmin, vmax, isLog=self._isLogarithmic(), name=self._defaultLabel) + + def isInverted(self): + """Return True if the axis is inverted (top to bottom for the y-axis), + False otherwise. It is always False for the X axis. + + :rtype: bool + """ + return False + + def setInverted(self, isInverted): + """Set the axis orientation. + + This is only available for the Y axis. + + :param bool flag: True for Y axis going from top to bottom, + False for Y axis going from bottom to top + """ + if isInverted == self.isInverted(): + return + raise NotImplementedError() + + def getLabel(self): + """Return the current displayed label of this axis. + + :param str axis: The Y axis for which to get the label (left or right) + :rtype: str + """ + return self._currentLabel + + def setLabel(self, label): + """Set the label displayed on the plot for this axis. + + The provided label can be temporarily replaced by the label of the + active curve if any. + + :param str label: The axis label + """ + self._defaultLabel = label + self._setCurrentLabel(label) + self._getPlot()._setDirtyPlot() + + def _setCurrentLabel(self, label): + """Define the label currently displayed. + + If the label is None or empty the default label is used. + + :param str label: Currently displayed label + """ + if label is None or label == '': + label = self._defaultLabel + if label is None: + label = '' + self._currentLabel = label + self._internalSetCurrentLabel(label) + + def getScale(self): + """Return the name of the scale used by this axis. + + :rtype: str + """ + return self._scale + + def setScale(self, scale): + """Set the scale to be used by this axis. + + :param str scale: Name of the scale ("log", or "linear") + """ + assert(scale in self._SCALES) + if self._scale == scale: + return + + # For the backward compatibility signal + emitLog = self._scale == self.LOGARITHMIC or scale == self.LOGARITHMIC + + self._scale = scale + + # TODO hackish way of forcing update of curves and images + plot = self._getPlot() + for item in plot.getItems(): + item._updated() + plot._invalidateDataRange() + + if scale == self.LOGARITHMIC: + self._internalSetLogarithmic(True) + elif scale == self.LINEAR: + self._internalSetLogarithmic(False) + else: + raise ValueError("Scale %s unsupported" % scale) + + plot._forceResetZoom() + + self.sigScaleChanged.emit(self._scale) + if emitLog: + self._sigLogarithmicChanged.emit(self._scale == self.LOGARITHMIC) + + def _isLogarithmic(self): + """Return True if this axis scale is logarithmic, False if linear. + + :rtype: bool + """ + return self._scale == self.LOGARITHMIC + + def _setLogarithmic(self, flag): + """Set the scale of this axes (either linear or logarithmic). + + :param bool flag: True to use a logarithmic scale, False for linear. + """ + flag = bool(flag) + self.setScale(self.LOGARITHMIC if flag else self.LINEAR) + + def getTimeZone(self): + """Sets tzinfo that is used if this axis plots date times. + + None means the datetimes are interpreted as local time. + + :rtype: datetime.tzinfo of None. + """ + raise NotImplementedError() + + def setTimeZone(self, tz): + """Sets tzinfo that is used if this axis' tickMode is TIME_SERIES + + The tz must be a descendant of the datetime.tzinfo class, "UTC" or None. + Use None to let the datetimes be interpreted as local time. + Use the string "UTC" to let the date datetimes be in UTC time. + + :param tz: datetime.tzinfo, "UTC" or None. + """ + raise NotImplementedError() + + def getTickMode(self): + """Determines if axis ticks are number or datetimes. + + :rtype: TickMode enum. + """ + raise NotImplementedError() + + def setTickMode(self, tickMode): + """Determines if axis ticks are number or datetimes. + + :param TickMode tickMode: tick mode enum. + """ + raise NotImplementedError() + + def isAutoScale(self): + """Return True if axis is automatically adjusting its limits. + + :rtype: bool + """ + return self._isAutoScale + + def setAutoScale(self, flag=True): + """Set the axis limits adjusting behavior of :meth:`resetZoom`. + + :param bool flag: True to resize limits automatically, + False to disable it. + """ + self._isAutoScale = bool(flag) + self.sigAutoScaleChanged.emit(self._isAutoScale) + + def _setLimitsConstraints(self, minPos=None, maxPos=None): + raise NotImplementedError() + + def setLimitsConstraints(self, minPos=None, maxPos=None): + """ + Set a constraint on the position of the axes. + + :param float minPos: Minimum allowed axis value. + :param float maxPos: Maximum allowed axis value. + :return: True if the constaints was updated + :rtype: bool + """ + updated = self._setLimitsConstraints(minPos, maxPos) + if updated: + plot = self._getPlot() + xMin, xMax = plot.getXAxis().getLimits() + yMin, yMax = plot.getYAxis().getLimits() + y2Min, y2Max = plot.getYAxis('right').getLimits() + plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + return updated + + def _setRangeConstraints(self, minRange=None, maxRange=None): + raise NotImplementedError() + + def setRangeConstraints(self, minRange=None, maxRange=None): + """ + Set a constraint on the position of the axes. + + :param float minRange: Minimum allowed left-to-right span across the + view + :param float maxRange: Maximum allowed left-to-right span across the + view + :return: True if the constaints was updated + :rtype: bool + """ + updated = self._setRangeConstraints(minRange, maxRange) + if updated: + plot = self._getPlot() + xMin, xMax = plot.getXAxis().getLimits() + yMin, yMax = plot.getYAxis().getLimits() + y2Min, y2Max = plot.getYAxis('right').getLimits() + plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + return updated + + +class XAxis(Axis): + """Axis class defining primitives for the X axis""" + + # TODO With some changes on the backend, it will be able to remove all this + # specialised implementations (prefixel by '_internal') + + def getTimeZone(self): + return self._getBackend().getXAxisTimeZone() + + def setTimeZone(self, tz): + if isinstance(tz, str) and tz.upper() == "UTC": + tz = dateutil.tz.tzutc() + elif not(tz is None or isinstance(tz, dt.tzinfo)): + raise TypeError("tz must be a dt.tzinfo object, None or 'UTC'.") + + self._getBackend().setXAxisTimeZone(tz) + self._getPlot()._setDirtyPlot() + + def getTickMode(self): + if self._getBackend().isXAxisTimeSeries(): + return TickMode.TIME_SERIES + else: + return TickMode.DEFAULT + + def setTickMode(self, tickMode): + if tickMode == TickMode.DEFAULT: + self._getBackend().setXAxisTimeSeries(False) + elif tickMode == TickMode.TIME_SERIES: + self._getBackend().setXAxisTimeSeries(True) + else: + raise ValueError("Unexpected TickMode: {}".format(tickMode)) + + def _internalSetCurrentLabel(self, label): + self._getBackend().setGraphXLabel(label) + + def _internalGetLimits(self): + return self._getBackend().getGraphXLimits() + + def _internalSetLimits(self, xmin, xmax): + self._getBackend().setGraphXLimits(xmin, xmax) + + def _internalSetLogarithmic(self, flag): + self._getBackend().setXAxisLogarithmic(flag) + + def _setLimitsConstraints(self, minPos=None, maxPos=None): + constrains = self._getPlot()._getViewConstraints() + updated = constrains.update(xMin=minPos, xMax=maxPos) + return updated + + def _setRangeConstraints(self, minRange=None, maxRange=None): + constrains = self._getPlot()._getViewConstraints() + updated = constrains.update(minXRange=minRange, maxXRange=maxRange) + return updated + + +class YAxis(Axis): + """Axis class defining primitives for the Y axis""" + + # TODO With some changes on the backend, it will be able to remove all this + # specialised implementations (prefixel by '_internal') + + def _internalSetCurrentLabel(self, label): + self._getBackend().setGraphYLabel(label, axis='left') + + def _internalGetLimits(self): + return self._getBackend().getGraphYLimits(axis='left') + + def _internalSetLimits(self, ymin, ymax): + self._getBackend().setGraphYLimits(ymin, ymax, axis='left') + + def _internalSetLogarithmic(self, flag): + self._getBackend().setYAxisLogarithmic(flag) + + def setInverted(self, flag=True): + """Set the axis orientation. + + This is only available for the Y axis. + + :param bool flag: True for Y axis going from top to bottom, + False for Y axis going from bottom to top + """ + flag = bool(flag) + if self.isInverted() == flag: + return + self._getBackend().setYAxisInverted(flag) + self._getPlot()._setDirtyPlot() + self.sigInvertedChanged.emit(flag) + + def isInverted(self): + """Return True if the axis is inverted (top to bottom for the y-axis), + False otherwise. It is always False for the X axis. + + :rtype: bool + """ + return self._getBackend().isYAxisInverted() + + def _setLimitsConstraints(self, minPos=None, maxPos=None): + constrains = self._getPlot()._getViewConstraints() + updated = constrains.update(yMin=minPos, yMax=maxPos) + return updated + + def _setRangeConstraints(self, minRange=None, maxRange=None): + constrains = self._getPlot()._getViewConstraints() + updated = constrains.update(minYRange=minRange, maxYRange=maxRange) + return updated + + +class YRightAxis(Axis): + """Proxy axis for the secondary Y axes. It manages it own label and limit + but share the some state like scale and direction with the main axis.""" + + # TODO With some changes on the backend, it will be able to remove all this + # specialised implementations (prefixel by '_internal') + + def __init__(self, plot, mainAxis): + """Constructor + + :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this + axis + :param Axis mainAxis: Axis which sharing state with this axis + """ + Axis.__init__(self, plot) + self.__mainAxis = mainAxis + + @property + def sigInvertedChanged(self): + """Signal emitted when axis orientation has changed""" + return self.__mainAxis.sigInvertedChanged + + @property + def sigScaleChanged(self): + """Signal emitted when axis scale has changed""" + return self.__mainAxis.sigScaleChanged + + @property + def _sigLogarithmicChanged(self): + """Signal emitted when axis scale has changed to or from logarithmic""" + return self.__mainAxis._sigLogarithmicChanged + + @property + def sigAutoScaleChanged(self): + """Signal emitted when axis autoscale has changed""" + return self.__mainAxis.sigAutoScaleChanged + + def _internalSetCurrentLabel(self, label): + self._getBackend().setGraphYLabel(label, axis='right') + + def _internalGetLimits(self): + return self._getBackend().getGraphYLimits(axis='right') + + def _internalSetLimits(self, ymin, ymax): + self._getBackend().setGraphYLimits(ymin, ymax, axis='right') + + def setInverted(self, flag=True): + """Set the Y axis orientation. + + :param bool flag: True for Y axis going from top to bottom, + False for Y axis going from bottom to top + """ + return self.__mainAxis.setInverted(flag) + + def isInverted(self): + """Return True if Y axis goes from top to bottom, False otherwise.""" + return self.__mainAxis.isInverted() + + def getScale(self): + """Return the name of the scale used by this axis. + + :rtype: str + """ + return self.__mainAxis.getScale() + + def setScale(self, scale): + """Set the scale to be used by this axis. + + :param str scale: Name of the scale ("log", or "linear") + """ + self.__mainAxis.setScale(scale) + + def _isLogarithmic(self): + """Return True if Y axis scale is logarithmic, False if linear.""" + return self.__mainAxis._isLogarithmic() + + def _setLogarithmic(self, flag): + """Set the Y axes scale (either linear or logarithmic). + + :param bool flag: True to use a logarithmic scale, False for linear. + """ + return self.__mainAxis._setLogarithmic(flag) + + def isAutoScale(self): + """Return True if Y axes are automatically adjusting its limits.""" + return self.__mainAxis.isAutoScale() + + def setAutoScale(self, flag=True): + """Set the Y axis limits adjusting behavior of :meth:`PlotWidget.resetZoom`. + + :param bool flag: True to resize limits automatically, + False to disable it. + """ + return self.__mainAxis.setAutoScale(flag) diff --git a/src/silx/gui/plot/items/complex.py b/src/silx/gui/plot/items/complex.py new file mode 100644 index 0000000..abb64ad --- /dev/null +++ b/src/silx/gui/plot/items/complex.py @@ -0,0 +1,386 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`ImageComplexData` of the :class:`Plot`. +""" + +from __future__ import absolute_import + +__authors__ = ["Vincent Favre-Nicolin", "T. Vincent"] +__license__ = "MIT" +__date__ = "14/06/2018" + + +import logging + +import numpy + +from ....utils.proxy import docstring +from ....utils.deprecation import deprecated +from ...colors import Colormap +from .core import ColormapMixIn, ComplexMixIn, ItemChangedType +from .image import ImageBase + + +_logger = logging.getLogger(__name__) + + +# Complex colormap functions + +def _phase2rgb(colormap, data): + """Creates RGBA image with colour-coded phase. + + :param Colormap colormap: The colormap to use + :param numpy.ndarray data: The data to convert + :return: Array of RGBA colors + :rtype: numpy.ndarray + """ + if data.size == 0: + return numpy.zeros((0, 0, 4), dtype=numpy.uint8) + + phase = numpy.angle(data) + return colormap.applyToData(phase) + + +def _complex2rgbalog(phaseColormap, data, amin=0., dlogs=2, smax=None): + """Returns RGBA colors: colour-coded phases and log10(amplitude) in alpha. + + :param Colormap phaseColormap: Colormap to use for the phase + :param numpy.ndarray data: the complex data array to convert to RGBA + :param float amin: the minimum value for the alpha channel + :param float dlogs: amplitude range displayed, in log10 units + :param float smax: + if specified, all values above max will be displayed with an alpha=1 + """ + if data.size == 0: + return numpy.zeros((0, 0, 4), dtype=numpy.uint8) + + rgba = _phase2rgb(phaseColormap, data) + sabs = numpy.absolute(data) + if smax is not None: + sabs[sabs > smax] = smax + a = numpy.log10(sabs + 1e-20) + a -= a.max() - dlogs # display dlogs orders of magnitude + rgba[..., 3] = 255 * (amin + a / dlogs * (1 - amin) * (a > 0)) + return rgba + + +def _complex2rgbalin(phaseColormap, data, gamma=1.0, smax=None): + """Returns RGBA colors: colour-coded phase and linear amplitude in alpha. + + :param Colormap phaseColormap: Colormap to use for the phase + :param numpy.ndarray data: + :param float gamma: Optional exponent gamma applied to the amplitude + :param float smax: + """ + if data.size == 0: + return numpy.zeros((0, 0, 4), dtype=numpy.uint8) + + rgba = _phase2rgb(phaseColormap, data) + a = numpy.absolute(data) + if smax is not None: + a[a > smax] = smax + a /= a.max() + rgba[..., 3] = 255 * a**gamma + return rgba + + +class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn): + """Specific plot item to force colormap when using complex colormap. + + This is returning the specific colormap when displaying + colored phase + amplitude. + """ + + _SUPPORTED_COMPLEX_MODES = ( + ComplexMixIn.ComplexMode.ABSOLUTE, + ComplexMixIn.ComplexMode.PHASE, + ComplexMixIn.ComplexMode.REAL, + ComplexMixIn.ComplexMode.IMAGINARY, + ComplexMixIn.ComplexMode.AMPLITUDE_PHASE, + ComplexMixIn.ComplexMode.LOG10_AMPLITUDE_PHASE, + ComplexMixIn.ComplexMode.SQUARE_AMPLITUDE) + """Overrides supported ComplexMode""" + + def __init__(self): + ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.complex64)) + ColormapMixIn.__init__(self) + ComplexMixIn.__init__(self) + self._dataByModesCache = {} + self._amplitudeRangeInfo = None, 2 + + # Use default from ColormapMixIn + colormap = super(ImageComplexData, self).getColormap() + + phaseColormap = Colormap( + name='hsv', + vmin=-numpy.pi, + vmax=numpy.pi) + + self._colormaps = { # Default colormaps for all modes + self.ComplexMode.ABSOLUTE: colormap, + self.ComplexMode.PHASE: phaseColormap, + self.ComplexMode.REAL: colormap, + self.ComplexMode.IMAGINARY: colormap, + self.ComplexMode.AMPLITUDE_PHASE: phaseColormap, + self.ComplexMode.LOG10_AMPLITUDE_PHASE: phaseColormap, + self.ComplexMode.SQUARE_AMPLITUDE: colormap, + } + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + plot = self.getPlot() + assert plot is not None + if not self._isPlotLinear(plot): + # Do not render with non linear scales + return None + + mode = self.getComplexMode() + if mode in (self.ComplexMode.AMPLITUDE_PHASE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE): + # For those modes, compute RGBA image here + colormap = None + data = self.getRgbaImageData(copy=False) + else: + colormap = self.getColormap() + if colormap.isAutoscale(): + # Avoid backend to compute autoscale: use item cache + colormap = colormap.copy() + colormap.setVRange(*colormap.getColormapRange(self)) + + data = self.getData(copy=False) + + if data.size == 0: + return None # No data to display + + return backend.addImage(data, + origin=self.getOrigin(), + scale=self.getScale(), + colormap=colormap, + alpha=self.getAlpha()) + + @docstring(ComplexMixIn) + def setComplexMode(self, mode): + changed = super(ImageComplexData, self).setComplexMode(mode) + if changed: + self._valueDataChanged() + + # Backward compatibility + self._updated(ItemChangedType.VISUALIZATION_MODE) + + # Update ColormapMixIn colormap + colormap = self._colormaps[self.getComplexMode()] + if colormap is not super(ImageComplexData, self).getColormap(): + super(ImageComplexData, self).setColormap(colormap) + + # Send data updated as value returned by getData has changed + self._updated(ItemChangedType.DATA) + return changed + + def _setAmplitudeRangeInfo(self, max_=None, delta=2): + """Set the amplitude range to display for 'log10_amplitude_phase' mode. + + :param max_: Max of the amplitude range. + If None it autoscales to data max. + :param float delta: Delta range in log10 to display + """ + self._amplitudeRangeInfo = max_, float(delta) + self._updated(ItemChangedType.VISUALIZATION_MODE) + + def _getAmplitudeRangeInfo(self): + """Returns the amplitude range to use for 'log10_amplitude_phase' mode. + + :return: (max, delta), if max is None, then it autoscales to data max + :rtype: 2-tuple""" + return self._amplitudeRangeInfo + + def setColormap(self, colormap, mode=None): + """Set the colormap for this specific mode. + + :param ~silx.gui.colors.Colormap colormap: The colormap + :param Union[ComplexMode,str] mode: + If specified, set the colormap of this specific mode. + Default: current mode. + """ + if mode is None: + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) + + self._colormaps[mode] = colormap + if mode is self.getComplexMode(): + super(ImageComplexData, self).setColormap(colormap) + else: + self._updated(ItemChangedType.COLORMAP) + + def getColormap(self, mode=None): + """Get the colormap for the (current) mode. + + :param Union[ComplexMode,str] mode: + If specified, get the colormap of this specific mode. + Default: current mode. + :rtype: ~silx.gui.colors.Colormap + """ + if mode is None: + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) + + return self._colormaps[mode] + + def setData(self, data, copy=True): + """"Set the image complex data + + :param numpy.ndarray data: 2D array of complex with 2 dimensions (h, w) + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + data = numpy.array(data, copy=copy) + assert data.ndim == 2 + if not numpy.issubdtype(data.dtype, numpy.complexfloating): + _logger.warning( + 'Image is not complex, converting it to complex to plot it.') + data = numpy.array(data, dtype=numpy.complex64) + + # Compute current mode data and set colormap data + mode = self.getComplexMode() + dataForMode = self.__convertComplexData(data, self.getComplexMode()) + self._dataByModesCache = {mode: dataForMode} + + super().setData(data) + + def _updated(self, event=None, checkVisibility=True): + # Synchronizes colormapped data if changed + # ItemChangedType.COMPLEX_MODE triggers ItemChangedType.DATA + # No need to handle it twice. + if event in (ItemChangedType.DATA, ItemChangedType.MASK): + # Color-mapped data is NOT the `getValueData` for some modes + if self.getComplexMode() in ( + self.ComplexMode.AMPLITUDE_PHASE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE): + data = self.getData(copy=False, mode=self.ComplexMode.PHASE) + mask = self.getMaskData(copy=False) + if mask is not None: + data = numpy.copy(data) + data[mask != 0] = numpy.nan + else: + data = self.getValueData(copy=False) + self._setColormappedData(data, copy=False) + super()._updated(event=event, checkVisibility=checkVisibility) + + def getComplexData(self, copy=True): + """Returns the image complex data + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray of complex + """ + return super().getData(copy=copy) + + def __convertComplexData(self, data, mode): + """Convert complex data to given mode. + + :param numpy.ndarray data: + :param Union[ComplexMode,str] mode: + :rtype: numpy.ndarray of float + """ + if mode is self.ComplexMode.PHASE: + return numpy.angle(data) + elif mode is self.ComplexMode.REAL: + return numpy.real(data) + elif mode is self.ComplexMode.IMAGINARY: + return numpy.imag(data) + elif mode in (self.ComplexMode.ABSOLUTE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE, + self.ComplexMode.AMPLITUDE_PHASE): + return numpy.absolute(data) + elif mode is self.ComplexMode.SQUARE_AMPLITUDE: + return numpy.absolute(data) ** 2 + else: + _logger.error( + 'Unsupported conversion mode: %s, fallback to absolute', + str(mode)) + return numpy.absolute(data) + + def getData(self, copy=True, mode=None): + """Returns the image data corresponding to (current) mode. + + The returned data is always floats, to get the complex data, use + :meth:`getComplexData`. + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :param Union[ComplexMode,str] mode: + If specified, get data corresponding to the mode. + Default: Current mode. + :rtype: numpy.ndarray of float + """ + if mode is None: + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) + + if mode not in self._dataByModesCache: + self._dataByModesCache[mode] = self.__convertComplexData( + self.getComplexData(copy=False), mode) + + return numpy.array(self._dataByModesCache[mode], copy=copy) + + def getRgbaImageData(self, copy=True, mode=None): + """Get the displayed RGB(A) image for (current) mode + + :param bool copy: Ignored for this class + :param Union[ComplexMode,str] mode: + If specified, get data corresponding to the mode. + Default: Current mode. + :rtype: numpy.ndarray of uint8 of shape (height, width, 4) + """ + if mode is None: + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) + + colormap = self.getColormap(mode=mode) + if mode is self.ComplexMode.AMPLITUDE_PHASE: + data = self.getComplexData(copy=False) + return _complex2rgbalin(colormap, data) + elif mode is self.ComplexMode.LOG10_AMPLITUDE_PHASE: + data = self.getComplexData(copy=False) + max_, delta = self._getAmplitudeRangeInfo() + return _complex2rgbalog(colormap, data, dlogs=delta, smax=max_) + else: + data = self.getData(copy=False, mode=mode) + return colormap.applyToData(data) + + # Backward compatibility + + Mode = ComplexMixIn.ComplexMode + + @deprecated(replacement='setComplexMode', since_version='0.11.0') + def setVisualizationMode(self, mode): + return self.setComplexMode(mode) + + @deprecated(replacement='getComplexMode', since_version='0.11.0') + def getVisualizationMode(self): + return self.getComplexMode() diff --git a/src/silx/gui/plot/items/core.py b/src/silx/gui/plot/items/core.py new file mode 100644 index 0000000..fa3b8cf --- /dev/null +++ b/src/silx/gui/plot/items/core.py @@ -0,0 +1,1733 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the base class for items of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "08/12/2020" + +import collections +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc +from copy import deepcopy +import logging +import enum +from typing import Optional, Tuple +import warnings +import weakref + +import numpy + +from ....utils.deprecation import deprecated +from ....utils.proxy import docstring +from ....utils.enum import Enum as _Enum +from ....math.combo import min_max +from ... import qt +from ... import colors +from ...colors import Colormap +from ._pick import PickingResult + +from silx import config + +_logger = logging.getLogger(__name__) + + +@enum.unique +class ItemChangedType(enum.Enum): + """Type of modification provided by :attr:`Item.sigItemChanged` signal.""" + # Private setters and setInfo are not emitting sigItemChanged signal. + # Signals to consider: + # COLORMAP_SET emitted when setColormap is called but not forward colormap object signal + # CURRENT_COLOR_CHANGED emitted current color changed because highlight changed, + # highlighted color changed or color changed depending on hightlight state. + + VISIBLE = 'visibleChanged' + """Item's visibility changed flag.""" + + ZVALUE = 'zValueChanged' + """Item's Z value changed flag.""" + + COLORMAP = 'colormapChanged' # Emitted when set + forward events from the colormap object + """Item's colormap changed flag. + + This is emitted both when setting a new colormap and + when the current colormap object is updated. + """ + + SYMBOL = 'symbolChanged' + """Item's symbol changed flag.""" + + SYMBOL_SIZE = 'symbolSizeChanged' + """Item's symbol size changed flag.""" + + LINE_WIDTH = 'lineWidthChanged' + """Item's line width changed flag.""" + + LINE_STYLE = 'lineStyleChanged' + """Item's line style changed flag.""" + + COLOR = 'colorChanged' + """Item's color changed flag.""" + + LINE_BG_COLOR = 'lineBgColorChanged' + """Item's line background color changed flag.""" + + YAXIS = 'yAxisChanged' + """Item's Y axis binding changed flag.""" + + FILL = 'fillChanged' + """Item's fill changed flag.""" + + ALPHA = 'alphaChanged' + """Item's transparency alpha changed flag.""" + + DATA = 'dataChanged' + """Item's data changed flag""" + + MASK = 'maskChanged' + """Item's mask changed flag""" + + HIGHLIGHTED = 'highlightedChanged' + """Item's highlight state changed flag.""" + + HIGHLIGHTED_COLOR = 'highlightedColorChanged' + """Deprecated, use HIGHLIGHTED_STYLE instead.""" + + HIGHLIGHTED_STYLE = 'highlightedStyleChanged' + """Item's highlighted style changed flag.""" + + SCALE = 'scaleChanged' + """Item's scale changed flag.""" + + TEXT = 'textChanged' + """Item's text changed flag.""" + + POSITION = 'positionChanged' + """Item's position changed flag. + + This is emitted when a marker position changed and + when an image origin changed. + """ + + OVERLAY = 'overlayChanged' + """Item's overlay state changed flag.""" + + VISUALIZATION_MODE = 'visualizationModeChanged' + """Item's visualization mode changed flag.""" + + COMPLEX_MODE = 'complexModeChanged' + """Item's complex data visualization mode changed flag.""" + + NAME = 'nameChanged' + """Item's name changed flag.""" + + EDITABLE = 'editableChanged' + """Item's editable state changed flags.""" + + SELECTABLE = 'selectableChanged' + """Item's selectable state changed flags.""" + + +class Item(qt.QObject): + """Description of an item of the plot""" + + _DEFAULT_Z_LAYER = 0 + """Default layer for overlay rendering""" + + _DEFAULT_SELECTABLE = False + """Default selectable state of items""" + + sigItemChanged = qt.Signal(object) + """Signal emitted when the item has changed. + + It provides a flag describing which property of the item has changed. + See :class:`ItemChangedType` for flags description. + """ + + _sigVisibleBoundsChanged = qt.Signal() + """Signal emitted when the visible extent of the item in the plot has changed. + + This signal is emitted only if visible extent tracking is enabled + (see :meth:`_setVisibleBoundsTracking`). + """ + + def __init__(self): + qt.QObject.__init__(self) + self._dirty = True + self._plotRef = None + self._visible = True + self._selectable = self._DEFAULT_SELECTABLE + self._z = self._DEFAULT_Z_LAYER + self._info = None + self._xlabel = None + self._ylabel = None + self.__name = '' + + self.__visibleBoundsTracking = False + self.__previousVisibleBounds = None + + self._backendRenderer = None + + def getPlot(self): + """Returns the ~silx.gui.plot.PlotWidget this item belongs to. + + :rtype: Union[~silx.gui.plot.PlotWidget,None] + """ + return None if self._plotRef is None else self._plotRef() + + def _setPlot(self, plot): + """Set the plot this item belongs to. + + WARNING: This should only be called from the Plot. + + :param Union[~silx.gui.plot.PlotWidget,None] plot: The Plot instance. + """ + if plot is not None and self._plotRef is not None: + raise RuntimeError('Trying to add a node at two places.') + self.__disconnectFromPlotWidget() + self._plotRef = None if plot is None else weakref.ref(plot) + self.__connectToPlotWidget() + self._updated() + + def getBounds(self): # TODO return a Bounds object rather than a tuple + """Returns the bounding box of this item in data coordinates + + :returns: (xmin, xmax, ymin, ymax) or None + :rtype: 4-tuple of float or None + """ + return self._getBounds() + + def _getBounds(self): + """:meth:`getBounds` implementation to override by sub-class""" + return None + + def isVisible(self): + """True if item is visible, False otherwise + + :rtype: bool + """ + return self._visible + + def setVisible(self, visible): + """Set visibility of item. + + :param bool visible: True to display it, False otherwise + """ + visible = bool(visible) + if visible != self._visible: + self._visible = visible + # When visibility has changed, always mark as dirty + self._updated(ItemChangedType.VISIBLE, + checkVisibility=False) + + def isOverlay(self): + """Return true if item is drawn as an overlay. + + :rtype: bool + """ + return False + + def getName(self): + """Returns the name of the item which is used as legend. + + :rtype: str + """ + return self.__name + + def setName(self, name): + """Set the name of the item which is used as legend. + + :param str name: New name of the item + :raises RuntimeError: If item belongs to a PlotWidget. + """ + name = str(name) + if self.__name != name: + if self.getPlot() is not None: + raise RuntimeError( + "Cannot change name while item is in a PlotWidget") + + self.__name = name + self._updated(ItemChangedType.NAME) + + def getLegend(self): # Replaced by getName for API consistency + return self.getName() + + @deprecated(replacement='setName', since_version='0.13') + def _setLegend(self, legend): + legend = str(legend) if legend is not None else '' + self.setName(legend) + + def isSelectable(self): + """Returns true if item is selectable (bool)""" + return self._selectable + + def _setSelectable(self, selectable): # TODO support update + """Set whether item is selectable or not. + + This is private for now as change is not handled. + + :param bool selectable: True to make item selectable + """ + self._selectable = bool(selectable) + + def getZValue(self): + """Returns the layer on which to draw this item (int)""" + return self._z + + def setZValue(self, z): + z = int(z) if z is not None else self._DEFAULT_Z_LAYER + if z != self._z: + self._z = z + self._updated(ItemChangedType.ZVALUE) + + def getInfo(self, copy=True): + """Returns the info associated to this item + + :param bool copy: True to get a deepcopy, False otherwise. + """ + return deepcopy(self._info) if copy else self._info + + def setInfo(self, info, copy=True): + if copy: + info = deepcopy(info) + self._info = info + + def getVisibleBounds(self) -> Optional[Tuple[float, float, float, float]]: + """Returns visible bounds of the item bounding box in the plot area. + + :returns: + (xmin, xmax, ymin, ymax) in data coordinates of the visible area or + None if item is not visible in the plot area. + :rtype: Union[List[float],None] + """ + plot = self.getPlot() + bounds = self.getBounds() + if plot is None or bounds is None or not self.isVisible(): + return None + + xmin, xmax = numpy.clip(bounds[:2], *plot.getXAxis().getLimits()) + ymin, ymax = numpy.clip( + bounds[2:], *plot.getYAxis(self.__getYAxis()).getLimits()) + + if xmin == xmax or ymin == ymax: # Outside the plot area + return None + else: + return xmin, xmax, ymin, ymax + + def _isVisibleBoundsTracking(self) -> bool: + """Returns True if visible bounds changes are tracked. + + When enabled, :attr:`_sigVisibleBoundsChanged` is emitted upon changes. + :rtype: bool + """ + return self.__visibleBoundsTracking + + def _setVisibleBoundsTracking(self, enable: bool) -> None: + """Set whether or not to track visible bounds changes. + + :param bool enable: + """ + if enable != self.__visibleBoundsTracking: + self.__disconnectFromPlotWidget() + self.__previousVisibleBounds = None + self.__visibleBoundsTracking = enable + self.__connectToPlotWidget() + + def __getYAxis(self) -> str: + """Returns current Y axis ('left' or 'right')""" + return self.getYAxis() if isinstance(self, YAxisMixIn) else 'left' + + def __connectToPlotWidget(self) -> None: + """Connect to PlotWidget signals and install event filter""" + if not self._isVisibleBoundsTracking(): + return + + plot = self.getPlot() + if plot is not None: + for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())): + axis.sigLimitsChanged.connect(self._visibleBoundsChanged) + + plot.installEventFilter(self) + + self._visibleBoundsChanged() + + def __disconnectFromPlotWidget(self) -> None: + """Disconnect from PlotWidget signals and remove event filter""" + if not self._isVisibleBoundsTracking(): + return + + plot = self.getPlot() + if plot is not None: + for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())): + axis.sigLimitsChanged.disconnect(self._visibleBoundsChanged) + + plot.removeEventFilter(self) + + def _visibleBoundsChanged(self, *args) -> None: + """Check if visible extent actually changed and emit signal""" + if not self._isVisibleBoundsTracking(): + return # No visible extent tracking + + plot = self.getPlot() + if plot is None or not plot.isVisible(): + return # No plot or plot not visible + + extent = self.getVisibleBounds() + if extent != self.__previousVisibleBounds: + self.__previousVisibleBounds = extent + self._sigVisibleBoundsChanged.emit() + + def eventFilter(self, watched, event): + """Event filter to handle PlotWidget show events""" + if watched is self.getPlot() and event.type() == qt.QEvent.Show: + self._visibleBoundsChanged() + return super().eventFilter(watched, event) + + def _updated(self, event=None, checkVisibility=True): + """Mark the item as dirty (i.e., needing update). + + This also triggers Plot.replot. + + :param event: The event to send to :attr:`sigItemChanged` signal. + :param bool checkVisibility: True to only mark as dirty if visible, + False to always mark as dirty. + """ + if not checkVisibility or self.isVisible(): + if not self._dirty: + self._dirty = True + # TODO: send event instead of explicit call + plot = self.getPlot() + if plot is not None: + plot._itemRequiresUpdate(self) + if event is not None: + self.sigItemChanged.emit(event) + + def _update(self, backend): + """Called by Plot to update the backend for this item. + + This is meant to be called asynchronously from _updated. + This optimizes the number of call to _update. + + :param backend: The backend to update + """ + if self._dirty: + # Remove previous renderer from backend if any + self._removeBackendRenderer(backend) + + # If not visible, do not add renderer to backend + if self.isVisible(): + self._backendRenderer = self._addBackendRenderer(backend) + + self._dirty = False + + def _addBackendRenderer(self, backend): + """Override in subclass to add specific backend renderer. + + :param BackendBase backend: The backend to update + :return: The renderer handle to store or None if no renderer in backend + """ + return None + + def _removeBackendRenderer(self, backend): + """Override in subclass to remove specific backend renderer. + + :param BackendBase backend: The backend to update + """ + if self._backendRenderer is not None: + backend.remove(self._backendRenderer) + self._backendRenderer = None + + def pick(self, x, y): + """Run picking test on this item + + :param float x: The x pixel coord where to pick. + :param float y: The y pixel coord where to pick. + :return: None if not picked, else the picked position information + :rtype: Union[None,PickingResult] + """ + if not self.isVisible() or self._backendRenderer is None: + return None + plot = self.getPlot() + if plot is None: + return None + + indices = plot._backend.pickItem(x, y, self._backendRenderer) + if indices is None: + return None + else: + return PickingResult(self, indices) + + +class DataItem(Item): + """Item with a data extent in the plot""" + + def _boundsChanged(self, checkVisibility: bool=True) -> None: + """Call this method in subclass when data bounds has changed. + + :param bool checkVisibility: + """ + if not checkVisibility or self.isVisible(): + self._visibleBoundsChanged() + + # TODO hackish data range implementation + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() + + @docstring(Item) + def setVisible(self, visible: bool): + if visible != self.isVisible(): + self._boundsChanged(checkVisibility=False) + super().setVisible(visible) + +# Mix-in classes ############################################################## + + +class ItemMixInBase(object): + """Base class for Item mix-in""" + + def _updated(self, event=None, checkVisibility=True): + """This is implemented in :class:`Item`. + + Mark the item as dirty (i.e., needing update). + This also triggers Plot.replot. + + :param event: The event to send to :attr:`sigItemChanged` signal. + :param bool checkVisibility: True to only mark as dirty if visible, + False to always mark as dirty. + """ + raise RuntimeError( + "Issue with Mix-In class inheritance order") + + +class LabelsMixIn(ItemMixInBase): + """Mix-in class for items with x and y labels + + Setters are private, otherwise it needs to check the plot + current active curve and access the internal current labels. + """ + + def __init__(self): + self._xlabel = None + self._ylabel = None + + def getXLabel(self): + """Return the X axis label associated to this curve + + :rtype: str or None + """ + return self._xlabel + + def _setXLabel(self, label): + """Set the X axis label associated with this curve + + :param str label: The X axis label + """ + self._xlabel = str(label) + + def getYLabel(self): + """Return the Y axis label associated to this curve + + :rtype: str or None + """ + return self._ylabel + + def _setYLabel(self, label): + """Set the Y axis label associated with this curve + + :param str label: The Y axis label + """ + self._ylabel = str(label) + + +class DraggableMixIn(ItemMixInBase): + """Mix-in class for draggable items""" + + def __init__(self): + self._draggable = False + + def isDraggable(self): + """Returns true if image is draggable + + :rtype: bool + """ + return self._draggable + + def _setDraggable(self, draggable): # TODO support update + """Set if image is draggable or not. + + This is private for not as it does not support update. + + :param bool draggable: + """ + self._draggable = bool(draggable) + + def drag(self, from_, to): + """Perform a drag of the item. + + :param List[float] from_: (x, y) previous position in data coordinates + :param List[float] to: (x, y) current position in data coordinates + """ + raise NotImplementedError("Must be implemented in subclass") + + +class ColormapMixIn(ItemMixInBase): + """Mix-in class for items with colormap""" + + def __init__(self): + self._colormap = Colormap() + self._colormap.sigChanged.connect(self._colormapChanged) + self.__data = None + self.__cacheColormapRange = {} # Store {normalization: range} + + def getColormap(self): + """Return the used colormap""" + return self._colormap + + def setColormap(self, colormap): + """Set the colormap of this item + + :param silx.gui.colors.Colormap colormap: colormap description + """ + if self._colormap is colormap: + return + if isinstance(colormap, dict): + colormap = Colormap._fromDict(colormap) + + if self._colormap is not None: + self._colormap.sigChanged.disconnect(self._colormapChanged) + self._colormap = colormap + if self._colormap is not None: + self._colormap.sigChanged.connect(self._colormapChanged) + self._colormapChanged() + + def _colormapChanged(self): + """Handle updates of the colormap""" + self._updated(ItemChangedType.COLORMAP) + + def _setColormappedData(self, data, copy=True, + min_=None, minPositive=None, max_=None): + """Set the data used to compute the colormapped display. + + It also resets the cache of data ranges. + + This method MUST be called by inheriting classes when data is updated. + + :param Union[None,numpy.ndarray] data: + :param Union[None,float] min_: Minimum value of the data + :param Union[None,float] minPositive: + Minimum of strictly positive values of the data + :param Union[None,float] max_: Maximum value of the data + """ + self.__data = None if data is None else numpy.array(data, copy=copy) + self.__cacheColormapRange = {} # Reset cache + + # Fill-up colormap range cache if values are provided + if max_ is not None and numpy.isfinite(max_): + if min_ is not None and numpy.isfinite(min_): + self.__cacheColormapRange[Colormap.LINEAR, Colormap.MINMAX] = min_, max_ + if minPositive is not None and numpy.isfinite(minPositive): + self.__cacheColormapRange[Colormap.LOGARITHM, Colormap.MINMAX] = minPositive, max_ + + colormap = self.getColormap() + if None in (colormap.getVMin(), colormap.getVMax()): + self._colormapChanged() + + def getColormappedData(self, copy=True): + """Returns the data used to compute the displayed colors + + :param bool copy: True to get a copy, + False to get internal data (do not modify!). + :rtype: Union[None,numpy.ndarray] + """ + if self.__data is None: + return None + else: + return numpy.array(self.__data, copy=copy) + + def _getColormapAutoscaleRange(self, colormap=None): + """Returns the autoscale range for current data and colormap. + + :param Union[None,~silx.gui.colors.Colormap] colormap: + The colormap for which to compute the autoscale range. + If None, the default, the colormap of the item is used + :return: (vmin, vmax) range (vmin and /or vmax might be `None`) + """ + if colormap is None: + colormap = self.getColormap() + + data = self.getColormappedData(copy=False) + if colormap is None or data is None: + return None, None + + normalization = colormap.getNormalization() + autoscaleMode = colormap.getAutoscaleMode() + key = normalization, autoscaleMode + vRange = self.__cacheColormapRange.get(key, None) + if vRange is None: + vRange = colormap._computeAutoscaleRange(data) + self.__cacheColormapRange[key] = vRange + return vRange + + +class SymbolMixIn(ItemMixInBase): + """Mix-in class for items with symbol type""" + + _DEFAULT_SYMBOL = None + """Default marker of the item""" + + _DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE + """Default marker size of the item""" + + _SUPPORTED_SYMBOLS = collections.OrderedDict(( + ('o', 'Circle'), + ('d', 'Diamond'), + ('s', 'Square'), + ('+', 'Plus'), + ('x', 'Cross'), + ('.', 'Point'), + (',', 'Pixel'), + ('|', 'Vertical line'), + ('_', 'Horizontal line'), + ('tickleft', 'Tick left'), + ('tickright', 'Tick right'), + ('tickup', 'Tick up'), + ('tickdown', 'Tick down'), + ('caretleft', 'Caret left'), + ('caretright', 'Caret right'), + ('caretup', 'Caret up'), + ('caretdown', 'Caret down'), + (u'\u2665', 'Heart'), + ('', 'None'))) + """Dict of supported symbols""" + + def __init__(self): + if self._DEFAULT_SYMBOL is None: # Use default from config + self._symbol = config.DEFAULT_PLOT_SYMBOL + else: + self._symbol = self._DEFAULT_SYMBOL + + if self._DEFAULT_SYMBOL_SIZE is None: # Use default from config + self._symbol_size = config.DEFAULT_PLOT_SYMBOL_SIZE + else: + self._symbol_size = self._DEFAULT_SYMBOL_SIZE + + @classmethod + def getSupportedSymbols(cls): + """Returns the list of supported symbol names. + + :rtype: tuple of str + """ + return tuple(cls._SUPPORTED_SYMBOLS.keys()) + + @classmethod + def getSupportedSymbolNames(cls): + """Returns the list of supported symbol human-readable names. + + :rtype: tuple of str + """ + return tuple(cls._SUPPORTED_SYMBOLS.values()) + + def getSymbolName(self, symbol=None): + """Returns human-readable name for a symbol. + + :param str symbol: The symbol from which to get the name. + Default: current symbol. + :rtype: str + :raise KeyError: if symbol is not in :meth:`getSupportedSymbols`. + """ + if symbol is None: + symbol = self.getSymbol() + return self._SUPPORTED_SYMBOLS[symbol] + + def getSymbol(self): + """Return the point marker type. + + Marker type:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + + :rtype: str + """ + return self._symbol + + def setSymbol(self, symbol): + """Set the marker type + + See :meth:`getSymbol`. + + :param str symbol: Marker type or marker name + """ + if symbol is None: + symbol = self._DEFAULT_SYMBOL + + elif symbol not in self.getSupportedSymbols(): + for symbolCode, name in self._SUPPORTED_SYMBOLS.items(): + if name.lower() == symbol.lower(): + symbol = symbolCode + break + else: + raise ValueError('Unsupported symbol %s' % str(symbol)) + + if symbol != self._symbol: + self._symbol = symbol + self._updated(ItemChangedType.SYMBOL) + + def getSymbolSize(self): + """Return the point marker size in points. + + :rtype: float + """ + return self._symbol_size + + def setSymbolSize(self, size): + """Set the point marker size in points. + + See :meth:`getSymbolSize`. + + :param str symbol: Marker type + """ + if size is None: + size = self._DEFAULT_SYMBOL_SIZE + if size != self._symbol_size: + self._symbol_size = size + self._updated(ItemChangedType.SYMBOL_SIZE) + + +class LineMixIn(ItemMixInBase): + """Mix-in class for item with line""" + + _DEFAULT_LINEWIDTH = 1. + """Default line width""" + + _DEFAULT_LINESTYLE = '-' + """Default line style""" + + _SUPPORTED_LINESTYLE = '', ' ', '-', '--', '-.', ':', None + """Supported line styles""" + + def __init__(self): + self._linewidth = self._DEFAULT_LINEWIDTH + self._linestyle = self._DEFAULT_LINESTYLE + + @classmethod + def getSupportedLineStyles(cls): + """Returns list of supported line styles. + + :rtype: List[str,None] + """ + return cls._SUPPORTED_LINESTYLE + + def getLineWidth(self): + """Return the curve line width in pixels + + :rtype: float + """ + return self._linewidth + + def setLineWidth(self, width): + """Set the width in pixel of the curve line + + See :meth:`getLineWidth`. + + :param float width: Width in pixels + """ + width = float(width) + if width != self._linewidth: + self._linewidth = width + self._updated(ItemChangedType.LINE_WIDTH) + + def getLineStyle(self): + """Return the type of the line + + Type of line:: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + + :rtype: str + """ + return self._linestyle + + def setLineStyle(self, style): + """Set the style of the curve line. + + See :meth:`getLineStyle`. + + :param str style: Line style + """ + style = str(style) + assert style in self.getSupportedLineStyles() + if style is None: + style = self._DEFAULT_LINESTYLE + if style != self._linestyle: + self._linestyle = style + self._updated(ItemChangedType.LINE_STYLE) + + +class ColorMixIn(ItemMixInBase): + """Mix-in class for item with color""" + + _DEFAULT_COLOR = (0., 0., 0., 1.) + """Default color of the item""" + + def __init__(self): + self._color = self._DEFAULT_COLOR + + def getColor(self): + """Returns the RGBA color of the item + + :rtype: 4-tuple of float in [0, 1] or array of colors + """ + return self._color + + def setColor(self, color, copy=True): + """Set item color + + :param color: color(s) to be used + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in colors.py + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + if isinstance(color, str): + color = colors.rgba(color) + elif isinstance(color, qt.QColor): + color = colors.rgba(color) + else: + color = numpy.array(color, copy=copy) + # TODO more checks + improve color array support + if color.ndim == 1: # Single RGBA color + color = colors.rgba(color) + else: # Array of colors + assert color.ndim == 2 + + self._color = color + self._updated(ItemChangedType.COLOR) + + +class YAxisMixIn(ItemMixInBase): + """Mix-in class for item with yaxis""" + + _DEFAULT_YAXIS = 'left' + """Default Y axis the item belongs to""" + + def __init__(self): + self._yaxis = self._DEFAULT_YAXIS + + def getYAxis(self): + """Returns the Y axis this curve belongs to. + + Either 'left' or 'right'. + + :rtype: str + """ + return self._yaxis + + def setYAxis(self, yaxis): + """Set the Y axis this curve belongs to. + + :param str yaxis: 'left' or 'right' + """ + yaxis = str(yaxis) + assert yaxis in ('left', 'right') + if yaxis != self._yaxis: + self._yaxis = yaxis + # Handle data extent changed for DataItem + if isinstance(self, DataItem): + self._boundsChanged() + + # Handle visible extent changed + if self._isVisibleBoundsTracking(): + # Switch Y axis signal connection + plot = self.getPlot() + if plot is not None: + previousYAxis = 'left' if self.getXAxis() == 'right' else 'right' + plot.getYAxis(previousYAxis).sigLimitsChanged.disconnect( + self._visibleBoundsChanged) + plot.getYAxis(self.getYAxis()).sigLimitsChanged.connect( + self._visibleBoundsChanged) + self._visibleBoundsChanged() + + self._updated(ItemChangedType.YAXIS) + + +class FillMixIn(ItemMixInBase): + """Mix-in class for item with fill""" + + def __init__(self): + self._fill = False + + def isFill(self): + """Returns whether the item is filled or not. + + :rtype: bool + """ + return self._fill + + def setFill(self, fill): + """Set whether to fill the item or not. + + :param bool fill: + """ + fill = bool(fill) + if fill != self._fill: + self._fill = fill + self._updated(ItemChangedType.FILL) + + +class AlphaMixIn(ItemMixInBase): + """Mix-in class for item with opacity""" + + def __init__(self): + self._alpha = 1. + + def getAlpha(self): + """Returns the opacity of the item + + :rtype: float in [0, 1.] + """ + return self._alpha + + def setAlpha(self, alpha): + """Set the opacity of the item + + .. note:: + + If the colormap already has some transparency, this alpha + adds additional transparency. The alpha channel of the colormap + is multiplied by this value. + + :param alpha: Opacity of the item, between 0 (full transparency) + and 1. (full opacity) + :type alpha: float + """ + alpha = float(alpha) + alpha = max(0., min(alpha, 1.)) # Clip alpha to [0., 1.] range + if alpha != self._alpha: + self._alpha = alpha + self._updated(ItemChangedType.ALPHA) + + +class ComplexMixIn(ItemMixInBase): + """Mix-in class for complex data mode""" + + _SUPPORTED_COMPLEX_MODES = None + """Override to only support a subset of all ComplexMode""" + + class ComplexMode(_Enum): + """Identify available display mode for complex""" + NONE = 'none' + ABSOLUTE = 'amplitude' + PHASE = 'phase' + REAL = 'real' + IMAGINARY = 'imaginary' + AMPLITUDE_PHASE = 'amplitude_phase' + LOG10_AMPLITUDE_PHASE = 'log10_amplitude_phase' + SQUARE_AMPLITUDE = 'square_amplitude' + + def __init__(self): + self.__complex_mode = self.ComplexMode.ABSOLUTE + + def getComplexMode(self): + """Returns the current complex visualization mode. + + :rtype: ComplexMode + """ + return self.__complex_mode + + def setComplexMode(self, mode): + """Set the complex visualization mode. + + :param ComplexMode mode: The visualization mode in: + 'real', 'imaginary', 'phase', 'amplitude' + :return: True if value was set, False if is was already set + :rtype: bool + """ + mode = self.ComplexMode.from_value(mode) + assert mode in self.supportedComplexModes() + + if mode != self.__complex_mode: + self.__complex_mode = mode + self._updated(ItemChangedType.COMPLEX_MODE) + return True + else: + return False + + def _convertComplexData(self, data, mode=None): + """Convert complex data to the specific mode. + + :param Union[ComplexMode,None] mode: + The kind of value to compute. + If None (the default), the current complex mode is used. + :return: The converted dataset + :rtype: Union[numpy.ndarray[float],None] + """ + if data is None: + return None + + if mode is None: + mode = self.getComplexMode() + + if mode is self.ComplexMode.REAL: + return numpy.real(data) + elif mode is self.ComplexMode.IMAGINARY: + return numpy.imag(data) + elif mode is self.ComplexMode.ABSOLUTE: + return numpy.absolute(data) + elif mode is self.ComplexMode.PHASE: + return numpy.angle(data) + elif mode is self.ComplexMode.SQUARE_AMPLITUDE: + return numpy.absolute(data) ** 2 + else: + raise ValueError('Unsupported conversion mode: %s', str(mode)) + + @classmethod + def supportedComplexModes(cls): + """Returns the list of supported complex visualization modes. + + See :class:`ComplexMode` and :meth:`setComplexMode`. + + :rtype: List[ComplexMode] + """ + if cls._SUPPORTED_COMPLEX_MODES is None: + return cls.ComplexMode.members() + else: + return cls._SUPPORTED_COMPLEX_MODES + + +class ScatterVisualizationMixIn(ItemMixInBase): + """Mix-in class for scatter plot visualization modes""" + + _SUPPORTED_SCATTER_VISUALIZATION = None + """Allows to override supported Visualizations""" + + @enum.unique + class Visualization(_Enum): + """Different modes of scatter plot visualizations""" + + POINTS = 'points' + """Display scatter plot as a point cloud""" + + LINES = 'lines' + """Display scatter plot as a wireframe. + + This is based on Delaunay triangulation + """ + + SOLID = 'solid' + """Display scatter plot as a set of filled triangles. + + This is based on Delaunay triangulation + """ + + REGULAR_GRID = 'regular_grid' + """Display scatter plot as an image. + + It expects the points to be the intersection of a regular grid, + and the order of points following that of an image. + First line, then second one, and always in the same direction + (either all lines from left to right or all from right to left). + """ + + IRREGULAR_GRID = 'irregular_grid' + """Display scatter plot as contiguous quadrilaterals. + + It expects the points to be the intersection of an irregular grid, + and the order of points following that of an image. + First line, then second one, and always in the same direction + (either all lines from left to right or all from right to left). + """ + + BINNED_STATISTIC = 'binned_statistic' + """Display scatter plot as 2D binned statistic (i.e., generalized histogram). + """ + + @enum.unique + class VisualizationParameter(_Enum): + """Different parameter names for scatter plot visualizations""" + + GRID_MAJOR_ORDER = 'grid_major_order' + """The major order of points in the regular grid. + + Either 'row' (row-major, fast X) or 'column' (column-major, fast Y). + """ + + GRID_BOUNDS = 'grid_bounds' + """The expected range in data coordinates of the regular grid. + + A 2-tuple of 2-tuple: (begin (x, y), end (x, y)). + This provides the data coordinates of the first point and the expected + last on. + As for `GRID_SHAPE`, this can be wider than the current data. + """ + + GRID_SHAPE = 'grid_shape' + """The expected size of the regular grid (height, width). + + The given shape can be wider than the number of points, + in which case the grid is not fully filled. + """ + + BINNED_STATISTIC_SHAPE = 'binned_statistic_shape' + """The number of bins in each dimension (height, width). + """ + + BINNED_STATISTIC_FUNCTION = 'binned_statistic_function' + """The reduction function to apply to each bin (str). + + Available reduction functions are: 'mean' (default), 'count', 'sum'. + """ + + DATA_BOUNDS_HINT = 'data_bounds_hint' + """The expected bounds of the data in data coordinates. + + A 2-tuple of 2-tuple: ((ymin, ymax), (xmin, xmax)). + This provides a hint for the data ranges in both dimensions. + It is eventually enlarged with actually data ranges. + + WARNING: dimension 0 i.e., Y first. + """ + + _SUPPORTED_VISUALIZATION_PARAMETER_VALUES = { + VisualizationParameter.GRID_MAJOR_ORDER: ('row', 'column'), + VisualizationParameter.BINNED_STATISTIC_FUNCTION: ('mean', 'count', 'sum'), + } + """Supported visualization parameter values. + + Defined for parameters with a set of acceptable values. + """ + + def __init__(self): + self.__visualization = self.Visualization.POINTS + self.__parameters = dict(# Init parameters to None + (parameter, None) for parameter in self.VisualizationParameter) + self.__parameters[self.VisualizationParameter.BINNED_STATISTIC_FUNCTION] = 'mean' + + @classmethod + def supportedVisualizations(cls): + """Returns the list of supported scatter visualization modes. + + See :meth:`setVisualization` + + :rtype: List[Visualization] + """ + if cls._SUPPORTED_SCATTER_VISUALIZATION is None: + return cls.Visualization.members() + else: + return cls._SUPPORTED_SCATTER_VISUALIZATION + + @classmethod + def supportedVisualizationParameterValues(cls, parameter): + """Returns the list of supported scatter visualization modes. + + See :meth:`VisualizationParameters` + + :param VisualizationParameter parameter: + This parameter for which to retrieve the supported values. + :returns: tuple of supported of values or None if not defined. + """ + parameter = cls.VisualizationParameter(parameter) + return cls._SUPPORTED_VISUALIZATION_PARAMETER_VALUES.get( + parameter, None) + + def setVisualization(self, mode): + """Set the scatter plot visualization mode to use. + + See :class:`Visualization` for all possible values, + and :meth:`supportedVisualizations` for supported ones. + + :param Union[str,Visualization] mode: + The visualization mode to use. + :return: True if value was set, False if is was already set + :rtype: bool + """ + mode = self.Visualization.from_value(mode) + assert mode in self.supportedVisualizations() + + if mode != self.__visualization: + self.__visualization = mode + + self._updated(ItemChangedType.VISUALIZATION_MODE) + return True + else: + return False + + def getVisualization(self): + """Returns the scatter plot visualization mode in use. + + :rtype: Visualization + """ + return self.__visualization + + def setVisualizationParameter(self, parameter, value=None): + """Set the given visualization parameter. + + :param Union[str,VisualizationParameter] parameter: + The name of the parameter to set + :param value: The value to use for this parameter + Set to None to automatically set the parameter + :raises ValueError: If parameter is not supported + :return: True if parameter was set, False if is was already set + :rtype: bool + :raise ValueError: If value is not supported + """ + parameter = self.VisualizationParameter.from_value(parameter) + + if self.__parameters[parameter] != value: + validValues = self.supportedVisualizationParameterValues(parameter) + if validValues is not None and value not in validValues: + raise ValueError("Unsupported parameter value: %s" % str(value)) + + self.__parameters[parameter] = value + self._updated(ItemChangedType.VISUALIZATION_MODE) + return True + return False + + def getVisualizationParameter(self, parameter): + """Returns the value of the given visualization parameter. + + This method returns the parameter as set by + :meth:`setVisualizationParameter`. + + :param parameter: The name of the parameter to retrieve + :returns: The value previously set or None if automatically set + :raises ValueError: If parameter is not supported + """ + if parameter not in self.VisualizationParameter: + raise ValueError("parameter not supported: %s", parameter) + + return self.__parameters[parameter] + + def getCurrentVisualizationParameter(self, parameter): + """Returns the current value of the given visualization parameter. + + If the parameter was set by :meth:`setVisualizationParameter` to + a value that is not None, this value is returned; + else the current value that is automatically computed is returned. + + :param parameter: The name of the parameter to retrieve + :returns: The current value (either set or automatically computed) + :raises ValueError: If parameter is not supported + """ + # Override in subclass to provide automatically computed parameters + return self.getVisualizationParameter(parameter) + + +class PointsBase(DataItem, SymbolMixIn, AlphaMixIn): + """Base class for :class:`Curve` and :class:`Scatter`""" + # note: _logFilterData must be overloaded if you overload + # getData to change its signature + + _DEFAULT_Z_LAYER = 1 + """Default overlay layer for points, + on top of images.""" + + def __init__(self): + DataItem.__init__(self) + SymbolMixIn.__init__(self) + AlphaMixIn.__init__(self) + self._x = () + self._y = () + self._xerror = None + self._yerror = None + + # Store filtered data for x > 0 and/or y > 0 + self._filteredCache = {} + self._clippedCache = {} + + # Store bounds depending on axes filtering >0: + # key is (isXPositiveFilter, isYPositiveFilter) + self._boundsCache = {} + + @staticmethod + def _logFilterError(value, error): + """Filter/convert error values if they go <= 0. + + Replace error leading to negative values by nan + + :param numpy.ndarray value: 1D array of values + :param numpy.ndarray error: + Array of errors: scalar, N, Nx1 or 2xN or None. + :return: Filtered error so error bars are never negative + """ + if error is not None: + # Convert Nx1 to N + if error.ndim == 2 and error.shape[1] == 1 and len(value) != 1: + error = numpy.ravel(error) + + # Supports error being scalar, N or 2xN array + valueMinusError = value - numpy.atleast_2d(error)[0] + errorClipped = numpy.isnan(valueMinusError) + mask = numpy.logical_not(errorClipped) + errorClipped[mask] = valueMinusError[mask] <= 0 + + if numpy.any(errorClipped): # Need filtering + + # expand errorbars to 2xN + if error.size == 1: # Scalar + error = numpy.full( + (2, len(value)), error, dtype=numpy.float64) + + elif error.ndim == 1: # N array + newError = numpy.empty((2, len(value)), + dtype=numpy.float64) + newError[0,:] = error + newError[1,:] = error + error = newError + + elif error.size == 2 * len(value): # 2xN array + error = numpy.array( + error, copy=True, dtype=numpy.float64) + + else: + _logger.error("Unhandled error array") + return error + + error[0, errorClipped] = numpy.nan + + return error + + def _getClippingBoolArray(self, xPositive, yPositive): + """Compute a boolean array to filter out points with negative + coordinates on log axes. + + :param bool xPositive: True to filter arrays according to X coords. + :param bool yPositive: True to filter arrays according to Y coords. + :rtype: boolean numpy.ndarray + """ + assert xPositive or yPositive + if (xPositive, yPositive) not in self._clippedCache: + xclipped, yclipped = False, False + + if xPositive: + x = self.getXData(copy=False) + with numpy.errstate(invalid='ignore'): # Ignore NaN warnings + xclipped = x <= 0 + + if yPositive: + y = self.getYData(copy=False) + with numpy.errstate(invalid='ignore'): # Ignore NaN warnings + yclipped = y <= 0 + + self._clippedCache[(xPositive, yPositive)] = \ + numpy.logical_or(xclipped, yclipped) + return self._clippedCache[(xPositive, yPositive)] + + def _logFilterData(self, xPositive, yPositive): + """Filter out values with x or y <= 0 on log axes + + :param bool xPositive: True to filter arrays according to X coords. + :param bool yPositive: True to filter arrays according to Y coords. + :return: The filter arrays or unchanged object if filtering not needed + :rtype: (x, y, xerror, yerror) + """ + x = self.getXData(copy=False) + y = self.getYData(copy=False) + xerror = self.getXErrorData(copy=False) + yerror = self.getYErrorData(copy=False) + + if xPositive or yPositive: + clipped = self._getClippingBoolArray(xPositive, yPositive) + + if numpy.any(clipped): + # copy to keep original array and convert to float + x = numpy.array(x, copy=True, dtype=numpy.float64) + x[clipped] = numpy.nan + y = numpy.array(y, copy=True, dtype=numpy.float64) + y[clipped] = numpy.nan + + if xPositive and xerror is not None: + xerror = self._logFilterError(x, xerror) + + if yPositive and yerror is not None: + yerror = self._logFilterError(y, yerror) + + return x, y, xerror, yerror + + def _getBounds(self): + if self.getXData(copy=False).size == 0: # Empty data + return None + + plot = self.getPlot() + if plot is not None: + xPositive = plot.getXAxis()._isLogarithmic() + yPositive = plot.getYAxis()._isLogarithmic() + else: + xPositive = False + yPositive = False + + # TODO bounds do not take error bars into account + if (xPositive, yPositive) not in self._boundsCache: + # use the getData class method because instance method can be + # overloaded to return additional arrays + data = PointsBase.getData(self, copy=False, displayed=True) + if len(data) == 5: + # hack to avoid duplicating caching mechanism in Scatter + # (happens when cached data is used, caching done using + # Scatter._logFilterData) + x, y, _xerror, _yerror = data[0], data[1], data[3], data[4] + else: + x, y, _xerror, _yerror = data + + xmin, xmax = min_max(x, finite=True) + ymin, ymax = min_max(y, finite=True) + self._boundsCache[(xPositive, yPositive)] = tuple([ + (bound if bound is not None else numpy.nan) + for bound in (xmin, xmax, ymin, ymax)]) + return self._boundsCache[(xPositive, yPositive)] + + def _getCachedData(self): + """Return cached filtered data if applicable, + i.e. if any axis is in log scale. + Return None if caching is not applicable.""" + plot = self.getPlot() + if plot is not None: + xPositive = plot.getXAxis()._isLogarithmic() + yPositive = plot.getYAxis()._isLogarithmic() + if xPositive or yPositive: + # At least one axis has log scale, filter data + if (xPositive, yPositive) not in self._filteredCache: + self._filteredCache[(xPositive, yPositive)] = \ + self._logFilterData(xPositive, yPositive) + return self._filteredCache[(xPositive, yPositive)] + return None + + def getData(self, copy=True, displayed=False): + """Returns the x, y values of the curve points and xerror, yerror + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :param bool displayed: True to only get curve points that are displayed + in the plot. Default: False + Note: If plot has log scale, negative points + are not displayed. + :returns: (x, y, xerror, yerror) + :rtype: 4-tuple of numpy.ndarray + """ + if displayed: # filter data according to plot state + cached_data = self._getCachedData() + if cached_data is not None: + return cached_data + + return (self.getXData(copy), + self.getYData(copy), + self.getXErrorData(copy), + self.getYErrorData(copy)) + + def getXData(self, copy=True): + """Returns the x coordinates of the data points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._x, copy=copy) + + def getYData(self, copy=True): + """Returns the y coordinates of the data points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._y, copy=copy) + + def getXErrorData(self, copy=True): + """Returns the x error of the points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray, float or None + """ + if isinstance(self._xerror, numpy.ndarray): + return numpy.array(self._xerror, copy=copy) + else: + return self._xerror # float or None + + def getYErrorData(self, copy=True): + """Returns the y error of the points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray, float or None + """ + if isinstance(self._yerror, numpy.ndarray): + return numpy.array(self._yerror, copy=copy) + else: + return self._yerror # float or None + + def setData(self, x, y, xerror=None, yerror=None, copy=True): + """Set the data of the curve. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + :param numpy.ndarray y: The data corresponding to the y coordinates. + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values. + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + x = numpy.array(x, copy=copy) + y = numpy.array(y, copy=copy) + assert len(x) == len(y) + assert x.ndim == y.ndim == 1 + + # Convert complex data + if numpy.iscomplexobj(x): + _logger.warning( + 'Converting x data to absolute value to plot it.') + x = numpy.absolute(x) + if numpy.iscomplexobj(y): + _logger.warning( + 'Converting y data to absolute value to plot it.') + y = numpy.absolute(y) + + if xerror is not None: + if isinstance(xerror, abc.Iterable): + xerror = numpy.array(xerror, copy=copy) + if numpy.iscomplexobj(xerror): + _logger.warning( + 'Converting xerror data to absolute value to plot it.') + xerror = numpy.absolute(xerror) + else: + xerror = float(xerror) + if yerror is not None: + if isinstance(yerror, abc.Iterable): + yerror = numpy.array(yerror, copy=copy) + if numpy.iscomplexobj(yerror): + _logger.warning( + 'Converting yerror data to absolute value to plot it.') + yerror = numpy.absolute(yerror) + else: + yerror = float(yerror) + # TODO checks on xerror, yerror + self._x, self._y = x, y + self._xerror, self._yerror = xerror, yerror + + self._boundsCache = {} # Reset cached bounds + self._filteredCache = {} # Reset cached filtered data + self._clippedCache = {} # Reset cached clipped bool array + + self._boundsChanged() + self._updated(ItemChangedType.DATA) + + +class BaselineMixIn(object): + """Base class for Baseline mix-in""" + + def __init__(self, baseline=None): + self._baseline = baseline + + def _setBaseline(self, baseline): + """ + Set baseline value + + :param baseline: baseline value(s) + :type: Union[None,float,numpy.ndarray] + """ + if (isinstance(baseline, abc.Iterable)): + baseline = numpy.array(baseline) + self._baseline = baseline + + def getBaseline(self, copy=True): + """ + + :param bool copy: + :return: histogram baseline + :rtype: Union[None,float,numpy.ndarray] + """ + if isinstance(self._baseline, numpy.ndarray): + return numpy.array(self._baseline, copy=True) + else: + return self._baseline + + +class _Style: + """Object which store styles""" + + +class HighlightedMixIn(ItemMixInBase): + + def __init__(self): + self._highlightStyle = self._DEFAULT_HIGHLIGHT_STYLE + self._highlighted = False + + def isHighlighted(self): + """Returns True if curve is highlighted. + + :rtype: bool + """ + return self._highlighted + + def setHighlighted(self, highlighted): + """Set the highlight state of the curve + + :param bool highlighted: + """ + highlighted = bool(highlighted) + if highlighted != self._highlighted: + self._highlighted = highlighted + # TODO inefficient: better to use backend's setCurveColor + self._updated(ItemChangedType.HIGHLIGHTED) + + def getHighlightedStyle(self): + """Returns the highlighted style in use + + :rtype: CurveStyle + """ + return self._highlightStyle + + def setHighlightedStyle(self, style): + """Set the style to use for highlighting + + :param CurveStyle style: New style to use + """ + previous = self.getHighlightedStyle() + if style != previous: + assert isinstance(style, _Style) + self._highlightStyle = style + self._updated(ItemChangedType.HIGHLIGHTED_STYLE) + + # Backward compatibility event + if previous.getColor() != style.getColor(): + self._updated(ItemChangedType.HIGHLIGHTED_COLOR) diff --git a/src/silx/gui/plot/items/curve.py b/src/silx/gui/plot/items/curve.py new file mode 100644 index 0000000..7cbe26e --- /dev/null +++ b/src/silx/gui/plot/items/curve.py @@ -0,0 +1,325 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`Curve` item of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "24/04/2018" + + +import logging + +import numpy + +from ....utils.deprecation import deprecated +from ... import colors +from .core import (PointsBase, LabelsMixIn, ColorMixIn, YAxisMixIn, + FillMixIn, LineMixIn, SymbolMixIn, ItemChangedType, + BaselineMixIn, HighlightedMixIn, _Style) + + +_logger = logging.getLogger(__name__) + + +class CurveStyle(_Style): + """Object storing the style of a curve. + + Set a value to None to use the default + + :param color: Color + :param Union[str,None] linestyle: Style of the line + :param Union[float,None] linewidth: Width of the line + :param Union[str,None] symbol: Symbol for markers + :param Union[float,None] symbolsize: Size of the markers + """ + + def __init__(self, color=None, linestyle=None, linewidth=None, + symbol=None, symbolsize=None): + if color is None: + self._color = None + else: + if isinstance(color, str): + color = colors.rgba(color) + else: # array-like expected + color = numpy.array(color, copy=False) + if color.ndim == 1: # Array is 1D, this is a single color + color = colors.rgba(color) + self._color = color + + if linestyle is not None: + assert linestyle in LineMixIn.getSupportedLineStyles() + self._linestyle = linestyle + + self._linewidth = None if linewidth is None else float(linewidth) + + if symbol is not None: + assert symbol in SymbolMixIn.getSupportedSymbols() + self._symbol = symbol + + self._symbolsize = None if symbolsize is None else float(symbolsize) + + def getColor(self, copy=True): + """Returns the color or None if not set. + + :param bool copy: True to get a copy (default), + False to get internal representation (do not modify!) + + :rtype: Union[List[float],None] + """ + if isinstance(self._color, numpy.ndarray): + return numpy.array(self._color, copy=copy) + else: + return self._color + + def getLineStyle(self): + """Return the type of the line or None if not set. + + Type of line:: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + + :rtype: Union[str,None] + """ + return self._linestyle + + def getLineWidth(self): + """Return the curve line width in pixels or None if not set. + + :rtype: Union[float,None] + """ + return self._linewidth + + def getSymbol(self): + """Return the point marker type. + + Marker type:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + + :rtype: Union[str,None] + """ + return self._symbol + + def getSymbolSize(self): + """Return the point marker size in points. + + :rtype: Union[float,None] + """ + return self._symbolsize + + def __eq__(self, other): + if isinstance(other, CurveStyle): + return (numpy.array_equal(self.getColor(), other.getColor()) and + self.getLineStyle() == other.getLineStyle() and + self.getLineWidth() == other.getLineWidth() and + self.getSymbol() == other.getSymbol() and + self.getSymbolSize() == other.getSymbolSize()) + else: + return False + + +class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, + LineMixIn, BaselineMixIn, HighlightedMixIn): + """Description of a curve""" + + _DEFAULT_Z_LAYER = 1 + """Default overlay layer for curves""" + + _DEFAULT_SELECTABLE = True + """Default selectable state for curves""" + + _DEFAULT_LINEWIDTH = 1. + """Default line width of the curve""" + + _DEFAULT_LINESTYLE = '-' + """Default line style of the curve""" + + _DEFAULT_HIGHLIGHT_STYLE = CurveStyle(color='black') + """Default highlight style of the item""" + + _DEFAULT_BASELINE = None + + def __init__(self): + PointsBase.__init__(self) + ColorMixIn.__init__(self) + YAxisMixIn.__init__(self) + FillMixIn.__init__(self) + LabelsMixIn.__init__(self) + LineMixIn.__init__(self) + BaselineMixIn.__init__(self) + HighlightedMixIn.__init__(self) + + self._setBaseline(Curve._DEFAULT_BASELINE) + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + # Filter-out values <= 0 + xFiltered, yFiltered, xerror, yerror = self.getData( + copy=False, displayed=True) + + if len(xFiltered) == 0 or not numpy.any(numpy.isfinite(xFiltered)): + return None # No data to display, do not add renderer to backend + + style = self.getCurrentStyle() + + return backend.addCurve(xFiltered, yFiltered, + color=style.getColor(), + symbol=style.getSymbol(), + linestyle=style.getLineStyle(), + linewidth=style.getLineWidth(), + yaxis=self.getYAxis(), + xerror=xerror, + yerror=yerror, + fill=self.isFill(), + alpha=self.getAlpha(), + symbolsize=style.getSymbolSize(), + baseline=self.getBaseline(copy=False)) + + def __getitem__(self, item): + """Compatibility with PyMca and silx <= 0.4.0""" + if isinstance(item, slice): + return [self[index] for index in range(*item.indices(5))] + elif item == 0: + return self.getXData(copy=False) + elif item == 1: + return self.getYData(copy=False) + elif item == 2: + return self.getName() + elif item == 3: + info = self.getInfo(copy=False) + return {} if info is None else info + elif item == 4: + params = { + 'info': self.getInfo(), + 'color': self.getColor(), + 'symbol': self.getSymbol(), + 'linewidth': self.getLineWidth(), + 'linestyle': self.getLineStyle(), + 'xlabel': self.getXLabel(), + 'ylabel': self.getYLabel(), + 'yaxis': self.getYAxis(), + 'xerror': self.getXErrorData(copy=False), + 'yerror': self.getYErrorData(copy=False), + 'z': self.getZValue(), + 'selectable': self.isSelectable(), + 'fill': self.isFill(), + } + return params + else: + raise IndexError("Index out of range: %s", str(item)) + + @deprecated(replacement='Curve.getHighlightedStyle().getColor()', + since_version='0.9.0') + def getHighlightedColor(self): + """Returns the RGBA highlight color of the item + + :rtype: 4-tuple of float in [0, 1] + """ + return self.getHighlightedStyle().getColor() + + @deprecated(replacement='Curve.setHighlightedStyle()', + since_version='0.9.0') + def setHighlightedColor(self, color): + """Set the color to use when highlighted + + :param color: color(s) to be used for highlight + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in colors.py + """ + self.setHighlightedStyle(CurveStyle(color)) + + def getCurrentStyle(self): + """Returns the current curve style. + + Curve style depends on curve highlighting + + :rtype: CurveStyle + """ + if self.isHighlighted(): + style = self.getHighlightedStyle() + color = style.getColor() + linestyle = style.getLineStyle() + linewidth = style.getLineWidth() + symbol = style.getSymbol() + symbolsize = style.getSymbolSize() + + return CurveStyle( + color=self.getColor() if color is None else color, + linestyle=self.getLineStyle() if linestyle is None else linestyle, + linewidth=self.getLineWidth() if linewidth is None else linewidth, + symbol=self.getSymbol() if symbol is None else symbol, + symbolsize=self.getSymbolSize() if symbolsize is None else symbolsize) + + else: + return CurveStyle(color=self.getColor(), + linestyle=self.getLineStyle(), + linewidth=self.getLineWidth(), + symbol=self.getSymbol(), + symbolsize=self.getSymbolSize()) + + @deprecated(replacement='Curve.getCurrentStyle()', + since_version='0.9.0') + def getCurrentColor(self): + """Returns the current color of the curve. + + This color is either the color of the curve or the highlighted color, + depending on the highlight state. + + :rtype: 4-tuple of float in [0, 1] + """ + return self.getCurrentStyle().getColor() + + def setData(self, x, y, xerror=None, yerror=None, baseline=None, copy=True): + """Set the data of the curve. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + :param numpy.ndarray y: The data corresponding to the y coordinates. + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values. + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param baseline: curve baseline + :type baseline: Union[None,float,numpy.ndarray] + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + PointsBase.setData(self, x=x, y=y, xerror=xerror, yerror=yerror, + copy=copy) + self._setBaseline(baseline=baseline) diff --git a/src/silx/gui/plot/items/histogram.py b/src/silx/gui/plot/items/histogram.py new file mode 100644 index 0000000..16bbefa --- /dev/null +++ b/src/silx/gui/plot/items/histogram.py @@ -0,0 +1,389 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions::t +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`Histogram` item of the :class:`Plot`. +""" + +__authors__ = ["H. Payno", "T. Vincent"] +__license__ = "MIT" +__date__ = "28/08/2018" + +import logging +import typing + +import numpy +from collections import OrderedDict, namedtuple +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc + +from ....utils.proxy import docstring +from .core import (DataItem, AlphaMixIn, BaselineMixIn, ColorMixIn, FillMixIn, + LineMixIn, YAxisMixIn, ItemChangedType, Item) +from ._pick import PickingResult + +_logger = logging.getLogger(__name__) + + +def _computeEdges(x, histogramType): + """Compute the edges from a set of xs and a rule to generate the edges + + :param x: the x value of the curve to transform into an histogram + :param histogramType: the type of histogram we wan't to generate. + This define the way to center the histogram values compared to the + curve value. Possible values can be:: + + - 'left' + - 'right' + - 'center' + + :return: the edges for the given x and the histogramType + """ + # for now we consider that the spaces between xs are constant + edges = x.copy() + if histogramType == 'left': + width = 1 + if len(x) > 1: + width = x[1] - x[0] + edges = numpy.append(x[0] - width, edges) + if histogramType == 'center': + edges = _computeEdges(edges, 'right') + widths = (edges[1:] - edges[0:-1]) / 2.0 + widths = numpy.append(widths, widths[-1]) + edges = edges - widths + if histogramType == 'right': + width = 1 + if len(x) > 1: + width = x[-1] - x[-2] + edges = numpy.append(edges, x[-1] + width) + + return edges + + +def _getHistogramCurve(histogram, edges): + """Returns the x and y value of a curve corresponding to the histogram + + :param numpy.ndarray histogram: The values of the histogram + :param numpy.ndarray edges: The bin edges of the histogram + :return: a tuple(x, y) which contains the value of the curve to use + to display the histogram + """ + assert len(histogram) + 1 == len(edges) + x = numpy.empty(len(histogram) * 2, dtype=edges.dtype) + y = numpy.empty(len(histogram) * 2, dtype=histogram.dtype) + # Make a curve with stairs + x[:-1:2] = edges[:-1] + x[1::2] = edges[1:] + y[:-1:2] = histogram + y[1::2] = histogram + + return x, y + + +# TODO: Yerror, test log scale +class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn, + LineMixIn, YAxisMixIn, BaselineMixIn): + """Description of an histogram""" + + _DEFAULT_Z_LAYER = 1 + """Default overlay layer for histograms""" + + _DEFAULT_SELECTABLE = False + """Default selectable state for histograms""" + + _DEFAULT_LINEWIDTH = 1. + """Default line width of the histogram""" + + _DEFAULT_LINESTYLE = '-' + """Default line style of the histogram""" + + _DEFAULT_BASELINE = None + + def __init__(self): + DataItem.__init__(self) + AlphaMixIn.__init__(self) + BaselineMixIn.__init__(self) + ColorMixIn.__init__(self) + FillMixIn.__init__(self) + LineMixIn.__init__(self) + YAxisMixIn.__init__(self) + + self._histogram = () + self._edges = () + self._setBaseline(Histogram._DEFAULT_BASELINE) + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + values, edges, baseline = self.getData(copy=False) + + if values.size == 0: + return None # No data to display, do not add renderer + + if values.size == 0: + return None # No data to display, do not add renderer to backend + + x, y = _getHistogramCurve(values, edges) + + # Filter-out values <= 0 + plot = self.getPlot() + if plot is not None: + xPositive = plot.getXAxis()._isLogarithmic() + yPositive = plot.getYAxis()._isLogarithmic() + else: + xPositive = False + yPositive = False + + if xPositive or yPositive: + clipped = numpy.logical_or( + (x <= 0) if xPositive else False, + (y <= 0) if yPositive else False) + # Make a copy and replace negative points by NaN + x = numpy.array(x, dtype=numpy.float64) + y = numpy.array(y, dtype=numpy.float64) + x[clipped] = numpy.nan + y[clipped] = numpy.nan + + return backend.addCurve(x, y, + color=self.getColor(), + symbol='', + linestyle=self.getLineStyle(), + linewidth=self.getLineWidth(), + yaxis=self.getYAxis(), + xerror=None, + yerror=None, + fill=self.isFill(), + alpha=self.getAlpha(), + baseline=baseline, + symbolsize=1) + + def _getBounds(self): + values, edges, baseline = self.getData(copy=False) + + plot = self.getPlot() + if plot is not None: + xPositive = plot.getXAxis()._isLogarithmic() + yPositive = plot.getYAxis()._isLogarithmic() + else: + xPositive = False + yPositive = False + + if xPositive or yPositive: + values = numpy.array(values, copy=True, dtype=numpy.float64) + + if xPositive: + # Replace edges <= 0 by NaN and corresponding values by NaN + clipped_edges = (edges <= 0) + edges = numpy.array(edges, copy=True, dtype=numpy.float64) + edges[clipped_edges] = numpy.nan + clipped_values = numpy.logical_or(clipped_edges[:-1], + clipped_edges[1:]) + else: + clipped_values = numpy.zeros_like(values, dtype=bool) + + if yPositive: + # Replace values <= 0 by NaN, do not modify edges + clipped_values = numpy.logical_or(clipped_values, values <= 0) + + values[clipped_values] = numpy.nan + + if yPositive: + return (numpy.nanmin(edges), + numpy.nanmax(edges), + numpy.nanmin(values), + numpy.nanmax(values)) + + else: # No log scale on y axis, include 0 in bounds + if numpy.all(numpy.isnan(values)): + return None + return (numpy.nanmin(edges), + numpy.nanmax(edges), + min(0, numpy.nanmin(values)), + max(0, numpy.nanmax(values))) + + def __pickFilledHistogram(self, x: float, y: float) -> typing.Optional[PickingResult]: + """Picking implementation for filled histogram + + :param x: X position in pixels + :param y: Y position in pixels + """ + if not self.isFill(): + return None + + plot = self.getPlot() + if plot is None: + return None + + xData, yData = plot.pixelToData(x, y, axis=self.getYAxis()) + xmin, xmax, ymin, ymax = self.getBounds() + if not xmin < xData < xmax or not ymin < yData < ymax: + return None # Outside bounding box + + # Check x + edges = self.getBinEdgesData(copy=False) + index = numpy.searchsorted(edges, (xData,), side='left')[0] - 1 + # Safe indexing in histogram values + index = numpy.clip(index, 0, len(edges) - 2) + + # Check y + baseline = self.getBaseline(copy=False) + if baseline is None: + baseline = 0 # Default value + + value = self.getValueData(copy=False)[index] + if ((baseline <= value and baseline <= yData <= value) or + (value < baseline and value <= yData <= baseline)): + return PickingResult(self, numpy.array([index])) + else: + return None + + @docstring(DataItem) + def pick(self, x, y): + if self.isFill(): + return self.__pickFilledHistogram(x, y) + else: + result = super().pick(x, y) + if result is None: + return None + else: # Convert from curve indices to histogram indices + return PickingResult(self, numpy.unique(result.getIndices() // 2)) + + def getValueData(self, copy=True): + """The values of the histogram + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :returns: The values of the histogram + :rtype: numpy.ndarray + """ + return numpy.array(self._histogram, copy=copy) + + def getBinEdgesData(self, copy=True): + """The bin edges of the histogram (number of histogram values + 1) + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :returns: The bin edges of the histogram + :rtype: numpy.ndarray + """ + return numpy.array(self._edges, copy=copy) + + def getData(self, copy=True): + """Return the histogram values, bin edges and baseline + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :returns: (N histogram value, N+1 bin edges) + :rtype: 2-tuple of numpy.nadarray + """ + return (self.getValueData(copy), + self.getBinEdgesData(copy), + self.getBaseline(copy)) + + def setData(self, histogram, edges, align='center', baseline=None, + copy=True): + """Set the histogram values and bin edges. + + :param numpy.ndarray histogram: The values of the histogram. + :param numpy.ndarray edges: + The bin edges of the histogram. + If histogram and edges have the same length, the bin edges + are computed according to the align parameter. + :param str align: + In case histogram values and edges have the same length N, + the N+1 bin edges are computed according to the alignment in: + 'center' (default), 'left', 'right'. + :param baseline: histogram baseline + :type baseline: Union[None,float,numpy.ndarray] + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + histogram = numpy.array(histogram, copy=copy) + edges = numpy.array(edges, copy=copy) + + assert histogram.ndim == 1 + assert edges.ndim == 1 + assert edges.size in (histogram.size, histogram.size + 1) + assert align in ('center', 'left', 'right') + + if histogram.size == 0: # No data + self._histogram = () + self._edges = () + else: + if edges.size == histogram.size: # Compute true bin edges + edges = _computeEdges(edges, align) + + # Check that bin edges are monotonic + edgesDiff = numpy.diff(edges) + edgesDiff = edgesDiff[numpy.logical_not(numpy.isnan(edgesDiff))] + assert numpy.all(edgesDiff >= 0) or numpy.all(edgesDiff <= 0) + # manage baseline + if (isinstance(baseline, abc.Iterable)): + baseline = numpy.array(baseline) + if baseline.size == histogram.size: + new_baseline = numpy.empty(baseline.shape[0] * 2) + for i_value, value in enumerate(baseline): + new_baseline[i_value*2:i_value*2+2] = value + baseline = new_baseline + self._histogram = histogram + self._edges = edges + self._alignement = align + self._setBaseline(baseline) + + self._boundsChanged() + self._updated(ItemChangedType.DATA) + + def getAlignment(self): + """ + + :return: histogram alignement. Value in ('center', 'left', 'right'). + """ + return self._alignement + + def _revertComputeEdges(self, x, histogramType): + """Compute the edges from a set of xs and a rule to generate the edges + + :param x: the x value of the curve to transform into an histogram + :param histogramType: the type of histogram we wan't to generate. + This define the way to center the histogram values compared to the + curve value. Possible values can be:: + + - 'left' + - 'right' + - 'center' + + :return: the edges for the given x and the histogramType + """ + # for now we consider that the spaces between xs are constant + edges = x.copy() + if histogramType == 'left': + return edges[1:] + if histogramType == 'center': + edges = (edges[1:] + edges[:-1]) / 2.0 + if histogramType == 'right': + width = 1 + if len(x) > 1: + width = x[-1] + x[-2] + edges = edges[:-1] + return edges diff --git a/src/silx/gui/plot/items/image.py b/src/silx/gui/plot/items/image.py new file mode 100644 index 0000000..5cc719b --- /dev/null +++ b/src/silx/gui/plot/items/image.py @@ -0,0 +1,641 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`ImageData` and :class:`ImageRgba` items +of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "08/12/2020" + +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc +import logging + +import numpy + +from ....utils.proxy import docstring +from .core import (DataItem, LabelsMixIn, DraggableMixIn, ColormapMixIn, + AlphaMixIn, ItemChangedType) + +_logger = logging.getLogger(__name__) + + +def _convertImageToRgba32(image, copy=True): + """Convert an RGB or RGBA image to RGBA32. + + It converts from floats in [0, 1], bool, integer and uint in [0, 255] + + If the input image is already an RGBA32 image, + the returned image shares the same data. + + :param image: Image to convert to + :type image: numpy.ndarray with 3 dimensions: height, width, color channels + :param bool copy: True (Default) to get a copy, False, avoid copy if possible + :return: The image converted to RGBA32 with dimension: (height, width, 4) + :rtype: numpy.ndarray of uint8 + """ + assert image.ndim == 3 + assert image.shape[-1] in (3, 4) + + # Convert type to uint8 + if image.dtype.name != 'uint8': + if image.dtype.kind == 'f': # Float in [0, 1] + image = (numpy.clip(image, 0., 1.) * 255).astype(numpy.uint8) + elif image.dtype.kind == 'b': # boolean + image = image.astype(numpy.uint8) * 255 + elif image.dtype.kind in ('i', 'u'): # int, uint + image = numpy.clip(image, 0, 255).astype(numpy.uint8) + else: + raise ValueError('Unsupported image dtype: %s', image.dtype.name) + copy = False # A copy as already been done, avoid next one + + # Convert RGB to RGBA + if image.shape[-1] == 3: + new_image = numpy.empty((image.shape[0], image.shape[1], 4), + dtype=numpy.uint8) + new_image[:,:,:3] = image + new_image[:,:, 3] = 255 + return new_image # This is a copy anyway + else: + return numpy.array(image, copy=copy) + + +class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn): + """Description of an image + + :param numpy.ndarray data: Initial image data + """ + + def __init__(self, data=None, mask=None): + DataItem.__init__(self) + LabelsMixIn.__init__(self) + DraggableMixIn.__init__(self) + AlphaMixIn.__init__(self) + if data is None: + data = numpy.zeros((0, 0, 4), dtype=numpy.uint8) + self._data = data + self._mask = mask + self.__valueDataCache = None # Store default data + self._origin = (0., 0.) + self._scale = (1., 1.) + + def __getitem__(self, item): + """Compatibility with PyMca and silx <= 0.4.0""" + if isinstance(item, slice): + return [self[index] for index in range(*item.indices(5))] + elif item == 0: + return self.getData(copy=False) + elif item == 1: + return self.getName() + elif item == 2: + info = self.getInfo(copy=False) + return {} if info is None else info + elif item == 3: + return None + elif item == 4: + params = { + 'info': self.getInfo(), + 'origin': self.getOrigin(), + 'scale': self.getScale(), + 'z': self.getZValue(), + 'selectable': self.isSelectable(), + 'draggable': self.isDraggable(), + 'colormap': None, + 'xlabel': self.getXLabel(), + 'ylabel': self.getYLabel(), + } + return params + else: + raise IndexError("Index out of range: %s" % str(item)) + + def _isPlotLinear(self, plot): + """Return True if plot only uses linear scale for both of x and y + axes.""" + linear = plot.getXAxis().LINEAR + if plot.getXAxis().getScale() != linear: + return False + if plot.getYAxis().getScale() != linear: + return False + return True + + def _getBounds(self): + if self.getData(copy=False).size == 0: # Empty data + return None + + height, width = self.getData(copy=False).shape[:2] + origin = self.getOrigin() + scale = self.getScale() + # Taking care of scale might be < 0 + xmin, xmax = origin[0], origin[0] + width * scale[0] + if xmin > xmax: + xmin, xmax = xmax, xmin + # Taking care of scale might be < 0 + ymin, ymax = origin[1], origin[1] + height * scale[1] + if ymin > ymax: + ymin, ymax = ymax, ymin + + plot = self.getPlot() + if plot is not None and not self._isPlotLinear(plot): + return None + else: + return xmin, xmax, ymin, ymax + + @docstring(DraggableMixIn) + def drag(self, from_, to): + origin = self.getOrigin() + self.setOrigin((origin[0] + to[0] - from_[0], + origin[1] + to[1] - from_[1])) + + def getData(self, copy=True): + """Returns the image data + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._data, copy=copy) + + def setData(self, data): + """Set the image data + + :param numpy.ndarray data: + """ + previousShape = self._data.shape + self._data = data + self._valueDataChanged() + self._boundsChanged() + self._updated(ItemChangedType.DATA) + + if (self.getMaskData(copy=False) is not None and + previousShape != self._data.shape): + # Data shape changed, so mask shape changes. + # Send event, mask is lazily updated in getMaskData + self._updated(ItemChangedType.MASK) + + def getMaskData(self, copy=True): + """Returns the mask data + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: Union[None,numpy.ndarray] + """ + if self._mask is None: + return None + + # Update mask if it does not match data shape + shape = self.getData(copy=False).shape[:2] + if self._mask.shape != shape: + # Clip/extend mask to match data + newMask = numpy.zeros(shape, dtype=self._mask.dtype) + newMask[:self._mask.shape[0], :self._mask.shape[1]] = self._mask[:shape[0], :shape[1]] + self._mask = newMask + + return numpy.array(self._mask, copy=copy) + + def setMaskData(self, mask, copy=True): + """Set the image data + + :param numpy.ndarray data: + :param bool copy: True (Default) to make a copy, + False to use as is (do not modify!) + """ + if mask is not None: + mask = numpy.array(mask, copy=copy) + + shape = self.getData(copy=False).shape[:2] + if mask.shape != shape: + _logger.warning("Inconsistent shape between mask and data %s, %s", mask.shape, shape) + # Clip/extent is done lazily in getMaskData + elif self._mask is None: + return # No update + + self._mask = mask + self._valueDataChanged() + self._updated(ItemChangedType.MASK) + + def _valueDataChanged(self): + """Clear cache of default data array""" + self.__valueDataCache = None + + def _getValueData(self, copy=True): + """Return data used by :meth:`getValueData` + + :param bool copy: + :rtype: numpy.ndarray + """ + return self.getData(copy=copy) + + def getValueData(self, copy=True): + """Return data (converted to int or float) with mask applied. + + Masked values are set to Not-A-Number. + It returns a 2D array of values (int or float). + + :param bool copy: + :rtype: numpy.ndarray + """ + if self.__valueDataCache is None: + data = self._getValueData(copy=False) + mask = self.getMaskData(copy=False) + if mask is not None: + if numpy.issubdtype(data.dtype, numpy.floating): + dtype = data.dtype + else: + dtype = numpy.float64 + data = numpy.array(data, dtype=dtype, copy=True) + data[mask != 0] = numpy.NaN + self.__valueDataCache = data + return numpy.array(self.__valueDataCache, copy=copy) + + def getRgbaImageData(self, copy=True): + """Get the displayed RGB(A) image + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :returns: numpy.ndarray of uint8 of shape (height, width, 4) + """ + raise NotImplementedError('This MUST be implemented in sub-class') + + def getOrigin(self): + """Returns the offset from origin at which to display the image. + + :rtype: 2-tuple of float + """ + return self._origin + + def setOrigin(self, origin): + """Set the offset from origin at which to display the image. + + :param origin: (ox, oy) Offset from origin + :type origin: float or 2-tuple of float + """ + if isinstance(origin, abc.Sequence): + origin = float(origin[0]), float(origin[1]) + else: # single value origin + origin = float(origin), float(origin) + if origin != self._origin: + self._origin = origin + self._boundsChanged() + self._updated(ItemChangedType.POSITION) + + def getScale(self): + """Returns the scale of the image in data coordinates. + + :rtype: 2-tuple of float + """ + return self._scale + + def setScale(self, scale): + """Set the scale of the image + + :param scale: (sx, sy) Scale of the image + :type scale: float or 2-tuple of float + """ + if isinstance(scale, abc.Sequence): + scale = float(scale[0]), float(scale[1]) + else: # single value scale + scale = float(scale), float(scale) + + if scale != self._scale: + self._scale = scale + self._boundsChanged() + self._updated(ItemChangedType.SCALE) + + +class ImageDataBase(ImageBase, ColormapMixIn): + """Base class for colormapped 2D data image""" + + def __init__(self): + ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.float32)) + ColormapMixIn.__init__(self) + + def _getColormapForRendering(self): + colormap = self.getColormap() + if colormap.isAutoscale(): + # Avoid backend to compute autoscale: use item cache + colormap = colormap.copy() + colormap.setVRange(*colormap.getColormapRange(self)) + return colormap + + def getRgbaImageData(self, copy=True): + """Get the displayed RGB(A) image + + :returns: Array of uint8 of shape (height, width, 4) + :rtype: numpy.ndarray + """ + return self.getColormap().applyToData(self) + + def setData(self, data, copy=True): + """"Set the image data + + :param numpy.ndarray data: Data array with 2 dimensions (h, w) + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + data = numpy.array(data, copy=copy) + assert data.ndim == 2 + if data.dtype.kind == 'b': + _logger.warning( + 'Converting boolean image to int8 to plot it.') + data = numpy.array(data, copy=False, dtype=numpy.int8) + elif numpy.iscomplexobj(data): + _logger.warning( + 'Converting complex image to absolute value to plot it.') + data = numpy.absolute(data) + super().setData(data) + + def _updated(self, event=None, checkVisibility=True): + # Synchronizes colormapped data if changed + if event in (ItemChangedType.DATA, ItemChangedType.MASK): + self._setColormappedData(self.getValueData(copy=False), copy=False) + super()._updated(event=event, checkVisibility=checkVisibility) + + +class ImageData(ImageDataBase): + """Description of a data image with a colormap""" + + def __init__(self): + ImageDataBase.__init__(self) + self._alternativeImage = None + self.__alpha = None + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + plot = self.getPlot() + assert plot is not None + if not self._isPlotLinear(plot): + # Do not render with non linear scales + return None + + if (self.getAlternativeImageData(copy=False) is not None or + self.getAlphaData(copy=False) is not None): + dataToUse = self.getRgbaImageData(copy=False) + else: + dataToUse = self.getData(copy=False) + + if dataToUse.size == 0: + return None # No data to display + + return backend.addImage(dataToUse, + origin=self.getOrigin(), + scale=self.getScale(), + colormap=self._getColormapForRendering(), + alpha=self.getAlpha()) + + def __getitem__(self, item): + """Compatibility with PyMca and silx <= 0.4.0""" + if item == 3: + return self.getAlternativeImageData(copy=False) + + params = ImageBase.__getitem__(self, item) + if item == 4: + params['colormap'] = self.getColormap() + + return params + + def getRgbaImageData(self, copy=True): + """Get the displayed RGB(A) image + + :returns: Array of uint8 of shape (height, width, 4) + :rtype: numpy.ndarray + """ + alternative = self.getAlternativeImageData(copy=False) + if alternative is not None: + return _convertImageToRgba32(alternative, copy=copy) + else: + image = super().getRgbaImageData(copy=copy) + alphaImage = self.getAlphaData(copy=False) + if alphaImage is not None: + # Apply transparency + image[:,:, 3] = image[:,:, 3] * alphaImage + return image + + def getAlternativeImageData(self, copy=True): + """Get the optional RGBA image that is displayed instead of the data + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: Union[None,numpy.ndarray] + """ + if self._alternativeImage is None: + return None + else: + return numpy.array(self._alternativeImage, copy=copy) + + def getAlphaData(self, copy=True): + """Get the optional transparency image applied on the data + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: Union[None,numpy.ndarray] + """ + if self.__alpha is None: + return None + else: + return numpy.array(self.__alpha, copy=copy) + + def setData(self, data, alternative=None, alpha=None, copy=True): + """"Set the image data and optionally an alternative RGB(A) representation + + :param numpy.ndarray data: Data array with 2 dimensions (h, w) + :param alternative: RGB(A) image to display instead of data, + shape: (h, w, 3 or 4) + :type alternative: Union[None,numpy.ndarray] + :param alpha: An array of transparency value in [0, 1] to use for + display with shape: (h, w) + :type alpha: Union[None,numpy.ndarray] + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + data = numpy.array(data, copy=copy) + assert data.ndim == 2 + + if alternative is not None: + alternative = numpy.array(alternative, copy=copy) + assert alternative.ndim == 3 + assert alternative.shape[2] in (3, 4) + assert alternative.shape[:2] == data.shape[:2] + self._alternativeImage = alternative + + if alpha is not None: + alpha = numpy.array(alpha, copy=copy) + assert alpha.shape == data.shape + if alpha.dtype.kind != 'f': + alpha = alpha.astype(numpy.float32) + if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)): + alpha = numpy.clip(alpha, 0., 1.) + self.__alpha = alpha + + super().setData(data) + + +class ImageRgba(ImageBase): + """Description of an RGB(A) image""" + + def __init__(self): + ImageBase.__init__(self, numpy.zeros((0, 0, 4), dtype=numpy.uint8)) + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + plot = self.getPlot() + assert plot is not None + if not self._isPlotLinear(plot): + # Do not render with non linear scales + return None + + data = self.getData(copy=False) + + if data.size == 0: + return None # No data to display + + return backend.addImage(data, + origin=self.getOrigin(), + scale=self.getScale(), + colormap=None, + alpha=self.getAlpha()) + + def getRgbaImageData(self, copy=True): + """Get the displayed RGB(A) image + + :returns: numpy.ndarray of uint8 of shape (height, width, 4) + """ + return _convertImageToRgba32(self.getData(copy=False), copy=copy) + + def setData(self, data, copy=True): + """Set the image data + + :param data: RGB(A) image data to set + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + data = numpy.array(data, copy=copy) + assert data.ndim == 3 + assert data.shape[-1] in (3, 4) + super().setData(data) + + def _getValueData(self, copy=True): + """Compute the intensity of the RGBA image as default data. + + Conversion: https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion + + :param bool copy: + """ + rgba = self.getRgbaImageData(copy=False).astype(numpy.float32) + intensity = (rgba[:, :, 0] * 0.299 + + rgba[:, :, 1] * 0.587 + + rgba[:, :, 2] * 0.114) + intensity *= rgba[:, :, 3] / 255. + return intensity + + +class MaskImageData(ImageData): + """Description of an image used as a mask. + + This class is used to flag mask items. This information is used to improve + internal silx widgets. + """ + pass + + +class ImageStack(ImageData): + """Item to store a stack of images and to show it in the plot as one + of the images of the stack. + + The stack is a 3D array ordered this way: `frame id, y, x`. + So the first image of the stack can be reached this way: `stack[0, :, :]` + """ + + def __init__(self): + ImageData.__init__(self) + self.__stack = None + """A 3D numpy array (or a mimic one, see ListOfImages)""" + self.__stackPosition = None + """Displayed position in the cube""" + + def setStackData(self, stack, position=None, copy=True): + """Set the stack data + + :param stack: A 3D numpy array like + :param int position: The position of the displayed image in the stack + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + if self.__stack is stack: + return + if copy: + stack = numpy.array(stack) + assert stack.ndim == 3 + self.__stack = stack + if position is not None: + self.__stackPosition = position + if self.__stackPosition is None: + self.__stackPosition = 0 + self.__updateDisplayedData() + + def getStackData(self, copy=True): + """Get the stored stack array. + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: A 3D numpy array, or numpy array like + """ + if copy: + return numpy.array(self.__stack) + else: + return self.__stack + + def setStackPosition(self, pos): + """Set the displayed position on the stack. + + This function will clamp the stack position according to + the real size of the first axis of the stack. + + :param int pos: A position on the first axis of the stack. + """ + if self.__stackPosition == pos: + return + self.__stackPosition = pos + self.__updateDisplayedData() + + def getStackPosition(self): + """Get the displayed position of the stack. + + :rtype: int + """ + return self.__stackPosition + + def __updateDisplayedData(self): + """Update the displayed frame whenever the stack or the stack + position are updated.""" + if self.__stack is None or self.__stackPosition is None: + empty = numpy.array([]).reshape(0, 0) + self.setData(empty, copy=False) + return + size = len(self.__stack) + self.__stackPosition = numpy.clip(self.__stackPosition, 0, size) + self.setData(self.__stack[self.__stackPosition], copy=False) diff --git a/src/silx/gui/plot/items/image_aggregated.py b/src/silx/gui/plot/items/image_aggregated.py new file mode 100644 index 0000000..75fdd59 --- /dev/null +++ b/src/silx/gui/plot/items/image_aggregated.py @@ -0,0 +1,229 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`ImageDataAggregated` items of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "07/07/2021" + +import enum +import logging +from typing import Tuple, Union + +import numpy + +from ....utils.enum import Enum as _Enum +from ....utils.proxy import docstring +from .axis import Axis +from .core import ItemChangedType +from .image import ImageDataBase +from ._pick import PickingResult + + +_logger = logging.getLogger(__name__) + + +class ImageDataAggregated(ImageDataBase): + """Item displaying an image as a density map.""" + + @enum.unique + class Aggregation(_Enum): + NONE = "none" + "Do not aggregate data, display as is (default)" + + MAX = "max" + "Aggregates elements with max (ignore NaNs)" + + MEAN = "mean" + "Aggregates elements with mean (ignore NaNs)" + + MIN = "min" + "Aggregates elements with min (ignore NaNs)" + + def __init__(self): + super().__init__() + self.__cacheLODData = {} + self.__currentLOD = 0, 0 + self.__aggregationMode = self.Aggregation.NONE + + def setAggregationMode(self, mode: Union[str,Aggregation]): + """Set the aggregation method used to reduce the data to screen resolution. + + :param Aggregation mode: The aggregation method + """ + aggregationMode = self.Aggregation.from_value(mode) + if aggregationMode != self.__aggregationMode: + self.__aggregationMode = aggregationMode + self.__cacheLODData = {} # Clear cache + self._updated(ItemChangedType.VISUALIZATION_MODE) + + def getAggregationMode(self) -> Aggregation: + """Returns the currently used aggregation method.""" + return self.__aggregationMode + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + plot = self.getPlot() + assert plot is not None + if not self._isPlotLinear(plot): + # Do not render with non linear scales + return None + + data = self.getData(copy=False) + if data.size == 0: + return None # No data to display + + aggregationMode = self.getAggregationMode() + if aggregationMode == self.Aggregation.NONE: # Pass data as it is + displayedData = data + scale = self.getScale() + + else: # Aggregate data according to level of details + if aggregationMode == self.Aggregation.MAX: + aggregator = numpy.nanmax + elif aggregationMode == self.Aggregation.MEAN: + aggregator = numpy.nanmean + elif aggregationMode == self.Aggregation.MIN: + aggregator = numpy.nanmin + else: + _logger.error("Unsupported aggregation mode") + return None + + lodx, lody = self._getLevelOfDetails() + + if (lodx, lody) not in self.__cacheLODData: + height, width = data.shape + self.__cacheLODData[(lodx, lody)] = aggregator( + data[: (height // lody) * lody, : (width // lodx) * lodx].reshape( + height // lody, lody, width // lodx, lodx + ), + axis=(1, 3), + ) + + self.__currentLOD = lodx, lody + displayedData = self.__cacheLODData[self.__currentLOD] + + sx, sy = self.getScale() + scale = sx * lodx, sy * lody + + return backend.addImage( + displayedData, + origin=self.getOrigin(), + scale=scale, + colormap=self._getColormapForRendering(), + alpha=self.getAlpha(), + ) + + def _getPixelSizeInData(self, axis="left"): + """Returns the size of a pixel in plot data coordinates + + :param str axis: Y axis to use in: 'left' (default), 'right' + :return: + Size (width, height) of a Qt pixel in data coordinates. + Size is None if it cannot be computed + :rtype: Union[List[float],None] + """ + assert axis in ("left", "right") + plot = self.getPlot() + if plot is None: + return None + + xaxis = plot.getXAxis() + yaxis = plot.getYAxis(axis) + + if ( + xaxis.getScale() != Axis.LINEAR + or yaxis.getScale() != Axis.LINEAR + ): + raise RuntimeError("Only available with linear axes") + + xmin, xmax = xaxis.getLimits() + ymin, ymax = yaxis.getLimits() + width, height = plot.getPlotBoundsInPixels()[2:] + if width == 0 or height == 0: + return None + else: + return (xmax - xmin) / width, (ymax - ymin) / height + + def _getLevelOfDetails(self) -> Tuple[int, int]: + """Return current level of details the image is displayed with.""" + plot = self.getPlot() + if plot is None or not self._isPlotLinear(plot): + return 1, 1 # Fallback to bas LOD + + sx, sy = self.getScale() + xUnitPerPixel, yUnitPerPixel = self._getPixelSizeInData() + lodx = max(1, int(numpy.ceil(xUnitPerPixel / sx))) + lody = max(1, int(numpy.ceil(yUnitPerPixel / sy))) + return lodx, lody + + @docstring(ImageDataBase) + def setData(self, data, copy=True): + self.__cacheLODData = {} # Reset cache + super().setData(data) + + @docstring(ImageDataBase) + def _setPlot(self, plot): + """Refresh image when plot limits change""" + previousPlot = self.getPlot() + if previousPlot is not None: + for axis in (previousPlot.getXAxis(), previousPlot.getYAxis()): + axis.sigLimitsChanged.disconnect(self.__plotLimitsChanged) + + super()._setPlot(plot) + + if plot is not None: + for axis in (plot.getXAxis(), plot.getYAxis()): + axis.sigLimitsChanged.connect(self.__plotLimitsChanged) + + def __plotLimitsChanged(self): + """Trigger update if level of details has changed""" + if (self.getAggregationMode() != self.Aggregation.NONE and + self.__currentLOD != self._getLevelOfDetails()): + self._updated() + + @docstring(ImageDataBase) + def pick(self, x, y): + result = super().pick(x, y) + if result is None: + return None + + # Compute indices in initial data + plot = self.getPlot() + if plot is None: + return None + dataPos = plot.pixelToData(x, y, axis="left", check=True) + if dataPos is None: + return None # Outside plot area + + ox, oy = self.getOrigin() + sx, sy = self.getScale() + col = int((dataPos[0] - ox) / sx) + row = int((dataPos[1] - oy) / sy) + height, width = self.getData(copy=False).shape[:2] + if 0 <= col < width and 0 <= row < height: + return PickingResult(self, ((row,), (col,))) + return None diff --git a/src/silx/gui/plot/items/marker.py b/src/silx/gui/plot/items/marker.py new file mode 100755 index 0000000..50d070c --- /dev/null +++ b/src/silx/gui/plot/items/marker.py @@ -0,0 +1,281 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides markers item of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2017" + + +import logging + +from ....utils.proxy import docstring +from .core import (Item, DraggableMixIn, ColorMixIn, LineMixIn, SymbolMixIn, + ItemChangedType, YAxisMixIn) +from silx.gui import qt + +_logger = logging.getLogger(__name__) + + +class MarkerBase(Item, DraggableMixIn, ColorMixIn, YAxisMixIn): + """Base class for markers""" + + sigDragStarted = qt.Signal() + """Signal emitted when the marker is pressed""" + sigDragFinished = qt.Signal() + """Signal emitted when the marker is released""" + + _DEFAULT_COLOR = (0., 0., 0., 1.) + """Default color of the markers""" + + def __init__(self): + Item.__init__(self) + DraggableMixIn.__init__(self) + ColorMixIn.__init__(self) + YAxisMixIn.__init__(self) + + self._text = '' + self._x = None + self._y = None + self._constraint = self._defaultConstraint + self.__isBeingDragged = False + + def _addRendererCall(self, backend, + symbol=None, linestyle='-', linewidth=1): + """Perform the update of the backend renderer""" + return backend.addMarker( + x=self.getXPosition(), + y=self.getYPosition(), + text=self.getText(), + color=self.getColor(), + symbol=symbol, + linestyle=linestyle, + linewidth=linewidth, + constraint=self.getConstraint(), + yaxis=self.getYAxis()) + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + raise NotImplementedError() + + @docstring(DraggableMixIn) + def drag(self, from_, to): + self.setPosition(to[0], to[1]) + + def isOverlay(self): + """Returns True: A marker is always rendered as an overlay. + + :rtype: bool + """ + return True + + def getText(self): + """Returns marker text. + + :rtype: str + """ + return self._text + + def setText(self, text): + """Set the text of the marker. + + :param str text: The text to use + """ + text = str(text) + if text != self._text: + self._text = text + self._updated(ItemChangedType.TEXT) + + def getXPosition(self): + """Returns the X position of the marker line in data coordinates + + :rtype: float or None + """ + return self._x + + def getYPosition(self): + """Returns the Y position of the marker line in data coordinates + + :rtype: float or None + """ + return self._y + + def getPosition(self): + """Returns the (x, y) position of the marker in data coordinates + + :rtype: 2-tuple of float or None + """ + return self._x, self._y + + def setPosition(self, x, y): + """Set marker position in data coordinates + + Constraint are applied if any. + + :param float x: X coordinates in data frame + :param float y: Y coordinates in data frame + """ + x, y = self.getConstraint()(x, y) + x, y = float(x), float(y) + if x != self._x or y != self._y: + self._x, self._y = x, y + self._updated(ItemChangedType.POSITION) + + def getConstraint(self): + """Returns the dragging constraint of this item""" + return self._constraint + + def _setConstraint(self, constraint): # TODO support update + """Set the constraint. + + This is private for now as update is not handled. + + :param callable constraint: + :param constraint: A function filtering item displacement by + dragging operations or None for no filter. + This function is called each time the item is + moved. + This is only used if isDraggable returns True. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + """ + if constraint is None: + constraint = self._defaultConstraint + assert callable(constraint) + self._constraint = constraint + + @staticmethod + def _defaultConstraint(*args): + """Default constraint not doing anything""" + return args + + def _startDrag(self): + self.__isBeingDragged = True + self.sigDragStarted.emit() + + def _endDrag(self): + self.__isBeingDragged = False + self.sigDragFinished.emit() + + def isBeingDragged(self) -> bool: + """Returns whether the marker is currently dragged by the user.""" + return self.__isBeingDragged + + +class Marker(MarkerBase, SymbolMixIn): + """Description of a marker""" + + _DEFAULT_SYMBOL = '+' + """Default symbol of the marker""" + + def __init__(self): + MarkerBase.__init__(self) + SymbolMixIn.__init__(self) + + self._x = 0. + self._y = 0. + + def _addBackendRenderer(self, backend): + return self._addRendererCall(backend, symbol=self.getSymbol()) + + def _setConstraint(self, constraint): + """Set the constraint function of the marker drag. + + It also supports 'horizontal' and 'vertical' str as constraint. + + :param constraint: The constraint of the dragging of this marker + :type: constraint: callable or str + """ + if constraint == 'horizontal': + constraint = self._horizontalConstraint + elif constraint == 'vertical': + constraint = self._verticalConstraint + + super(Marker, self)._setConstraint(constraint) + + def _horizontalConstraint(self, _, y): + return self.getXPosition(), y + + def _verticalConstraint(self, x, _): + return x, self.getYPosition() + + +class _LineMarker(MarkerBase, LineMixIn): + """Base class for line markers""" + + def __init__(self): + MarkerBase.__init__(self) + LineMixIn.__init__(self) + + def _addBackendRenderer(self, backend): + return self._addRendererCall(backend, + linestyle=self.getLineStyle(), + linewidth=self.getLineWidth()) + + +class XMarker(_LineMarker): + """Description of a marker""" + + def __init__(self): + _LineMarker.__init__(self) + self._x = 0. + + def setPosition(self, x, y): + """Set marker line position in data coordinates + + Constraint are applied if any. + + :param float x: X coordinates in data frame + :param float y: Y coordinates in data frame + """ + x, _ = self.getConstraint()(x, y) + x = float(x) + if x != self._x: + self._x = x + self._updated(ItemChangedType.POSITION) + + +class YMarker(_LineMarker): + """Description of a marker""" + + def __init__(self): + _LineMarker.__init__(self) + self._y = 0. + + def setPosition(self, x, y): + """Set marker line position in data coordinates + + Constraint are applied if any. + + :param float x: X coordinates in data frame + :param float y: Y coordinates in data frame + """ + _, y = self.getConstraint()(x, y) + y = float(y) + if y != self._y: + self._y = y + self._updated(ItemChangedType.POSITION) diff --git a/src/silx/gui/plot/items/roi.py b/src/silx/gui/plot/items/roi.py new file mode 100644 index 0000000..38a1424 --- /dev/null +++ b/src/silx/gui/plot/items/roi.py @@ -0,0 +1,1519 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides ROI item for the :class:`~silx.gui.plot.PlotWidget`. + +.. inheritance-diagram:: + silx.gui.plot.items.roi + :parts: 1 +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import logging +import numpy + +from ... import utils +from .. import items +from ...colors import rgba +from silx.image.shapes import Polygon +from silx.image._boundingbox import _BoundingBox +from ....utils.proxy import docstring +from ..utils.intersections import segments_intersection +from ._roi_base import _RegionOfInterestBase + +# He following imports have to be exposed by this module +from ._roi_base import RegionOfInterest +from ._roi_base import HandleBasedROI +from ._arc_roi import ArcROI # noqa +from ._roi_base import InteractionModeMixIn # noqa +from ._roi_base import RoiInteractionMode # noqa + + +logger = logging.getLogger(__name__) + + +class PointROI(RegionOfInterest, items.SymbolMixIn): + """A ROI identifying a point in a 2D plot.""" + + ICON = 'add-shape-point' + NAME = 'point markers' + SHORT_NAME = "point" + """Metadata for this kind of ROI""" + + _plotShape = "point" + """Plot shape which is used for the first interaction""" + + _DEFAULT_SYMBOL = '+' + """Default symbol of the PointROI + + It overwrite the `SymbolMixIn` class attribte. + """ + + def __init__(self, parent=None): + RegionOfInterest.__init__(self, parent=parent) + items.SymbolMixIn.__init__(self) + self._marker = items.Marker() + self._marker.sigItemChanged.connect(self._pointPositionChanged) + self._marker.setSymbol(self._DEFAULT_SYMBOL) + self._marker.sigDragStarted.connect(self._editingStarted) + self._marker.sigDragFinished.connect(self._editingFinished) + self.addItem(self._marker) + + def setFirstShapePoints(self, points): + self.setPosition(points[0]) + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.NAME: + label = self.getName() + self._marker.setText(label) + elif event == items.ItemChangedType.EDITABLE: + self._marker._setDraggable(self.isEditable()) + elif event in [items.ItemChangedType.VISIBLE, + items.ItemChangedType.SELECTABLE]: + self._updateItemProperty(event, self, self._marker) + super(PointROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + self._marker.setColor(style.getColor()) + + def getPosition(self): + """Returns the position of this ROI + + :rtype: numpy.ndarray + """ + return self._marker.getPosition() + + def setPosition(self, pos): + """Set the position of this ROI + + :param numpy.ndarray pos: 2d-coordinate of this point + """ + self._marker.setPosition(*pos) + + @docstring(_RegionOfInterestBase) + def contains(self, position): + roiPos = self.getPosition() + return position[0] == roiPos[0] and position[1] == roiPos[1] + + def _pointPositionChanged(self, event): + """Handle position changed events of the marker""" + if event is items.ItemChangedType.POSITION: + self.sigRegionChanged.emit() + + def __str__(self): + params = '%f %f' % self.getPosition() + return "%s(%s)" % (self.__class__.__name__, params) + + +class CrossROI(HandleBasedROI, items.LineMixIn): + """A ROI identifying a point in a 2D plot and displayed as a cross + """ + + ICON = 'add-shape-cross' + NAME = 'cross marker' + SHORT_NAME = "cross" + """Metadata for this kind of ROI""" + + _plotShape = "point" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + HandleBasedROI.__init__(self, parent=parent) + items.LineMixIn.__init__(self) + self._handle = self.addHandle() + self._handle.sigItemChanged.connect(self._handlePositionChanged) + self._handleLabel = self.addLabelHandle() + self._vmarker = self.addUserHandle(items.YMarker()) + self._vmarker._setSelectable(False) + self._vmarker._setDraggable(False) + self._vmarker.setPosition(*self.getPosition()) + self._hmarker = self.addUserHandle(items.XMarker()) + self._hmarker._setSelectable(False) + self._hmarker._setDraggable(False) + self._hmarker.setPosition(*self.getPosition()) + + def _updated(self, event=None, checkVisibility=True): + if event in [items.ItemChangedType.VISIBLE]: + markers = (self._vmarker, self._hmarker) + self._updateItemProperty(event, self, markers) + super(CrossROI, self)._updated(event, checkVisibility) + + def _updateText(self, text): + self._handleLabel.setText(text) + + def _updatedStyle(self, event, style): + super(CrossROI, self)._updatedStyle(event, style) + for marker in [self._vmarker, self._hmarker]: + marker.setColor(style.getColor()) + marker.setLineStyle(style.getLineStyle()) + marker.setLineWidth(style.getLineWidth()) + + def setFirstShapePoints(self, points): + pos = points[0] + self.setPosition(pos) + + def getPosition(self): + """Returns the position of this ROI + + :rtype: numpy.ndarray + """ + return self._handle.getPosition() + + def setPosition(self, pos): + """Set the position of this ROI + + :param numpy.ndarray pos: 2d-coordinate of this point + """ + self._handle.setPosition(*pos) + + def _handlePositionChanged(self, event): + """Handle center marker position updates""" + if event is items.ItemChangedType.POSITION: + position = self.getPosition() + self._handleLabel.setPosition(*position) + self._vmarker.setPosition(*position) + self._hmarker.setPosition(*position) + self.sigRegionChanged.emit() + + @docstring(HandleBasedROI) + def contains(self, position): + roiPos = self.getPosition() + return position[0] == roiPos[0] or position[1] == roiPos[1] + + +class LineROI(HandleBasedROI, items.LineMixIn): + """A ROI identifying a line in a 2D plot. + + This ROI provides 1 anchor for each boundary of the line, plus an center + in the center to translate the full ROI. + """ + + ICON = 'add-shape-diagonal' + NAME = 'line ROI' + SHORT_NAME = "line" + """Metadata for this kind of ROI""" + + _plotShape = "line" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + HandleBasedROI.__init__(self, parent=parent) + items.LineMixIn.__init__(self) + self._handleStart = self.addHandle() + self._handleEnd = self.addHandle() + self._handleCenter = self.addTranslateHandle() + self._handleLabel = self.addLabelHandle() + + shape = items.Shape("polylines") + shape.setPoints([[0, 0], [0, 0]]) + shape.setColor(rgba(self.getColor())) + shape.setFill(False) + shape.setOverlay(True) + shape.setLineStyle(self.getLineStyle()) + shape.setLineWidth(self.getLineWidth()) + self.__shape = shape + self.addItem(shape) + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.VISIBLE: + self._updateItemProperty(event, self, self.__shape) + super(LineROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + super(LineROI, self)._updatedStyle(event, style) + self.__shape.setColor(style.getColor()) + self.__shape.setLineStyle(style.getLineStyle()) + self.__shape.setLineWidth(style.getLineWidth()) + + def setFirstShapePoints(self, points): + assert len(points) == 2 + self.setEndPoints(points[0], points[1]) + + def _updateText(self, text): + self._handleLabel.setText(text) + + def setEndPoints(self, startPoint, endPoint): + """Set this line location using the ending points + + :param numpy.ndarray startPoint: Staring bounding point of the line + :param numpy.ndarray endPoint: Ending bounding point of the line + """ + if not numpy.array_equal((startPoint, endPoint), self.getEndPoints()): + self.__updateEndPoints(startPoint, endPoint) + + def __updateEndPoints(self, startPoint, endPoint): + """Update marker and shape to match given end points + + :param numpy.ndarray startPoint: Staring bounding point of the line + :param numpy.ndarray endPoint: Ending bounding point of the line + """ + startPoint = numpy.array(startPoint) + endPoint = numpy.array(endPoint) + center = (startPoint + endPoint) * 0.5 + + with utils.blockSignals(self._handleStart): + self._handleStart.setPosition(startPoint[0], startPoint[1]) + with utils.blockSignals(self._handleEnd): + self._handleEnd.setPosition(endPoint[0], endPoint[1]) + with utils.blockSignals(self._handleCenter): + self._handleCenter.setPosition(center[0], center[1]) + with utils.blockSignals(self._handleLabel): + self._handleLabel.setPosition(center[0], center[1]) + + line = numpy.array((startPoint, endPoint)) + self.__shape.setPoints(line) + self.sigRegionChanged.emit() + + def getEndPoints(self): + """Returns bounding points of this ROI. + + :rtype: Tuple(numpy.ndarray,numpy.ndarray) + """ + startPoint = numpy.array(self._handleStart.getPosition()) + endPoint = numpy.array(self._handleEnd.getPosition()) + return (startPoint, endPoint) + + def handleDragUpdated(self, handle, origin, previous, current): + if handle is self._handleStart: + _start, end = self.getEndPoints() + self.__updateEndPoints(current, end) + elif handle is self._handleEnd: + start, _end = self.getEndPoints() + self.__updateEndPoints(start, current) + elif handle is self._handleCenter: + start, end = self.getEndPoints() + delta = current - previous + start += delta + end += delta + self.setEndPoints(start, end) + + @docstring(_RegionOfInterestBase) + def contains(self, position): + bottom_left = position[0], position[1] + bottom_right = position[0] + 1, position[1] + top_left = position[0], position[1] + 1 + top_right = position[0] + 1, position[1] + 1 + + points = self.__shape.getPoints() + line_pt1 = points[0] + line_pt2 = points[1] + + bb1 = _BoundingBox.from_points(points) + if not bb1.contains(position): + return False + + return ( + segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2, + seg2_start_pt=bottom_left, seg2_end_pt=bottom_right) or + segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2, + seg2_start_pt=bottom_right, seg2_end_pt=top_right) or + segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2, + seg2_start_pt=top_right, seg2_end_pt=top_left) or + segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2, + seg2_start_pt=top_left, seg2_end_pt=bottom_left) + ) is not None + + def __str__(self): + start, end = self.getEndPoints() + params = start[0], start[1], end[0], end[1] + params = 'start: %f %f; end: %f %f' % params + return "%s(%s)" % (self.__class__.__name__, params) + + +class HorizontalLineROI(RegionOfInterest, items.LineMixIn): + """A ROI identifying an horizontal line in a 2D plot.""" + + ICON = 'add-shape-horizontal' + NAME = 'horizontal line ROI' + SHORT_NAME = "hline" + """Metadata for this kind of ROI""" + + _plotShape = "hline" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + RegionOfInterest.__init__(self, parent=parent) + items.LineMixIn.__init__(self) + self._marker = items.YMarker() + self._marker.sigItemChanged.connect(self._linePositionChanged) + self._marker.sigDragStarted.connect(self._editingStarted) + self._marker.sigDragFinished.connect(self._editingFinished) + self.addItem(self._marker) + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.NAME: + label = self.getName() + self._marker.setText(label) + elif event == items.ItemChangedType.EDITABLE: + self._marker._setDraggable(self.isEditable()) + elif event in [items.ItemChangedType.VISIBLE, + items.ItemChangedType.SELECTABLE]: + self._updateItemProperty(event, self, self._marker) + super(HorizontalLineROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + self._marker.setColor(style.getColor()) + self._marker.setLineStyle(style.getLineStyle()) + self._marker.setLineWidth(style.getLineWidth()) + + def setFirstShapePoints(self, points): + pos = points[0, 1] + if pos == self.getPosition(): + return + self.setPosition(pos) + + def getPosition(self): + """Returns the position of this line if the horizontal axis + + :rtype: float + """ + pos = self._marker.getPosition() + return pos[1] + + def setPosition(self, pos): + """Set the position of this ROI + + :param float pos: Horizontal position of this line + """ + self._marker.setPosition(0, pos) + + @docstring(_RegionOfInterestBase) + def contains(self, position): + return position[1] == self.getPosition() + + def _linePositionChanged(self, event): + """Handle position changed events of the marker""" + if event is items.ItemChangedType.POSITION: + self.sigRegionChanged.emit() + + def __str__(self): + params = 'y: %f' % self.getPosition() + return "%s(%s)" % (self.__class__.__name__, params) + + +class VerticalLineROI(RegionOfInterest, items.LineMixIn): + """A ROI identifying a vertical line in a 2D plot.""" + + ICON = 'add-shape-vertical' + NAME = 'vertical line ROI' + SHORT_NAME = "vline" + """Metadata for this kind of ROI""" + + _plotShape = "vline" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + RegionOfInterest.__init__(self, parent=parent) + items.LineMixIn.__init__(self) + self._marker = items.XMarker() + self._marker.sigItemChanged.connect(self._linePositionChanged) + self._marker.sigDragStarted.connect(self._editingStarted) + self._marker.sigDragFinished.connect(self._editingFinished) + self.addItem(self._marker) + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.NAME: + label = self.getName() + self._marker.setText(label) + elif event == items.ItemChangedType.EDITABLE: + self._marker._setDraggable(self.isEditable()) + elif event in [items.ItemChangedType.VISIBLE, + items.ItemChangedType.SELECTABLE]: + self._updateItemProperty(event, self, self._marker) + super(VerticalLineROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + self._marker.setColor(style.getColor()) + self._marker.setLineStyle(style.getLineStyle()) + self._marker.setLineWidth(style.getLineWidth()) + + def setFirstShapePoints(self, points): + pos = points[0, 0] + self.setPosition(pos) + + def getPosition(self): + """Returns the position of this line if the horizontal axis + + :rtype: float + """ + pos = self._marker.getPosition() + return pos[0] + + def setPosition(self, pos): + """Set the position of this ROI + + :param float pos: Horizontal position of this line + """ + self._marker.setPosition(pos, 0) + + @docstring(RegionOfInterest) + def contains(self, position): + return position[0] == self.getPosition() + + def _linePositionChanged(self, event): + """Handle position changed events of the marker""" + if event is items.ItemChangedType.POSITION: + self.sigRegionChanged.emit() + + def __str__(self): + params = 'x: %f' % self.getPosition() + return "%s(%s)" % (self.__class__.__name__, params) + + +class RectangleROI(HandleBasedROI, items.LineMixIn): + """A ROI identifying a rectangle in a 2D plot. + + This ROI provides 1 anchor for each corner, plus an anchor in the + center to translate the full ROI. + """ + + ICON = 'add-shape-rectangle' + NAME = 'rectangle ROI' + SHORT_NAME = "rectangle" + """Metadata for this kind of ROI""" + + _plotShape = "rectangle" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + HandleBasedROI.__init__(self, parent=parent) + items.LineMixIn.__init__(self) + self._handleTopLeft = self.addHandle() + self._handleTopRight = self.addHandle() + self._handleBottomLeft = self.addHandle() + self._handleBottomRight = self.addHandle() + self._handleCenter = self.addTranslateHandle() + self._handleLabel = self.addLabelHandle() + + shape = items.Shape("rectangle") + shape.setPoints([[0, 0], [0, 0]]) + shape.setFill(False) + shape.setOverlay(True) + shape.setLineStyle(self.getLineStyle()) + shape.setLineWidth(self.getLineWidth()) + shape.setColor(rgba(self.getColor())) + self.__shape = shape + self.addItem(shape) + + def _updated(self, event=None, checkVisibility=True): + if event in [items.ItemChangedType.VISIBLE]: + self._updateItemProperty(event, self, self.__shape) + super(RectangleROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + super(RectangleROI, self)._updatedStyle(event, style) + self.__shape.setColor(style.getColor()) + self.__shape.setLineStyle(style.getLineStyle()) + self.__shape.setLineWidth(style.getLineWidth()) + + def setFirstShapePoints(self, points): + assert len(points) == 2 + self._setBound(points) + + def _setBound(self, points): + """Initialize the rectangle from a bunch of points""" + top = max(points[:, 1]) + bottom = min(points[:, 1]) + left = min(points[:, 0]) + right = max(points[:, 0]) + size = right - left, top - bottom + self._updateGeometry(origin=(left, bottom), size=size) + + def _updateText(self, text): + self._handleLabel.setText(text) + + def getCenter(self): + """Returns the central point of this rectangle + + :rtype: numpy.ndarray([float,float]) + """ + pos = self._handleCenter.getPosition() + return numpy.array(pos) + + def getOrigin(self): + """Returns the corner point with the smaller coordinates + + :rtype: numpy.ndarray([float,float]) + """ + pos = self._handleBottomLeft.getPosition() + return numpy.array(pos) + + def getSize(self): + """Returns the size of this rectangle + + :rtype: numpy.ndarray([float,float]) + """ + vmin = self._handleBottomLeft.getPosition() + vmax = self._handleTopRight.getPosition() + vmin, vmax = numpy.array(vmin), numpy.array(vmax) + return vmax - vmin + + def setOrigin(self, position): + """Set the origin position of this ROI + + :param numpy.ndarray position: Location of the smaller corner of the ROI + """ + size = self.getSize() + self.setGeometry(origin=position, size=size) + + def setSize(self, size): + """Set the size of this ROI + + :param numpy.ndarray size: Size of the center of the ROI + """ + origin = self.getOrigin() + self.setGeometry(origin=origin, size=size) + + def setCenter(self, position): + """Set the size of this ROI + + :param numpy.ndarray position: Location of the center of the ROI + """ + size = self.getSize() + self.setGeometry(center=position, size=size) + + def setGeometry(self, origin=None, size=None, center=None): + """Set the geometry of the ROI + """ + if ((origin is None or numpy.array_equal(origin, self.getOrigin())) and + (center is None or numpy.array_equal(center, self.getCenter())) and + numpy.array_equal(size, self.getSize())): + return # Nothing has changed + + self._updateGeometry(origin, size, center) + + def _updateGeometry(self, origin=None, size=None, center=None): + """Forced update of the geometry of the ROI""" + if origin is not None: + origin = numpy.array(origin) + size = numpy.array(size) + points = numpy.array([origin, origin + size]) + center = origin + size * 0.5 + elif center is not None: + center = numpy.array(center) + size = numpy.array(size) + points = numpy.array([center - size * 0.5, center + size * 0.5]) + else: + raise ValueError("Origin or center expected") + + with utils.blockSignals(self._handleBottomLeft): + self._handleBottomLeft.setPosition(points[0, 0], points[0, 1]) + with utils.blockSignals(self._handleBottomRight): + self._handleBottomRight.setPosition(points[1, 0], points[0, 1]) + with utils.blockSignals(self._handleTopLeft): + self._handleTopLeft.setPosition(points[0, 0], points[1, 1]) + with utils.blockSignals(self._handleTopRight): + self._handleTopRight.setPosition(points[1, 0], points[1, 1]) + with utils.blockSignals(self._handleCenter): + self._handleCenter.setPosition(center[0], center[1]) + with utils.blockSignals(self._handleLabel): + self._handleLabel.setPosition(points[0, 0], points[0, 1]) + + self.__shape.setPoints(points) + self.sigRegionChanged.emit() + + @docstring(HandleBasedROI) + def contains(self, position): + assert isinstance(position, (tuple, list, numpy.array)) + points = self.__shape.getPoints() + bb1 = _BoundingBox.from_points(points) + return bb1.contains(position) + + def handleDragUpdated(self, handle, origin, previous, current): + if handle is self._handleCenter: + # It is the center anchor + size = self.getSize() + self._updateGeometry(center=current, size=size) + else: + opposed = { + self._handleBottomLeft: self._handleTopRight, + self._handleTopRight: self._handleBottomLeft, + self._handleBottomRight: self._handleTopLeft, + self._handleTopLeft: self._handleBottomRight, + } + handle2 = opposed[handle] + current2 = handle2.getPosition() + points = numpy.array([current, current2]) + + # Switch handles if they were crossed by interaction + if self._handleBottomLeft.getXPosition() > self._handleBottomRight.getXPosition(): + self._handleBottomLeft, self._handleBottomRight = self._handleBottomRight, self._handleBottomLeft + + if self._handleTopLeft.getXPosition() > self._handleTopRight.getXPosition(): + self._handleTopLeft, self._handleTopRight = self._handleTopRight, self._handleTopLeft + + if self._handleBottomLeft.getYPosition() > self._handleTopLeft.getYPosition(): + self._handleBottomLeft, self._handleTopLeft = self._handleTopLeft, self._handleBottomLeft + + if self._handleBottomRight.getYPosition() > self._handleTopRight.getYPosition(): + self._handleBottomRight, self._handleTopRight = self._handleTopRight, self._handleBottomRight + + self._setBound(points) + + def __str__(self): + origin = self.getOrigin() + w, h = self.getSize() + params = origin[0], origin[1], w, h + params = 'origin: %f %f; width: %f; height: %f' % params + return "%s(%s)" % (self.__class__.__name__, params) + + +class CircleROI(HandleBasedROI, items.LineMixIn): + """A ROI identifying a circle in a 2D plot. + + This ROI provides 1 anchor at the center to translate the circle, + and one anchor on the perimeter to change the radius. + """ + + ICON = 'add-shape-circle' + NAME = 'circle ROI' + SHORT_NAME = "circle" + """Metadata for this kind of ROI""" + + _kind = "Circle" + """Label for this kind of ROI""" + + _plotShape = "line" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + HandleBasedROI.__init__(self, parent=parent) + self._handlePerimeter = self.addHandle() + self._handleCenter = self.addTranslateHandle() + self._handleCenter.sigItemChanged.connect(self._centerPositionChanged) + self._handleLabel = self.addLabelHandle() + + shape = items.Shape("polygon") + shape.setPoints([[0, 0], [0, 0]]) + shape.setColor(rgba(self.getColor())) + shape.setFill(False) + shape.setOverlay(True) + shape.setLineStyle(self.getLineStyle()) + shape.setLineWidth(self.getLineWidth()) + self.__shape = shape + self.addItem(shape) + + self.__radius = 0 + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.VISIBLE: + self._updateItemProperty(event, self, self.__shape) + super(CircleROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + super(CircleROI, self)._updatedStyle(event, style) + self.__shape.setColor(style.getColor()) + self.__shape.setLineStyle(style.getLineStyle()) + self.__shape.setLineWidth(style.getLineWidth()) + + def setFirstShapePoints(self, points): + assert len(points) == 2 + self._setRay(points) + + def _setRay(self, points): + """Initialize the circle from the center point and a + perimeter point.""" + center = points[0] + radius = numpy.linalg.norm(points[0] - points[1]) + self.setGeometry(center=center, radius=radius) + + def _updateText(self, text): + self._handleLabel.setText(text) + + def getCenter(self): + """Returns the central point of this rectangle + + :rtype: numpy.ndarray([float,float]) + """ + pos = self._handleCenter.getPosition() + return numpy.array(pos) + + def getRadius(self): + """Returns the radius of this circle + + :rtype: float + """ + return self.__radius + + def setCenter(self, position): + """Set the center point of this ROI + + :param numpy.ndarray position: Location of the center of the circle + """ + self._handleCenter.setPosition(*position) + + def setRadius(self, radius): + """Set the size of this ROI + + :param float size: Radius of the circle + """ + radius = float(radius) + if radius != self.__radius: + self.__radius = radius + self._updateGeometry() + + def setGeometry(self, center, radius): + """Set the geometry of the ROI + """ + if numpy.array_equal(center, self.getCenter()): + self.setRadius(radius) + else: + self.__radius = float(radius) # Update radius directly + self.setCenter(center) # Calls _updateGeometry + + def _updateGeometry(self): + """Update the handles and shape according to given parameters""" + center = self.getCenter() + perimeter_point = numpy.array([center[0] + self.__radius, center[1]]) + + self._handlePerimeter.setPosition(perimeter_point[0], perimeter_point[1]) + self._handleLabel.setPosition(center[0], center[1]) + + nbpoints = 27 + angles = numpy.arange(nbpoints) * 2.0 * numpy.pi / nbpoints + circleShape = numpy.array((numpy.cos(angles) * self.__radius, + numpy.sin(angles) * self.__radius)).T + circleShape += center + self.__shape.setPoints(circleShape) + self.sigRegionChanged.emit() + + def _centerPositionChanged(self, event): + """Handle position changed events of the center marker""" + if event is items.ItemChangedType.POSITION: + self._updateGeometry() + + def handleDragUpdated(self, handle, origin, previous, current): + if handle is self._handlePerimeter: + center = self.getCenter() + self.setRadius(numpy.linalg.norm(center - current)) + + @docstring(HandleBasedROI) + def contains(self, position): + return numpy.linalg.norm(self.getCenter() - position) <= self.getRadius() + + def __str__(self): + center = self.getCenter() + radius = self.getRadius() + params = center[0], center[1], radius + params = 'center: %f %f; radius: %f;' % params + return "%s(%s)" % (self.__class__.__name__, params) + + +class EllipseROI(HandleBasedROI, items.LineMixIn): + """A ROI identifying an oriented ellipse in a 2D plot. + + This ROI provides 1 anchor at the center to translate the circle, + and two anchors on the perimeter to modify the major-radius and + minor-radius. These two anchors also allow to change the orientation. + """ + + ICON = 'add-shape-ellipse' + NAME = 'ellipse ROI' + SHORT_NAME = "ellipse" + """Metadata for this kind of ROI""" + + _plotShape = "line" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + HandleBasedROI.__init__(self, parent=parent) + self._handleAxis0 = self.addHandle() + self._handleAxis1 = self.addHandle() + self._handleCenter = self.addTranslateHandle() + self._handleCenter.sigItemChanged.connect(self._centerPositionChanged) + self._handleLabel = self.addLabelHandle() + + shape = items.Shape("polygon") + shape.setPoints([[0, 0], [0, 0]]) + shape.setColor(rgba(self.getColor())) + shape.setFill(False) + shape.setOverlay(True) + shape.setLineStyle(self.getLineStyle()) + shape.setLineWidth(self.getLineWidth()) + self.__shape = shape + self.addItem(shape) + + self._radius = 0., 0. + self._orientation = 0. # angle in radians between the X-axis and the _handleAxis0 + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.VISIBLE: + self._updateItemProperty(event, self, self.__shape) + super(EllipseROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + super(EllipseROI, self)._updatedStyle(event, style) + self.__shape.setColor(style.getColor()) + self.__shape.setLineStyle(style.getLineStyle()) + self.__shape.setLineWidth(style.getLineWidth()) + + def setFirstShapePoints(self, points): + assert len(points) == 2 + self._setRay(points) + + @staticmethod + def _calculateOrientation(p0, p1): + """return angle in radians between the vector p0-p1 + and the X axis + + :param p0: first point coordinates (x, y) + :param p1: second point coordinates + :return: + """ + vector = (p1[0] - p0[0], p1[1] - p0[1]) + x_unit_vector = (1, 0) + norm = numpy.linalg.norm(vector) + if norm != 0: + theta = numpy.arccos(numpy.dot(vector, x_unit_vector) / norm) + else: + theta = 0 + if vector[1] < 0: + # arccos always returns values in range [0, pi] + theta = 2 * numpy.pi - theta + return theta + + def _setRay(self, points): + """Initialize the circle from the center point and a + perimeter point.""" + center = points[0] + radius = numpy.linalg.norm(points[0] - points[1]) + orientation = self._calculateOrientation(points[0], points[1]) + self.setGeometry(center=center, + radius=(radius, radius), + orientation=orientation) + + def _updateText(self, text): + self._handleLabel.setText(text) + + def getCenter(self): + """Returns the central point of this rectangle + + :rtype: numpy.ndarray([float,float]) + """ + pos = self._handleCenter.getPosition() + return numpy.array(pos) + + def getMajorRadius(self): + """Returns the half-diameter of the major axis. + + :rtype: float + """ + return max(self._radius) + + def getMinorRadius(self): + """Returns the half-diameter of the minor axis. + + :rtype: float + """ + return min(self._radius) + + def getOrientation(self): + """Return angle in radians between the horizontal (X) axis + and the major axis of the ellipse in [0, 2*pi[ + + :rtype: float: + """ + return self._orientation + + def setCenter(self, center): + """Set the center point of this ROI + + :param numpy.ndarray position: Coordinates (X, Y) of the center + of the ellipse + """ + self._handleCenter.setPosition(*center) + + def setMajorRadius(self, radius): + """Set the half-diameter of the major axis of the ellipse. + + :param float radius: + Major radius of the ellipsis. Must be a positive value. + """ + if self._radius[0] > self._radius[1]: + newRadius = radius, self._radius[1] + else: + newRadius = self._radius[0], radius + self.setGeometry(radius=newRadius) + + def setMinorRadius(self, radius): + """Set the half-diameter of the minor axis of the ellipse. + + :param float radius: + Minor radius of the ellipsis. Must be a positive value. + """ + if self._radius[0] > self._radius[1]: + newRadius = self._radius[0], radius + else: + newRadius = radius, self._radius[1] + self.setGeometry(radius=newRadius) + + def setOrientation(self, orientation): + """Rotate the ellipse + + :param float orientation: Angle in radians between the horizontal and + the major axis. + :return: + """ + self.setGeometry(orientation=orientation) + + def setGeometry(self, center=None, radius=None, orientation=None): + """ + + :param center: (X, Y) coordinates + :param float majorRadius: + :param float minorRadius: + :param float orientation: angle in radians between the major axis and the + horizontal + :return: + """ + if center is None: + center = self.getCenter() + + if radius is None: + radius = self._radius + else: + radius = float(radius[0]), float(radius[1]) + + if orientation is None: + orientation = self._orientation + else: + # ensure that we store the orientation in range [0, 2*pi + orientation = numpy.mod(orientation, 2 * numpy.pi) + + if (numpy.array_equal(center, self.getCenter()) or + radius != self._radius or + orientation != self._orientation): + + # Update parameters directly + self._radius = radius + self._orientation = orientation + + if numpy.array_equal(center, self.getCenter()): + self._updateGeometry() + else: + # This will call _updateGeometry + self.setCenter(center) + + def _updateGeometry(self): + """Update shape and markers""" + center = self.getCenter() + + orientation = self.getOrientation() + if self._radius[1] > self._radius[0]: + # _handleAxis1 is the major axis + orientation -= numpy.pi / 2 + + point0 = numpy.array([center[0] + self._radius[0] * numpy.cos(orientation), + center[1] + self._radius[0] * numpy.sin(orientation)]) + point1 = numpy.array([center[0] - self._radius[1] * numpy.sin(orientation), + center[1] + self._radius[1] * numpy.cos(orientation)]) + with utils.blockSignals(self._handleAxis0): + self._handleAxis0.setPosition(*point0) + with utils.blockSignals(self._handleAxis1): + self._handleAxis1.setPosition(*point1) + with utils.blockSignals(self._handleLabel): + self._handleLabel.setPosition(*center) + + nbpoints = 27 + angles = numpy.arange(nbpoints) * 2.0 * numpy.pi / nbpoints + X = (self._radius[0] * numpy.cos(angles) * numpy.cos(orientation) + - self._radius[1] * numpy.sin(angles) * numpy.sin(orientation)) + Y = (self._radius[0] * numpy.cos(angles) * numpy.sin(orientation) + + self._radius[1] * numpy.sin(angles) * numpy.cos(orientation)) + + ellipseShape = numpy.array((X, Y)).T + ellipseShape += center + self.__shape.setPoints(ellipseShape) + self.sigRegionChanged.emit() + + def handleDragUpdated(self, handle, origin, previous, current): + if handle in (self._handleAxis0, self._handleAxis1): + center = self.getCenter() + orientation = self._calculateOrientation(center, current) + distance = numpy.linalg.norm(center - current) + + if handle is self._handleAxis1: + if self._radius[0] > distance: + # _handleAxis1 is not the major axis, rotate -90 degrees + orientation -= numpy.pi / 2 + radius = self._radius[0], distance + + else: # _handleAxis0 + if self._radius[1] > distance: + # _handleAxis0 is not the major axis, rotate +90 degrees + orientation += numpy.pi / 2 + radius = distance, self._radius[1] + + self.setGeometry(radius=radius, orientation=orientation) + + def _centerPositionChanged(self, event): + """Handle position changed events of the center marker""" + if event is items.ItemChangedType.POSITION: + self._updateGeometry() + + @docstring(HandleBasedROI) + def contains(self, position): + major, minor = self.getMajorRadius(), self.getMinorRadius() + delta = self.getOrientation() + x, y = position - self.getCenter() + return ((x*numpy.cos(delta) + y*numpy.sin(delta))**2/major**2 + + (x*numpy.sin(delta) - y*numpy.cos(delta))**2/minor**2) <= 1 + + def __str__(self): + center = self.getCenter() + major = self.getMajorRadius() + minor = self.getMinorRadius() + orientation = self.getOrientation() + params = center[0], center[1], major, minor, orientation + params = 'center: %f %f; major radius: %f: minor radius: %f; orientation: %f' % params + return "%s(%s)" % (self.__class__.__name__, params) + + +class PolygonROI(HandleBasedROI, items.LineMixIn): + """A ROI identifying a closed polygon in a 2D plot. + + This ROI provides 1 anchor for each point of the polygon. + """ + + ICON = 'add-shape-polygon' + NAME = 'polygon ROI' + SHORT_NAME = "polygon" + """Metadata for this kind of ROI""" + + _plotShape = "polygon" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + HandleBasedROI.__init__(self, parent=parent) + items.LineMixIn.__init__(self) + self._handleLabel = self.addLabelHandle() + self._handleCenter = self.addTranslateHandle() + self._handlePoints = [] + self._points = numpy.empty((0, 2)) + self._handleClose = None + + self._polygon_shape = None + shape = self.__createShape() + self.__shape = shape + self.addItem(shape) + + def _updated(self, event=None, checkVisibility=True): + if event in [items.ItemChangedType.VISIBLE]: + self._updateItemProperty(event, self, self.__shape) + super(PolygonROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + super(PolygonROI, self)._updatedStyle(event, style) + self.__shape.setColor(style.getColor()) + self.__shape.setLineStyle(style.getLineStyle()) + self.__shape.setLineWidth(style.getLineWidth()) + if self._handleClose is not None: + color = self._computeHandleColor(style.getColor()) + self._handleClose.setColor(color) + + def __createShape(self, interaction=False): + kind = "polygon" if not interaction else "polylines" + shape = items.Shape(kind) + shape.setPoints([[0, 0], [0, 0]]) + shape.setFill(False) + shape.setOverlay(True) + style = self.getCurrentStyle() + shape.setLineStyle(style.getLineStyle()) + shape.setLineWidth(style.getLineWidth()) + shape.setColor(rgba(style.getColor())) + return shape + + def setFirstShapePoints(self, points): + if self._handleClose is not None: + self._handleClose.setPosition(*points[0]) + self.setPoints(points) + + def creationStarted(self): + """"Called when the ROI creation interaction was started. + """ + # Handle to see where to close the polygon + self._handleClose = self.addUserHandle() + self._handleClose.setSymbol("o") + color = self._computeHandleColor(rgba(self.getColor())) + self._handleClose.setColor(color) + + # Hide the center while creating the first shape + self._handleCenter.setSymbol("") + + # In interaction replace the polygon by a line, to display something unclosed + self.removeItem(self.__shape) + self.__shape = self.__createShape(interaction=True) + self.__shape.setPoints(self._points) + self.addItem(self.__shape) + + def isBeingCreated(self): + """Returns true if the ROI is in creation step""" + return self._handleClose is not None + + def creationFinalized(self): + """"Called when the ROI creation interaction was finalized. + """ + self.removeHandle(self._handleClose) + self._handleClose = None + self.removeItem(self.__shape) + self.__shape = self.__createShape() + self.__shape.setPoints(self._points) + self.addItem(self.__shape) + # Hide the center while creating the first shape + self._handleCenter.setSymbol("+") + for handle in self._handlePoints: + handle.setSymbol("s") + + def _updateText(self, text): + self._handleLabel.setText(text) + + def getPoints(self): + """Returns the list of the points of this polygon. + + :rtype: numpy.ndarray + """ + return self._points.copy() + + def setPoints(self, points): + """Set the position of this ROI + + :param numpy.ndarray pos: 2d-coordinate of this point + """ + assert(len(points.shape) == 2 and points.shape[1] == 2) + + if numpy.array_equal(points, self._points): + return # Nothing has changed + + self._polygon_shape = None + + # Update the needed handles + while len(self._handlePoints) != len(points): + if len(self._handlePoints) < len(points): + handle = self.addHandle() + self._handlePoints.append(handle) + if self.isBeingCreated(): + handle.setSymbol("") + else: + handle = self._handlePoints.pop(-1) + self.removeHandle(handle) + + for handle, position in zip(self._handlePoints, points): + with utils.blockSignals(handle): + handle.setPosition(position[0], position[1]) + + if len(points) > 0: + if not self.isHandleBeingDragged(): + vmin = numpy.min(points, axis=0) + vmax = numpy.max(points, axis=0) + center = (vmax + vmin) * 0.5 + with utils.blockSignals(self._handleCenter): + self._handleCenter.setPosition(center[0], center[1]) + + num = numpy.argmin(points[:, 1]) + pos = points[num] + with utils.blockSignals(self._handleLabel): + self._handleLabel.setPosition(pos[0], pos[1]) + + if len(points) == 0: + self._points = numpy.empty((0, 2)) + else: + self._points = points + self.__shape.setPoints(self._points) + self.sigRegionChanged.emit() + + def translate(self, x, y): + points = self.getPoints() + delta = numpy.array([x, y]) + self.setPoints(points) + self.setPoints(points + delta) + + def handleDragUpdated(self, handle, origin, previous, current): + if handle is self._handleCenter: + delta = current - previous + self.translate(delta[0], delta[1]) + else: + points = self.getPoints() + num = self._handlePoints.index(handle) + points[num] = current + self.setPoints(points) + + def handleDragFinished(self, handle, origin, current): + points = self._points + if len(points) > 0: + # Only update the center at the end + # To avoid to disturb the interaction + vmin = numpy.min(points, axis=0) + vmax = numpy.max(points, axis=0) + center = (vmax + vmin) * 0.5 + with utils.blockSignals(self._handleCenter): + self._handleCenter.setPosition(center[0], center[1]) + + def __str__(self): + points = self._points + params = '; '.join('%f %f' % (pt[0], pt[1]) for pt in points) + return "%s(%s)" % (self.__class__.__name__, params) + + @docstring(HandleBasedROI) + def contains(self, position): + bb1 = _BoundingBox.from_points(self.getPoints()) + if bb1.contains(position) is False: + return False + + if self._polygon_shape is None: + self._polygon_shape = Polygon(vertices=self.getPoints()) + + # warning: both the polygon and the value are inverted + return self._polygon_shape.is_inside(row=position[0], col=position[1]) + + def _setControlPoints(self, points): + RegionOfInterest._setControlPoints(self, points=points) + self._polygon_shape = None + + +class HorizontalRangeROI(RegionOfInterest, items.LineMixIn): + """A ROI identifying an horizontal range in a 1D plot.""" + + ICON = 'add-range-horizontal' + NAME = 'horizontal range ROI' + SHORT_NAME = "hrange" + + _plotShape = "line" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + RegionOfInterest.__init__(self, parent=parent) + items.LineMixIn.__init__(self) + self._markerMin = items.XMarker() + self._markerMax = items.XMarker() + self._markerCen = items.XMarker() + self._markerCen.setLineStyle(" ") + self._markerMin._setConstraint(self.__positionMinConstraint) + self._markerMax._setConstraint(self.__positionMaxConstraint) + self._markerMin.sigDragStarted.connect(self._editingStarted) + self._markerMin.sigDragFinished.connect(self._editingFinished) + self._markerMax.sigDragStarted.connect(self._editingStarted) + self._markerMax.sigDragFinished.connect(self._editingFinished) + self._markerCen.sigDragStarted.connect(self._editingStarted) + self._markerCen.sigDragFinished.connect(self._editingFinished) + self.addItem(self._markerCen) + self.addItem(self._markerMin) + self.addItem(self._markerMax) + self.__filterReentrant = utils.LockReentrant() + + def setFirstShapePoints(self, points): + vmin = min(points[:, 0]) + vmax = max(points[:, 0]) + self._updatePos(vmin, vmax) + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.NAME: + self._updateText() + elif event == items.ItemChangedType.EDITABLE: + self._updateEditable() + self._updateText() + elif event == items.ItemChangedType.LINE_STYLE: + markers = [self._markerMin, self._markerMax] + self._updateItemProperty(event, self, markers) + elif event in [items.ItemChangedType.VISIBLE, + items.ItemChangedType.SELECTABLE]: + markers = [self._markerMin, self._markerMax, self._markerCen] + self._updateItemProperty(event, self, markers) + super(HorizontalRangeROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + markers = [self._markerMin, self._markerMax, self._markerCen] + for m in markers: + m.setColor(style.getColor()) + m.setLineWidth(style.getLineWidth()) + + def _updateText(self): + text = self.getName() + if self.isEditable(): + self._markerMin.setText("") + self._markerCen.setText(text) + else: + self._markerMin.setText(text) + self._markerCen.setText("") + + def _updateEditable(self): + editable = self.isEditable() + self._markerMin._setDraggable(editable) + self._markerMax._setDraggable(editable) + self._markerCen._setDraggable(editable) + if self.isEditable(): + self._markerMin.sigItemChanged.connect(self._minPositionChanged) + self._markerMax.sigItemChanged.connect(self._maxPositionChanged) + self._markerCen.sigItemChanged.connect(self._cenPositionChanged) + self._markerCen.setLineStyle(":") + else: + self._markerMin.sigItemChanged.disconnect(self._minPositionChanged) + self._markerMax.sigItemChanged.disconnect(self._maxPositionChanged) + self._markerCen.sigItemChanged.disconnect(self._cenPositionChanged) + self._markerCen.setLineStyle(" ") + + def _updatePos(self, vmin, vmax, force=False): + """Update marker position and emit signal. + + :param float vmin: + :param float vmax: + :param bool force: + True to update even if already at the right position. + """ + if not force and numpy.array_equal((vmin, vmax), self.getRange()): + return # Nothing has changed + + center = (vmin + vmax) * 0.5 + with self.__filterReentrant: + with utils.blockSignals(self._markerMin): + self._markerMin.setPosition(vmin, 0) + with utils.blockSignals(self._markerCen): + self._markerCen.setPosition(center, 0) + with utils.blockSignals(self._markerMax): + self._markerMax.setPosition(vmax, 0) + self.sigRegionChanged.emit() + + def setRange(self, vmin, vmax): + """Set the range of this ROI. + + :param float vmin: Staring location of the range + :param float vmax: Ending location of the range + """ + if vmin is None or vmax is None: + err = "Can't set vmin or vmax to None" + raise ValueError(err) + if vmin > vmax: + err = "Can't set vmin and vmax because vmin >= vmax " \ + "vmin = %s, vmax = %s" % (vmin, vmax) + raise ValueError(err) + self._updatePos(vmin, vmax) + + def getRange(self): + """Returns the range of this ROI. + + :rtype: Tuple[float,float] + """ + vmin = self.getMin() + vmax = self.getMax() + return vmin, vmax + + def setMin(self, vmin): + """Set the min of this ROI. + + :param float vmin: New min + """ + vmax = self.getMax() + self._updatePos(vmin, vmax) + + def getMin(self): + """Returns the min value of this ROI. + + :rtype: float + """ + return self._markerMin.getPosition()[0] + + def setMax(self, vmax): + """Set the max of this ROI. + + :param float vmax: New max + """ + vmin = self.getMin() + self._updatePos(vmin, vmax) + + def getMax(self): + """Returns the max value of this ROI. + + :rtype: float + """ + return self._markerMax.getPosition()[0] + + def setCenter(self, center): + """Set the center of this ROI. + + :param float center: New center + """ + vmin, vmax = self.getRange() + previousCenter = (vmin + vmax) * 0.5 + delta = center - previousCenter + self._updatePos(vmin + delta, vmax + delta) + + def getCenter(self): + """Returns the center location of this ROI. + + :rtype: float + """ + vmin, vmax = self.getRange() + return (vmin + vmax) * 0.5 + + def __positionMinConstraint(self, x, y): + """Constraint of the min marker""" + if self.__filterReentrant.locked(): + # Ignore the constraint when we set an explicit value + return x, y + vmax = self.getMax() + if vmax is None: + return x, y + return min(x, vmax), y + + def __positionMaxConstraint(self, x, y): + """Constraint of the max marker""" + if self.__filterReentrant.locked(): + # Ignore the constraint when we set an explicit value + return x, y + vmin = self.getMin() + if vmin is None: + return x, y + return max(x, vmin), y + + def _minPositionChanged(self, event): + """Handle position changed events of the marker""" + if event is items.ItemChangedType.POSITION: + marker = self.sender() + self._updatePos(marker.getXPosition(), self.getMax(), force=True) + + def _maxPositionChanged(self, event): + """Handle position changed events of the marker""" + if event is items.ItemChangedType.POSITION: + marker = self.sender() + self._updatePos(self.getMin(), marker.getXPosition(), force=True) + + def _cenPositionChanged(self, event): + """Handle position changed events of the marker""" + if event is items.ItemChangedType.POSITION: + marker = self.sender() + self.setCenter(marker.getXPosition()) + + @docstring(HandleBasedROI) + def contains(self, position): + return self.getMin() <= position[0] <= self.getMax() + + def __str__(self): + vrange = self.getRange() + params = 'min: %f; max: %f' % vrange + return "%s(%s)" % (self.__class__.__name__, params) diff --git a/src/silx/gui/plot/items/scatter.py b/src/silx/gui/plot/items/scatter.py new file mode 100644 index 0000000..fdc66f7 --- /dev/null +++ b/src/silx/gui/plot/items/scatter.py @@ -0,0 +1,1002 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`Scatter` item of the :class:`Plot`. +""" + +from __future__ import division + + +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "29/03/2017" + + +from collections import namedtuple +import logging +import threading +import numpy + +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, CancelledError + +from ....utils.proxy import docstring +from ....math.combo import min_max +from ....math.histogram import Histogramnd +from ....utils.weakref import WeakList +from .._utils.delaunay import delaunay +from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn +from .axis import Axis +from ._pick import PickingResult + + +_logger = logging.getLogger(__name__) + + +class _GreedyThreadPoolExecutor(ThreadPoolExecutor): + """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method. + """ + + def __init__(self, *args, **kwargs): + super(_GreedyThreadPoolExecutor, self).__init__(*args, **kwargs) + self.__futures = defaultdict(WeakList) + self.__lock = threading.RLock() + + def submit_greedy(self, queue, fn, *args, **kwargs): + """Same as :meth:`submit` but cancel previous tasks in given queue. + + This means that when a new task is submitted for a given queue, + all other pending tasks of that queue are cancelled. + + :param queue: Identifier of the queue. This must be hashable. + :param callable fn: The callable to call with provided extra arguments + :return: Future corresponding to this task + :rtype: concurrent.futures.Future + """ + with self.__lock: + # Cancel previous tasks in given queue + for future in self.__futures.pop(queue, []): + if not future.done(): + future.cancel() + + future = super(_GreedyThreadPoolExecutor, self).submit( + fn, *args, **kwargs) + self.__futures[queue].append(future) + + return future + + +# Functions to guess grid shape from coordinates + +def _get_z_line_length(array): + """Return length of line if array is a Z-like 2D regular grid. + + :param numpy.ndarray array: The 1D array of coordinates to check + :return: 0 if no line length could be found, + else the number of element per line. + :rtype: int + """ + sign = numpy.sign(numpy.diff(array)) + if len(sign) == 0 or sign[0] == 0: # We don't handle that + return 0 + # Check this way to account for 0 sign (i.e., diff == 0) + beginnings = numpy.where(sign == - sign[0])[0] + 1 + if len(beginnings) == 0: + return 0 + length = beginnings[0] + if numpy.all(numpy.equal(numpy.diff(beginnings), length)): + return length + return 0 + + +def _guess_z_grid_shape(x, y): + """Guess the shape of a grid from (x, y) coordinates. + + The grid might contain more elements than x and y, + as the last line might be partly filled. + + :param numpy.ndarray x: + :paran numpy.ndarray y: + :returns: (order, (height, width)) of the regular grid, + or None if could not guess one. + 'order' is 'row' if X (i.e., column) is the fast dimension, else 'column'. + :rtype: Union[List(str,int),None] + """ + width = _get_z_line_length(x) + if width != 0: + return 'row', (int(numpy.ceil(len(x) / width)), width) + else: + height = _get_z_line_length(y) + if height != 0: + return 'column', (height, int(numpy.ceil(len(y) / height))) + return None + + +def is_monotonic(array): + """Returns whether array is monotonic (increasing or decreasing). + + :param numpy.ndarray array: 1D array-like container. + :returns: 1 if array is monotonically increasing, + -1 if array is monotonically decreasing, + 0 if array is not monotonic + :rtype: int + """ + diff = numpy.diff(numpy.ravel(array)) + with numpy.errstate(invalid='ignore'): + if numpy.all(diff >= 0): + return 1 + elif numpy.all(diff <= 0): + return -1 + else: + return 0 + + +def _guess_grid(x, y): + """Guess a regular grid from the points. + + Result convention is (x, y) + + :param numpy.ndarray x: X coordinates of the points + :param numpy.ndarray y: Y coordinates of the points + :returns: (order, (height, width) + order is 'row' or 'column' + :rtype: Union[List[str,List[int]],None] + """ + x, y = numpy.ravel(x), numpy.ravel(y) + + guess = _guess_z_grid_shape(x, y) + if guess is not None: + return guess + + else: + # Cannot guess a regular grid + # Let's assume it's a single line + order = 'row' # or 'column' doesn't matter for a single line + y_monotonic = is_monotonic(y) + if is_monotonic(x) or y_monotonic: # we can guess a line + x_min, x_max = min_max(x) + y_min, y_max = min_max(y) + + if not y_monotonic or x_max - x_min >= y_max - y_min: + # x only is monotonic or both are and X varies more + # line along X + shape = 1, len(x) + else: + # y only is monotonic or both are and Y varies more + # line along Y + shape = len(y), 1 + + else: # Cannot guess a line from the points + return None + + return order, shape + + +def _quadrilateral_grid_coords(points): + """Compute an irregular grid of quadrilaterals from a set of points + + The input points are expected to lie on a grid. + + :param numpy.ndarray points: + 3D data set of 2D input coordinates (height, width, 2) + height and width must be at least 2. + :return: 3D dataset of 2D coordinates of the grid (height+1, width+1, 2) + """ + assert points.ndim == 3 + assert points.shape[0] >= 2 + assert points.shape[1] >= 2 + assert points.shape[2] == 2 + + dim0, dim1 = points.shape[:2] + grid_points = numpy.zeros((dim0 + 1, dim1 + 1, 2), dtype=numpy.float64) + + # Compute inner points as mean of 4 neighbours + neighbour_view = numpy.lib.stride_tricks.as_strided( + points, + shape=(dim0 - 1, dim1 - 1, 2, 2, points.shape[2]), + strides=points.strides[:2] + points.strides[:2] + points.strides[-1:], writeable=False) + inner_points = numpy.mean(neighbour_view, axis=(2, 3)) + grid_points[1:-1, 1:-1] = inner_points + + # Compute 'vertical' sides + # Alternative: grid_points[1:-1, [0, -1]] = points[:-1, [0, -1]] + points[1:, [0, -1]] - inner_points[:, [0, -1]] + grid_points[1:-1, [0, -1], 0] = points[:-1, [0, -1], 0] + points[1:, [0, -1], 0] - inner_points[:, [0, -1], 0] + grid_points[1:-1, [0, -1], 1] = inner_points[:, [0, -1], 1] + + # Compute 'horizontal' sides + grid_points[[0, -1], 1:-1, 0] = inner_points[[0, -1], :, 0] + grid_points[[0, -1], 1:-1, 1] = points[[0, -1], :-1, 1] + points[[0, -1], 1:, 1] - inner_points[[0, -1], :, 1] + + # Compute corners + d0, d1 = [0, 0, -1, -1], [0, -1, -1, 0] + grid_points[d0, d1] = 2 * points[d0, d1] - inner_points[d0, d1] + return grid_points + + +def _quadrilateral_grid_as_triangles(points): + """Returns the points and indices to make a grid of quadirlaterals + + :param numpy.ndarray points: + 3D array of points (height, width, 2) + :return: triangle corners (4 * N, 2), triangle indices (2 * N, 3) + With N = height * width, the number of input points + """ + nbpoints = numpy.prod(points.shape[:2]) + + grid = _quadrilateral_grid_coords(points) + coords = numpy.empty((4 * nbpoints, 2), dtype=grid.dtype) + coords[::4] = grid[:-1, :-1].reshape(-1, 2) + coords[1::4] = grid[1:, :-1].reshape(-1, 2) + coords[2::4] = grid[:-1, 1:].reshape(-1, 2) + coords[3::4] = grid[1:, 1:].reshape(-1, 2) + + indices = numpy.empty((2 * nbpoints, 3), dtype=numpy.uint32) + indices[::2, 0] = numpy.arange(0, 4 * nbpoints, 4) + indices[::2, 1] = numpy.arange(1, 4 * nbpoints, 4) + indices[::2, 2] = numpy.arange(2, 4 * nbpoints, 4) + indices[1::2, 0] = indices[::2, 1] + indices[1::2, 1] = indices[::2, 2] + indices[1::2, 2] = numpy.arange(3, 4 * nbpoints, 4) + + return coords, indices + + +_RegularGridInfo = namedtuple( + '_RegularGridInfo', ['bounds', 'origin', 'scale', 'shape', 'order']) + + +_HistogramInfo = namedtuple( + '_HistogramInfo', ['mean', 'count', 'sum', 'origin', 'scale', 'shape']) + + +class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): + """Description of a scatter""" + + _DEFAULT_SELECTABLE = True + """Default selectable state for scatter plots""" + + _SUPPORTED_SCATTER_VISUALIZATION = ( + ScatterVisualizationMixIn.Visualization.POINTS, + ScatterVisualizationMixIn.Visualization.SOLID, + ScatterVisualizationMixIn.Visualization.REGULAR_GRID, + ScatterVisualizationMixIn.Visualization.IRREGULAR_GRID, + ScatterVisualizationMixIn.Visualization.BINNED_STATISTIC, + ) + """Overrides supported Visualizations""" + + def __init__(self): + PointsBase.__init__(self) + ColormapMixIn.__init__(self) + ScatterVisualizationMixIn.__init__(self) + self._value = () + self.__alpha = None + # Cache Delaunay triangulation future object + self.__delaunayFuture = None + # Cache interpolator future object + self.__interpolatorFuture = None + self.__executor = None + + # Cache triangles: x, y, indices + self.__cacheTriangles = None, None, None + + # Cache regular grid and histogram info + self.__cacheRegularGridInfo = None + self.__cacheHistogramInfo = None + + def _updateColormappedData(self): + """Update the colormapped data, to be called when changed""" + if self.getVisualization() is self.Visualization.BINNED_STATISTIC: + histoInfo = self.__getHistogramInfo() + if histoInfo is None: + data = None + else: + data = getattr( + histoInfo, + self.getVisualizationParameter( + self.VisualizationParameter.BINNED_STATISTIC_FUNCTION)) + else: + data = self.getValueData(copy=False) + self._setColormappedData(data, copy=False) + + @docstring(ScatterVisualizationMixIn) + def setVisualization(self, mode): + previous = self.getVisualization() + if super().setVisualization(mode): + if (bool(mode is self.Visualization.BINNED_STATISTIC) ^ + bool(previous is self.Visualization.BINNED_STATISTIC)): + self._updateColormappedData() + return True + else: + return False + + @docstring(ScatterVisualizationMixIn) + def setVisualizationParameter(self, parameter, value): + parameter = self.VisualizationParameter.from_value(parameter) + + if super(Scatter, self).setVisualizationParameter(parameter, value): + if parameter in (self.VisualizationParameter.GRID_BOUNDS, + self.VisualizationParameter.GRID_MAJOR_ORDER, + self.VisualizationParameter.GRID_SHAPE): + self.__cacheRegularGridInfo = None + + if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE, + self.VisualizationParameter.BINNED_STATISTIC_FUNCTION, + self.VisualizationParameter.DATA_BOUNDS_HINT): + if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE, + self.VisualizationParameter.DATA_BOUNDS_HINT): + self.__cacheHistogramInfo = None # Clean-up cache + if self.getVisualization() is self.Visualization.BINNED_STATISTIC: + self._updateColormappedData() + return True + else: + return False + + @docstring(ScatterVisualizationMixIn) + def getCurrentVisualizationParameter(self, parameter): + value = self.getVisualizationParameter(parameter) + if (parameter is self.VisualizationParameter.DATA_BOUNDS_HINT or + value is not None): + return value # Value has been set, return it + + elif parameter is self.VisualizationParameter.GRID_BOUNDS: + grid = self.__getRegularGridInfo() + return None if grid is None else grid.bounds + + elif parameter is self.VisualizationParameter.GRID_MAJOR_ORDER: + grid = self.__getRegularGridInfo() + return None if grid is None else grid.order + + elif parameter is self.VisualizationParameter.GRID_SHAPE: + grid = self.__getRegularGridInfo() + return None if grid is None else grid.shape + + elif parameter is self.VisualizationParameter.BINNED_STATISTIC_SHAPE: + info = self.__getHistogramInfo() + return None if info is None else info.shape + + else: + raise NotImplementedError() + + def __getRegularGridInfo(self): + """Get grid info""" + if self.__cacheRegularGridInfo is None: + shape = self.getVisualizationParameter( + self.VisualizationParameter.GRID_SHAPE) + order = self.getVisualizationParameter( + self.VisualizationParameter.GRID_MAJOR_ORDER) + if shape is None or order is None: + guess = _guess_grid(self.getXData(copy=False), + self.getYData(copy=False)) + if guess is None: + _logger.warning( + 'Cannot guess a grid: Cannot display as regular grid image') + return None + if shape is None: + shape = guess[1] + if order is None: + order = guess[0] + + nbpoints = len(self.getXData(copy=False)) + if nbpoints > shape[0] * shape[1]: + # More data points that provided grid shape: enlarge grid + _logger.warning( + "More data points than provided grid shape size: extends grid") + dim0, dim1 = shape + if order == 'row': # keep dim1, enlarge dim0 + dim0 = nbpoints // dim1 + (1 if nbpoints % dim1 else 0) + else: # keep dim0, enlarge dim1 + dim1 = nbpoints // dim0 + (1 if nbpoints % dim0 else 0) + shape = dim0, dim1 + + bounds = self.getVisualizationParameter( + self.VisualizationParameter.GRID_BOUNDS) + if bounds is None: + x, y = self.getXData(copy=False), self.getYData(copy=False) + min_, max_ = min_max(x) + xRange = (min_, max_) if (x[0] - min_) < (max_ - x[0]) else (max_, min_) + min_, max_ = min_max(y) + yRange = (min_, max_) if (y[0] - min_) < (max_ - y[0]) else (max_, min_) + bounds = (xRange[0], yRange[0]), (xRange[1], yRange[1]) + + begin, end = bounds + scale = ((end[0] - begin[0]) / max(1, shape[1] - 1), + (end[1] - begin[1]) / max(1, shape[0] - 1)) + if scale[0] == 0 and scale[1] == 0: + scale = 1., 1. + elif scale[0] == 0: + scale = scale[1], scale[1] + elif scale[1] == 0: + scale = scale[0], scale[0] + + origin = begin[0] - 0.5 * scale[0], begin[1] - 0.5 * scale[1] + + self.__cacheRegularGridInfo = _RegularGridInfo( + bounds=bounds, origin=origin, scale=scale, shape=shape, order=order) + + return self.__cacheRegularGridInfo + + def __getHistogramInfo(self): + """Get histogram info""" + if self.__cacheHistogramInfo is None: + shape = self.getVisualizationParameter( + self.VisualizationParameter.BINNED_STATISTIC_SHAPE) + if shape is None: + shape = 100, 100 # TODO compute auto shape + + x, y, values = self.getData(copy=False)[:3] + if len(x) == 0: # No histogram + return None + + if not numpy.issubdtype(x.dtype, numpy.floating): + x = x.astype(numpy.float64) + if not numpy.issubdtype(y.dtype, numpy.floating): + y = y.astype(numpy.float64) + if not numpy.issubdtype(values.dtype, numpy.floating): + values = values.astype(numpy.float64) + + ranges = (tuple(min_max(y, finite=True)), + tuple(min_max(x, finite=True))) + rangesHint = self.getVisualizationParameter( + self.VisualizationParameter.DATA_BOUNDS_HINT) + if rangesHint is not None: + ranges = tuple((min(dataMin, hintMin), max(dataMax, hintMax)) + for (dataMin, dataMax), (hintMin, hintMax) in zip(ranges, rangesHint)) + + points = numpy.transpose(numpy.array((y, x))) + counts, sums, bin_edges = Histogramnd( + points, + histo_range=ranges, + n_bins=shape, + weights=values) + yEdges, xEdges = bin_edges + origin = xEdges[0], yEdges[0] + scale = ((xEdges[-1] - xEdges[0]) / (len(xEdges) - 1), + (yEdges[-1] - yEdges[0]) / (len(yEdges) - 1)) + + with numpy.errstate(divide='ignore', invalid='ignore'): + histo = sums / counts + + self.__cacheHistogramInfo = _HistogramInfo( + mean=histo, count=counts, sum=sums, + origin=origin, scale=scale, shape=shape) + + return self.__cacheHistogramInfo + + def __applyColormapToData(self): + """Compute colors by applying colormap to values. + + :returns: Array of RGBA colors + """ + cmap = self.getColormap() + rgbacolors = cmap.applyToData(self) + + if self.__alpha is not None: + rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8) + return rgbacolors + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + # Filter-out values <= 0 + xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData( + copy=False, displayed=True) + + # Remove not finite numbers (this includes filtered out x, y <= 0) + mask = numpy.logical_and(numpy.isfinite(xFiltered), numpy.isfinite(yFiltered)) + xFiltered = xFiltered[mask] + yFiltered = yFiltered[mask] + + if len(xFiltered) == 0: + return None # No data to display, do not add renderer to backend + + visualization = self.getVisualization() + + if visualization is self.Visualization.BINNED_STATISTIC: + plot = self.getPlot() + if (plot is None or + plot.getXAxis().getScale() != Axis.LINEAR or + plot.getYAxis().getScale() != Axis.LINEAR): + # Those visualizations are not available with log scaled axes + return None + + histoInfo = self.__getHistogramInfo() + if histoInfo is None: + return None + data = getattr(histoInfo, self.getVisualizationParameter( + self.VisualizationParameter.BINNED_STATISTIC_FUNCTION)) + + return backend.addImage( + data=data, + origin=histoInfo.origin, + scale=histoInfo.scale, + colormap=self.getColormap(), + alpha=self.getAlpha()) + + elif visualization is self.Visualization.POINTS: + rgbacolors = self.__applyColormapToData() + return backend.addCurve(xFiltered, yFiltered, + color=rgbacolors[mask], + symbol=self.getSymbol(), + linewidth=0, + linestyle="", + yaxis='left', + xerror=xerror, + yerror=yerror, + fill=False, + alpha=self.getAlpha(), + symbolsize=self.getSymbolSize(), + baseline=None) + + else: + plot = self.getPlot() + if (plot is None or + plot.getXAxis().getScale() != Axis.LINEAR or + plot.getYAxis().getScale() != Axis.LINEAR): + # Those visualizations are not available with log scaled axes + return None + + if visualization is self.Visualization.SOLID: + triangulation = self._getDelaunay().result() + if triangulation is None: + _logger.warning( + 'Cannot get a triangulation: Cannot display as solid surface') + return None + else: + rgbacolors = self.__applyColormapToData() + triangles = triangulation.simplices.astype(numpy.int32) + return backend.addTriangles(xFiltered, + yFiltered, + triangles, + color=rgbacolors[mask], + alpha=self.getAlpha()) + + elif visualization is self.Visualization.REGULAR_GRID: + gridInfo = self.__getRegularGridInfo() + if gridInfo is None: + return None + + dim0, dim1 = gridInfo.shape + if gridInfo.order == 'column': # transposition needed + dim0, dim1 = dim1, dim0 + + values = self.getValueData(copy=False) + if self.__alpha is None and len(values) == dim0 * dim1: + image = values.reshape(dim0, dim1) + else: + # The points do not fill the whole image + if (self.__alpha is None and + numpy.issubdtype(values.dtype, numpy.floating)): + image = numpy.empty(dim0 * dim1, dtype=values.dtype) + image[:len(values)] = values + image[len(values):] = float('nan') # Transparent pixels + image.shape = dim0, dim1 + else: # Per value alpha or no NaN, so convert to RGBA + rgbacolors = self.__applyColormapToData() + image = numpy.empty((dim0 * dim1, 4), dtype=numpy.uint8) + image[:len(rgbacolors)] = rgbacolors + image[len(rgbacolors):] = (0, 0, 0, 0) # Transparent pixels + image.shape = dim0, dim1, 4 + + if gridInfo.order == 'column': + if image.ndim == 2: + image = numpy.transpose(image) + else: + image = numpy.transpose(image, axes=(1, 0, 2)) + + if image.ndim == 2: + colormap = self.getColormap() + if colormap.isAutoscale(): + # Avoid backend to compute autoscale: use item cache + colormap = colormap.copy() + colormap.setVRange(*colormap.getColormapRange(self)) + else: + colormap = None + + return backend.addImage( + data=image, + origin=gridInfo.origin, + scale=gridInfo.scale, + colormap=colormap, + alpha=self.getAlpha()) + + elif visualization is self.Visualization.IRREGULAR_GRID: + gridInfo = self.__getRegularGridInfo() + if gridInfo is None: + return None + + shape = gridInfo.shape + if shape is None: # No shape, no display + return None + + rgbacolors = self.__applyColormapToData() + + nbpoints = len(xFiltered) + if nbpoints == 1: + # single point, render as a square points + return backend.addCurve(xFiltered, yFiltered, + color=rgbacolors[mask], + symbol='s', + linewidth=0, + linestyle="", + yaxis='left', + xerror=None, + yerror=None, + fill=False, + alpha=self.getAlpha(), + symbolsize=7, + baseline=None) + + # Make shape include all points + gridOrder = gridInfo.order + if nbpoints != numpy.prod(shape): + if gridOrder == 'row': + shape = int(numpy.ceil(nbpoints / shape[1])), shape[1] + else: # column-major order + shape = shape[0], int(numpy.ceil(nbpoints / shape[0])) + + if shape[0] < 2 or shape[1] < 2: # Single line, at least 2 points + points = numpy.ones((2, nbpoints, 2), dtype=numpy.float64) + # Use row/column major depending on shape, not on info value + gridOrder = 'row' if shape[0] == 1 else 'column' + + if gridOrder == 'row': + points[0, :, 0] = xFiltered + points[0, :, 1] = yFiltered + else: # column-major order + points[0, :, 0] = yFiltered + points[0, :, 1] = xFiltered + + # Add a second line that will be clipped in the end + points[1, :-1] = points[0, :-1] + numpy.cross( + points[0, 1:] - points[0, :-1], (0., 0., 1.))[:, :2] + points[1, -1] = points[0, -1] + numpy.cross( + points[0, -1] - points[0, -2], (0., 0., 1.))[:2] + + points.shape = 2, nbpoints, 2 # Use same shape for both orders + coords, indices = _quadrilateral_grid_as_triangles(points) + + elif gridOrder == 'row': # row-major order + if nbpoints != numpy.prod(shape): + points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64) + points[:nbpoints, 0] = xFiltered + points[:nbpoints, 1] = yFiltered + # Index of last element of last fully filled row + index = (nbpoints // shape[1]) * shape[1] + points[nbpoints:, 0] = xFiltered[index - (numpy.prod(shape) - nbpoints):index] + points[nbpoints:, 1] = yFiltered[-1] + else: + points = numpy.transpose((xFiltered, yFiltered)) + points.shape = shape[0], shape[1], 2 + + else: # column-major order + if nbpoints != numpy.prod(shape): + points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64) + points[:nbpoints, 0] = yFiltered + points[:nbpoints, 1] = xFiltered + # Index of last element of last fully filled column + index = (nbpoints // shape[0]) * shape[0] + points[nbpoints:, 0] = yFiltered[index - (numpy.prod(shape) - nbpoints):index] + points[nbpoints:, 1] = xFiltered[-1] + else: + points = numpy.transpose((yFiltered, xFiltered)) + points.shape = shape[1], shape[0], 2 + + coords, indices = _quadrilateral_grid_as_triangles(points) + + # Remove unused extra triangles + coords = coords[:4*nbpoints] + indices = indices[:2*nbpoints] + + if gridOrder == 'row': + x, y = coords[:, 0], coords[:, 1] + else: # column-major order + y, x = coords[:, 0], coords[:, 1] + + rgbacolors = rgbacolors[mask] # Filter-out not finite points + gridcolors = numpy.empty( + (4 * nbpoints, rgbacolors.shape[-1]), dtype=rgbacolors.dtype) + for first in range(4): + gridcolors[first::4] = rgbacolors[:nbpoints] + + return backend.addTriangles(x, + y, + indices, + color=gridcolors, + alpha=self.getAlpha()) + + else: + _logger.error("Unhandled visualization %s", visualization) + return None + + @docstring(PointsBase) + def pick(self, x, y): + result = super(Scatter, self).pick(x, y) + + if result is not None: + visualization = self.getVisualization() + + if visualization is self.Visualization.IRREGULAR_GRID: + # Specific handling of picking for the irregular grid mode + index = result.getIndices(copy=False)[0] // 4 + result = PickingResult(self, (index,)) + + elif visualization is self.Visualization.REGULAR_GRID: + # Specific handling of picking for the regular grid mode + picked = result.getIndices(copy=False) + if picked is None: + return None + row, column = picked[0][0], picked[1][0] + + gridInfo = self.__getRegularGridInfo() + if gridInfo is None: + return None + + if gridInfo.order == 'row': + index = row * gridInfo.shape[1] + column + else: + index = row + column * gridInfo.shape[0] + if index >= len(self.getXData(copy=False)): # OK as long as not log scale + return None # Image can be larger than scatter + + result = PickingResult(self, (index,)) + + elif visualization is self.Visualization.BINNED_STATISTIC: + picked = result.getIndices(copy=False) + if picked is None or len(picked) == 0 or len(picked[0]) == 0: + return None + row, col = picked[0][0], picked[1][0] + histoInfo = self.__getHistogramInfo() + if histoInfo is None: + return None + sx, sy = histoInfo.scale + ox, oy = histoInfo.origin + xdata = self.getXData(copy=False) + ydata = self.getYData(copy=False) + indices = numpy.nonzero(numpy.logical_and( + numpy.logical_and(xdata >= ox + sx * col, xdata < ox + sx * (col + 1)), + numpy.logical_and(ydata >= oy + sy * row, ydata < oy + sy * (row + 1))))[0] + result = None if len(indices) == 0 else PickingResult(self, indices) + + return result + + def __getExecutor(self): + """Returns async greedy executor + + :rtype: _GreedyThreadPoolExecutor + """ + if self.__executor is None: + self.__executor = _GreedyThreadPoolExecutor(max_workers=2) + return self.__executor + + def _getDelaunay(self): + """Returns a :class:`Future` which result is the Delaunay object. + + :rtype: concurrent.futures.Future + """ + if self.__delaunayFuture is None or self.__delaunayFuture.cancelled(): + # Need to init a new delaunay + x, y = self.getData(copy=False)[:2] + # Remove not finite points + mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y)) + + self.__delaunayFuture = self.__getExecutor().submit_greedy( + 'delaunay', delaunay, x[mask], y[mask]) + + return self.__delaunayFuture + + @staticmethod + def __initInterpolator(delaunayFuture, values): + """Returns an interpolator for the given data points + + :param concurrent.futures.Future delaunayFuture: + Future object which result is a Delaunay object + :param numpy.ndarray values: The data value of valid points. + :rtype: Union[callable,None] + """ + # Wait for Delaunay to complete + try: + triangulation = delaunayFuture.result() + except CancelledError: + triangulation = None + + if triangulation is None: + interpolator = None # Error case + else: + # Lazy-loading of interpolator + try: + from scipy.interpolate import LinearNDInterpolator + except ImportError: + LinearNDInterpolator = None + + if LinearNDInterpolator is not None: + interpolator = LinearNDInterpolator(triangulation, values) + + # First call takes a while, do it here + interpolator([(0., 0.)]) + + else: + # Fallback using matplotlib interpolator + import matplotlib.tri + + x, y = triangulation.points.T + tri = matplotlib.tri.Triangulation( + x, y, triangles=triangulation.simplices) + mplInterpolator = matplotlib.tri.LinearTriInterpolator( + tri, values) + + # Wrap interpolator to have same API as scipy's one + def interpolator(points): + return mplInterpolator(*points.T) + + return interpolator + + def _getInterpolator(self): + """Returns a :class:`Future` which result is the interpolator. + + The interpolator is a callable taking an array Nx2 of points + as a single argument. + The :class:`Future` result is None in case the interpolator cannot + be initialized. + + :rtype: concurrent.futures.Future + """ + if (self.__interpolatorFuture is None or + self.__interpolatorFuture.cancelled()): + # Need to init a new interpolator + x, y, values = self.getData(copy=False)[:3] + # Remove not finite points + mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y)) + x, y, values = x[mask], y[mask], values[mask] + + self.__interpolatorFuture = self.__getExecutor().submit_greedy( + 'interpolator', + self.__initInterpolator, self._getDelaunay(), values) + return self.__interpolatorFuture + + def _logFilterData(self, xPositive, yPositive): + """Filter out values with x or y <= 0 on log axes + + :param bool xPositive: True to filter arrays according to X coords. + :param bool yPositive: True to filter arrays according to Y coords. + :return: The filtered arrays or unchanged object if not filtering needed + :rtype: (x, y, value, xerror, yerror) + """ + # overloaded from PointsBase to filter also value. + value = self.getValueData(copy=False) + + if xPositive or yPositive: + clipped = self._getClippingBoolArray(xPositive, yPositive) + + if numpy.any(clipped): + # copy to keep original array and convert to float + value = numpy.array(value, copy=True, dtype=numpy.float64) + value[clipped] = numpy.nan + + x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive) + + return x, y, value, xerror, yerror + + def getValueData(self, copy=True): + """Returns the value assigned to the scatter data points. + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._value, copy=copy) + + def getAlphaData(self, copy=True): + """Returns the alpha (transparency) assigned to the scatter data points. + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self.__alpha, copy=copy) + + def getData(self, copy=True, displayed=False): + """Returns the x, y coordinates and the value of the data points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :param bool displayed: True to only get curve points that are displayed + in the plot. Default: False. + Note: If plot has log scale, negative points + are not displayed. + :returns: (x, y, value, xerror, yerror) + :rtype: 5-tuple of numpy.ndarray + """ + if displayed: + data = self._getCachedData() + if data is not None: + assert len(data) == 5 + return data + + return (self.getXData(copy), + self.getYData(copy), + self.getValueData(copy), + self.getXErrorData(copy), + self.getYErrorData(copy)) + + # reimplemented from PointsBase to handle `value` + def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True): + """Set the data of the scatter. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + :param numpy.ndarray y: The data corresponding to the y coordinates. + :param numpy.ndarray value: The data corresponding to the value of + the data points. + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param alpha: Values with the transparency (between 0 and 1) + :type alpha: A float, or a numpy.ndarray of float32 + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + value = numpy.array(value, copy=copy) + assert value.ndim == 1 + assert len(x) == len(value) + + # Convert complex data + if numpy.iscomplexobj(value): + _logger.warning( + 'Converting value data to absolute value to plot it.') + value = numpy.absolute(value) + + # Reset triangulation and interpolator + if self.__delaunayFuture is not None: + self.__delaunayFuture.cancel() + self.__delaunayFuture = None + if self.__interpolatorFuture is not None: + self.__interpolatorFuture.cancel() + self.__interpolatorFuture = None + + # Data changed, this needs update + self.__cacheRegularGridInfo = None + self.__cacheHistogramInfo = None + + self._value = value + + if alpha is not None: + # Make sure alpha is an array of float in [0, 1] + alpha = numpy.array(alpha, copy=copy) + assert alpha.ndim == 1 + assert len(x) == len(alpha) + if alpha.dtype.kind != 'f': + alpha = alpha.astype(numpy.float32) + if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)): + alpha = numpy.clip(alpha, 0., 1.) + self.__alpha = alpha + + # set x, y, xerror, yerror + + # call self._updated + plot._invalidateDataRange() + PointsBase.setData(self, x, y, xerror, yerror, copy) + + self._updateColormappedData() diff --git a/src/silx/gui/plot/items/shape.py b/src/silx/gui/plot/items/shape.py new file mode 100644 index 0000000..00ac5f5 --- /dev/null +++ b/src/silx/gui/plot/items/shape.py @@ -0,0 +1,287 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`Shape` item of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/12/2018" + + +import logging + +import numpy + +from ... import colors +from .core import ( + Item, DataItem, + ColorMixIn, FillMixIn, ItemChangedType, LineMixIn, YAxisMixIn) + + +_logger = logging.getLogger(__name__) + + +# TODO probably make one class for each kind of shape +# TODO check fill:polygon/polyline + fill = duplicated +class Shape(Item, ColorMixIn, FillMixIn, LineMixIn): + """Description of a shape item + + :param str type_: The type of shape in: + 'hline', 'polygon', 'rectangle', 'vline', 'polylines' + """ + + def __init__(self, type_): + Item.__init__(self) + ColorMixIn.__init__(self) + FillMixIn.__init__(self) + LineMixIn.__init__(self) + self._overlay = False + assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polylines') + self._type = type_ + self._points = () + self._lineBgColor = None + + self._handle = None + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + points = self.getPoints(copy=False) + x, y = points.T[0], points.T[1] + return backend.addShape(x, + y, + shape=self.getType(), + color=self.getColor(), + fill=self.isFill(), + overlay=self.isOverlay(), + linestyle=self.getLineStyle(), + linewidth=self.getLineWidth(), + linebgcolor=self.getLineBgColor()) + + def isOverlay(self): + """Return true if shape is drawn as an overlay + + :rtype: bool + """ + return self._overlay + + def setOverlay(self, overlay): + """Set the overlay state of the shape + + :param bool overlay: True to make it an overlay + """ + overlay = bool(overlay) + if overlay != self._overlay: + self._overlay = overlay + self._updated(ItemChangedType.OVERLAY) + + def getType(self): + """Returns the type of shape to draw. + + One of: 'hline', 'polygon', 'rectangle', 'vline', 'polylines' + + :rtype: str + """ + return self._type + + def getPoints(self, copy=True): + """Get the control points of the shape. + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :return: Array of point coordinates + :rtype: numpy.ndarray with 2 dimensions + """ + return numpy.array(self._points, copy=copy) + + def setPoints(self, points, copy=True): + """Set the point coordinates + + :param numpy.ndarray points: Array of point coordinates + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :return: + """ + self._points = numpy.array(points, copy=copy) + self._updated(ItemChangedType.DATA) + + def getLineBgColor(self): + """Returns the RGBA color of the item + :rtype: 4-tuple of float in [0, 1] or array of colors + """ + return self._lineBgColor + + def setLineBgColor(self, color, copy=True): + """Set item color + :param color: color(s) to be used + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in colors.py + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + if color is not None: + if isinstance(color, str): + color = colors.rgba(color) + else: + color = numpy.array(color, copy=copy) + # TODO more checks + improve color array support + if color.ndim == 1: # Single RGBA color + color = colors.rgba(color) + else: # Array of colors + assert color.ndim == 2 + + self._lineBgColor = color + self._updated(ItemChangedType.LINE_BG_COLOR) + + +class BoundingRect(DataItem, YAxisMixIn): + """An invisible shape which enforce the plot view to display the defined + space on autoscale. + + This item do not display anything. But if the visible property is true, + this bounding box is used by the plot, if not, the bounding box is + ignored. That's the default behaviour for plot items. + + It can be applied on the "left" or "right" axes. Not both at the same time. + """ + + def __init__(self): + DataItem.__init__(self) + YAxisMixIn.__init__(self) + self.__bounds = None + + def setBounds(self, rect): + """Set the bounding box of this item in data coordinates + + :param Union[None,List[float]] rect: (xmin, xmax, ymin, ymax) or None + """ + if rect is not None: + rect = float(rect[0]), float(rect[1]), float(rect[2]), float(rect[3]) + assert rect[0] <= rect[1] + assert rect[2] <= rect[3] + + if rect != self.__bounds: + self.__bounds = rect + self._boundsChanged() + self._updated(ItemChangedType.DATA) + + def _getBounds(self): + if self.__bounds is None: + return None + plot = self.getPlot() + if plot is not None: + xPositive = plot.getXAxis()._isLogarithmic() + yPositive = plot.getYAxis()._isLogarithmic() + if xPositive or yPositive: + bounds = list(self.__bounds) + if xPositive and bounds[1] <= 0: + return None + if xPositive and bounds[0] <= 0: + bounds[0] = bounds[1] + if yPositive and bounds[3] <= 0: + return None + if yPositive and bounds[2] <= 0: + bounds[2] = bounds[3] + return tuple(bounds) + + return self.__bounds + + +class _BaseExtent(DataItem): + """Base class for :class:`XAxisExtent` and :class:`YAxisExtent`. + + :param str axis: Either 'x' or 'y'. + """ + + def __init__(self, axis='x'): + assert axis in ('x', 'y') + DataItem.__init__(self) + self.__axis = axis + self.__range = 1., 100. + + def setRange(self, min_, max_): + """Set the range of the extent of this item in data coordinates. + + :param float min_: Lower bound of the extent + :param float max_: Upper bound of the extent + :raises ValueError: If min > max or not finite bounds + """ + range_ = float(min_), float(max_) + if not numpy.all(numpy.isfinite(range_)): + raise ValueError("min_ and max_ must be finite numbers.") + if range_[0] > range_[1]: + raise ValueError("min_ must be lesser or equal to max_") + + if range_ != self.__range: + self.__range = range_ + self._boundsChanged() + self._updated(ItemChangedType.DATA) + + def getRange(self): + """Returns the range (min, max) of the extent in data coordinates. + + :rtype: List[float] + """ + return self.__range + + def _getBounds(self): + min_, max_ = self.getRange() + + plot = self.getPlot() + if plot is not None: + axis = plot.getXAxis() if self.__axis == 'x' else plot.getYAxis() + if axis._isLogarithmic(): + if max_ <= 0: + return None + if min_ <= 0: + min_ = max_ + + if self.__axis == 'x': + return min_, max_, float('nan'), float('nan') + else: + return float('nan'), float('nan'), min_, max_ + + +class XAxisExtent(_BaseExtent): + """Invisible item with a settable horizontal data extent. + + This item do not display anything, but it behaves as a data + item with a horizontal extent regarding plot data bounds, i.e., + :meth:`PlotWidget.resetZoom` will take this horizontal extent into account. + """ + def __init__(self): + _BaseExtent.__init__(self, axis='x') + + +class YAxisExtent(_BaseExtent, YAxisMixIn): + """Invisible item with a settable vertical data extent. + + This item do not display anything, but it behaves as a data + item with a vertical extent regarding plot data bounds, i.e., + :meth:`PlotWidget.resetZoom` will take this vertical extent into account. + """ + + def __init__(self): + _BaseExtent.__init__(self, axis='y') + YAxisMixIn.__init__(self) diff --git a/src/silx/gui/plot/matplotlib/Colormap.py b/src/silx/gui/plot/matplotlib/Colormap.py new file mode 100644 index 0000000..dc432b2 --- /dev/null +++ b/src/silx/gui/plot/matplotlib/Colormap.py @@ -0,0 +1,249 @@ +# coding: utf-8 +# /*########################################################################## +# Copyright (C) 2017-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""Matplotlib's new colormaps""" + +import numpy +import logging +from matplotlib.colors import ListedColormap +import matplotlib.colors +import matplotlib.cm +import silx.resources +from silx.utils.deprecation import deprecated, deprecated_warning + + +deprecated_warning(type_='module', + name=__file__, + replacement='silx.gui.colors.Colormap', + since_version='0.10.0') + + +_logger = logging.getLogger(__name__) + +_AVAILABLE_AS_RESOURCE = ('magma', 'inferno', 'plasma', 'viridis') +"""List available colormap name as resources""" + +_AVAILABLE_AS_BUILTINS = ('gray', 'reversed gray', + 'temperature', 'red', 'green', 'blue') +"""List of colormaps available through built-in declarations""" + +_CMAPS = {} +"""Cache colormaps""" + + +@property +@deprecated(since_version='0.10.0') +def magma(): + return getColormap('magma') + + +@property +@deprecated(since_version='0.10.0') +def inferno(): + return getColormap('inferno') + + +@property +@deprecated(since_version='0.10.0') +def plasma(): + return getColormap('plasma') + + +@property +@deprecated(since_version='0.10.0') +def viridis(): + return getColormap('viridis') + + +@deprecated(since_version='0.10.0') +def getColormap(name): + """Returns matplotlib colormap corresponding to given name + + :param str name: The name of the colormap + :return: The corresponding colormap + :rtype: matplolib.colors.Colormap + """ + if not _CMAPS: # Lazy initialization of own colormaps + cdict = {'red': ((0.0, 0.0, 0.0), + (1.0, 1.0, 1.0)), + 'green': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0)), + 'blue': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0))} + _CMAPS['red'] = matplotlib.colors.LinearSegmentedColormap( + 'red', cdict, 256) + + cdict = {'red': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0)), + 'green': ((0.0, 0.0, 0.0), + (1.0, 1.0, 1.0)), + 'blue': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0))} + _CMAPS['green'] = matplotlib.colors.LinearSegmentedColormap( + 'green', cdict, 256) + + cdict = {'red': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0)), + 'green': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0)), + 'blue': ((0.0, 0.0, 0.0), + (1.0, 1.0, 1.0))} + _CMAPS['blue'] = matplotlib.colors.LinearSegmentedColormap( + 'blue', cdict, 256) + + # Temperature as defined in spslut + cdict = {'red': ((0.0, 0.0, 0.0), + (0.5, 0.0, 0.0), + (0.75, 1.0, 1.0), + (1.0, 1.0, 1.0)), + 'green': ((0.0, 0.0, 0.0), + (0.25, 1.0, 1.0), + (0.75, 1.0, 1.0), + (1.0, 0.0, 0.0)), + 'blue': ((0.0, 1.0, 1.0), + (0.25, 1.0, 1.0), + (0.5, 0.0, 0.0), + (1.0, 0.0, 0.0))} + # but limited to 256 colors for a faster display (of the colorbar) + _CMAPS['temperature'] = \ + matplotlib.colors.LinearSegmentedColormap( + 'temperature', cdict, 256) + + # reversed gray + cdict = {'red': ((0.0, 1.0, 1.0), + (1.0, 0.0, 0.0)), + 'green': ((0.0, 1.0, 1.0), + (1.0, 0.0, 0.0)), + 'blue': ((0.0, 1.0, 1.0), + (1.0, 0.0, 0.0))} + + _CMAPS['reversed gray'] = \ + matplotlib.colors.LinearSegmentedColormap( + 'yerg', cdict, 256) + + if name in _CMAPS: + return _CMAPS[name] + elif name in _AVAILABLE_AS_RESOURCE: + filename = silx.resources.resource_filename("gui/colormaps/%s.npy" % name) + data = numpy.load(filename) + lut = ListedColormap(data, name=name) + _CMAPS[name] = lut + return lut + else: + # matplotlib built-in + return matplotlib.cm.get_cmap(name) + + +@deprecated(since_version='0.10.0') +def getScalarMappable(colormap, data=None): + """Returns matplotlib ScalarMappable corresponding to colormap + + :param :class:`.Colormap` colormap: The colormap to convert + :param numpy.ndarray data: + The data on which the colormap is applied. + If provided, it is used to compute autoscale. + :return: matplotlib object corresponding to colormap + :rtype: matplotlib.cm.ScalarMappable + """ + assert colormap is not None + + if colormap.getName() is not None: + cmap = getColormap(colormap.getName()) + + else: # No name, use custom colors + if colormap.getColormapLUT() is None: + raise ValueError( + 'addImage: colormap no name nor list of colors.') + colors = colormap.getColormapLUT() + assert len(colors.shape) == 2 + assert colors.shape[-1] in (3, 4) + if colors.dtype == numpy.uint8: + # Convert to float in [0., 1.] + colors = colors.astype(numpy.float32) / 255. + cmap = matplotlib.colors.ListedColormap(colors) + + vmin, vmax = colormap.getColormapRange(data) + normalization = colormap.getNormalization() + if normalization == colormap.LOGARITHM: + norm = matplotlib.colors.LogNorm(vmin, vmax) + elif normalization == colormap.LINEAR: + norm = matplotlib.colors.Normalize(vmin, vmax) + else: + raise RuntimeError("Unsupported normalization: %s" % normalization) + + return matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) + + +@deprecated(replacement='silx.colors.Colormap.applyToData', + since_version='0.8.0') +def applyColormapToData(data, colormap): + """Apply a colormap to the data and returns the RGBA image + + This supports data of any dimensions (not only of dimension 2). + The returned array will have one more dimension (with 4 entries) + than the input data to store the RGBA channels + corresponding to each bin in the array. + + :param numpy.ndarray data: The data to convert. + :param :class:`.Colormap`: The colormap to apply + """ + # Debian 7 specific support + # No transparent colormap with matplotlib < 1.2.0 + # Add support for transparent colormap for uint8 data with + # colormap with 256 colors, linear norm, [0, 255] range + if matplotlib.__version__ < '1.2.0': + if (colormap.getName() is None and + colormap.getColormapLUT() is not None): + colors = colormap.getColormapLUT() + if (colors.shape[-1] == 4 and + not numpy.all(numpy.equal(colors[3], 255))): + # This is a transparent colormap + if (colors.shape == (256, 4) and + colormap.getNormalization() == 'linear' and + not colormap.isAutoscale() and + colormap.getVMin() == 0 and + colormap.getVMax() == 255 and + data.dtype == numpy.uint8): + # Supported case, convert data to RGBA + return colors[data.reshape(-1)].reshape( + data.shape + (4,)) + else: + _logger.warning( + 'matplotlib %s does not support transparent ' + 'colormap.', matplotlib.__version__) + + scalarMappable = getScalarMappable(colormap, data) + rgbaImage = scalarMappable.to_rgba(data, bytes=True) + + return rgbaImage + + +@deprecated(replacement='silx.colors.Colormap.getSupportedColormaps', + since_version='0.10.0') +def getSupportedColormaps(): + """Get the supported colormap names as a tuple of str. + """ + colormaps = set(matplotlib.cm.datad.keys()) + colormaps.update(_AVAILABLE_AS_BUILTINS) + colormaps.update(_AVAILABLE_AS_RESOURCE) + return tuple(sorted(colormaps)) diff --git a/src/silx/gui/plot/matplotlib/__init__.py b/src/silx/gui/plot/matplotlib/__init__.py new file mode 100644 index 0000000..e787240 --- /dev/null +++ b/src/silx/gui/plot/matplotlib/__init__.py @@ -0,0 +1,37 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "15/07/2020" + +from silx.utils.deprecation import deprecated_warning + +deprecated_warning(type_='module', + name=__file__, + replacement='silx.gui.utils.matplotlib', + since_version='0.14.0') + +from silx.gui.utils.matplotlib import FigureCanvasQTAgg # noqa diff --git a/src/silx/gui/plot/setup.py b/src/silx/gui/plot/setup.py new file mode 100644 index 0000000..e0b2c91 --- /dev/null +++ b/src/silx/gui/plot/setup.py @@ -0,0 +1,54 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "29/06/2017" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('plot', parent_package, top_path) + config.add_subpackage('_utils') + config.add_subpackage('utils') + config.add_subpackage('matplotlib') + config.add_subpackage('stats') + config.add_subpackage('backends') + config.add_subpackage('backends.glutils') + config.add_subpackage('items') + config.add_subpackage('test') + config.add_subpackage('tools') + config.add_subpackage('tools.profile') + config.add_subpackage('tools.test') + config.add_subpackage('actions') + + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + + setup(configuration=configuration) diff --git a/src/silx/gui/plot/stats/__init__.py b/src/silx/gui/plot/stats/__init__.py new file mode 100644 index 0000000..04a5327 --- /dev/null +++ b/src/silx/gui/plot/stats/__init__.py @@ -0,0 +1,33 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "07/03/2018" + + +from .stats import * diff --git a/src/silx/gui/plot/stats/stats.py b/src/silx/gui/plot/stats/stats.py new file mode 100644 index 0000000..a81f7bb --- /dev/null +++ b/src/silx/gui/plot/stats/stats.py @@ -0,0 +1,890 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides mechanism relative to stats calculation within a +:class:`PlotWidget`. +It also include the implementation of the statistics themselves. +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "06/06/2018" + + +from collections import OrderedDict +from functools import lru_cache +import logging + +import numpy +import numpy.ma + +from .. import items +from ..CurvesROIWidget import ROI +from ..items.roi import RegionOfInterest + +from ....math.combo import min_max +from silx.utils.proxy import docstring +from ....utils.deprecation import deprecated + +logger = logging.getLogger(__name__) + + +class Stats(OrderedDict): + """Class to define a set of statistic relative to a dataset + (image, curve...). + + The goal of this class is to avoid multiple recalculation of some + basic operations such as filtering data area where the statistics has to + be apply. + Min and max are also stored because they can be used several time. + + :param List statslist: List of the :class:`Stat` object to be computed. + """ + def __init__(self, statslist=None): + OrderedDict.__init__(self) + _statslist = statslist if not None else [] + if statslist is not None: + for stat in _statslist: + self.add(stat) + + def calculate(self, item, plot, onlimits, roi, data_changed=False, + roi_changed=False): + """ + Call all :class:`Stat` object registered and return the result of the + computation. + + :param item: the item for which we want statistics + :param plot: plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: region of interest for statistic calculation. Incompatible + with the `onlimits` option. + :type roi: Union[None, :class:`~_RegionOfInterestBase`] + :param bool data_changed: did the data changed since last calculation. + :param bool roi_changed: did the associated roi (if any) has changed + since last calculation. + :return dict: dictionary with :class:`Stat` name as ket and result + of the calculation as value + """ + res = {} + context = self._getContext(item=item, plot=plot, onlimits=onlimits, + roi=roi) + for statName, stat in list(self.items()): + if context.kind not in stat.compatibleKinds: + logger.debug('kind %s not managed by statistic %s' + % (context.kind, stat.name)) + res[statName] = None + else: + if roi_changed is True: + context.clear_mask() + if data_changed is True or roi_changed is True: + # if data changed or mask changed + context.clipData(item=item, plot=plot, onlimits=onlimits, + roi=roi) + # init roi and data + res[statName] = stat.calculate(context) + return res + + def __setitem__(self, key, value): + assert isinstance(value, StatBase) + OrderedDict.__setitem__(self, key, value) + + def add(self, stat): + """Add a :class:`Stat` to the set + + :param Stat stat: stat to add to the set + """ + self.__setitem__(key=stat.name, value=stat) + + @staticmethod + @lru_cache(maxsize=50) + def _getContext(item, plot, onlimits, roi): + context = None + # Check for PlotWidget items + if isinstance(item, items.Curve): + context = _CurveContext(item, plot, onlimits, roi=roi) + elif isinstance(item, items.ImageData): + context = _ImageContext(item, plot, onlimits, roi=roi) + elif isinstance(item, items.Scatter): + context = _ScatterContext(item, plot, onlimits, roi=roi) + elif isinstance(item, items.Histogram): + context = _HistogramContext(item, plot, onlimits, roi=roi) + else: + # Check for SceneWidget items + from ...plot3d import items as items3d # Lazy import + + if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)): + context = _plot3DScatterContext(item, plot, onlimits, + roi=roi) + elif isinstance(item, + (items3d.ImageData, items3d.ScalarField3D)): + context = _plot3DArrayContext(item, plot, onlimits, + roi=roi) + if context is None: + raise ValueError('Item type not managed') + return context + + +class _StatsContext(object): + """ + The context is designed to be a simple buffer and avoid repetition of + calculations that can appear during stats evaluation. + + .. warning:: this class gives access to the data to be used for computation + . It deal with filtering data visible by the user on plot. + The filtering is a simple data sub-sampling. No interpolation + is made to fit data to boundaries. + + :param item: the item for which we want to compute the context + :param str kind: the kind of the item + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlimits` calculation + :type roi: Union[None,:class:`_RegionOfInterestBase`] + """ + def __init__(self, item, kind, plot, onlimits, roi): + assert item + assert plot + assert type(onlimits) is bool + self.kind = kind + self.min = None + self.max = None + self.data = None + self.roi = None + self.onlimits = onlimits + + self.values = None + """The array of data with limit filtering if any. Is a numpy.ma.array, + meaning that it embed the mask applied by the roi if any""" + + self.axes = None + """A list of array of position on each axis. + + If the signal is an array, + then each axis has the length of that dimension, + and the order is (z, y, x) (i.e., as the array shape). + If the signal is not an array, + then each axis has the same length as the signal, + and the order is (x, y, z). + """ + + self.clipData(item, plot, onlimits, roi=roi) + + def clear_mask(self): + """ + Remove the mask to force recomputation of it on next iteration + :return: + """ + raise NotImplementedError() + + @property + def mask(self): + if self.values is not None: + assert isinstance(self.values, numpy.ma.MaskedArray) + return self.values.mask + else: + return None + + @property + def is_mask_valid(self, **kwargs): + """Return if the mask is valid for the data or need to be recomputed""" + raise NotImplementedError("Base class") + + def _set_mask_validity(self, **kwargs): + """User to set some values that allows to define the mask properties + and boundaries""" + raise NotImplementedError("Base class") + + def clipData(self, item, plot, onlimits, roi): + """Clip the data to the current mask to have accurate statistics + + Function called before computing each statistics associated to this + context. It will insure the context for the (item, plot, onlimits, roi) + is created. + + :param item: item for which we want statistics + :param plot: plot containing the statistics + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlimits` calculation + :type roi: Union[None,:class:`_RegionOfInterestBase`] + """ + raise NotImplementedError("Base class") + + @deprecated(reason="context are now stored and keep during stats life." + "So this function will be called only once", + replacement="clipData", since_version="0.13.0") + def createContext(self, item, plot, onlimits, roi): + return self.clipData(item=item, plot=plot, onlimits=onlimits, + roi=roi) + + def isStructuredData(self): + """Returns True if data as an array-like structure. + + :rtype: bool + """ + if self.values is None or self.axes is None: + return False + + if numpy.prod([len(axis) for axis in self.axes]) == self.values.size: + return True + else: + # Make sure there is the right number of value in axes + for axis in self.axes: + assert len(axis) == self.values.size + return False + + def isScalarData(self): + """Returns True if data is a scalar. + + :rtype: bool + """ + if self.values is None or self.axes is None: + return False + if self.isStructuredData(): + return len(self.axes) == self.values.ndim + else: + return self.values.ndim == 1 + + def _checkContextInputs(self, item, plot, onlimits, roi): + if roi is not None and onlimits is True: + raise ValueError('Stats context is unable to manage both a ROI' + 'and the `onlimits` option') + + +class _ScatterCurveHistoMixInContext(_StatsContext): + def __init__(self, kind, item, plot, onlimits, roi): + self.clear_mask() + _StatsContext.__init__(self, item=item, kind=kind, + plot=plot, onlimits=onlimits, roi=roi) + + def _set_mask_validity(self, onlimits, from_, to_): + self._onlimits = onlimits + self._from_ = from_ + self._to_ = to_ + + def clear_mask(self): + self._onlimits = None + self._from_ = None + self._to_ = None + + def is_mask_valid(self, onlimits, from_, to_): + return (onlimits == self.onlimits and from_ == self._from_ and + to_ == self._to_) + + +class _CurveContext(_ScatterCurveHistoMixInContext): + """ + StatsContext for :class:`Curve` + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] + """ + def __init__(self, item, plot, onlimits, roi): + _ScatterCurveHistoMixInContext.__init__(self, kind='curve', item=item, + plot=plot, onlimits=onlimits, + roi=roi) + + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) + self.roi = roi + self.onlimits = onlimits + xData, yData = item.getData(copy=True)[0:2] + + if onlimits: + minX, maxX = plot.getXAxis().getLimits() + if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX): + mask = self.mask + else: + mask = (minX <= xData) & (xData <= maxX) + mask = mask == 0 + self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX) + elif roi: + minX, maxX = roi.getFrom(), roi.getTo() + if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX): + mask = self.mask + else: + mask = (minX <= xData) & (xData <= maxX) + mask = mask == 0 + self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX) + else: + mask = numpy.zeros_like(yData) + + mask = mask.astype(numpy.uint32) + self.xData = xData + self.yData = yData + self.values = numpy.ma.array(yData, mask=mask) + unmasked_data = self.values.compressed() + if len(unmasked_data) > 0: + self.min, self.max = min_max(unmasked_data) + else: + self.min, self.max = None, None + self.data = (xData, yData) + self.axes = (xData,) + + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + if roi is not None and not isinstance(roi, ROI): + raise TypeError('curve `context` can ony manage 1D roi') + + +class _HistogramContext(_ScatterCurveHistoMixInContext): + """ + StatsContext for :class:`Histogram` + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] + """ + def __init__(self, item, plot, onlimits, roi): + _ScatterCurveHistoMixInContext.__init__(self, kind='histogram', + item=item, plot=plot, + onlimits=onlimits, roi=roi) + + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) + yData, edges = item.getData(copy=True)[0:2] + xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment()) + + if onlimits: + minX, maxX = plot.getXAxis().getLimits() + if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX): + mask = self.mask + else: + mask = (minX <= xData) & (xData <= maxX) + mask = mask == 0 + self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX) + elif roi: + if self.is_mask_valid(onlimits=onlimits, from_=roi._fromdata, to_=roi._todata): + mask = self.mask + else: + mask = (roi._fromdata <= xData) & (xData <= roi._todata) + mask = mask == 0 + self._set_mask_validity(onlimits=onlimits, from_=roi._fromdata, + to_=roi._todata) + else: + mask = numpy.zeros_like(yData) + mask = mask.astype(numpy.uint32) + self.xData = xData + self.yData = yData + self.values = numpy.ma.array(yData, mask=(mask)) + unmasked_data = self.values.compressed() + if len(unmasked_data) > 0: + self.min, self.max = min_max(unmasked_data) + else: + self.min, self.max = None, None + self.data = (self.xData, self.yData) + self.axes = (self.xData,) + + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, ROI): + raise TypeError('curve `context` can ony manage 1D roi') + + +class _ScatterContext(_ScatterCurveHistoMixInContext): + """StatsContext scatter plots. + + It supports :class:`~silx.gui.plot.items.Scatter`. + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] + """ + def __init__(self, item, plot, onlimits, roi): + _ScatterCurveHistoMixInContext.__init__(self, kind='scatter', + item=item, plot=plot, + onlimits=onlimits, roi=roi) + + @docstring(_ScatterCurveHistoMixInContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) + valueData = item.getValueData(copy=True) + xData = item.getXData(copy=True) + yData = item.getYData(copy=True) + + if onlimits: + minX, maxX = plot.getXAxis().getLimits() + minY, maxY = plot.getYAxis().getLimits() + + # filter on X axis + valueData = valueData[(minX <= xData) & (xData <= maxX)] + yData = yData[(minX <= xData) & (xData <= maxX)] + xData = xData[(minX <= xData) & (xData <= maxX)] + # filter on Y axis + valueData = valueData[(minY <= yData) & (yData <= maxY)] + xData = xData[(minY <= yData) & (yData <= maxY)] + yData = yData[(minY <= yData) & (yData <= maxY)] + + if roi: + if self.is_mask_valid(onlimits=onlimits, from_=roi.getFrom(), + to_=roi.getTo()): + mask = self.mask + else: + mask = (xData < roi.getFrom()) | (xData > roi.getTo()) + else: + mask = numpy.zeros_like(xData) + + self.data = (xData, yData, valueData) + self.values = numpy.ma.array(valueData, mask=mask) + self.axes = (xData, yData) + + unmasked_values = self.values.compressed() + if len(unmasked_values) > 0: + self.min, self.max = min_max(unmasked_values) + else: + self.min, self.max = None, None + + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, ROI): + raise TypeError('curve `context` can ony manage 1D roi') + + +class _ImageContext(_StatsContext): + """StatsContext for images. + + It supports :class:`~silx.gui.plot.items.ImageData`. + + :warning: behaviour of scale images: now the statistics are computed on + the entire data array (there is no sampling in the array or + interpolation regarding the scale). + This also mean that the result can differ from what is displayed. + But I guess there is no perfect behaviour. + + :warning: `isIn` functions for image context: for now have basically a + binary approach, the pixel is in a roi or not. To have a fully + 'correct behaviour' we should add a weight on stats calculation + to moderate the pixel value. + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] + """ + def __init__(self, item, plot, onlimits, roi): + self.clear_mask() + _StatsContext.__init__(self, kind='image', item=item, + plot=plot, onlimits=onlimits, roi=roi) + + def _set_mask_validity(self, xmin: float, xmax: float, ymin: float, ymax + : float): + self._mask_x_min = xmin + self._mask_x_max = xmax + self._mask_y_min = ymin + self._mask_y_max = ymax + + def clear_mask(self): + self._mask_x_min = None + self._mask_x_max = None + self._mask_y_min = None + self._mask_y_max = None + + def is_mask_valid(self, xmin, xmax, ymin, ymax): + return (xmin == self._mask_x_min and xmax == self._mask_x_max and + ymin == self._mask_y_min and ymax == self._mask_y_max) + + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) + self.origin = item.getOrigin() + self.scale = item.getScale() + + self.data = item.getData(copy=True) + mask = numpy.zeros_like(self.data) + """mask use to know of the stat should be count in or not""" + + if onlimits: + minX, maxX = plot.getXAxis().getLimits() + minY, maxY = plot.getYAxis().getLimits() + + XMinBound = int((minX - self.origin[0]) / self.scale[0]) + YMinBound = int((minY - self.origin[1]) / self.scale[1]) + XMaxBound = int((maxX - self.origin[0]) / self.scale[0]) + YMaxBound = int((maxY - self.origin[1]) / self.scale[1]) + + XMinBound = max(XMinBound, 0) + YMinBound = max(YMinBound, 0) + + if onlimits: + if XMaxBound <= XMinBound or YMaxBound <= YMinBound: + self.data = None + else: + self.data = self.data[YMinBound:YMaxBound + 1, + XMinBound:XMaxBound + 1] + mask = numpy.zeros_like(self.data) + elif roi: + minX, maxX = 0, self.data.shape[1] + minY, maxY = 0, self.data.shape[0] + + XMinBound = max(minX, 0) + YMinBound = max(minY, 0) + XMaxBound = min(maxX, self.data.shape[1]) + YMaxBound = min(maxY, self.data.shape[0]) + + if self.is_mask_valid(xmin=XMinBound, xmax=XMaxBound, + ymin=YMinBound, ymax=YMaxBound): + mask = self.mask + else: + for x in range(XMinBound, XMaxBound): + for y in range(YMinBound, YMaxBound): + _x = (x * self.scale[0]) + self.origin[0] + _y = (y * self.scale[1]) + self.origin[1] + mask[y, x] = not roi.contains((_x, _y)) + self._set_mask_validity(xmin=XMinBound, xmax=XMaxBound, + ymin=YMinBound, ymax=YMaxBound) + self.values = numpy.ma.array(self.data, mask=mask) + if self.values.compressed().size > 0: + self.min, self.max = min_max(self.values.compressed()) + else: + self.min, self.max = None, None + + if self.values is not None: + self.axes = (self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]), + self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1])) + + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, RegionOfInterest): + raise TypeError('curve `context` can ony manage 2D roi') + + +class _plot3DScatterContext(_StatsContext): + """StatsContext for 3D scatter plots. + + It supports :class:`~silx.gui.plot3d.items.Scatter2D` and + :class:`~silx.gui.plot3d.items.Scatter3D`. + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] + """ + def __init__(self, item, plot, onlimits, roi): + _StatsContext.__init__(self, kind='scatter', item=item, plot=plot, + onlimits=onlimits, roi=roi) + + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) + if onlimits: + raise RuntimeError("Unsupported plot %s" % str(plot)) + values = item.getValueData(copy=False) + if roi: + logger.warning("Roi are unsupported on volume for now") + mask = numpy.zeros_like(values) + else: + mask = numpy.zeros_like(values) + + if values is not None and len(values) > 0: + self.values = values + axes = [item.getXData(copy=False), item.getYData(copy=False)] + if self.values.ndim == 3: + axes.append(item.getZData(copy=False)) + self.axes = tuple(axes) + self.min, self.max = min_max(self.values) + self.values = numpy.ma.array(self.values, mask=mask) + else: + self.values = None + self.axes = None + self.min, self.max = None, None + + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, RegionOfInterest): + raise TypeError('curve `context` can ony manage 2D roi') + + +class _plot3DArrayContext(_StatsContext): + """StatsContext for 3D scalar field and data image. + + It supports :class:`~silx.gui.plot3d.items.ScalarField3D` and + :class:`~silx.gui.plot3d.items.ImageData`. + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] + """ + def __init__(self, item, plot, onlimits, roi): + _StatsContext.__init__(self, kind='image', item=item, plot=plot, + onlimits=onlimits, roi=roi) + + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) + if onlimits: + raise RuntimeError("Unsupported plot %s" % str(plot)) + + values = item.getData(copy=False) + if roi: + logger.warning("Roi are unsuported on volume for now") + mask = numpy.zeros_like(values) + else: + mask = numpy.zeros_like(values) + + if values is not None and len(values) > 0: + self.values = values + self.axes = tuple([numpy.arange(size) for size in self.values.shape]) + self.min, self.max = min_max(self.values) + self.values = numpy.ma.array(self.values, mask=mask) + else: + self.values = None + self.axes = None + self.min, self.max = None, None + + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, RegionOfInterest): + raise TypeError('curve `context` can ony manage 2D roi') + + +BASIC_COMPATIBLE_KINDS = 'curve', 'image', 'scatter', 'histogram' + + +class StatBase(object): + """ + Base class for defining a statistic. + + :param str name: the name of the statistic. Must be unique. + :param List[str] compatibleKinds: + The kind of items (curve, scatter...) for which the statistic apply. + """ + def __init__(self, name, compatibleKinds=BASIC_COMPATIBLE_KINDS, description=None): + self.name = name + self.compatibleKinds = compatibleKinds + self.description = description + + def calculate(self, context): + """ + compute the statistic for the given :class:`StatsContext` + + :param _StatsContext context: + :return dict: key is stat name, statistic computed is the dict value + """ + raise NotImplementedError('Base class') + + def getToolTip(self, kind): + """ + If necessary add a tooltip for a stat kind + + :param str kind: the kind of item the statistic is compute for. + :return: tooltip or None if no tooltip + """ + return None + + +class Stat(StatBase): + """ + Create a StatBase class based on a function pointer. + + :param str name: name of the statistic. Used as id + :param fct: function which should have as unique mandatory parameter the + data. Should be able to adapt to all `kinds` defined as + compatible + :param tuple kinds: the compatible item kinds of the function (curve, + image...) + """ + def __init__(self, name, fct, kinds=BASIC_COMPATIBLE_KINDS): + StatBase.__init__(self, name, kinds) + self._fct = fct + + @docstring(StatBase) + def calculate(self, context): + if context.values is not None: + if context.kind in self.compatibleKinds: + return self._fct(context.values) + else: + raise ValueError('Kind %s not managed by %s' + '' % (context.kind, self.name)) + else: + return None + + +class StatMin(StatBase): + """Compute the minimal value on data""" + def __init__(self): + StatBase.__init__(self, name='min') + + @docstring(StatBase) + def calculate(self, context): + return context.min + + +class StatMax(StatBase): + """Compute the maximal value on data""" + def __init__(self): + StatBase.__init__(self, name='max') + + @docstring(StatBase) + def calculate(self, context): + return context.max + + +class StatDelta(StatBase): + """Compute the delta between minimal and maximal on data""" + def __init__(self): + StatBase.__init__(self, name='delta') + + @docstring(StatBase) + def calculate(self, context): + return context.max - context.min + + +class _StatCoord(StatBase): + """Base class for argmin and argmax stats""" + + def _indexToCoordinates(self, context, index): + """Returns the coordinates of data point at given index + + If data is an array, coordinates are in reverse order from data shape. + + :param _StatsContext context: + :param int index: Index in the flattened data array + :rtype: List[int] + """ + + axes = context.axes + + if context.isStructuredData() or context.roi: + coordinates = [] + for axis in reversed(axes): + coordinates.append(axis[index % len(axis)]) + index = index // len(axis) + return tuple(coordinates) + else: + return tuple(axis[index] for axis in axes) + + +class StatCoordMin(_StatCoord): + """Compute the coordinates of the first minimum value of the data""" + def __init__(self): + _StatCoord.__init__(self, name='coords min') + + @docstring(StatBase) + def calculate(self, context): + if context.values is None or not context.isScalarData(): + return None + + index = context.values.argmin() + return self._indexToCoordinates(context, index) + + @docstring(StatBase) + def getToolTip(self, kind): + return "Coordinates of the first minimum value of the data" + + +class StatCoordMax(_StatCoord): + """Compute the coordinates of the first maximum value of the data""" + def __init__(self): + _StatCoord.__init__(self, name='coords max') + + @docstring(StatBase) + def calculate(self, context): + if context.values is None or not context.isScalarData(): + return None + + # TODO: the values should be a mask array by default, will be simpler + # if possible + index = context.values.argmax() + return self._indexToCoordinates(context, index) + + @docstring(StatBase) + def getToolTip(self, kind): + return "Coordinates of the first maximum value of the data" + + +class StatCOM(StatBase): + """Compute data center of mass""" + def __init__(self): + StatBase.__init__(self, name='COM', description='Center of mass') + + @docstring(StatBase) + def calculate(self, context): + if context.values is None or not context.isScalarData(): + return None + + values = numpy.ma.array(context.values, mask=context.mask, dtype=numpy.float64) + sum_ = numpy.sum(values) + if sum_ == 0.: + return (numpy.nan,) * len(context.axes) + + if context.isStructuredData(): + centerofmass = [] + for index, axis in enumerate(context.axes): + axes = tuple([i for i in range(len(context.axes)) if i != index]) + centerofmass.append( + numpy.sum(axis * numpy.sum(values, axis=axes)) / sum_) + return tuple(reversed(centerofmass)) + else: + return tuple( + numpy.sum(axis * values) / sum_ for axis in context.axes) + + @docstring(StatBase) + def getToolTip(self, kind): + return "Compute the center of mass of the dataset" diff --git a/src/silx/gui/plot/stats/statshandler.py b/src/silx/gui/plot/stats/statshandler.py new file mode 100644 index 0000000..17578d8 --- /dev/null +++ b/src/silx/gui/plot/stats/statshandler.py @@ -0,0 +1,202 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module containts the classes relative to the management of statistics +display. +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "05/06/2018" + + +import logging + +from silx.gui import qt +from silx.gui.plot import stats as statsmdl + +logger = logging.getLogger(__name__) + + +class _FloatItem(qt.QTableWidgetItem): + """Simple QTableWidgetItem allowing ordering on floats""" + + def __init__(self, type=qt.QTableWidgetItem.Type): + qt.QTableWidgetItem.__init__(self, type=type) + + def __lt__(self, other): + self_values = self.text().lstrip('(').rstrip(')').split(',') + other_values = other.text().lstrip('(').rstrip(')').split(',') + for self_value, other_value in zip(self_values, other_values): + f_self_value = float(self_value) + f_other_value = float(other_value) + if f_self_value != f_other_value: + return f_self_value < f_other_value + return False + + +class StatFormatter(object): + """ + Class used to apply format on :class:`Stat` + + :param formatter: the formatter. Defined as str.format() + :param qItemClass: the class inheriting from :class:`QTableWidgetItem` + which will be used to display the result of the + statistic computation. + """ + DEFAULT_FORMATTER = '{0:.3f}' + + def __init__(self, formatter=DEFAULT_FORMATTER, qItemClass=_FloatItem): + self.formatter = formatter + self.tabWidgetItemClass = qItemClass + + def format(self, val): + if self.formatter is None or val is None: + return str(val) + else: + return self.formatter.format(val) + + +class StatsHandler(object): + """ + Give + create: + + * Stats object which will manage the statistic computation + * Associate formatter and :class:`Stat` + + :param statFormatters: Stat and optional formatter. + If elements are given as a tuple, elements + should be (:class:`Stat`, formatter). + Otherwise should be :class:`Stat` elements. + :rtype: List or tuple + """ + + def __init__(self, statFormatters): + self.stats = statsmdl.Stats() + self.formatters = {} + for elmt in statFormatters: + stat, formatter = self._processStatArgument(elmt) + self.add(stat=stat, formatter=formatter) + + @staticmethod + def _processStatArgument(arg): + """Process an element of the init arguments + + :param arg: The argument to process + :return: Corresponding (StatBase, StatFormatter) + """ + stat, formatter = None, None + + if isinstance(arg, statsmdl.StatBase): + stat = arg + else: + assert len(arg) > 0 + if isinstance(arg[0], statsmdl.StatBase): + stat = arg[0] + if len(arg) > 2: + raise ValueError('To many argument with %s. At most one ' + 'argument can be associated with the ' + 'BaseStat (the `StatFormatter`') + if len(arg) == 2: + assert arg[1] is None or isinstance(arg[1], (StatFormatter, str)) + formatter = arg[1] + else: + if isinstance(arg[0], tuple): + if len(arg) > 1: + formatter = arg[1] + arg = arg[0] + + if type(arg[0]) is not str: + raise ValueError('first element of the tuple should be a string' + ' or a StatBase instance') + if len(arg) == 1: + raise ValueError('A function should be associated with the' + 'stat name') + if len(arg) > 3: + raise ValueError('Two much argument given for defining statistic.' + 'Take at most three arguments (name, function, ' + 'kinds)') + if len(arg) == 2: + stat = statsmdl.Stat(name=arg[0], fct=arg[1]) + else: + stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2]) + + return stat, formatter + + def add(self, stat, formatter=None): + """Add a stat to the list. + + :param StatBase stat: + :param Union[None,StatFormatter] formatter: + """ + assert isinstance(stat, statsmdl.StatBase) + self.stats.add(stat) + _formatter = formatter + if type(_formatter) is str: + _formatter = StatFormatter(formatter=_formatter) + self.formatters[stat.name] = _formatter + + def format(self, name, val): + """Apply the format for the `name` statistic and the given value + + :param str name: the name of the associated statistic + :param val: value before formatting + :return: formatted value + """ + if name not in self.formatters: + logger.warning("statistic %s haven't been registred" % name) + return val + else: + if self.formatters[name] is None: + return str(val) + else: + if isinstance(val, (tuple, list)): + res = [] + [res.append(self.formatters[name].format(_val)) for _val in val] + return ', '.join(res) + else: + return self.formatters[name].format(val) + + def calculate(self, item, plot, onlimits, roi=None, data_changed=False, + roi_changed=False): + """ + compute all statistic registered and return the list of formatted + statistics result. + + :param item: item for which we want to compute statistics + :param plot: plot containing the item + :param onlimits: True if we want to compute statistics on visible data + only + :type: bool + :param roi: region of interest for statistic calculation + :type: Union[None,:class:`_RegionOfInterestBase`] + :return: list of formatted statistics (as str) + :rtype: dict + """ + res = self.stats.calculate(item, plot, onlimits, roi, + data_changed=data_changed, roi_changed=roi_changed) + for resName, resValue in list(res.items()): + res[resName] = self.format(resName, res[resName]) + return res diff --git a/src/silx/gui/plot/test/__init__.py b/src/silx/gui/plot/test/__init__.py new file mode 100644 index 0000000..3ad225d --- /dev/null +++ b/src/silx/gui/plot/test/__init__.py @@ -0,0 +1,24 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ diff --git a/src/silx/gui/plot/test/testAlphaSlider.py b/src/silx/gui/plot/test/testAlphaSlider.py new file mode 100644 index 0000000..ca57bf5 --- /dev/null +++ b/src/silx/gui/plot/test/testAlphaSlider.py @@ -0,0 +1,204 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests for ImageAlphaSlider""" + + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "28/03/2017" + +import numpy +import unittest + +from silx.gui import qt +from silx.gui.utils.testutils import TestCaseQt +from silx.gui.plot import PlotWidget +from silx.gui.plot import AlphaSlider + + +class TestActiveImageAlphaSlider(TestCaseQt): + def setUp(self): + super(TestActiveImageAlphaSlider, self).setUp() + self.plot = PlotWidget() + self.aslider = AlphaSlider.ActiveImageAlphaSlider(plot=self.plot) + self.aslider.setOrientation(qt.Qt.Horizontal) + + toolbar = qt.QToolBar("plot", self.plot) + toolbar.addWidget(self.aslider) + self.plot.addToolBar(toolbar) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.mouseMove(self.plot) # Move to center + self.qapp.processEvents() + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + del self.aslider + + super(TestActiveImageAlphaSlider, self).tearDown() + + def testWidgetEnabled(self): + # no active image initially, slider must be deactivate + self.assertFalse(self.aslider.isEnabled()) + + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]])) + # now we have an active image + self.assertTrue(self.aslider.isEnabled()) + + self.plot.setActiveImage(None) + self.assertFalse(self.aslider.isEnabled()) + + def testGetImage(self): + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]])) + self.assertEqual(self.plot.getActiveImage(), + self.aslider.getItem()) + + self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2") + self.plot.setActiveImage("2") + self.assertEqual(self.plot.getImage("2"), + self.aslider.getItem()) + + def testGetAlpha(self): + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1") + self.aslider.setValue(137) + self.assertAlmostEqual(self.aslider.getAlpha(), + 137. / 255) + + +class TestNamedImageAlphaSlider(TestCaseQt): + def setUp(self): + super(TestNamedImageAlphaSlider, self).setUp() + self.plot = PlotWidget() + self.aslider = AlphaSlider.NamedImageAlphaSlider(plot=self.plot) + self.aslider.setOrientation(qt.Qt.Horizontal) + + toolbar = qt.QToolBar("plot", self.plot) + toolbar.addWidget(self.aslider) + self.plot.addToolBar(toolbar) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.mouseMove(self.plot) # Move to center + self.qapp.processEvents() + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + del self.aslider + + super(TestNamedImageAlphaSlider, self).tearDown() + + def testWidgetEnabled(self): + # no image set initially, slider must be deactivate + self.assertFalse(self.aslider.isEnabled()) + + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1") + self.aslider.setLegend("1") + # now we have an image set + self.assertTrue(self.aslider.isEnabled()) + + def testGetImage(self): + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1") + self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2") + self.aslider.setLegend("1") + self.assertEqual(self.plot.getImage("1"), + self.aslider.getItem()) + + self.aslider.setLegend("2") + self.assertEqual(self.plot.getImage("2"), + self.aslider.getItem()) + + def testGetAlpha(self): + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1") + self.aslider.setLegend("1") + self.aslider.setValue(128) + self.assertAlmostEqual(self.aslider.getAlpha(), + 128. / 255) + + +class TestNamedScatterAlphaSlider(TestCaseQt): + def setUp(self): + super(TestNamedScatterAlphaSlider, self).setUp() + self.plot = PlotWidget() + self.aslider = AlphaSlider.NamedScatterAlphaSlider(plot=self.plot) + self.aslider.setOrientation(qt.Qt.Horizontal) + + toolbar = qt.QToolBar("plot", self.plot) + toolbar.addWidget(self.aslider) + self.plot.addToolBar(toolbar) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.mouseMove(self.plot) # Move to center + self.qapp.processEvents() + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + del self.aslider + + super(TestNamedScatterAlphaSlider, self).tearDown() + + def testWidgetEnabled(self): + # no Scatter set initially, slider must be deactivate + self.assertFalse(self.aslider.isEnabled()) + + self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7], + legend="1") + self.aslider.setLegend("1") + # now we have an image set + self.assertTrue(self.aslider.isEnabled()) + + def testGetScatter(self): + self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7], + legend="1") + self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70], + legend="2") + self.aslider.setLegend("1") + self.assertEqual(self.plot.getScatter("1"), + self.aslider.getItem()) + + self.aslider.setLegend("2") + self.assertEqual(self.plot.getScatter("2"), + self.aslider.getItem()) + + def testGetAlpha(self): + self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70], + legend="1") + self.aslider.setLegend("1") + self.aslider.setValue(128) + self.assertAlmostEqual(self.aslider.getAlpha(), + 128. / 255) diff --git a/src/silx/gui/plot/test/testColorBar.py b/src/silx/gui/plot/test/testColorBar.py new file mode 100644 index 0000000..3dc8ff1 --- /dev/null +++ b/src/silx/gui/plot/test/testColorBar.py @@ -0,0 +1,340 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for ColorBar featues and sub widgets of Colorbar module""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "24/04/2018" + +import unittest +from silx.gui.utils.testutils import TestCaseQt +from silx.gui.plot.ColorBar import _ColorScale +from silx.gui.plot.ColorBar import ColorBarWidget +from silx.gui.colors import Colormap +from silx.math.colormap import LinearNormalization, LogarithmicNormalization +from silx.gui.plot import Plot2D +from silx.gui import qt +import numpy + + +class TestColorScale(TestCaseQt): + """Test that interaction with the colorScale is correct""" + def setUp(self): + super(TestColorScale, self).setUp() + self.colorScaleWidget = _ColorScale(colormap=None, parent=None) + self.colorScaleWidget.show() + self.qWaitForWindowExposed(self.colorScaleWidget) + + def tearDown(self): + self.qapp.processEvents() + self.colorScaleWidget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.colorScaleWidget.close() + del self.colorScaleWidget + super(TestColorScale, self).tearDown() + + def testNoColormap(self): + """Test _ColorScale without a colormap""" + colormap = self.colorScaleWidget.getColormap() + self.assertIsNone(colormap) + + def testRelativePositionLinear(self): + self.colorMapLin1 = Colormap(name='gray', + normalization=Colormap.LINEAR, + vmin=0.0, + vmax=1.0) + self.colorScaleWidget.setColormap(self.colorMapLin1) + + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25) + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(0.5) == 0.5) + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(1.0) == 1.0) + + self.colorMapLin2 = Colormap(name='viridis', + normalization=Colormap.LINEAR, + vmin=-10, + vmax=0) + self.colorScaleWidget.setColormap(self.colorMapLin2) + + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5) + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(0.5) == -5.0) + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(1.0) == 0.0) + + def testRelativePositionLog(self): + self.colorMapLog1 = Colormap(name='temperature', + normalization=Colormap.LOGARITHM, + vmin=1.0, + vmax=100.0) + + self.colorScaleWidget.setColormap(self.colorMapLog1) + + val = self.colorScaleWidget.getValueFromRelativePosition(1.0) + self.assertAlmostEqual(val, 100.0) + + val = self.colorScaleWidget.getValueFromRelativePosition(0.5) + self.assertAlmostEqual(val, 10.0) + + val = self.colorScaleWidget.getValueFromRelativePosition(0.0) + self.assertTrue(val == 1.0) + + +class TestNoAutoscale(TestCaseQt): + """Test that ticks and color displayed are correct in the case of a colormap + with no autoscale + """ + + def setUp(self): + super(TestNoAutoscale, self).setUp() + self.plot = Plot2D() + self.colorBar = self.plot.getColorBarWidget() + self.colorBar.setVisible(True) # Makes sure the colormap is visible + self.tickBar = self.colorBar.getColorScaleBar().getTickBar() + self.colorScale = self.colorBar.getColorScaleBar().getColorScale() + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + self.qapp.processEvents() + self.tickBar = None + self.colorScale = None + del self.colorBar + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestNoAutoscale, self).tearDown() + + def testLogNormNoAutoscale(self): + colormapLog = Colormap(name='gray', + normalization=Colormap.LOGARITHM, + vmin=1.0, + vmax=100.0) + + data = numpy.linspace(10, 1e10, 9).reshape(3, 3) + self.plot.addImage(data=data, colormap=colormapLog, legend='toto') + self.plot.setActiveImage('toto') + + # test Ticks + self.tickBar.setTicksNumber(10) + self.tickBar.computeTicks() + + ticksTh = numpy.linspace(1.0, 100.0, 10) + ticksTh = 10**ticksTh + numpy.array_equal(self.tickBar.ticks, ticksTh) + + # test ColorScale + val = self.colorScale.getValueFromRelativePosition(1.0) + self.assertAlmostEqual(val, 100.0) + + val = self.colorScale.getValueFromRelativePosition(0.0) + self.assertTrue(val == 1.0) + + def testLinearNormNoAutoscale(self): + colormapLog = Colormap(name='gray', + normalization=Colormap.LINEAR, + vmin=-4, + vmax=5) + + data = numpy.linspace(1, 9, 9).reshape(3, 3) + self.plot.addImage(data=data, colormap=colormapLog, legend='toto') + self.plot.setActiveImage('toto') + + # test Ticks + self.tickBar.setTicksNumber(10) + self.tickBar.computeTicks() + + numpy.array_equal(self.tickBar.ticks, numpy.linspace(-4, 5, 10)) + + # test ColorScale + val = self.colorScale.getValueFromRelativePosition(1.0) + self.assertTrue(val == 5.0) + + val = self.colorScale.getValueFromRelativePosition(0.0) + self.assertTrue(val == -4.0) + + +class TestColorBarWidget(TestCaseQt): + """Test interaction with the ColorBarWidget""" + + def setUp(self): + super(TestColorBarWidget, self).setUp() + self.plot = Plot2D() + self.colorBar = self.plot.getColorBarWidget() + self.colorBar.setVisible(True) # Makes sure the colormap is visible + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + self.qapp.processEvents() + del self.colorBar + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestColorBarWidget, self).tearDown() + + def testEmptyColorBar(self): + colorBar = ColorBarWidget(parent=None) + colorBar.show() + self.qWaitForWindowExposed(colorBar) + + def testNegativeColormaps(self): + """test the behavior of the ColorBarWidget in the case of negative + values + + Note : colorbar is modified by the Plot directly not ColorBarWidget + """ + colormapLog = Colormap(name='gray', + normalization=Colormap.LOGARITHM, + vmin=None, + vmax=None) + + data = numpy.array([-5, -4, 0, 2, 3, 5, 10, 20, 30]) + data = data.reshape(3, 3) + self.plot.addImage(data=data, colormap=colormapLog, legend='toto') + self.plot.setActiveImage('toto') + + # default behavior when with log and negative values: should set vmin + # to 1 and vmax to 10 + self.assertTrue(self.colorBar.getColorScaleBar().minVal == 2) + self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 30) + + # if data is positive + data[data < 1] = data.max() + self.plot.addImage(data=data, + colormap=colormapLog, + legend='toto', + replace=True) + self.plot.setActiveImage('toto') + + self.assertTrue(self.colorBar.getColorScaleBar().minVal == data.min()) + self.assertTrue(self.colorBar.getColorScaleBar().maxVal == data.max()) + + def testPlotAssocation(self): + """Make sure the ColorBarWidget is properly connected with the plot""" + colormap = Colormap(name='gray', + normalization=Colormap.LINEAR, + vmin=None, + vmax=None) + + # make sure that default settings are the same (but a copy of the + self.colorBar.setPlot(self.plot) + self.assertTrue( + self.colorBar.getColormap() is self.plot.getDefaultColormap()) + + data = numpy.linspace(0, 10, 100).reshape(10, 10) + self.plot.addImage(data=data, colormap=colormap, legend='toto') + self.plot.setActiveImage('toto') + + # make sure the modification of the colormap has been done + self.assertFalse( + self.colorBar.getColormap() is self.plot.getDefaultColormap()) + self.assertTrue( + self.colorBar.getColormap() is colormap) + + # test that colorbar is updated when default plot colormap changes + self.plot.clear() + plotColormap = Colormap(name='gray', + normalization=Colormap.LOGARITHM, + vmin=None, + vmax=None) + self.plot.setDefaultColormap(plotColormap) + self.assertTrue(self.colorBar.getColormap() is plotColormap) + + def testColormapWithoutRange(self): + """Test with a colormap with vmin==vmax""" + colormap = Colormap(name='gray', + normalization=Colormap.LINEAR, + vmin=1.0, + vmax=1.0) + self.colorBar.setColormap(colormap) + + +class TestColorBarUpdate(TestCaseQt): + """Test that the ColorBar is correctly updated when the signal 'sigChanged' + of the colormap is emitted + """ + + def setUp(self): + super(TestColorBarUpdate, self).setUp() + self.plot = Plot2D() + self.colorBar = self.plot.getColorBarWidget() + self.colorBar.setVisible(True) # Makes sure the colormap is visible + self.colorBar.setPlot(self.plot) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + self.data = numpy.random.rand(9).reshape(3, 3) + + def tearDown(self): + self.qapp.processEvents() + del self.colorBar + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestColorBarUpdate, self).tearDown() + + def testUpdateColorMap(self): + colormap = Colormap(name='gray', + normalization='linear', + vmin=0, + vmax=1) + + # check inital state + self.plot.addImage(data=self.data, colormap=colormap, legend='toto') + self.plot.setActiveImage('toto') + + self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0) + self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 1) + self.assertTrue( + self.colorBar.getColorScaleBar().getTickBar()._vmin == 0) + self.assertTrue( + self.colorBar.getColorScaleBar().getTickBar()._vmax == 1) + self.assertIsInstance( + self.colorBar.getColorScaleBar().getTickBar()._normalizer, + LinearNormalization) + + # update colormap + colormap.setVMin(0.5) + self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0.5) + self.assertTrue( + self.colorBar.getColorScaleBar().getTickBar()._vmin == 0.5) + + colormap.setVMax(0.8) + self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 0.8) + self.assertTrue( + self.colorBar.getColorScaleBar().getTickBar()._vmax == 0.8) + + colormap.setNormalization('log') + self.assertIsInstance( + self.colorBar.getColorScaleBar().getTickBar()._normalizer, + LogarithmicNormalization) + + # TODO : should also check that if the colormap is changing then values (especially in log scale) + # should be coherent if in autoscale diff --git a/src/silx/gui/plot/test/testCompareImages.py b/src/silx/gui/plot/test/testCompareImages.py new file mode 100644 index 0000000..cf54b99 --- /dev/null +++ b/src/silx/gui/plot/test/testCompareImages.py @@ -0,0 +1,106 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests for CompareImages widget""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "23/07/2018" + +import unittest +import numpy +import weakref + +from silx.gui.utils.testutils import TestCaseQt +from silx.gui.plot.CompareImages import CompareImages + + +class TestCompareImages(TestCaseQt): + """Test that CompareImages widget is working in some cases""" + + def setUp(self): + super(TestCompareImages, self).setUp() + self.widget = CompareImages() + + def tearDown(self): + ref = weakref.ref(self.widget) + self.widget = None + self.qWaitForDestroy(ref) + super(TestCompareImages, self).tearDown() + + def testIntensityImage(self): + image1 = numpy.random.rand(10, 10) + image2 = numpy.random.rand(10, 10) + self.widget.setData(image1, image2) + + def testRgbImage(self): + image1 = numpy.random.randint(0, 255, size=(10, 10, 3)) + image2 = numpy.random.randint(0, 255, size=(10, 10, 3)) + self.widget.setData(image1, image2) + + def testRgbaImage(self): + image1 = numpy.random.randint(0, 255, size=(10, 10, 4)) + image2 = numpy.random.randint(0, 255, size=(10, 10, 4)) + self.widget.setData(image1, image2) + + def testVizualisations(self): + image1 = numpy.random.rand(10, 10) + image2 = numpy.random.rand(10, 10) + self.widget.setData(image1, image2) + for mode in CompareImages.VisualizationMode: + self.widget.setVisualizationMode(mode) + + def testAlignemnt(self): + image1 = numpy.random.rand(10, 10) + image2 = numpy.random.rand(5, 5) + self.widget.setData(image1, image2) + for mode in CompareImages.AlignmentMode: + self.widget.setAlignmentMode(mode) + + def testGetPixel(self): + image1 = numpy.random.rand(11, 11) + image2 = numpy.random.rand(5, 5) + image1[5, 5] = 111.111 + image2[2, 2] = 222.222 + self.widget.setData(image1, image2) + expectedValue = {} + expectedValue[CompareImages.AlignmentMode.CENTER] = 222.222 + expectedValue[CompareImages.AlignmentMode.STRETCH] = 222.222 + expectedValue[CompareImages.AlignmentMode.ORIGIN] = None + for mode in expectedValue.keys(): + self.widget.setAlignmentMode(mode) + data = self.widget.getRawPixelData(11 / 2.0, 11 / 2.0) + data1, data2 = data + self.assertEqual(data1, 111.111) + self.assertEqual(data2, expectedValue[mode]) + + def testImageEmpty(self): + self.widget.setData(image1=None, image2=None) + self.assertTrue(self.widget.getRawPixelData(11 / 2.0, 11 / 2.0) == (None, None)) + + def testSetImageSeparately(self): + self.widget.setImage1(numpy.random.rand(10, 10)) + self.widget.setImage2(numpy.random.rand(10, 10)) + for mode in CompareImages.VisualizationMode: + self.widget.setVisualizationMode(mode) diff --git a/src/silx/gui/plot/test/testComplexImageView.py b/src/silx/gui/plot/test/testComplexImageView.py new file mode 100644 index 0000000..46025b9 --- /dev/null +++ b/src/silx/gui/plot/test/testComplexImageView.py @@ -0,0 +1,84 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test suite for :class:`ComplexImageView`""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "17/01/2018" + + +import unittest +import logging +import numpy + +from silx.utils.testutils import ParametricTestCase +from silx.gui.plot import ComplexImageView + +from .utils import PlotWidgetTestCase + + +logger = logging.getLogger(__name__) + + +class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase): + """Test suite of ComplexImageView widget""" + + def _createPlot(self): + return ComplexImageView.ComplexImageView() + + def testPlot2DComplex(self): + """Test API of ComplexImageView widget""" + data = numpy.array(((0, 1j), (1, 1 + 1j)), dtype=numpy.complex64) + self.plot.setData(data) + self.plot.setKeepDataAspectRatio(True) + self.plot.getPlot().resetZoom() + self.qWait(100) + + # Test colormap API + colormap = self.plot.getColormap().copy() + colormap.setName('magma') + self.plot.setColormap(colormap) + self.qWait(100) + + # Test all modes + modes = self.plot.supportedComplexModes() + for mode in modes: + with self.subTest(mode=mode): + self.plot.setComplexMode(mode) + self.qWait(100) + + # Test origin and scale API + self.plot.setScale((2, 1)) + self.qWait(100) + self.plot.setOrigin((1, 1)) + self.qWait(100) + + # Test no data + self.plot.setData(numpy.zeros((0, 0), dtype=numpy.complex64)) + self.qWait(100) + + # Test float data + self.plot.setData(numpy.arange(100, dtype=numpy.float64).reshape(10, 10)) + self.qWait(100) diff --git a/src/silx/gui/plot/test/testCurvesROIWidget.py b/src/silx/gui/plot/test/testCurvesROIWidget.py new file mode 100644 index 0000000..d7dfafd --- /dev/null +++ b/src/silx/gui/plot/test/testCurvesROIWidget.py @@ -0,0 +1,465 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for CurvesROIWidget""" + +__authors__ = ["T. Vincent", "P. Knobel", "H. Payno"] +__license__ = "MIT" +__date__ = "16/11/2017" + + +import logging +import os.path +import pytest +from collections import OrderedDict +import numpy + +from silx.gui import qt +from silx.gui.plot import items +from silx.gui.plot import Plot1D +from silx.test.utils import temp_dir +from silx.gui.utils.testutils import TestCaseQt, SignalListener +from silx.gui.plot import PlotWindow, CurvesROIWidget +from silx.gui.plot.CurvesROIWidget import ROITable +from silx.gui.utils.testutils import getQToolButtonFromAction +from silx.gui.plot.PlotInteraction import ItemsInteraction + +_logger = logging.getLogger(__name__) + + +class TestCurvesROIWidget(TestCaseQt): + """Basic test for CurvesROIWidget""" + + def setUp(self): + super(TestCurvesROIWidget, self).setUp() + self.plot = PlotWindow() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.widget = self.plot.getCurvesRoiDockWidget() + + self.widget.show() + self.qWaitForWindowExposed(self.widget) + + def tearDown(self): + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + del self.widget + + super(TestCurvesROIWidget, self).tearDown() + + def testDummyAPI(self): + """Simple test of the getRois and setRois API""" + roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20, + todata=-10, type_='X') + roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10, + todata=20, type_='X') + + self.widget.roiWidget.setRois((roi_pos, roi_neg)) + + rois_defs = self.widget.roiWidget.getRois() + self.widget.roiWidget.setRois(rois=rois_defs) + + def testWithCurves(self): + """Plot with curves: test all ROI widget buttons""" + for offset in range(2): + self.plot.addCurve(numpy.arange(1000), + offset + numpy.random.random(1000), + legend=str(offset)) + + # Add two ROI + self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) + self.qWait(200) + self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) + self.qWait(200) + + # Change active curve + self.plot.setActiveCurve(str(1)) + + # Delete a ROI + self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton) + self.qWait(200) + + with temp_dir() as tmpDir: + self.tmpFile = os.path.join(tmpDir, 'test.ini') + + # Save ROIs + self.widget.roiWidget.save(self.tmpFile) + self.assertTrue(os.path.isfile(self.tmpFile)) + self.assertEqual(len(self.widget.getRois()), 2) + + # Reset ROIs + self.mouseClick(self.widget.roiWidget.resetButton, + qt.Qt.LeftButton) + self.qWait(200) + rois = self.widget.getRois() + self.assertEqual(len(rois), 1) + roiID = list(rois.keys())[0] + self.assertEqual(rois[roiID].getName(), 'ICR') + + # Load ROIs + self.widget.roiWidget.load(self.tmpFile) + self.assertEqual(len(self.widget.getRois()), 2) + + del self.tmpFile + + def testMiddleMarker(self): + """Test with middle marker enabled""" + self.widget.roiWidget.roiTable.setMiddleROIMarkerFlag(True) + + # Add a ROI + self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) + + for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers: + handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[roiID] + assert handler.getMarker('min') + xleftMarker = handler.getMarker('min').getXPosition() + xMiddleMarker = handler.getMarker('middle').getXPosition() + xRightMarker = handler.getMarker('max').getXPosition() + thValue = xleftMarker + (xRightMarker - xleftMarker) / 2. + self.assertAlmostEqual(xMiddleMarker, thValue) + + def testAreaCalculation(self): + """Test result of area calculation""" + x = numpy.arange(100.) + y = numpy.arange(100.) + + # Add two curves + self.plot.addCurve(x, y, legend="positive") + self.plot.addCurve(-x, y, legend="negative") + + # Make sure there is an active curve and it is the positive one + self.plot.setActiveCurve("positive") + + # Add two ROIs + roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20, + todata=-10, type_='X') + roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10, + todata=20, type_='X') + + self.widget.roiWidget.setRois((roi_pos, roi_neg)) + + posCurve = self.plot.getCurve('positive') + negCurve = self.plot.getCurve('negative') + + self.assertEqual(roi_pos.computeRawAndNetArea(posCurve), + (numpy.trapz(y=[10, 20], x=[10, 20]), + 0.0)) + self.assertEqual(roi_pos.computeRawAndNetArea(negCurve), + (0.0, 0.0)) + self.assertEqual(roi_neg.computeRawAndNetArea(posCurve), + ((0.0), 0.0)) + self.assertEqual(roi_neg.computeRawAndNetArea(negCurve), + ((-150.0), 0.0)) + + def testCountsCalculation(self): + """Test result of count calculation""" + x = numpy.arange(100.) + y = numpy.arange(100.) + + # Add two curves + self.plot.addCurve(x, y, legend="positive") + self.plot.addCurve(-x, y, legend="negative") + + # Make sure there is an active curve and it is the positive one + self.plot.setActiveCurve("positive") + + # Add two ROIs + roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20, + todata=-10, type_='X') + roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10, + todata=20, type_='X') + + self.widget.roiWidget.setRois((roi_pos, roi_neg)) + + posCurve = self.plot.getCurve('positive') + negCurve = self.plot.getCurve('negative') + + self.assertEqual(roi_pos.computeRawAndNetCounts(posCurve), + (y[10:21].sum(), 0.0)) + self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve), + (0.0, 0.0)) + self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve), + ((0.0), 0.0)) + self.assertEqual(roi_neg.computeRawAndNetCounts(negCurve), + (y[10:21].sum(), 0.0)) + + def testDeferedInit(self): + """Test behavior of the deferedInit""" + x = numpy.arange(100.) + y = numpy.arange(100.) + self.plot.addCurve(x=x, y=y, legend="name", replace="True") + roisDefs = OrderedDict([ + ["range1", + OrderedDict([["from", 20], ["to", 200], ["type", "energy"]])], + ["range2", + OrderedDict([["from", 300], ["to", 500], ["type", "energy"]])] + ]) + + roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget + self.plot.getCurvesRoiDockWidget().setRois(roisDefs) + self.assertEqual(len(roiWidget.getRois()), len(roisDefs)) + self.plot.getCurvesRoiDockWidget().setVisible(True) + self.assertEqual(len(roiWidget.getRois()), len(roisDefs)) + + def testDictCompatibility(self): + """Test that ROI api is valid with dict and not information is lost""" + roiDict = {'from': 20, 'to': 200, 'type': 'energy', 'comment': 'no', + 'name': 'myROI', 'calibration': [1, 2, 3]} + roi = CurvesROIWidget.ROI._fromDict(roiDict) + self.assertEqual(roi.toDict(), roiDict) + + def testShowAllROI(self): + """Test the show allROI action""" + x = numpy.arange(100.) + y = numpy.arange(100.) + self.plot.addCurve(x=x, y=y, legend="name", replace="True") + + roisDefsDict = { + "range1": {"from": 20, "to": 200,"type": "energy"}, + "range2": {"from": 300, "to": 500, "type": "energy"} + } + + roisDefsObj = ( + CurvesROIWidget.ROI(name='range3', fromdata=20, todata=200, + type_='energy'), + CurvesROIWidget.ROI(name='range4', fromdata=300, todata=500, + type_='energy') + ) + self.widget.roiWidget.showAllMarkers(True) + roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget + roiWidget.setRois(roisDefsDict) + markers = [item for item in self.plot.getItems() + if isinstance(item, items.MarkerBase)] + self.assertEqual(len(markers), 2*3) + + markersHandler = self.widget.roiWidget.roiTable._markersHandler + roiWidget.showAllMarkers(True) + ICRROI = markersHandler.getVisibleRois() + self.assertEqual(len(ICRROI), 2) + + roiWidget.showAllMarkers(False) + ICRROI = markersHandler.getVisibleRois() + self.assertEqual(len(ICRROI), 1) + + roiWidget.setRois(roisDefsObj) + self.qapp.processEvents() + markers = [item for item in self.plot.getItems() + if isinstance(item, items.MarkerBase)] + self.assertEqual(len(markers), 2*3) + + markersHandler = self.widget.roiWidget.roiTable._markersHandler + roiWidget.showAllMarkers(True) + ICRROI = markersHandler.getVisibleRois() + self.assertEqual(len(ICRROI), 2) + + roiWidget.showAllMarkers(False) + ICRROI = markersHandler.getVisibleRois() + self.assertEqual(len(ICRROI), 1) + + def testRoiEdition(self): + """Make sure if the ROI object is edited the ROITable will be updated + """ + roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.widget.roiWidget.setRois((roi, )) + + x = (0, 1, 1, 2, 2, 3) + y = (1, 1, 2, 2, 1, 1) + self.plot.addCurve(x=x, y=y, legend='linearCurve') + self.plot.setActiveCurve(legend='linearCurve') + self.widget.calculateROIs() + + roiTable = self.widget.roiWidget.roiTable + indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX + itemRawCounts = roiTable.item(0, indexesColumns['Raw Counts']) + itemNetCounts = roiTable.item(0, indexesColumns['Net Counts']) + + self.assertTrue(itemRawCounts.text() == '8.0') + self.assertTrue(itemNetCounts.text() == '2.0') + + itemRawArea = roiTable.item(0, indexesColumns['Raw Area']) + itemNetArea = roiTable.item(0, indexesColumns['Net Area']) + + self.assertTrue(itemRawArea.text() == '4.0') + self.assertTrue(itemNetArea.text() == '1.0') + + roi.setTo(2) + itemRawArea = roiTable.item(0, indexesColumns['Raw Area']) + self.assertTrue(itemRawArea.text() == '3.0') + roi.setFrom(1) + itemRawArea = roiTable.item(0, indexesColumns['Raw Area']) + self.assertTrue(itemRawArea.text() == '2.0') + + def testRemoveActiveROI(self): + """Test widget behavior when removing the active ROI""" + roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.widget.roiWidget.setRois((roi,)) + + self.widget.roiWidget.roiTable.setActiveRoi(None) + self.assertEqual(len(self.widget.roiWidget.roiTable.selectedItems()), 0) + self.widget.roiWidget.setRois((roi,)) + self.plot.setActiveCurve(legend='linearCurve') + self.widget.calculateROIs() + + def testEmitCurrentROI(self): + """Test behavior of the CurvesROIWidget.sigROISignal""" + roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.widget.roiWidget.setRois((roi,)) + signalListener = SignalListener() + self.widget.roiWidget.sigROISignal.connect(signalListener.partial()) + self.widget.show() + self.qapp.processEvents() + self.assertEqual(signalListener.callCount(), 0) + self.assertIs(self.widget.roiWidget.roiTable.activeRoi, roi) + roi.setFrom(0.0) + self.qapp.processEvents() + self.assertEqual(signalListener.callCount(), 0) + roi.setFrom(0.3) + self.qapp.processEvents() + self.assertEqual(signalListener.callCount(), 1) + + +class TestRoiWidgetSignals(TestCaseQt): + """Test Signals emitted by the RoiWidgetSignals""" + + def setUp(self): + self.plot = Plot1D() + x = range(20) + y = range(20) + self.plot.addCurve(x, y, legend='curve0') + self.listener = SignalListener() + self.curves_roi_widget = self.plot.getCurvesRoiWidget() + self.curves_roi_widget.sigROISignal.connect(self.listener) + assert self.curves_roi_widget.isVisible() is False + assert self.listener.callCount() == 0 + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + toolButton = getQToolButtonFromAction(self.plot.getRoiAction()) + self.qapp.processEvents() + self.mouseClick(widget=toolButton, button=qt.Qt.LeftButton) + + self.curves_roi_widget.show() + self.qWaitForWindowExposed(self.curves_roi_widget) + + def tearDown(self): + self.plot = None + self.curves_roi_widget = None + + def testSigROISignalAddRmRois(self): + """Test SigROISignal when adding and removing ROIS""" + self.listener.clear() + + roi1 = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi1) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') + self.listener.clear() + + roi2 = CurvesROIWidget.ROI(name='linear2', fromdata=0, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi2) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear2') + self.listener.clear() + + self.curves_roi_widget.roiTable.removeROI(roi2) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') + self.listener.clear() + + self.curves_roi_widget.roiTable.deleteActiveRoi() + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.curves_roi_widget.roiTable.activeRoi is None) + self.assertTrue(self.listener.arguments()[0][0]['current'] is None) + self.listener.clear() + + self.curves_roi_widget.roiTable.addRoi(roi1) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') + self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1) + self.listener.clear() + self.qapp.processEvents() + + self.curves_roi_widget.roiTable.removeROI(roi1) + self.qapp.processEvents() + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'ICR') + self.listener.clear() + + def testSigROISignalModifyROI(self): + """Test SigROISignal when modifying it""" + self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True) + roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi1) + self.curves_roi_widget.roiTable.setActiveRoi(roi1) + + # test modify the roi2 object + self.listener.clear() + roi1.setFrom(0.56) + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + roi1.setTo(2.56) + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + roi1.setName('linear2') + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + roi1.setType('new type') + self.assertEqual(self.listener.callCount(), 1) + + widget = self.plot.getWidgetHandle() + widget.setFocus(qt.Qt.OtherFocusReason) + self.plot.raise_() + self.qapp.processEvents() + + # modify roi limits (from the gui) + roi_marker_handler = self.curves_roi_widget.roiTable._markersHandler.getMarkerHandler(roi1.getID()) + for marker_type in ('min', 'max', 'middle'): + with self.subTest(marker_type=marker_type): + self.listener.clear() + marker = roi_marker_handler.getMarker(marker_type) + x_pix, y_pix = self.plot.dataToPixel(marker.getXPosition(), marker.getYPosition()) + self.mouseMove(widget, pos=(x_pix, y_pix)) + self.qWait(100) + self.mousePress(widget, qt.Qt.LeftButton, pos=(x_pix, y_pix)) + self.mouseMove(widget, pos=(x_pix+20, y_pix)) + self.qWait(100) + self.mouseRelease(widget, qt.Qt.LeftButton, pos=(x_pix+20, y_pix)) + self.qWait(100) + self.mouseMove(widget, pos=(x_pix, y_pix)) + self.qapp.processEvents() + self.assertEqual(self.listener.callCount(), 1) + + def testSetActiveCurve(self): + """Test sigRoiSignal when set an active curve""" + roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5) + self.curves_roi_widget.roiTable.setActiveRoi(roi1) + self.listener.clear() + self.plot.setActiveCurve('curve0') + self.assertEqual(self.listener.callCount(), 0) diff --git a/src/silx/gui/plot/test/testImageStack.py b/src/silx/gui/plot/test/testImageStack.py new file mode 100644 index 0000000..5c44691 --- /dev/null +++ b/src/silx/gui/plot/test/testImageStack.py @@ -0,0 +1,186 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for ImageStack""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "15/01/2020" + + +import unittest +import tempfile +import numpy +import h5py + +from silx.gui import qt +from silx.gui.utils.testutils import TestCaseQt +from silx.io.url import DataUrl +from silx.gui.plot.ImageStack import ImageStack +from silx.gui.utils.testutils import SignalListener +from collections import OrderedDict +import os +import time +import shutil + + +class TestImageStack(TestCaseQt): + """Simple test of the Image stack""" + + def setUp(self): + TestCaseQt.setUp(self) + self.urls = OrderedDict() + self._raw_data = {} + self._folder = tempfile.mkdtemp() + self._n_urls = 10 + file_name = os.path.join(self._folder, 'test_inage_stack_file.h5') + with h5py.File(file_name, 'w') as h5f: + for i in range(self._n_urls): + width = numpy.random.randint(10, 40) + height = numpy.random.randint(10, 40) + raw_data = numpy.random.random((width, height)) + self._raw_data[i] = raw_data + h5f[str(i)] = raw_data + self.urls[i] = DataUrl(file_path=file_name, + data_path=str(i), + scheme='silx') + self.widget = ImageStack() + + self.urlLoadedListener = SignalListener() + self.widget.sigLoaded.connect(self.urlLoadedListener) + + self.currentUrlChangedListener = SignalListener() + self.widget.sigCurrentUrlChanged.connect(self.currentUrlChangedListener) + + def tearDown(self): + shutil.rmtree(self._folder) + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose, True) + self.widget.close() + TestCaseQt.setUp(self) + + def testControls(self): + """Test that selection using the url table and the slider are working + """ + self.widget.show() + self.assertEqual(self.widget.getCurrentUrl(), None) + self.assertEqual(self.widget.getCurrentUrlIndex(), None) + self.widget.setUrls(list(self.urls.values())) + + # wait for image to be loaded + self._waitUntilUrlLoaded() + + self.assertEqual(self.widget.getCurrentUrl(), self.urls[0]) + + # make sure all image are loaded + self.assertEqual(self.urlLoadedListener.callCount(), self._n_urls) + numpy.testing.assert_array_equal( + self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(), + self._raw_data[0]) + self.assertEqual(self.widget._slider.value(), 0) + + self.widget._urlsTable.setUrl(self.urls[4]) + numpy.testing.assert_array_equal( + self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(), + self._raw_data[4]) + self.assertEqual(self.widget._slider.value(), 4) + self.assertEqual(self.widget.getCurrentUrl(), self.urls[4]) + self.assertEqual(self.widget.getCurrentUrlIndex(), 4) + + self.widget._slider.setUrlIndex(6) + numpy.testing.assert_array_equal( + self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(), + self._raw_data[6]) + self.assertEqual(self.widget._urlsTable.currentItem().text(), + self.urls[6].path()) + + def testCurrentUrlSignals(self): + """Test emission of 'currentUrlChangedListener'""" + # check initialization + self.assertEqual(self.currentUrlChangedListener.callCount(), 0) + self.widget.setUrls(list(self.urls.values())) + self.qapp.processEvents() + time.sleep(0.5) + self.qapp.processEvents() + # once loaded the two signals should have been sended + self.assertEqual(self.currentUrlChangedListener.callCount(), 1) + # if the slider is stuck to the same position no signal should be + # emitted + self.qapp.processEvents() + time.sleep(0.5) + self.qapp.processEvents() + self.assertEqual(self.widget._slider.value(), 0) + self.assertEqual(self.currentUrlChangedListener.callCount(), 1) + # if slider position is changed, one of each signal should have been + # emitted + self.widget._urlsTable.setUrl(self.urls[4]) + self.qapp.processEvents() + time.sleep(1.5) + self.qapp.processEvents() + self.assertEqual(self.currentUrlChangedListener.callCount(), 2) + + def testUtils(self): + """Test that some utils functions are working""" + self.widget.show() + self.widget.setUrls(list(self.urls.values())) + self.assertEqual(len(self.widget.getUrls()), len(self.urls)) + + # wait for image to be loaded + self._waitUntilUrlLoaded() + + urls_values = list(self.urls.values()) + self.assertEqual(urls_values[0], self.urls[0]) + self.assertEqual(urls_values[7], self.urls[7]) + + self.assertEqual(self.widget._getNextUrl(urls_values[2]).path(), + urls_values[3].path()) + self.assertEqual(self.widget._getPreviousUrl(urls_values[0]), None) + self.assertEqual(self.widget._getPreviousUrl(urls_values[6]).path(), + urls_values[5].path()) + + self.assertEqual(self.widget._getNNextUrls(2, urls_values[0]), + urls_values[1:3]) + self.assertEqual(self.widget._getNNextUrls(5, urls_values[7]), + urls_values[8:]) + self.assertEqual(self.widget._getNPreviousUrls(3, urls_values[2]), + urls_values[:2]) + self.assertEqual(self.widget._getNPreviousUrls(5, urls_values[8]), + urls_values[3:8]) + + def _waitUntilUrlLoaded(self, timeout=2.0): + """Wait until all image urls are loaded""" + loop_duration = 0.2 + remaining_duration = timeout + while(len(self.widget._loadingThreads) > 0 and remaining_duration > 0): + remaining_duration -= loop_duration + time.sleep(loop_duration) + self.qapp.processEvents() + + if remaining_duration <= 0.0: + remaining_urls = [] + for thread_ in self.widget._loadingThreads: + remaining_urls.append(thread_.url.path()) + mess = 'All images are not loaded after the time out. ' \ + 'Remaining urls are: ' + str(remaining_urls) + raise TimeoutError(mess) + return True diff --git a/src/silx/gui/plot/test/testImageView.py b/src/silx/gui/plot/test/testImageView.py new file mode 100644 index 0000000..7c1355f --- /dev/null +++ b/src/silx/gui/plot/test/testImageView.py @@ -0,0 +1,194 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWindow""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "24/04/2018" + + +import numpy + +from silx.gui import qt +from silx.gui.utils.testutils import TestCaseQt +from silx.gui.plot import items + +from silx.gui.plot.ImageView import ImageView +from silx.gui.colors import Colormap + + +class TestImageView(TestCaseQt): + """Tests of ImageView widget.""" + + def setUp(self): + super(TestImageView, self).setUp() + self.plot = ImageView() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + self.qapp.processEvents() + super(TestImageView, self).tearDown() + + def testSetImage(self): + """Test setImage""" + image = numpy.arange(100).reshape(10, 10) + + self.plot.setImage(image, reset=True) + self.qWait(100) + self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10)) + self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10)) + + # With reset=False + self.plot.setImage(image[::2, ::2], reset=False) + self.qWait(100) + self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10)) + self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10)) + + self.plot.setImage(image, origin=(10, 20), scale=(2, 4), reset=False) + self.qWait(100) + self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10)) + self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10)) + + # With reset=True + self.plot.setImage(image, origin=(1, 2), scale=(1, 0.5), reset=True) + self.qWait(100) + self.assertEqual(self.plot.getXAxis().getLimits(), (1, 11)) + self.assertEqual(self.plot.getYAxis().getLimits(), (2, 7)) + + self.plot.setImage(image[::2, ::2], reset=True) + self.qWait(100) + self.assertEqual(self.plot.getXAxis().getLimits(), (0, 5)) + self.assertEqual(self.plot.getYAxis().getLimits(), (0, 5)) + + def testColormap(self): + """Test get|setColormap""" + image = numpy.arange(100).reshape(10, 10) + self.plot.setImage(image) + + # Colormap as dict + self.plot.setColormap({'name': 'viridis', + 'normalization': 'log', + 'autoscale': False, + 'vmin': 0, + 'vmax': 1}) + colormap = self.plot.getColormap() + self.assertEqual(colormap.getName(), 'viridis') + self.assertEqual(colormap.getNormalization(), 'log') + self.assertEqual(colormap.getVMin(), 0) + self.assertEqual(colormap.getVMax(), 1) + + # Colormap as keyword arguments + self.plot.setColormap(colormap='magma', + normalization='linear', + autoscale=True, + vmin=1, + vmax=2) + self.assertEqual(colormap.getName(), 'magma') + self.assertEqual(colormap.getNormalization(), 'linear') + self.assertEqual(colormap.getVMin(), None) + self.assertEqual(colormap.getVMax(), None) + + # Update colormap with keyword argument + self.plot.setColormap(normalization='log') + self.assertEqual(colormap.getNormalization(), 'log') + + # Colormap as Colormap object + cmap = Colormap() + self.plot.setColormap(cmap) + self.assertIs(self.plot.getColormap(), cmap) + + def testSetProfileWindowBehavior(self): + """Test change of profile window display behavior""" + self.assertIs( + self.plot.getProfileWindowBehavior(), + ImageView.ProfileWindowBehavior.POPUP, + ) + + self.plot.setProfileWindowBehavior('embedded') + self.assertIs( + self.plot.getProfileWindowBehavior(), + ImageView.ProfileWindowBehavior.EMBEDDED, + ) + + image = numpy.arange(100).reshape(10, 10) + self.plot.setImage(image) + + self.plot.setProfileWindowBehavior( + ImageView.ProfileWindowBehavior.POPUP + ) + self.assertIs( + self.plot.getProfileWindowBehavior(), + ImageView.ProfileWindowBehavior.POPUP, + ) + + def testRGBImage(self): + """Test setImage""" + image = numpy.arange(100 * 3, dtype=numpy.uint8).reshape(10, 10, 3) + + self.plot.setImage(image, reset=True) + self.qWait(100) + self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10)) + self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10)) + + def testRGBAImage(self): + """Test setImage""" + image = numpy.arange(100 * 4, dtype=numpy.uint8).reshape(10, 10, 4) + + self.plot.setImage(image, reset=True) + self.qWait(100) + self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10)) + self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10)) + + def testImageAggregationMode(self): + """Test setImage""" + image = numpy.arange(100).reshape(10, 10) + self.plot.setImage(image, reset=True) + self.qWait(100) + self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX) + self.qWait(100) + + def testImageAggregationModeBackToNormalMode(self): + """Test setImage""" + image = numpy.arange(100).reshape(10, 10) + self.plot.setImage(image, reset=True) + self.qWait(100) + self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX) + self.qWait(100) + self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.NONE) + self.qWait(100) + + def testRGBAInAggregationMode(self): + """Test setImage""" + image = numpy.arange(100 * 3, dtype=numpy.uint8).reshape(10, 10, 3) + + self.plot.setImage(image, reset=True) + self.qWait(100) + self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX) + self.qWait(100) diff --git a/src/silx/gui/plot/test/testInteraction.py b/src/silx/gui/plot/test/testInteraction.py new file mode 100644 index 0000000..d136b21 --- /dev/null +++ b/src/silx/gui/plot/test/testInteraction.py @@ -0,0 +1,78 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests from interaction state machines""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/02/2016" + + +import unittest + +from silx.gui.plot import Interaction + + +class TestInteraction(unittest.TestCase): + def testClickOrDrag(self): + """Minimalistic test for click or drag state machine.""" + events = [] + + class TestClickOrDrag(Interaction.ClickOrDrag): + def click(self, x, y, btn): + events.append(('click', x, y, btn)) + + def beginDrag(self, x, y, btn): + events.append(('beginDrag', x, y, btn)) + + def drag(self, x, y, btn): + events.append(('drag', x, y, btn)) + + def endDrag(self, start, end, btn): + events.append(('endDrag', start, end, btn)) + + clickOrDrag = TestClickOrDrag() + + # click + clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN) + self.assertEqual(len(events), 0) + + clickOrDrag.handleEvent('release', 10, 10, Interaction.LEFT_BTN) + self.assertEqual(len(events), 1) + self.assertEqual(events[0], ('click', 10, 10, Interaction.LEFT_BTN)) + + # drag + events = [] + clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN) + self.assertEqual(len(events), 0) + clickOrDrag.handleEvent('move', 15, 10) + self.assertEqual(len(events), 2) # Received beginDrag and drag + self.assertEqual(events[0], ('beginDrag', 10, 10, Interaction.LEFT_BTN)) + self.assertEqual(events[1], ('drag', 15, 10, Interaction.LEFT_BTN)) + clickOrDrag.handleEvent('move', 20, 10) + self.assertEqual(len(events), 3) + self.assertEqual(events[-1], ('drag', 20, 10, Interaction.LEFT_BTN)) + clickOrDrag.handleEvent('release', 20, 10, Interaction.LEFT_BTN) + self.assertEqual(len(events), 4) + self.assertEqual(events[-1], ('endDrag', (10, 10), (20, 10), Interaction.LEFT_BTN)) diff --git a/src/silx/gui/plot/test/testItem.py b/src/silx/gui/plot/test/testItem.py new file mode 100644 index 0000000..0b15dc3 --- /dev/null +++ b/src/silx/gui/plot/test/testItem.py @@ -0,0 +1,360 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests for PlotWidget items.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "01/09/2017" + + +import unittest + +import numpy + +from silx.gui.utils.testutils import SignalListener +from silx.gui.plot.items import ItemChangedType +from silx.gui.plot import items +from .utils import PlotWidgetTestCase + + +class TestSigItemChangedSignal(PlotWidgetTestCase): + """Test item's sigItemChanged signal""" + + def testCurveChanged(self): + """Test sigItemChanged for curve""" + self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test') + curve = self.plot.getCurve('test') + + listener = SignalListener() + curve.sigItemChanged.connect(listener) + + # Test for signal in Item class + curve.setVisible(False) + curve.setVisible(True) + curve.setZValue(100) + + # Test for signals in PointsBase class + curve.setData(numpy.arange(100), numpy.arange(100)) + + # SymbolMixIn + curve.setSymbol('Circle') + curve.setSymbol('d') + curve.setSymbolSize(20) + + # AlphaMixIn + curve.setAlpha(0.5) + + # Test for signals in Curve class + # ColorMixIn + curve.setColor('yellow') + # YAxisMixIn + curve.setYAxis('right') + # FillMixIn + curve.setFill(True) + # LineMixIn + curve.setLineStyle(':') + curve.setLineStyle(':') # Not sending event + curve.setLineWidth(2) + + self.assertEqual(listener.arguments(argumentIndex=0), + [ItemChangedType.VISIBLE, + ItemChangedType.VISIBLE, + ItemChangedType.ZVALUE, + ItemChangedType.DATA, + ItemChangedType.SYMBOL, + ItemChangedType.SYMBOL, + ItemChangedType.SYMBOL_SIZE, + ItemChangedType.ALPHA, + ItemChangedType.COLOR, + ItemChangedType.YAXIS, + ItemChangedType.FILL, + ItemChangedType.LINE_STYLE, + ItemChangedType.LINE_WIDTH]) + + def testHistogramChanged(self): + """Test sigItemChanged for Histogram""" + self.plot.addHistogram( + numpy.arange(10), edges=numpy.arange(11), legend='test') + histogram = self.plot.getHistogram('test') + listener = SignalListener() + histogram.sigItemChanged.connect(listener) + + # Test signals in Histogram class + histogram.setData(numpy.zeros(10), numpy.arange(11)) + + self.assertEqual(listener.arguments(argumentIndex=0), + [ItemChangedType.DATA]) + + def testImageDataChanged(self): + """Test sigItemChanged for ImageData""" + self.plot.addImage(numpy.arange(100).reshape(10, 10), legend='test') + image = self.plot.getImage('test') + + listener = SignalListener() + image.sigItemChanged.connect(listener) + + # ColormapMixIn + colormap = self.plot.getDefaultColormap().copy() + image.setColormap(colormap) + image.getColormap().setName('viridis') + + # Test of signals in ImageBase class + image.setOrigin(10) + image.setScale(2) + + # Test of signals in ImageData class + image.setData(numpy.ones((10, 10))) + + self.assertEqual(listener.arguments(argumentIndex=0), + [ItemChangedType.COLORMAP, + ItemChangedType.COLORMAP, + ItemChangedType.POSITION, + ItemChangedType.SCALE, + ItemChangedType.COLORMAP, + ItemChangedType.DATA]) + + def testImageRgbaChanged(self): + """Test sigItemChanged for ImageRgba""" + self.plot.addImage(numpy.ones((10, 10, 3)), legend='rgb') + image = self.plot.getImage('rgb') + + listener = SignalListener() + image.sigItemChanged.connect(listener) + + # Test of signals in ImageRgba class + image.setData(numpy.zeros((10, 10, 3))) + + self.assertEqual(listener.arguments(argumentIndex=0), + [ItemChangedType.DATA]) + + def testMarkerChanged(self): + """Test sigItemChanged for markers""" + self.plot.addMarker(10, 20, legend='test') + marker = self.plot._getMarker('test') + + listener = SignalListener() + marker.sigItemChanged.connect(listener) + + # Test signals in _BaseMarker + marker.setPosition(10, 10) + marker.setPosition(10, 10) # Not sending event + marker.setText('toto') + self.assertEqual(listener.arguments(argumentIndex=0), + [ItemChangedType.POSITION, + ItemChangedType.TEXT]) + + # XMarker + self.plot.addXMarker(10, legend='x') + marker = self.plot._getMarker('x') + + listener = SignalListener() + marker.sigItemChanged.connect(listener) + marker.setPosition(20, 20) + self.assertEqual(listener.arguments(argumentIndex=0), + [ItemChangedType.POSITION]) + + # YMarker + self.plot.addYMarker(10, legend='x') + marker = self.plot._getMarker('x') + + listener = SignalListener() + marker.sigItemChanged.connect(listener) + marker.setPosition(20, 20) + self.assertEqual(listener.arguments(argumentIndex=0), + [ItemChangedType.POSITION]) + + def testScatterChanged(self): + """Test sigItemChanged for scatter""" + data = numpy.arange(10) + self.plot.addScatter(data, data, data, legend='test') + scatter = self.plot.getScatter('test') + + listener = SignalListener() + scatter.sigItemChanged.connect(listener) + + # ColormapMixIn + scatter.getColormap().setName('viridis') + + # Test of signals in Scatter class + scatter.setData((0, 1, 2), (1, 0, 2), (0, 1, 2)) + + # Visualization mode changed + scatter.setVisualization(scatter.Visualization.SOLID) + + self.assertEqual(listener.arguments(), + [(ItemChangedType.COLORMAP,), + (ItemChangedType.DATA,), + (ItemChangedType.COLORMAP,), + (ItemChangedType.VISUALIZATION_MODE,)]) + + def testShapeChanged(self): + """Test sigItemChanged for shape""" + data = numpy.array((1., 10.)) + self.plot.addShape(data, data, legend='test', shape='rectangle') + shape = self.plot._getItem(kind='item', legend='test') + + listener = SignalListener() + shape.sigItemChanged.connect(listener) + + shape.setOverlay(True) + shape.setPoints(((2., 2.), (3., 3.))) + + self.assertEqual(listener.arguments(), + [(ItemChangedType.OVERLAY,), + (ItemChangedType.DATA,)]) + + +class TestSymbol(PlotWidgetTestCase): + """Test item's symbol """ + + def test(self): + """Test sigItemChanged for curve""" + self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test') + curve = self.plot.getCurve('test') + + # SymbolMixIn + curve.setSymbol('o') + name = curve.getSymbolName() + self.assertEqual('Circle', name) + + name = curve.getSymbolName('d') + self.assertEqual('Diamond', name) + + +class TestVisibleExtent(PlotWidgetTestCase): + """Test item's visible extent feature""" + + def testGetVisibleBounds(self): + """Test Item.getVisibleBounds""" + + # Create test items (with a bounding box of x: [1,3], y: [0,2]) + curve = items.Curve() + curve.setData((1, 2, 3), (0, 1, 2)) + + histogram = items.Histogram() + histogram.setData((0, 1, 2), (1, 5/3, 7/3, 3)) + + image = items.ImageData() + image.setOrigin((1, 0)) + image.setData(numpy.arange(4).reshape(2, 2)) + + scatter = items.Scatter() + scatter.setData((1, 2, 3), (0, 1, 2), (1, 2, 3)) + + bbox = items.BoundingRect() + bbox.setBounds((1, 3, 0, 2)) + + xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis() + for item in (curve, histogram, image, scatter, bbox): + with self.subTest(item=item): + xaxis.setLimits(0, 100) + yaxis.setLimits(0, 100) + self.plot.addItem(item) + self.assertEqual(item.getVisibleBounds(), (1., 3., 0., 2.)) + + xaxis.setLimits(0.5, 2.5) + self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0., 2.)) + + yaxis.setLimits(0.5, 1.5) + self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.5, 1.5)) + + item.setVisible(False) + self.assertIsNone(item.getVisibleBounds()) + + self.plot.clear() + + def testVisibleExtentTracking(self): + """Test Item's visible extent tracking""" + image = items.ImageData() + image.setData(numpy.arange(6).reshape(2, 3)) + + listener = SignalListener() + image._sigVisibleBoundsChanged.connect(listener) + image._setVisibleBoundsTracking(True) + self.assertTrue(image._isVisibleBoundsTracking()) + + self.plot.addItem(image) + self.assertEqual(listener.callCount(), 1) + + self.plot.getXAxis().setLimits(0, 1) + self.assertEqual(listener.callCount(), 2) + + self.plot.hide() + self.qapp.processEvents() + # No event here + self.assertEqual(listener.callCount(), 2) + + self.plot.getXAxis().setLimits(1, 2) + # No event since PlotWidget is hidden, delayed to PlotWidget show + self.assertEqual(listener.callCount(), 2) + + self.plot.show() + self.qapp.processEvents() + # Receives delayed event now + self.assertEqual(listener.callCount(), 3) + + image.setOrigin((-1, -1)) + self.assertEqual(listener.callCount(), 4) + + image.setVisible(False) + image.setOrigin((0, 0)) + # No event since item is not visible + self.assertEqual(listener.callCount(), 4) + + image.setVisible(True) + # Receives delayed event now + self.assertEqual(listener.callCount(), 5) + + +class TestImageDataAggregated(PlotWidgetTestCase): + """Test ImageDataAggregated item""" + + def test(self): + data = numpy.random.random(1024**2).reshape(1024, 1024) + + item = items.ImageDataAggregated() + item.setData(data) + self.assertEqual(item.getAggregationMode(), item.Aggregation.NONE) + self.plot.addItem(item) + + for mode in item.Aggregation.members(): + with self.subTest(mode=mode): + self.plot.resetZoom() + self.qapp.processEvents() + + item.setAggregationMode(mode) + self.qapp.processEvents() + + # Zoom-out + for i in range(4): + xmin, xmax = self.plot.getXAxis().getLimits() + ymin, ymax = self.plot.getYAxis().getLimits() + self.plot.setLimits( + xmin - (xmax - xmin)/2, + xmax + (xmax - xmin)/2, + ymin - (ymax - ymin)/2, + ymax + (ymax - ymin)/2, + ) + self.qapp.processEvents() diff --git a/src/silx/gui/plot/test/testLegendSelector.py b/src/silx/gui/plot/test/testLegendSelector.py new file mode 100644 index 0000000..c40875d --- /dev/null +++ b/src/silx/gui/plot/test/testLegendSelector.py @@ -0,0 +1,130 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWidget""" + +__authors__ = ["T. Rueter", "T. Vincent"] +__license__ = "MIT" +__date__ = "15/05/2017" + + +import logging +import unittest + +from silx.gui import qt +from silx.gui.utils.testutils import TestCaseQt +from silx.gui.plot import LegendSelector + + +_logger = logging.getLogger(__name__) + + +class TestLegendSelector(TestCaseQt): + """Basic test for LegendSelector""" + + def testLegendSelector(self): + """Test copied from __main__ of LegendSelector in PyMca""" + class Notifier(qt.QObject): + def __init__(self): + qt.QObject.__init__(self) + self.chk = True + + def signalReceived(self, **kw): + obj = self.sender() + _logger.info('NOTIFIER -- signal received\n\tsender: %s', + str(obj)) + + notifier = Notifier() + + legends = ['Legend0', + 'Legend1', + 'Long Legend 2', + 'Foo Legend 3', + 'Even Longer Legend 4', + 'Short Leg 5', + 'Dot symbol 6', + 'Comma symbol 7'] + colors = [qt.Qt.darkRed, qt.Qt.green, qt.Qt.yellow, qt.Qt.darkCyan, + qt.Qt.blue, qt.Qt.darkBlue, qt.Qt.red, qt.Qt.darkYellow] + symbols = ['o', 't', '+', 'x', 's', 'd', '.', ','] + + win = LegendSelector.LegendListView() + # win = LegendListContextMenu() + # win = qt.QWidget() + # layout = qt.QVBoxLayout() + # layout.setContentsMargins(0,0,0,0) + llist = [] + + for _idx, (l, c, s) in enumerate(zip(legends, colors, symbols)): + ddict = { + 'color': qt.QColor(c), + 'linewidth': 4, + 'symbol': s, + } + legend = l + llist.append((legend, ddict)) + # item = qt.QListWidgetItem(win) + # legendWidget = LegendListItemWidget(l) + # legendWidget.icon.setSymbol(s) + # legendWidget.icon.setColor(qt.QColor(c)) + # layout.addWidget(legendWidget) + # win.setItemWidget(item, legendWidget) + + # win = LegendListItemWidget('Some Legend 1') + # print(llist) + model = LegendSelector.LegendModel(legendList=llist) + win.setModel(model) + win.setSelectionModel(qt.QItemSelectionModel(model)) + win.setContextMenu() + # print('Edit triggers: %d'%win.editTriggers()) + + # win = LegendListWidget(None, legends) + # win[0].updateItem(ddict) + # win.setLayout(layout) + win.sigLegendSignal.connect(notifier.signalReceived) + win.show() + + win.clear() + win.setLegendList(llist) + + self.qWaitForWindowExposed(win) + + +class TestRenameCurveDialog(TestCaseQt): + """Basic test for RenameCurveDialog""" + + def testDialog(self): + """Create dialog, change name and press OK""" + self.dialog = LegendSelector.RenameCurveDialog( + None, 'curve1', ['curve1', 'curve2', 'curve3']) + self.dialog.open() + self.qWaitForWindowExposed(self.dialog) + self.keyClicks(self.dialog.lineEdit, 'changed') + self.mouseClick(self.dialog.okButton, qt.Qt.LeftButton) + self.qapp.processEvents() + ret = self.dialog.result() + self.assertEqual(ret, qt.QDialog.Accepted) + newName = self.dialog.getText() + self.assertEqual(newName, 'curve1changed') + del self.dialog diff --git a/src/silx/gui/plot/test/testLimitConstraints.py b/src/silx/gui/plot/test/testLimitConstraints.py new file mode 100644 index 0000000..0bd8e50 --- /dev/null +++ b/src/silx/gui/plot/test/testLimitConstraints.py @@ -0,0 +1,114 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test setLimitConstaints on the PlotWidget""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "30/08/2017" + + +import unittest +from silx.gui.plot import PlotWidget + + +class TestLimitConstaints(unittest.TestCase): + """Tests setLimitConstaints class""" + + def setUp(self): + self.plot = PlotWidget() + + def tearDown(self): + self.plot = None + + def testApi(self): + """Test availability of the API""" + self.plot.getXAxis().setLimitsConstraints(minPos=1, maxPos=10) + self.plot.getXAxis().setRangeConstraints(minRange=1, maxRange=1) + self.plot.getYAxis().setLimitsConstraints(minPos=1, maxPos=10) + self.plot.getYAxis().setRangeConstraints(minRange=1, maxRange=1) + + def testXMinMax(self): + """Test limit constains on x-axis""" + self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100) + self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101) + self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100)) + self.assertEqual(self.plot.getYAxis().getLimits(), (-1, 101)) + + def testYMinMax(self): + """Test limit constains on y-axis""" + self.plot.getYAxis().setLimitsConstraints(minPos=0, maxPos=100) + self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101) + self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101)) + self.assertEqual(self.plot.getYAxis().getLimits(), (0, 100)) + + def testMinXRange(self): + """Test min range constains on x-axis""" + self.plot.getXAxis().setRangeConstraints(minRange=100) + self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99) + limits = self.plot.getXAxis().getLimits() + self.assertEqual(limits[1] - limits[0], 100) + limits = self.plot.getYAxis().getLimits() + self.assertNotEqual(limits[1] - limits[0], 100) + + def testMaxXRange(self): + """Test max range constains on x-axis""" + self.plot.getXAxis().setRangeConstraints(maxRange=100) + self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101) + limits = self.plot.getXAxis().getLimits() + self.assertEqual(limits[1] - limits[0], 100) + limits = self.plot.getYAxis().getLimits() + self.assertNotEqual(limits[1] - limits[0], 100) + + def testMinYRange(self): + """Test min range constains on y-axis""" + self.plot.getYAxis().setRangeConstraints(minRange=100) + self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99) + limits = self.plot.getXAxis().getLimits() + self.assertNotEqual(limits[1] - limits[0], 100) + limits = self.plot.getYAxis().getLimits() + self.assertEqual(limits[1] - limits[0], 100) + + def testMaxYRange(self): + """Test max range constains on y-axis""" + self.plot.getYAxis().setRangeConstraints(maxRange=100) + self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101) + limits = self.plot.getXAxis().getLimits() + self.assertNotEqual(limits[1] - limits[0], 100) + limits = self.plot.getYAxis().getLimits() + self.assertEqual(limits[1] - limits[0], 100) + + def testChangeOfConstraints(self): + """Test changing of the constraints""" + self.plot.getXAxis().setRangeConstraints(minRange=10, maxRange=10) + # There is no more constraints on the range + self.plot.getXAxis().setRangeConstraints(minRange=None, maxRange=None) + self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101) + self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101)) + + def testSettingConstraints(self): + """Test setting a constaint (setLimits first then the constaint)""" + self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101) + self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100) + self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100)) diff --git a/src/silx/gui/plot/test/testMaskToolsWidget.py b/src/silx/gui/plot/test/testMaskToolsWidget.py new file mode 100644 index 0000000..522ca51 --- /dev/null +++ b/src/silx/gui/plot/test/testMaskToolsWidget.py @@ -0,0 +1,306 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for MaskToolsWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "17/01/2018" + + +import logging +import os.path +import unittest + +import numpy + +from silx.gui import qt +from silx.test.utils import temp_dir +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import getQToolButtonFromAction +from silx.gui.plot import PlotWindow, MaskToolsWidget +from .utils import PlotWidgetTestCase + +import fabio + + +_logger = logging.getLogger(__name__) + + +class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase): + """Basic test for MaskToolsWidget""" + + def _createPlot(self): + return PlotWindow() + + def setUp(self): + super(TestMaskToolsWidget, self).setUp() + self.widget = MaskToolsWidget.MaskToolsDockWidget(plot=self.plot, name='TEST') + self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget) + self.maskWidget = self.widget.widget() + + def tearDown(self): + del self.maskWidget + del self.widget + super(TestMaskToolsWidget, self).tearDown() + + def testEmptyPlot(self): + """Empty plot, display MaskToolsDockWidget, toggle multiple masks""" + self.maskWidget.setMultipleMasks('single') + self.qapp.processEvents() + + self.maskWidget.setMultipleMasks('exclusive') + self.qapp.processEvents() + + def _drag(self): + """Drag from plot center to offset position""" + plot = self.plot.getWidgetHandle() + xCenter, yCenter = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + pos0 = xCenter, yCenter + pos1 = xCenter + offset, yCenter + offset + + self.mouseMove(plot, pos=(0, 0)) + self.mouseMove(plot, pos=pos0) + self.qapp.processEvents() + self.mousePress(plot, qt.Qt.LeftButton, pos=pos0) + self.qapp.processEvents() + self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2)) + self.mouseMove(plot, pos=pos1) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1) + self.qapp.processEvents() + self.mouseMove(plot, pos=(0, 0)) + + def _drawPolygon(self): + """Draw a star polygon in the plot""" + plot = self.plot.getWidgetHandle() + x, y = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + star = [(x, y + offset), + (x - offset, y - offset), + (x + offset, y), + (x - offset, y), + (x + offset, y - offset), + (x, y + offset)] # Close polygon + + self.mouseMove(plot, pos=(0, 0)) + for pos in star: + self.mouseMove(plot, pos=pos) + self.qapp.processEvents() + self.mousePress(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + + def _drawPencil(self): + """Draw a star polygon in the plot""" + plot = self.plot.getWidgetHandle() + x, y = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + star = [(x, y + offset), + (x - offset, y - offset), + (x + offset, y), + (x - offset, y), + (x + offset, y - offset)] + + self.mouseMove(plot, pos=(0, 0)) + for start, end in zip(star[:-1], star[1:]): + self.mouseMove(plot, pos=start) + self.mousePress(plot, qt.Qt.LeftButton, pos=start) + self.qapp.processEvents() + self.mouseMove(plot, pos=end) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=end) + self.qapp.processEvents() + + def _isMaskItemSync(self): + """Check if masks from item and tools are sync or not""" + if self.maskWidget.isItemMaskUpdated(): + return numpy.all(numpy.equal( + self.maskWidget.getSelectionMask(), + self.plot.getActiveImage().getMaskData(copy=False))) + else: + return True + + def testWithAnImage(self): + """Plot with an image: test MaskToolsWidget interactions""" + + # Add and remove a image (this should enable/disable GUI + change mask) + self.plot.addImage(numpy.random.random(1024**2).reshape(1024, 1024), + legend='test') + self.qapp.processEvents() + + self.plot.remove('test', kind='image') + self.qapp.processEvents() + + tests = [((0, 0), (1, 1)), + ((1000, 1000), (1, 1)), + ((0, 0), (-1, -1)), + ((1000, 1000), (-1, -1))] + + for itemMaskUpdated in (False, True): + for origin, scale in tests: + with self.subTest(origin=origin, scale=scale): + self.maskWidget.setItemMaskUpdated(itemMaskUpdated) + self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024), + legend='test', + origin=origin, + scale=scale) + self.qapp.processEvents() + + self.assertEqual( + self.maskWidget.isItemMaskUpdated(), itemMaskUpdated) + + # Test draw rectangle # + toolButton = getQToolButtonFromAction(self.maskWidget.rectAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drag() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + self.assertTrue(self._isMaskItemSync()) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drag() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + self.assertTrue(self._isMaskItemSync()) + + # Test draw polygon # + toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drawPolygon() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + self.assertTrue(self._isMaskItemSync()) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drawPolygon() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + self.assertTrue(self._isMaskItemSync()) + + # Test draw pencil # + toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + self.maskWidget.pencilSpinBox.setValue(30) + self.qapp.processEvents() + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drawPencil() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + self.assertTrue(self._isMaskItemSync()) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drawPencil() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + self.assertTrue(self._isMaskItemSync()) + + # Test no draw tool # + toolButton = getQToolButtonFromAction(self.maskWidget.browseAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + self.plot.clear() + + def __loadSave(self, file_format): + """Plot with an image: test MaskToolsWidget operations""" + self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024), + legend='test') + self.qapp.processEvents() + + # Draw a polygon mask + toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self._drawPolygon() + + ref_mask = self.maskWidget.getSelectionMask() + self.assertFalse(numpy.all(numpy.equal(ref_mask, 0))) + + with temp_dir() as tmp: + mask_filename = os.path.join(tmp, 'mask.' + file_format) + self.maskWidget.save(mask_filename, file_format) + + self.maskWidget.resetSelectionMask() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + self.maskWidget.load(mask_filename) + self.assertTrue(numpy.all(numpy.equal( + self.maskWidget.getSelectionMask(), ref_mask))) + + def testLoadSaveNpy(self): + self.__loadSave("npy") + + def testLoadSaveFit2D(self): + self.__loadSave("msk") + + def testSigMaskChangedEmitted(self): + self.plot.addImage(numpy.arange(512**2).reshape(512, 512), + legend='test') + self.plot.resetZoom() + self.qapp.processEvents() + + l = [] + + def slot(): + l.append(1) + + self.maskWidget.sigMaskChanged.connect(slot) + + # rectangle mask + toolButton = getQToolButtonFromAction(self.maskWidget.rectAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drag() + + self.assertGreater(len(l), 0) diff --git a/src/silx/gui/plot/test/testPixelIntensityHistoAction.py b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py new file mode 100644 index 0000000..14a467d --- /dev/null +++ b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py @@ -0,0 +1,145 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PixelIntensitiesHistoAction""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/03/2018" + + +import numpy +import unittest + +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction +from silx.gui import qt +from silx.gui.plot import Plot2D + + +class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase): + """Tests for PixelIntensitiesHistoAction widget.""" + + def setUp(self): + super(TestPixelIntensitiesHisto, self).setUp() + self.image = numpy.random.rand(10, 10) + self.plotImage = Plot2D() + self.plotImage.getIntensityHistogramAction().setVisible(True) + + def tearDown(self): + del self.plotImage + super(TestPixelIntensitiesHisto, self).tearDown() + + def testShowAndHide(self): + """Simple test that the plot is showing and hiding when activating the + action""" + self.plotImage.addImage(self.image, origin=(0, 0), legend='sino') + self.plotImage.show() + + histoAction = self.plotImage.getIntensityHistogramAction() + + # test the pixel intensity diagram is showing + button = getQToolButtonFromAction(histoAction) + self.assertIsNot(button, None) + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + self.assertTrue(histoAction.getHistogramWidget().isVisible()) + + # test the pixel intensity diagram is hiding + self.qapp.setActiveWindow(self.plotImage) + self.qapp.processEvents() + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + self.assertFalse(histoAction.getHistogramWidget().isVisible()) + + def testImageFormatInput(self): + """Test multiple type as image input""" + typesToTest = [numpy.uint8, numpy.int8, numpy.int16, numpy.int32, + numpy.float32, numpy.float64] + self.plotImage.addImage(self.image, origin=(0, 0), legend='sino') + self.plotImage.show() + button = getQToolButtonFromAction( + self.plotImage.getIntensityHistogramAction()) + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + for typeToTest in typesToTest: + with self.subTest(typeToTest=typeToTest): + self.plotImage.addImage(self.image.astype(typeToTest), + origin=(0, 0), legend='sino') + + def testScatter(self): + """Test that an histogram from a scatter is displayed""" + xx = numpy.arange(10) + yy = numpy.arange(10) + value = numpy.sin(xx) + self.plotImage.addScatter(xx, yy, value) + self.plotImage.show() + + histoAction = self.plotImage.getIntensityHistogramAction() + + # test the pixel intensity diagram is showing + button = getQToolButtonFromAction(histoAction) + self.assertIsNot(button, None) + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + + widget = histoAction.getHistogramWidget() + self.assertTrue(widget.isVisible()) + items = widget.getPlotWidget().getItems() + self.assertEqual(len(items), 1) + + def testChangeItem(self): + """Test that histogram changes it the item changes""" + xx = numpy.arange(10) + yy = numpy.arange(10) + value = numpy.sin(xx) + self.plotImage.addScatter(xx, yy, value) + self.plotImage.show() + + histoAction = self.plotImage.getIntensityHistogramAction() + + # test the pixel intensity diagram is showing + button = getQToolButtonFromAction(histoAction) + self.assertIsNot(button, None) + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + + # Reach histogram from the first item + widget = histoAction.getHistogramWidget() + self.assertTrue(widget.isVisible()) + items = widget.getPlotWidget().getItems() + data1 = items[0].getValueData(copy=False) + + # Set another item to the plot + self.plotImage.addImage(self.image, origin=(0, 0), legend='sino') + self.qapp.processEvents() + data2 = items[0].getValueData(copy=False) + + # Histogram is not the same + self.assertFalse(numpy.array_equal(data1, data2)) diff --git a/src/silx/gui/plot/test/testPlotActions.py b/src/silx/gui/plot/test/testPlotActions.py new file mode 100644 index 0000000..f38e05b --- /dev/null +++ b/src/silx/gui/plot/test/testPlotActions.py @@ -0,0 +1,110 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test of actions integrated in the plot window""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "09/11/2018" + + +import pytest +import weakref + +from silx.gui import qt +from silx.gui.colors import Colormap +from silx.gui.plot.PlotWindow import PlotWindow + +import numpy + + +@pytest.fixture +def colormap1(): + colormap = Colormap(name='gray', + vmin=10.0, vmax=20.0, + normalization='linear') + yield colormap + + +@pytest.fixture +def colormap2(): + colormap = Colormap(name='red', + vmin=10.0, vmax=20.0, + normalization='linear') + yield colormap + + +@pytest.fixture +def plot(qapp): + plot = PlotWindow() + plot.setAttribute(qt.Qt.WA_DeleteOnClose) + yield weakref.proxy(plot) + plot.close() + qapp.processEvents() + + +def test_action_active_colormap(qapp_utils, plot, colormap1, colormap2): + plot.getColormapAction()._actionTriggered(checked=True) + colormapDialog = plot.getColormapAction()._dialog + + defaultColormap = plot.getDefaultColormap() + assert colormapDialog.getColormap() is defaultColormap + + plot.addImage(data=numpy.random.rand(10, 10), legend='img1', + origin=(0, 0), + colormap=colormap1) + plot.setActiveImage('img1') + assert colormapDialog.getColormap() is colormap1 + + plot.addImage(data=numpy.random.rand(10, 10), legend='img2', + origin=(0, 0), colormap=colormap2) + plot.addImage(data=numpy.random.rand(10, 10), legend='img3', + origin=(0, 0)) + + plot.setActiveImage('img3') + assert colormapDialog.getColormap() is defaultColormap + plot.getActiveImage().setColormap(colormap2) + assert colormapDialog.getColormap() is colormap2 + + plot.remove('img2') + plot.remove('img3') + plot.remove('img1') + assert colormapDialog.getColormap() is defaultColormap + + +def test_action_show_hide_colormap_dialog(qapp_utils, plot, colormap1): + plot.getColormapAction()._actionTriggered(checked=True) + colormapDialog = plot.getColormapAction()._dialog + + plot.getColormapAction()._actionTriggered(checked=False) + assert not plot.getColormapAction().isChecked() + plot.getColormapAction()._actionTriggered(checked=True) + assert plot.getColormapAction().isChecked() + plot.addImage(data=numpy.random.rand(10, 10), legend='img1', + origin=(0, 0), colormap=colormap1) + colormap1.setName('red') + plot.getColormapAction()._actionTriggered() + colormap1.setName('blue') + colormapDialog.close() + assert not plot.getColormapAction().isChecked() diff --git a/src/silx/gui/plot/test/testPlotInteraction.py b/src/silx/gui/plot/test/testPlotInteraction.py new file mode 100644 index 0000000..fba364e --- /dev/null +++ b/src/silx/gui/plot/test/testPlotInteraction.py @@ -0,0 +1,160 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016=2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests of plot interaction, through a PlotWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "01/09/2017" + + +import unittest +from silx.gui import qt +from .utils import PlotWidgetTestCase + + +class _SignalDump(object): + """Callable object that store passed arguments in a list""" + + def __init__(self): + self._received = [] + + def __call__(self, *args): + self._received.append(args) + + @property + def received(self): + """Return a shallow copy of the list of received arguments""" + return list(self._received) + + +class TestSelectPolygon(PlotWidgetTestCase): + """Test polygon selection interaction""" + + def _interactionModeChanged(self, source): + """Check that source received in event is the correct one""" + self.assertEqual(source, self) + + def _draw(self, polygon): + """Draw a polygon in the plot + + :param polygon: List of points (x, y) of the polygon (closed) + """ + plot = self.plot.getWidgetHandle() + + dump = _SignalDump() + self.plot.sigPlotSignal.connect(dump) + + for pos in polygon: + self.mouseMove(plot, pos=pos) + self.qapp.processEvents() + self.mousePress(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + + self.plot.sigPlotSignal.disconnect(dump) + return [args[0] for args in dump.received] + + def test(self): + """Test draw polygons + events""" + self.plot.sigInteractiveModeChanged.connect( + self._interactionModeChanged) + + self.plot.setInteractiveMode( + 'draw', shape='polygon', label='test', source=self) + interaction = self.plot.getInteractiveMode() + + self.assertEqual(interaction['mode'], 'draw') + self.assertEqual(interaction['shape'], 'polygon') + + self.plot.sigInteractiveModeChanged.disconnect( + self._interactionModeChanged) + + plot = self.plot.getWidgetHandle() + xCenter, yCenter = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + # Star polygon + star = [(xCenter, yCenter + offset), + (xCenter - offset, yCenter - offset), + (xCenter + offset, yCenter), + (xCenter - offset, yCenter), + (xCenter + offset, yCenter - offset), + (xCenter, yCenter + offset)] # Close polygon + + # Draw while dumping signals + events = self._draw(star) + + # Test last event + drawEvents = [event for event in events + if event['event'].startswith('drawing')] + self.assertEqual(drawEvents[-1]['event'], 'drawingFinished') + self.assertEqual(len(drawEvents[-1]['points']), 6) + + # Large square + largeSquare = [(xCenter - offset, yCenter - offset), + (xCenter + offset, yCenter - offset), + (xCenter + offset, yCenter + offset), + (xCenter - offset, yCenter + offset), + (xCenter - offset, yCenter - offset)] # Close polygon + + # Draw while dumping signals + events = self._draw(largeSquare) + + # Test last event + drawEvents = [event for event in events + if event['event'].startswith('drawing')] + self.assertEqual(drawEvents[-1]['event'], 'drawingFinished') + self.assertEqual(len(drawEvents[-1]['points']), 5) + + # Rectangle too thin along X: Some points are ignored + thinRectX = [(xCenter, yCenter - offset), + (xCenter, yCenter + offset), + (xCenter + 1, yCenter + offset), + (xCenter + 1, yCenter - offset)] # Close polygon + + # Draw while dumping signals + events = self._draw(thinRectX) + + # Test last event + drawEvents = [event for event in events + if event['event'].startswith('drawing')] + self.assertEqual(drawEvents[-1]['event'], 'drawingFinished') + self.assertEqual(len(drawEvents[-1]['points']), 3) + + # Rectangle too thin along Y: Some points are ignored + thinRectY = [(xCenter - offset, yCenter), + (xCenter + offset, yCenter), + (xCenter + offset, yCenter + 1), + (xCenter - offset, yCenter + 1)] # Close polygon + + # Draw while dumping signals + events = self._draw(thinRectY) + + # Test last event + drawEvents = [event for event in events + if event['event'].startswith('drawing')] + self.assertEqual(drawEvents[-1]['event'], 'drawingFinished') + self.assertEqual(len(drawEvents[-1]['points']), 3) diff --git a/src/silx/gui/plot/test/testPlotWidget.py b/src/silx/gui/plot/test/testPlotWidget.py new file mode 100755 index 0000000..f6e108d --- /dev/null +++ b/src/silx/gui/plot/test/testPlotWidget.py @@ -0,0 +1,2113 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/01/2019" + + +import unittest +import logging +import numpy +import pytest + +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import SignalListener +from silx.gui.utils.testutils import TestCaseQt + +from silx.gui import qt +from silx.gui.plot import PlotWidget +from silx.gui.plot.items.curve import CurveStyle +from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent, Axis +from silx.gui.colors import Colormap + +from .utils import PlotWidgetTestCase + + +SIZE = 1024 +"""Size of the test image""" + +DATA_2D = numpy.arange(SIZE ** 2).reshape(SIZE, SIZE) +"""Image data set""" + + +logger = logging.getLogger(__name__) + + +class TestSpecialBackend(PlotWidgetTestCase, ParametricTestCase): + + def __init__(self, methodName='runTest', backend=None): + TestCaseQt.__init__(self, methodName=methodName) + self.__backend = backend + + def _createPlot(self): + return PlotWidget(backend=self.__backend) + + def testPlot(self): + self.assertIsNotNone(self.plot) + + +class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase): + """Basic tests for PlotWidget""" + + def testShow(self): + """Most basic test""" + pass + + def testSetTitleLabels(self): + """Set title and axes labels""" + + title, xlabel, ylabel = 'the title', 'x label', 'y label' + self.plot.setGraphTitle(title) + self.plot.getXAxis().setLabel(xlabel) + self.plot.getYAxis().setLabel(ylabel) + self.qapp.processEvents() + + self.assertEqual(self.plot.getGraphTitle(), title) + self.assertEqual(self.plot.getXAxis().getLabel(), xlabel) + self.assertEqual(self.plot.getYAxis().getLabel(), ylabel) + + def _checkLimits(self, + expectedXLim=None, + expectedYLim=None, + expectedRatio=None): + """Assert that limits are as expected""" + xlim = self.plot.getXAxis().getLimits() + ylim = self.plot.getYAxis().getLimits() + ratio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0]) + + if expectedXLim is not None: + self.assertEqual(expectedXLim, xlim) + + if expectedYLim is not None: + self.assertEqual(expectedYLim, ylim) + + if expectedRatio is not None: + self.assertTrue( + numpy.allclose(expectedRatio, ratio, atol=0.01)) + + def testChangeLimitsWithAspectRatio(self): + self.plot.setKeepDataAspectRatio() + self.qapp.processEvents() + xlim = self.plot.getXAxis().getLimits() + ylim = self.plot.getYAxis().getLimits() + defaultRatio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0]) + + self.plot.getXAxis().setLimits(1., 10.) + self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio) + self.qapp.processEvents() + self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio) + + self.plot.getYAxis().setLimits(1., 10.) + self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio) + self.qapp.processEvents() + self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio) + + def testResizeWidget(self): + """Test resizing the widget and receiving limitsChanged events""" + self.plot.resize(200, 200) + self.qapp.processEvents() + self.qWait(100) + + xlim = self.plot.getXAxis().getLimits() + ylim = self.plot.getYAxis().getLimits() + + listener = SignalListener() + self.plot.getXAxis().sigLimitsChanged.connect(listener.partial('x')) + self.plot.getYAxis().sigLimitsChanged.connect(listener.partial('y')) + + # Resize without aspect ratio + self.plot.resize(200, 300) + self.qapp.processEvents() + self.qWait(100) + self._checkLimits(expectedXLim=xlim, expectedYLim=ylim) + self.assertEqual(listener.callCount(), 0) + + # Resize with aspect ratio + self.plot.setKeepDataAspectRatio(True) + self.qapp.processEvents() + self.qWait(1000) + listener.clear() # Clean-up received signal + + self.plot.resize(200, 200) + self.qapp.processEvents() + self.qWait(100) + self.assertNotEqual(listener.callCount(), 0) + + def testAddRemoveItemSignals(self): + """Test sigItemAdded and sigItemAboutToBeRemoved""" + listener = SignalListener() + self.plot.sigItemAdded.connect(listener.partial('add')) + self.plot.sigItemAboutToBeRemoved.connect(listener.partial('remove')) + + self.plot.addCurve((1, 2, 3), (3, 2, 1), legend='curve') + self.assertEqual(listener.callCount(), 1) + + curve = self.plot.getCurve('curve') + self.plot.remove('curve') + self.assertEqual(listener.callCount(), 2) + self.assertEqual(listener.arguments(callIndex=0), ('add', curve)) + self.assertEqual(listener.arguments(callIndex=1), ('remove', curve)) + + def testGetItems(self): + """Test getItems method""" + curve_x = 1, 2 + self.plot.addCurve(curve_x, (3, 4)) + image = (0, 1), (2, 3) + self.plot.addImage(image) + scatter_x = 10, 11 + self.plot.addScatter(scatter_x, (12, 13), (0, 1)) + marker_pos = 5, 5 + self.plot.addMarker(*marker_pos) + marker_x = 6 + self.plot.addXMarker(marker_x) + self.plot.addShape((0, 5), (2, 10), shape='rectangle') + + items = self.plot.getItems() + self.assertEqual(len(items), 6) + self.assertTrue(numpy.all(numpy.equal(items[0].getXData(), curve_x))) + self.assertTrue(numpy.all(numpy.equal(items[1].getData(), image))) + self.assertTrue(numpy.all(numpy.equal(items[2].getXData(), scatter_x))) + self.assertTrue(numpy.all(numpy.equal(items[3].getPosition(), marker_pos))) + self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x))) + self.assertEqual(items[5].getType(), 'rectangle') + + def testRemoveDiscardItem(self): + """Test removeItem and discardItem""" + self.plot.addCurve((1, 2, 3), (1, 2, 3)) + curve = self.plot.getItems()[0] + self.plot.removeItem(curve) + with self.assertRaises(ValueError): + self.plot.removeItem(curve) + + self.plot.addCurve((1, 2, 3), (1, 2, 3)) + curve = self.plot.getItems()[0] + result = self.plot.discardItem(curve) + self.assertTrue(result) + result = self.plot.discardItem(curve) + self.assertFalse(result) + + def testBackGroundColors(self): + self.plot.setVisible(True) + self.qWaitForWindowExposed(self.plot) + self.qapp.processEvents() + + # Custom the full background + color = self.plot.getBackgroundColor() + self.assertTrue(color.isValid()) + self.assertEqual(color, qt.QColor(255, 255, 255)) + self.plot.setBackgroundColor("red") + color = self.plot.getBackgroundColor() + self.assertTrue(color.isValid()) + self.qapp.processEvents() + + # Custom the data background + color = self.plot.getDataBackgroundColor() + self.assertFalse(color.isValid()) + self.plot.setDataBackgroundColor("red") + color = self.plot.getDataBackgroundColor() + self.assertTrue(color.isValid()) + self.qapp.processEvents() + + # Back to default + self.plot.setBackgroundColor('white') + self.plot.setDataBackgroundColor(None) + color = self.plot.getBackgroundColor() + self.assertTrue(color.isValid()) + self.assertEqual(color, qt.QColor(255, 255, 255)) + color = self.plot.getDataBackgroundColor() + self.assertFalse(color.isValid()) + self.qapp.processEvents() + + +class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): + """Basic tests for addImage""" + + def setUp(self): + super(TestPlotImage, self).setUp() + + self.plot.getYAxis().setLabel('Rows') + self.plot.getXAxis().setLabel('Columns') + + def testPlotColormapTemperature(self): + self.plot.setGraphTitle('Temp. Linear') + + colormap = Colormap(name='temperature', + normalization='linear', + vmin=None, + vmax=None) + self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap) + + def testPlotColormapGray(self): + self.plot.setKeepDataAspectRatio(False) + self.plot.setGraphTitle('Gray Linear') + + colormap = Colormap(name='gray', + normalization='linear', + vmin=None, + vmax=None) + self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap) + + def testPlotColormapTemperatureLog(self): + self.plot.setGraphTitle('Temp. Log') + + colormap = Colormap(name='temperature', + normalization=Colormap.LOGARITHM, + vmin=None, + vmax=None) + self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap) + + def testPlotRgbRgba(self): + self.plot.setKeepDataAspectRatio(False) + self.plot.setGraphTitle('RGB + RGBA') + + rgb = numpy.array( + (((0, 0, 0), (128, 0, 0), (255, 0, 0)), + ((0, 128, 0), (0, 128, 128), (0, 128, 255))), + dtype=numpy.uint8) + + self.plot.addImage(rgb, legend="rgb_uint8", + origin=(0, 0), scale=(1, 1), + resetzoom=False) + + rgb = numpy.array( + (((0, 0, 0), (32768, 0, 0), (65535, 0, 0)), + ((0, 32768, 0), (0, 32768, 32768), (0, 32768, 65535))), + dtype=numpy.uint16) + + self.plot.addImage(rgb, legend="rgb_uint16", + origin=(3, 2), scale=(2, 2), + resetzoom=False) + + rgba = numpy.array( + (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)), + ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))), + dtype=numpy.float32) + + self.plot.addImage(rgba, legend="rgba_float32", + origin=(9, 6), scale=(1, 1), + resetzoom=False) + + self.plot.resetZoom() + + def testPlotColormapCustom(self): + self.plot.setKeepDataAspectRatio(False) + self.plot.setGraphTitle('Custom colormap') + + colormap = Colormap(name=None, + normalization=Colormap.LINEAR, + vmin=None, + vmax=None, + colors=((0., 0., 0.), (1., 0., 0.), + (0., 1., 0.), (0., 0., 1.))) + self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap, + resetzoom=False) + + colormap = Colormap(name=None, + normalization=Colormap.LINEAR, + vmin=None, + vmax=None, + colors=numpy.array( + ((0, 0, 0, 0), (0, 0, 0, 128), + (128, 128, 128, 128), (255, 255, 255, 255)), + dtype=numpy.uint8)) + self.plot.addImage(DATA_2D, legend="image 2", colormap=colormap, + origin=(DATA_2D.shape[0], 0), + resetzoom=False) + self.plot.resetZoom() + + def testPlotColormapNaNColor(self): + self.plot.setKeepDataAspectRatio(False) + self.plot.setGraphTitle('Colormap with NaN color') + + colormap = Colormap() + colormap.setNaNColor('red') + self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0)) + data = DATA_2D.astype(numpy.float32) + data[len(data)//2:] = numpy.nan + self.plot.addImage(data, legend="image 1", colormap=colormap, + resetzoom=False) + self.plot.resetZoom() + + colormap.setNaNColor((0., 1., 0., 1.)) + self.assertEqual(colormap.getNaNColor(), qt.QColor(0, 255, 0)) + self.qapp.processEvents() + + def testImageOriginScale(self): + """Test of image with different origin and scale""" + self.plot.setGraphTitle('origin and scale') + + tests = [ # (origin, scale) + ((10, 20), (1, 1)), + ((10, 20), (-1, -1)), + ((-10, 20), (2, 1)), + ((10, -20), (-1, -2)), + (100, 2), + (-100, (1, 1)), + ((10, 20), 2), + ] + + for origin, scale in tests: + with self.subTest(origin=origin, scale=scale): + self.plot.addImage(DATA_2D, origin=origin, scale=scale) + + try: + ox, oy = origin + except TypeError: + ox, oy = origin, origin + try: + sx, sy = scale + except TypeError: + sx, sy = scale, scale + xbounds = ox, ox + DATA_2D.shape[1] * sx + ybounds = oy, oy + DATA_2D.shape[0] * sy + + # Check limits without aspect ratio + xmin, xmax = self.plot.getXAxis().getLimits() + ymin, ymax = self.plot.getYAxis().getLimits() + self.assertEqual(xmin, min(xbounds)) + self.assertEqual(xmax, max(xbounds)) + self.assertEqual(ymin, min(ybounds)) + self.assertEqual(ymax, max(ybounds)) + + # Check limits with aspect ratio + self.plot.setKeepDataAspectRatio(True) + xmin, xmax = self.plot.getXAxis().getLimits() + ymin, ymax = self.plot.getYAxis().getLimits() + self.assertTrue(round(xmin, 7) <= min(xbounds)) + self.assertTrue(round(xmax, 7) >= max(xbounds)) + self.assertTrue(round(ymin, 7) <= min(ybounds)) + self.assertTrue(round(ymax, 7) >= max(ybounds)) + + self.plot.setKeepDataAspectRatio(False) # Reset aspect ratio + self.plot.clear() + self.plot.resetZoom() + + def testPlotColormapDictAPI(self): + """Test that the addImage API using a colormap dictionary is still + working""" + self.plot.setGraphTitle('Temp. Log') + + colormap = { + 'name': 'temperature', + 'normalization': 'log', + 'vmin': None, + 'vmax': None + } + self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap) + + def testPlotComplexImage(self): + """Test that a complex image is displayed as its absolute value.""" + data = numpy.linspace(1, 1j, 100).reshape(10, 10) + self.plot.addImage(data, legend='complex') + + image = self.plot.getActiveImage() + retrievedData = image.getData(copy=False) + self.assertTrue( + numpy.all(numpy.equal(retrievedData, numpy.absolute(data)))) + + def testPlotBooleanImage(self): + """Test that a boolean image is displayed and converted to int8.""" + data = numpy.zeros((10, 10), dtype=bool) + data[::2, ::2] = True + self.plot.addImage(data, legend='boolean') + + image = self.plot.getActiveImage() + retrievedData = image.getData(copy=False) + self.assertTrue(numpy.all(numpy.equal(retrievedData, data))) + self.assertIs(retrievedData.dtype.type, numpy.int8) + + def testPlotAlphaImage(self): + """Test with an alpha image layer""" + data = numpy.random.random((10, 10)) + alpha = numpy.linspace(0, 1, 100).reshape(10, 10) + self.plot.addImage(data, legend='image') + image = self.plot.getActiveImage() + image.setData(data, alpha=alpha) + self.qapp.processEvents() + self.assertTrue(numpy.array_equal(alpha, image.getAlphaData())) + + +class TestPlotCurve(PlotWidgetTestCase): + """Basic tests for addCurve.""" + + # Test data sets + xData = numpy.arange(1000) + yData = -500 + 100 * numpy.sin(xData) + xData2 = xData + 1000 + yData2 = xData - 1000 + 200 * numpy.random.random(1000) + + def setUp(self): + super(TestPlotCurve, self).setUp() + self.plot.setGraphTitle('Curve') + self.plot.getYAxis().setLabel('Rows') + self.plot.getXAxis().setLabel('Columns') + + self.plot.setActiveCurveHandling(False) + + def testPlotCurveInfinite(self): + """Test plot curves with not finite data""" + tests = { + 'y all not finite': ([0, 1, 2], [numpy.inf, numpy.nan, -numpy.inf]), + 'x all not finite': ([numpy.inf, numpy.nan, -numpy.inf], [0, 1, 2]), + 'x some inf': ([0, numpy.inf, 2], [0, 1, 2]), + 'y some inf': ([0, 1, 2], [0, numpy.inf, 2]) + } + for name, args in tests.items(): + with self.subTest(name): + self.plot.addCurve(*args) + self.plot.resetZoom() + self.qapp.processEvents() + self.plot.clear() + + def testPlotCurveColorFloat(self): + color = numpy.array(numpy.random.random(3 * 1000), + dtype=numpy.float32).reshape(1000, 3) + + self.plot.addCurve(self.xData, self.yData, + legend="curve 1", + replace=False, resetzoom=False, + color=color, + linestyle="", symbol="s") + self.plot.addCurve(self.xData2, self.yData2, + legend="curve 2", + replace=False, resetzoom=False, + color='green', linestyle="-", symbol='o') + self.plot.resetZoom() + + def testPlotCurveColorByte(self): + color = numpy.array(255 * numpy.random.random(3 * 1000), + dtype=numpy.uint8).reshape(1000, 3) + + self.plot.addCurve(self.xData, self.yData, + legend="curve 1", + replace=False, resetzoom=False, + color=color, + linestyle="", symbol="s") + self.plot.addCurve(self.xData2, self.yData2, + legend="curve 2", + replace=False, resetzoom=False, + color='green', linestyle="-", symbol='o') + self.plot.resetZoom() + + def testPlotCurveColors(self): + color = numpy.array(numpy.random.random(3 * 1000), + dtype=numpy.float32).reshape(1000, 3) + + self.plot.addCurve(self.xData, self.yData, + legend="curve 2", + replace=False, resetzoom=False, + color=color, linestyle="-", symbol='o') + self.plot.resetZoom() + + # Test updating color array + + # From array to array + newColors = numpy.ones((len(self.xData), 3), dtype=numpy.float32) + self.plot.addCurve(self.xData, self.yData, + legend="curve 2", + replace=False, resetzoom=False, + color=newColors, symbol='o') + + # Array to single color + self.plot.addCurve(self.xData, self.yData, + legend="curve 2", + replace=False, resetzoom=False, + color='green', symbol='o') + + # single color to array + self.plot.addCurve(self.xData, self.yData, + legend="curve 2", + replace=False, resetzoom=False, + color=color, symbol='o') + + def testPlotBaselineNumpyArray(self): + """simple test of the API with baseline as a numpy array""" + x = numpy.arange(0, 10, step=0.1) + my_sin = numpy.sin(x) + y = numpy.arange(-4, 6, step=0.1) + my_sin + baseline = y - 1.0 + + self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True, + baseline=baseline) + + def testPlotBaselineScalar(self): + """simple test of the API with baseline as an int""" + x = numpy.arange(0, 10, step=0.1) + my_sin = numpy.sin(x) + y = numpy.arange(-4, 6, step=0.1) + my_sin + + self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True, + baseline=0) + + def testPlotBaselineList(self): + """simple test of the API with baseline as an int""" + x = numpy.arange(0, 10, step=0.1) + my_sin = numpy.sin(x) + y = numpy.arange(-4, 6, step=0.1) + my_sin + + self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True, + baseline=list(range(0, 100, 1))) + + def testPlotCurveComplexData(self): + """Test curve with complex data""" + data = numpy.arange(100.) + 1j + self.plot.addCurve(x=data, y=data, xerror=data, yerror=data) + + +class TestPlotHistogram(PlotWidgetTestCase): + """Basic tests for add Histogram""" + def setUp(self): + super(TestPlotHistogram, self).setUp() + self.edges = numpy.arange(0, 10, step=1) + self.histogram = numpy.random.random(len(self.edges)) + + def testPlot(self): + self.plot.addHistogram(histogram=self.histogram, + edges=self.edges, + legend='histogram1') + + def testPlotBaseline(self): + self.plot.addHistogram(histogram=self.histogram, + edges=self.edges, + legend='histogram1', + color='blue', + baseline=-2, + z=2, + fill=True) + + +class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase): + """Basic tests for addScatter""" + + def testScatter(self): + x = numpy.arange(100) + y = numpy.arange(100) + value = numpy.arange(100) + self.plot.addScatter(x, y, value) + self.plot.resetZoom() + + def testScatterComplexData(self): + """Test scatter item with complex data""" + data = numpy.arange(100.) + 1j + self.plot.addScatter( + x=data, y=data, value=data, xerror=data, yerror=data) + self.plot.resetZoom() + + def testScatterVisualization(self): + self.plot.addScatter((0, 1, 0, 1), (0, 0, 2, 2), (0, 1, 2, 3)) + self.plot.resetZoom() + self.qapp.processEvents() + + scatter = self.plot.getItems()[0] + + for visualization in ('solid', + 'points', + 'regular_grid', + 'irregular_grid', + 'binned_statistic', + scatter.Visualization.SOLID, + scatter.Visualization.POINTS, + scatter.Visualization.REGULAR_GRID, + scatter.Visualization.IRREGULAR_GRID, + scatter.Visualization.BINNED_STATISTIC): + with self.subTest(visualization=visualization): + scatter.setVisualization(visualization) + self.qapp.processEvents() + + def testGridVisualization(self): + """Test regular and irregular grid mode with different points""" + points = { # name: (x, y, order) + 'single point': ((1.,), (1.,), 'row'), + 'horizontal line': ((0, 1, 2), (0, 0, 0), 'row'), + 'horizontal line backward': ((2, 1, 0), (0, 0, 0), 'row'), + 'vertical line': ((0, 0, 0), (0, 1, 2), 'row'), + 'vertical line backward': ((0, 0, 0), (2, 1, 0), 'row'), + 'grid fast x, +x +y': ((0, 1, 2, 0, 1, 2), (0, 0, 0, 1, 1, 1), 'row'), + 'grid fast x, +x -y': ((0, 1, 2, 0, 1, 2), (1, 1, 1, 0, 0, 0), 'row'), + 'grid fast x, -x -y': ((2, 1, 0, 2, 1, 0), (1, 1, 1, 0, 0, 0), 'row'), + 'grid fast x, -x +y': ((2, 1, 0, 2, 1, 0), (0, 0, 0, 1, 1, 1), 'row'), + 'grid fast y, +x +y': ((0, 0, 0, 1, 1, 1), (0, 1, 2, 0, 1, 2), 'column'), + 'grid fast y, +x -y': ((0, 0, 0, 1, 1, 1), (2, 1, 0, 2, 1, 0), 'column'), + 'grid fast y, -x -y': ((1, 1, 1, 0, 0, 0), (2, 1, 0, 2, 1, 0), 'column'), + 'grid fast y, -x +y': ((1, 1, 1, 0, 0, 0), (0, 1, 2, 0, 1, 2), 'column'), + } + + self.plot.addScatter((), (), ()) + scatter = self.plot.getItems()[0] + + self.qapp.processEvents() + + for visualization in (scatter.Visualization.REGULAR_GRID, + scatter.Visualization.IRREGULAR_GRID): + scatter.setVisualization(visualization) + self.assertIs(scatter.getVisualization(), visualization) + + for name, (x, y, ref_order) in points.items(): + with self.subTest(name=name, visualization=visualization.name): + scatter.setData(x, y, numpy.arange(len(x))) + self.plot.setGraphTitle(name) + self.plot.resetZoom() + self.qapp.processEvents() + + order = scatter.getCurrentVisualizationParameter( + scatter.VisualizationParameter.GRID_MAJOR_ORDER) + self.assertEqual(ref_order, order) + + ref_bounds = (x[0], y[0]), (x[-1], y[-1]) + bounds = scatter.getCurrentVisualizationParameter( + scatter.VisualizationParameter.GRID_BOUNDS) + self.assertEqual(ref_bounds, bounds) + + shape = scatter.getCurrentVisualizationParameter( + scatter.VisualizationParameter.GRID_SHAPE) + + self.plot.getXAxis().setLimits(numpy.min(x) - 1, numpy.max(x) + 1) + self.plot.getYAxis().setLimits(numpy.min(y) - 1, numpy.max(y) + 1) + self.qapp.processEvents() + + for index, position in enumerate(zip(x, y)): + xpixel, ypixel = self.plot.dataToPixel(*position) + result = scatter.pick(xpixel, ypixel) + self.assertIsNotNone(result) + self.assertIs(result.getItem(), scatter) + self.assertEqual(result.getIndices(), (index,)) + + def testBinnedStatisticVisualization(self): + """Test binned display""" + self.plot.addScatter((), (), ()) + scatter = self.plot.getItems()[0] + scatter.setVisualization(scatter.Visualization.BINNED_STATISTIC) + self.assertIs(scatter.getVisualization(), + scatter.Visualization.BINNED_STATISTIC) + self.assertEqual( + scatter.getVisualizationParameter( + scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION), + 'mean') + + self.qapp.processEvents() + + scatter.setData(*numpy.random.random(300).reshape(3, -1)) + self.qapp.processEvents() + + # Update data + scatter.setData(*numpy.random.random(3000).reshape(3, -1)) + self.qapp.processEvents() + + for reduction in ('count', 'sum', 'mean'): + with self.subTest(reduction=reduction): + scatter.setVisualizationParameter( + scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION, + reduction) + self.assertEqual( + scatter.getVisualizationParameter( + scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION), + reduction) + + self.qapp.processEvents() + + +class TestPlotMarker(PlotWidgetTestCase): + """Basic tests for add*Marker""" + + def setUp(self): + super(TestPlotMarker, self).setUp() + self.plot.getYAxis().setLabel('Rows') + self.plot.getXAxis().setLabel('Columns') + + self.plot.getXAxis().setAutoScale(False) + self.plot.getYAxis().setAutoScale(False) + self.plot.setKeepDataAspectRatio(False) + self.plot.setLimits(0., 100., -100., 100.) + + def testPlotMarkerX(self): + self.plot.setGraphTitle('Markers X') + + markers = [ + (10., 'blue', False, False), + (20., 'red', False, False), + (40., 'green', True, False), + (60., 'gray', True, True), + (80., 'black', False, True), + ] + + for x, color, select, drag in markers: + name = str(x) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addXMarker(x, name, name, color, select, drag) + self.plot.resetZoom() + + def testPlotMarkerY(self): + self.plot.setGraphTitle('Markers Y') + + markers = [ + (-50., 'blue', False, False), + (-30., 'red', False, False), + (0., 'green', True, False), + (10., 'gray', True, True), + (80., 'black', False, True), + ] + + for y, color, select, drag in markers: + name = str(y) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addYMarker(y, name, name, color, select, drag) + self.plot.resetZoom() + + def testPlotMarkerPt(self): + self.plot.setGraphTitle('Markers Pt') + + markers = [ + (10., -50., 'blue', False, False), + (40., -30., 'red', False, False), + (50., 0., 'green', True, False), + (50., 20., 'gray', True, True), + (70., 50., 'black', False, True), + ] + for x, y, color, select, drag in markers: + name = "{0},{1}".format(x, y) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addMarker(x, y, name, name, color, select, drag) + + self.plot.resetZoom() + + def testPlotMarkerWithoutLegend(self): + self.plot.setGraphTitle('Markers without legend') + self.plot.getYAxis().setInverted(True) + + # Markers without legend + self.plot.addMarker(10, 10) + self.plot.addMarker(10, 20) + self.plot.addMarker(40, 50, text='test', symbol=None) + self.plot.addMarker(40, 50, text='test', symbol='+') + self.plot.addXMarker(25) + self.plot.addXMarker(35) + self.plot.addXMarker(45, text='test') + self.plot.addYMarker(55) + self.plot.addYMarker(65) + self.plot.addYMarker(75, text='test') + + self.plot.resetZoom() + + def testPlotMarkerYAxis(self): + # Check only the API + + legend = self.plot.addMarker(10, 10) + item = self.plot._getMarker(legend) + self.assertEqual(item.getYAxis(), "left") + + legend = self.plot.addMarker(10, 10, yaxis="right") + item = self.plot._getMarker(legend) + self.assertEqual(item.getYAxis(), "right") + + legend = self.plot.addMarker(10, 10, yaxis="left") + item = self.plot._getMarker(legend) + self.assertEqual(item.getYAxis(), "left") + + legend = self.plot.addXMarker(10, yaxis="right") + item = self.plot._getMarker(legend) + self.assertEqual(item.getYAxis(), "right") + + legend = self.plot.addXMarker(10, yaxis="left") + item = self.plot._getMarker(legend) + self.assertEqual(item.getYAxis(), "left") + + legend = self.plot.addYMarker(10, yaxis="right") + item = self.plot._getMarker(legend) + self.assertEqual(item.getYAxis(), "right") + + legend = self.plot.addYMarker(10, yaxis="left") + item = self.plot._getMarker(legend) + self.assertEqual(item.getYAxis(), "left") + + self.plot.resetZoom() + + +# TestPlotItem ################################################################ + +class TestPlotItem(PlotWidgetTestCase): + """Basic tests for addItem.""" + + # Polygon coordinates and color + POLYGONS = [ # legend, x coords, y coords, color + ('triangle', numpy.array((10, 30, 50)), + numpy.array((55, 70, 55)), 'red'), + ('square', numpy.array((10, 10, 50, 50)), + numpy.array((10, 50, 50, 10)), 'green'), + ('star', numpy.array((60, 70, 80, 60, 80)), + numpy.array((25, 50, 25, 40, 40)), 'blue'), + ('2 triangles-simple', + numpy.array((90., 95., 100., numpy.nan, 90., 95., 100.)), + numpy.array((25., 5., 25., numpy.nan, 30., 50., 30.)), + 'pink'), + ('2 triangles-extra NaN', + numpy.array((numpy.nan, 90., 95., 100., numpy.nan, 0., 90., 95., 100., numpy.nan)), + numpy.array((0., 55., 70., 55., numpy.nan, numpy.nan, 75., 90., 75., numpy.nan)), + 'black'), + ] + + # Rectangle coordinantes and color + RECTANGLES = [ # legend, x coords, y coords, color + ('square 1', numpy.array((1., 10.)), + numpy.array((1., 10.)), 'red'), + ('square 2', numpy.array((10., 20.)), + numpy.array((10., 20.)), 'green'), + ('square 3', numpy.array((20., 30.)), + numpy.array((20., 30.)), 'blue'), + ('rect 1', numpy.array((1., 30.)), + numpy.array((35., 40.)), 'black'), + ('line h', numpy.array((1., 30.)), + numpy.array((45., 45.)), 'darkRed'), + ] + + SCALES = Axis.LINEAR, Axis.LOGARITHMIC + + def setUp(self): + super(TestPlotItem, self).setUp() + + self.plot.getYAxis().setLabel('Rows') + self.plot.getXAxis().setLabel('Columns') + self.plot.getXAxis().setAutoScale(False) + self.plot.getYAxis().setAutoScale(False) + self.plot.setKeepDataAspectRatio(False) + self.plot.setLimits(0., 100., -100., 100.) + + def testPlotItemPolygonFill(self): + for scale in self.SCALES: + with self.subTest(scale=scale): + self.plot.clear() + self.plot.getXAxis().setScale(scale) + self.plot.getYAxis().setScale(scale) + self.plot.setGraphTitle('Item Fill %s' % scale) + + for legend, xList, yList, color in self.POLYGONS: + self.plot.addShape(xList, yList, legend=legend, + replace=False, linestyle='--', + shape="polygon", fill=True, color=color) + self.plot.resetZoom() + + def testPlotItemPolygonNoFill(self): + for scale in self.SCALES: + with self.subTest(scale=scale): + self.plot.clear() + self.plot.getXAxis().setScale(scale) + self.plot.getYAxis().setScale(scale) + self.plot.setGraphTitle('Item No Fill %s' % scale) + + for legend, xList, yList, color in self.POLYGONS: + self.plot.addShape(xList, yList, legend=legend, + replace=False, linestyle='--', + shape="polygon", fill=False, color=color) + self.plot.resetZoom() + + def testPlotItemRectangleFill(self): + for scale in self.SCALES: + with self.subTest(scale=scale): + self.plot.clear() + self.plot.getXAxis().setScale(scale) + self.plot.getYAxis().setScale(scale) + self.plot.setGraphTitle('Rectangle Fill %s' % scale) + + for legend, xList, yList, color in self.RECTANGLES: + self.plot.addShape(xList, yList, legend=legend, + replace=False, + shape="rectangle", fill=True, color=color) + self.plot.resetZoom() + + def testPlotItemRectangleNoFill(self): + for scale in self.SCALES: + with self.subTest(scale=scale): + self.plot.clear() + self.plot.getXAxis().setScale(scale) + self.plot.getYAxis().setScale(scale) + self.plot.setGraphTitle('Rectangle No Fill %s' % scale) + + for legend, xList, yList, color in self.RECTANGLES: + self.plot.addShape(xList, yList, legend=legend, + replace=False, + shape="rectangle", fill=False, color=color) + self.plot.resetZoom() + + +class TestPlotActiveCurveImage(PlotWidgetTestCase): + """Basic tests for active curve and image handling""" + xData = numpy.arange(1000) + yData = -500 + 100 * numpy.sin(xData) + xData2 = xData + 1000 + yData2 = xData - 1000 + 200 * numpy.random.random(1000) + + def tearDown(self): + self.plot.setActiveCurveHandling(False) + super(TestPlotActiveCurveImage, self).tearDown() + + def testActiveCurveAndLabels(self): + # Active curve handling off, no label change + self.plot.setActiveCurveHandling(False) + self.plot.getXAxis().setLabel('XLabel') + self.plot.getYAxis().setLabel('YLabel') + self.plot.addCurve((1, 2), (1, 2)) + self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel') + self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel') + + self.plot.addCurve((1, 2), (2, 3), xlabel='x1', ylabel='y1') + self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel') + self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel') + + self.plot.clear() + self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel') + self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel') + + # Active curve handling on, label changes + self.plot.setActiveCurveHandling(True) + self.plot.getXAxis().setLabel('XLabel') + self.plot.getYAxis().setLabel('YLabel') + + # labels changed as active curve + self.plot.addCurve((1, 2), (1, 2), legend='1', + xlabel='x1', ylabel='y1') + self.plot.setActiveCurve('1') + self.assertEqual(self.plot.getXAxis().getLabel(), 'x1') + self.assertEqual(self.plot.getYAxis().getLabel(), 'y1') + + # labels not changed as not active curve + self.plot.addCurve((1, 2), (2, 3), legend='2') + self.assertEqual(self.plot.getXAxis().getLabel(), 'x1') + self.assertEqual(self.plot.getYAxis().getLabel(), 'y1') + + # labels changed + self.plot.setActiveCurve('2') + self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel') + self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel') + + self.plot.setActiveCurve('1') + self.assertEqual(self.plot.getXAxis().getLabel(), 'x1') + self.assertEqual(self.plot.getYAxis().getLabel(), 'y1') + + self.plot.clear() + self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel') + self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel') + + def testPlotActiveCurveSelectionMode(self): + self.plot.clear() + self.plot.setActiveCurveHandling(True) + legend = "curve 1" + self.plot.addCurve(self.xData, self.yData, + legend=legend, + color="green") + + # active curve should be None + self.assertEqual(self.plot.getActiveCurve(just_legend=True), None) + + # active curve should be None when None is set as active curve + self.plot.setActiveCurve(legend) + current = self.plot.getActiveCurve(just_legend=True) + self.assertEqual(current, legend) + self.plot.setActiveCurve(None) + current = self.plot.getActiveCurve(just_legend=True) + self.assertEqual(current, None) + + # testing it automatically toggles if there is only one + self.plot.setActiveCurveSelectionMode("legacy") + current = self.plot.getActiveCurve(just_legend=True) + self.assertEqual(current, legend) + + # active curve should not change when None set as active curve + self.assertEqual(self.plot.getActiveCurveSelectionMode(), "legacy") + self.plot.setActiveCurve(None) + current = self.plot.getActiveCurve(just_legend=True) + self.assertEqual(current, legend) + + # situation where no curve is active + self.plot.clear() + self.plot.setActiveCurveHandling(True) + self.assertEqual(self.plot.getActiveCurveSelectionMode(), "atmostone") + self.plot.addCurve(self.xData, self.yData, + legend=legend, + color="green") + self.assertEqual(self.plot.getActiveCurve(just_legend=True), None) + self.plot.addCurve(self.xData2, self.yData2, + legend="curve 2", + color="red") + self.assertEqual(self.plot.getActiveCurve(just_legend=True), None) + self.plot.setActiveCurveSelectionMode("legacy") + self.assertEqual(self.plot.getActiveCurve(just_legend=True), None) + + # the first curve added should be active + self.plot.clear() + self.plot.addCurve(self.xData, self.yData, + legend=legend, + color="green") + self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend) + self.plot.addCurve(self.xData2, self.yData2, + legend="curve 2", + color="red") + self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend) + + def testActiveCurveStyle(self): + """Test change of active curve style""" + self.plot.setActiveCurveHandling(True) + self.plot.setActiveCurveStyle(color='black') + style = self.plot.getActiveCurveStyle() + self.assertEqual(style.getColor(), (0., 0., 0., 1.)) + self.assertIsNone(style.getLineStyle()) + self.assertIsNone(style.getLineWidth()) + self.assertIsNone(style.getSymbol()) + self.assertIsNone(style.getSymbolSize()) + + self.plot.addCurve(x=self.xData, y=self.yData, legend="curve1") + curve = self.plot.getCurve("curve1") + curve.setColor('blue') + curve.setLineStyle('-') + curve.setLineWidth(1) + curve.setSymbol('o') + curve.setSymbolSize(5) + + # Check default current style + defaultStyle = curve.getCurrentStyle() + self.assertEqual(defaultStyle, CurveStyle(color='blue', + linestyle='-', + linewidth=1, + symbol='o', + symbolsize=5)) + + # Activate curve with highlight color=black + self.plot.setActiveCurve("curve1") + style = curve.getCurrentStyle() + self.assertEqual(style.getColor(), (0., 0., 0., 1.)) + self.assertEqual(style.getLineStyle(), '-') + self.assertEqual(style.getLineWidth(), 1) + self.assertEqual(style.getSymbol(), 'o') + self.assertEqual(style.getSymbolSize(), 5) + + # Change highlight to linewidth=2 + self.plot.setActiveCurveStyle(linewidth=2) + style = curve.getCurrentStyle() + self.assertEqual(style.getColor(), (0., 0., 1., 1.)) + self.assertEqual(style.getLineStyle(), '-') + self.assertEqual(style.getLineWidth(), 2) + self.assertEqual(style.getSymbol(), 'o') + self.assertEqual(style.getSymbolSize(), 5) + + self.plot.setActiveCurve(None) + self.assertEqual(curve.getCurrentStyle(), defaultStyle) + + def testActiveImageAndLabels(self): + # Active image handling always on, no API for toggling it + self.plot.getXAxis().setLabel('XLabel') + self.plot.getYAxis().setLabel('YLabel') + + # labels changed as active curve + self.plot.addImage(numpy.arange(100).reshape(10, 10), + legend='1', xlabel='x1', ylabel='y1') + self.assertEqual(self.plot.getXAxis().getLabel(), 'x1') + self.assertEqual(self.plot.getYAxis().getLabel(), 'y1') + + # labels not changed as not active curve + self.plot.addImage(numpy.arange(100).reshape(10, 10), + legend='2') + self.assertEqual(self.plot.getXAxis().getLabel(), 'x1') + self.assertEqual(self.plot.getYAxis().getLabel(), 'y1') + + # labels changed + self.plot.setActiveImage('2') + self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel') + self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel') + + self.plot.setActiveImage('1') + self.assertEqual(self.plot.getXAxis().getLabel(), 'x1') + self.assertEqual(self.plot.getYAxis().getLabel(), 'y1') + + self.plot.clear() + self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel') + self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel') + + +############################################################################## +# Log +############################################################################## + +class TestPlotEmptyLog(PlotWidgetTestCase): + """Basic tests for log plot""" + def testEmptyPlotTitleLabelsLog(self): + self.plot.setGraphTitle('Empty Log Log') + self.plot.getXAxis().setLabel('X') + self.plot.getYAxis().setLabel('Y') + self.plot.getXAxis()._setLogarithmic(True) + self.plot.getYAxis()._setLogarithmic(True) + self.plot.resetZoom() + + +class TestPlotAxes(TestCaseQt, ParametricTestCase): + + # Test data + xData = numpy.arange(1, 10) + yData = xData ** 2 + + def __init__(self, methodName='runTest', backend=None): + unittest.TestCase.__init__(self, methodName) + self.__backend = backend + + def setUp(self): + super(TestPlotAxes, self).setUp() + self.plot = PlotWidget(backend=self.__backend) + # It is not needed to display the plot + # It saves a lot of time + # self.plot.show() + # self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestPlotAxes, self).tearDown() + + def testDefaultAxes(self): + axis = self.plot.getXAxis() + self.assertEqual(axis.getScale(), axis.LINEAR) + axis = self.plot.getYAxis() + self.assertEqual(axis.getScale(), axis.LINEAR) + axis = self.plot.getYAxis(axis="right") + self.assertEqual(axis.getScale(), axis.LINEAR) + + def testOldPlotAxis_getterSetter(self): + """Test silx API prior to silx 0.6""" + x = self.plot.getXAxis() + y = self.plot.getYAxis() + p = self.plot + + tests = [ + # setters + (p.setGraphXLimits, (10, 20), x.getLimits, (10, 20)), + (p.setGraphYLimits, (10, 20), y.getLimits, (10, 20)), + (p.setGraphXLabel, "foox", x.getLabel, "foox"), + (p.setGraphYLabel, "fooy", y.getLabel, "fooy"), + (p.setYAxisInverted, True, y.isInverted, True), + (p.setXAxisLogarithmic, True, x.getScale, x.LOGARITHMIC), + (p.setYAxisLogarithmic, True, y.getScale, y.LOGARITHMIC), + (p.setXAxisAutoScale, False, x.isAutoScale, False), + (p.setYAxisAutoScale, False, y.isAutoScale, False), + # getters + (x.setLimits, (11, 20), p.getGraphXLimits, (11, 20)), + (y.setLimits, (11, 20), p.getGraphYLimits, (11, 20)), + (x.setLabel, "fooxx", p.getGraphXLabel, "fooxx"), + (y.setLabel, "fooyy", p.getGraphYLabel, "fooyy"), + (y.setInverted, False, p.isYAxisInverted, False), + (x.setScale, x.LINEAR, p.isXAxisLogarithmic, False), + (y.setScale, y.LINEAR, p.isYAxisLogarithmic, False), + (x.setAutoScale, True, p.isXAxisAutoScale, True), + (y.setAutoScale, True, p.isYAxisAutoScale, True), + ] + for testCase in tests: + setter, value, getter, expected = testCase + with self.subTest(): + if setter is not None: + if not isinstance(value, tuple): + value = (value, ) + setter(*value) + if getter is not None: + self.assertEqual(getter(), expected) + + def testOldPlotAxis_Logarithmic(self): + """Test silx API prior to silx 0.6""" + x = self.plot.getXAxis() + y = self.plot.getYAxis() + yright = self.plot.getYAxis(axis="right") + + self.assertEqual(x.getScale(), x.LINEAR) + self.assertEqual(y.getScale(), x.LINEAR) + self.assertEqual(yright.getScale(), x.LINEAR) + + self.plot.setXAxisLogarithmic(True) + self.assertEqual(x.getScale(), x.LOGARITHMIC) + self.assertEqual(y.getScale(), x.LINEAR) + self.assertEqual(yright.getScale(), x.LINEAR) + self.assertEqual(self.plot.isXAxisLogarithmic(), True) + self.assertEqual(self.plot.isYAxisLogarithmic(), False) + + self.plot.setYAxisLogarithmic(True) + self.assertEqual(x.getScale(), x.LOGARITHMIC) + self.assertEqual(y.getScale(), x.LOGARITHMIC) + self.assertEqual(yright.getScale(), x.LOGARITHMIC) + self.assertEqual(self.plot.isXAxisLogarithmic(), True) + self.assertEqual(self.plot.isYAxisLogarithmic(), True) + + yright.setScale(yright.LINEAR) + self.assertEqual(x.getScale(), x.LOGARITHMIC) + self.assertEqual(y.getScale(), x.LINEAR) + self.assertEqual(yright.getScale(), x.LINEAR) + self.assertEqual(self.plot.isXAxisLogarithmic(), True) + self.assertEqual(self.plot.isYAxisLogarithmic(), False) + + def testOldPlotAxis_AutoScale(self): + """Test silx API prior to silx 0.6""" + x = self.plot.getXAxis() + y = self.plot.getYAxis() + yright = self.plot.getYAxis(axis="right") + + self.assertEqual(x.isAutoScale(), True) + self.assertEqual(y.isAutoScale(), True) + self.assertEqual(yright.isAutoScale(), True) + + self.plot.setXAxisAutoScale(False) + self.assertEqual(x.isAutoScale(), False) + self.assertEqual(y.isAutoScale(), True) + self.assertEqual(yright.isAutoScale(), True) + self.assertEqual(self.plot.isXAxisAutoScale(), False) + self.assertEqual(self.plot.isYAxisAutoScale(), True) + + self.plot.setYAxisAutoScale(False) + self.assertEqual(x.isAutoScale(), False) + self.assertEqual(y.isAutoScale(), False) + self.assertEqual(yright.isAutoScale(), False) + self.assertEqual(self.plot.isXAxisAutoScale(), False) + self.assertEqual(self.plot.isYAxisAutoScale(), False) + + yright.setAutoScale(True) + self.assertEqual(x.isAutoScale(), False) + self.assertEqual(y.isAutoScale(), True) + self.assertEqual(yright.isAutoScale(), True) + self.assertEqual(self.plot.isXAxisAutoScale(), False) + self.assertEqual(self.plot.isYAxisAutoScale(), True) + + def testOldPlotAxis_Inverted(self): + """Test silx API prior to silx 0.6""" + x = self.plot.getXAxis() + y = self.plot.getYAxis() + yright = self.plot.getYAxis(axis="right") + + self.assertEqual(x.isInverted(), False) + self.assertEqual(y.isInverted(), False) + self.assertEqual(yright.isInverted(), False) + + self.plot.setYAxisInverted(True) + self.assertEqual(x.isInverted(), False) + self.assertEqual(y.isInverted(), True) + self.assertEqual(yright.isInverted(), True) + self.assertEqual(self.plot.isYAxisInverted(), True) + + yright.setInverted(False) + self.assertEqual(x.isInverted(), False) + self.assertEqual(y.isInverted(), False) + self.assertEqual(yright.isInverted(), False) + self.assertEqual(self.plot.isYAxisInverted(), False) + + def testLogXWithData(self): + self.plot.setGraphTitle('Curve X: Log Y: Linear') + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + axis = self.plot.getXAxis() + axis.setScale(axis.LOGARITHMIC) + + self.assertEqual(axis.getScale(), axis.LOGARITHMIC) + + def testLogYWithData(self): + self.plot.setGraphTitle('Curve X: Linear Y: Log') + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + axis = self.plot.getYAxis() + axis.setScale(axis.LOGARITHMIC) + + self.assertEqual(axis.getScale(), axis.LOGARITHMIC) + axis = self.plot.getYAxis(axis="right") + self.assertEqual(axis.getScale(), axis.LOGARITHMIC) + + def testLogYRightWithData(self): + self.plot.setGraphTitle('Curve X: Linear Y: Log') + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + axis = self.plot.getYAxis(axis="right") + axis.setScale(axis.LOGARITHMIC) + + self.assertEqual(axis.getScale(), axis.LOGARITHMIC) + axis = self.plot.getYAxis() + self.assertEqual(axis.getScale(), axis.LOGARITHMIC) + + def testLimitsChanged_setLimits(self): + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=False, + color='green', linestyle="-", symbol='o') + listener = SignalListener() + self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x")) + self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y")) + self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2")) + self.plot.setLimits(0, 1, 0, 1, 0, 1) + # at least one event per axis + self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3) + + def testLimitsChanged_resetZoom(self): + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=False, + color='green', linestyle="-", symbol='o') + listener = SignalListener() + self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x")) + self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y")) + self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2")) + self.plot.resetZoom() + # at least one event per axis + self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3) + + def testLimitsChanged_setXLimit(self): + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=False, + color='green', linestyle="-", symbol='o') + listener = SignalListener() + axis = self.plot.getXAxis() + axis.sigLimitsChanged.connect(listener) + axis.setLimits(20, 30) + # at least one event per axis + self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0)) + self.assertEqual(axis.getLimits(), (20.0, 30.0)) + + def testLimitsChanged_setYLimit(self): + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=False, + color='green', linestyle="-", symbol='o') + listener = SignalListener() + axis = self.plot.getYAxis() + axis.sigLimitsChanged.connect(listener) + axis.setLimits(20, 30) + # at least one event per axis + self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0)) + self.assertEqual(axis.getLimits(), (20.0, 30.0)) + + def testLimitsChanged_setYRightLimit(self): + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=False, + color='green', linestyle="-", symbol='o') + listener = SignalListener() + axis = self.plot.getYAxis(axis="right") + axis.sigLimitsChanged.connect(listener) + axis.setLimits(20, 30) + # at least one event per axis + self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0)) + self.assertEqual(axis.getLimits(), (20.0, 30.0)) + + def testScaleProxy(self): + listener = SignalListener() + y = self.plot.getYAxis() + yright = self.plot.getYAxis(axis="right") + y.sigScaleChanged.connect(listener.partial("left")) + yright.sigScaleChanged.connect(listener.partial("right")) + yright.setScale(yright.LOGARITHMIC) + + self.assertEqual(y.getScale(), y.LOGARITHMIC) + events = listener.arguments() + self.assertEqual(len(events), 2) + self.assertIn(("left", y.LOGARITHMIC), events) + self.assertIn(("right", y.LOGARITHMIC), events) + + def testAutoScaleProxy(self): + listener = SignalListener() + y = self.plot.getYAxis() + yright = self.plot.getYAxis(axis="right") + y.sigAutoScaleChanged.connect(listener.partial("left")) + yright.sigAutoScaleChanged.connect(listener.partial("right")) + yright.setAutoScale(False) + + self.assertEqual(y.isAutoScale(), False) + events = listener.arguments() + self.assertEqual(len(events), 2) + self.assertIn(("left", False), events) + self.assertIn(("right", False), events) + + def testInvertedProxy(self): + listener = SignalListener() + y = self.plot.getYAxis() + yright = self.plot.getYAxis(axis="right") + y.sigInvertedChanged.connect(listener.partial("left")) + yright.sigInvertedChanged.connect(listener.partial("right")) + yright.setInverted(True) + + self.assertEqual(y.isInverted(), True) + events = listener.arguments() + self.assertEqual(len(events), 2) + self.assertIn(("left", True), events) + self.assertIn(("right", True), events) + + def testAxesDisplayedFalse(self): + """Test coverage on setAxesDisplayed(False)""" + self.plot.setAxesDisplayed(False) + + def testAxesDisplayedTrue(self): + """Test coverage on setAxesDisplayed(True)""" + self.plot.setAxesDisplayed(True) + + def testAxesMargins(self): + """Test PlotWidget's getAxesMargins and setAxesMargins""" + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + margins = self.plot.getAxesMargins() + self.assertEqual(margins, (.15, .1, .1, .15)) + + for margins in ((0., 0., 0., 0.), (.15, .1, .1, .15)): + with self.subTest(margins=margins): + self.plot.setAxesMargins(*margins) + self.qapp.processEvents() + self.assertEqual(self.plot.getAxesMargins(), margins) + + def testBoundingRectItem(self): + item = BoundingRect() + item.setBounds((-1000, 1000, -2000, 2000)) + self.plot.addItem(item) + self.plot.resetZoom() + limits = numpy.array(self.plot.getXAxis().getLimits()) + numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000])) + limits = numpy.array(self.plot.getYAxis().getLimits()) + numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000])) + + def testBoundingRectRightItem(self): + item = BoundingRect() + item.setYAxis("right") + item.setBounds((-1000, 1000, -2000, 2000)) + self.plot.addItem(item) + self.plot.resetZoom() + limits = numpy.array(self.plot.getXAxis().getLimits()) + numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000])) + limits = numpy.array(self.plot.getYAxis("right").getLimits()) + numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000])) + + def testBoundingRectArguments(self): + item = BoundingRect() + with self.assertRaises(Exception): + item.setBounds((1000, -1000, -2000, 2000)) + with self.assertRaises(Exception): + item.setBounds((-1000, 1000, 2000, -2000)) + + def testBoundingRectWithLog(self): + item = BoundingRect() + self.plot.addItem(item) + + item.setBounds((-1000, 1000, -2000, 2000)) + self.plot.getXAxis()._setLogarithmic(True) + self.plot.getYAxis()._setLogarithmic(False) + self.assertEqual(item.getBounds(), (1000, 1000, -2000, 2000)) + + item.setBounds((-1000, 1000, -2000, 2000)) + self.plot.getXAxis()._setLogarithmic(False) + self.plot.getYAxis()._setLogarithmic(True) + self.assertEqual(item.getBounds(), (-1000, 1000, 2000, 2000)) + + item.setBounds((-1000, 0, -2000, 2000)) + self.plot.getXAxis()._setLogarithmic(True) + self.plot.getYAxis()._setLogarithmic(False) + self.assertIsNone(item.getBounds()) + + def testAxisExtent(self): + """Test XAxisExtent and yAxisExtent""" + for cls, axis in ((XAxisExtent, self.plot.getXAxis()), + (YAxisExtent, self.plot.getYAxis())): + for range_, logRange in (((2, 3), (2, 3)), + ((-2, -1), (1, 100)), + ((-1, 3), (3. * 0.9, 3. * 1.1))): + extent = cls() + extent.setRange(*range_) + self.plot.addItem(extent) + + for isLog, plotRange in ((False, range_), (True, logRange)): + with self.subTest( + cls=cls.__name__, range=range_, isLog=isLog): + axis._setLogarithmic(isLog) + self.plot.resetZoom() + self.qapp.processEvents() + self.assertEqual(axis.getLimits(), plotRange) + + axis._setLogarithmic(False) + self.plot.clear() + + def testAxisLimitOverflow(self): + """Test setting limis beyond supported range""" + xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis() + for scale in ("linear", "log"): + xaxis.setScale(scale) + yaxis.setScale(scale) + for limits in ((1e300, 1e308), + (-1e308, 1e308), + (1e-300, 2e-300)): + with self.subTest(scale=scale, limits=limits): + xaxis.setLimits(*limits) + self.qapp.processEvents() + self.assertNotEqual(xaxis.getLimits(), limits) + yaxis.setLimits(*limits) + self.qapp.processEvents() + self.assertNotEqual(yaxis.getLimits(), limits) + + +class TestPlotCurveLog(PlotWidgetTestCase, ParametricTestCase): + """Basic tests for addCurve with log scale axes""" + + # Test data + xData = numpy.arange(1000) + 1 + yData = xData ** 2 + + def _setLabels(self): + self.plot.getXAxis().setLabel('X') + self.plot.getYAxis().setLabel('X * X') + + def testPlotCurveLogX(self): + self._setLabels() + self.plot.getXAxis()._setLogarithmic(True) + self.plot.setGraphTitle('Curve X: Log Y: Linear') + + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + + def testPlotCurveLogY(self): + self._setLabels() + self.plot.getYAxis()._setLogarithmic(True) + + self.plot.setGraphTitle('Curve X: Linear Y: Log') + + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + + def testPlotCurveLogXY(self): + self._setLabels() + self.plot.getXAxis()._setLogarithmic(True) + self.plot.getYAxis()._setLogarithmic(True) + + self.plot.setGraphTitle('Curve X: Log Y: Log') + + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + + def testPlotCurveErrorLogXY(self): + self.plot.getXAxis()._setLogarithmic(True) + self.plot.getYAxis()._setLogarithmic(True) + + # Every second error leads to negative number + errors = numpy.ones_like(self.xData) + errors[::2] = self.xData[::2] + 1 + + tests = [ # name, xerror, yerror + ('xerror=3', 3, None), + ('xerror=N array', errors, None), + ('xerror=Nx1 array', errors.reshape(len(errors), 1), None), + ('xerror=2xN array', numpy.array((errors, errors)), None), + ('yerror=6', None, 6), + ('yerror=N array', None, errors ** 2), + ('yerror=Nx1 array', None, (errors ** 2).reshape(len(errors), 1)), + ('yerror=2xN array', None, numpy.array((errors, errors)) ** 2), + ] + + for name, xError, yError in tests: + with self.subTest(name): + self.plot.setGraphTitle(name) + self.plot.addCurve(self.xData, self.yData, + legend=name, + xerror=xError, yerror=yError, + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + + self.qapp.processEvents() + + self.plot.clear() + self.plot.resetZoom() + self.qapp.processEvents() + + def testPlotCurveToggleLog(self): + """Add a curve with negative data and toggle log axis""" + arange = numpy.arange(1000) + 1 + tests = [ # name, xData, yData + ('x>0, some negative y', arange, arange - 500), + ('x>0, y<0', arange, -arange), + ('some negative x, y>0', arange - 500, arange), + ('x<0, y>0', -arange, arange), + ('some negative x and y', arange - 500, arange - 500), + ('x<0, y<0', -arange, -arange), + ] + + for name, xData, yData in tests: + with self.subTest(name): + self.plot.addCurve(xData, yData, resetzoom=True) + self.qapp.processEvents() + + # no log axis + xLim = self.plot.getXAxis().getLimits() + self.assertEqual(xLim, (min(xData), max(xData))) + yLim = self.plot.getYAxis().getLimits() + self.assertEqual(yLim, (min(yData), max(yData))) + + # x axis log + self.plot.getXAxis()._setLogarithmic(True) + self.qapp.processEvents() + + xLim = self.plot.getXAxis().getLimits() + yLim = self.plot.getYAxis().getLimits() + positives = xData > 0 + if numpy.any(positives): + self.assertTrue(numpy.allclose( + xLim, (min(xData[positives]), max(xData[positives])))) + self.assertEqual( + yLim, (min(yData[positives]), max(yData[positives]))) + else: # No positive x in the curve + self.assertEqual(xLim, (1., 100.)) + self.assertEqual(yLim, (1., 100.)) + + # x axis and y axis log + self.plot.getYAxis()._setLogarithmic(True) + self.qapp.processEvents() + + xLim = self.plot.getXAxis().getLimits() + yLim = self.plot.getYAxis().getLimits() + positives = numpy.logical_and(xData > 0, yData > 0) + if numpy.any(positives): + self.assertTrue(numpy.allclose( + xLim, (min(xData[positives]), max(xData[positives])))) + self.assertTrue(numpy.allclose( + yLim, (min(yData[positives]), max(yData[positives])))) + else: # No positive x and y in the curve + self.assertEqual(xLim, (1., 100.)) + self.assertEqual(yLim, (1., 100.)) + + # y axis log + self.plot.getXAxis()._setLogarithmic(False) + self.qapp.processEvents() + + xLim = self.plot.getXAxis().getLimits() + yLim = self.plot.getYAxis().getLimits() + positives = yData > 0 + if numpy.any(positives): + self.assertEqual( + xLim, (min(xData[positives]), max(xData[positives]))) + self.assertTrue(numpy.allclose( + yLim, (min(yData[positives]), max(yData[positives])))) + else: # No positive y in the curve + self.assertEqual(xLim, (1., 100.)) + self.assertEqual(yLim, (1., 100.)) + + # no log axis + self.plot.getYAxis()._setLogarithmic(False) + self.qapp.processEvents() + + xLim = self.plot.getXAxis().getLimits() + self.assertEqual(xLim, (min(xData), max(xData))) + yLim = self.plot.getYAxis().getLimits() + self.assertEqual(yLim, (min(yData), max(yData))) + + self.plot.clear() + self.plot.resetZoom() + self.qapp.processEvents() + + +class TestPlotImageLog(PlotWidgetTestCase): + """Basic tests for addImage with log scale axes.""" + + def setUp(self): + super(TestPlotImageLog, self).setUp() + + self.plot.getXAxis().setLabel('Columns') + self.plot.getYAxis().setLabel('Rows') + + def testPlotColormapGrayLogX(self): + self.plot.getXAxis()._setLogarithmic(True) + self.plot.setGraphTitle('CMap X: Log Y: Linear') + + colormap = Colormap(name='gray', + normalization='linear', + vmin=None, + vmax=None) + self.plot.addImage(DATA_2D, legend="image 1", + origin=(1., 1.), scale=(1., 1.), + resetzoom=False, colormap=colormap) + self.plot.resetZoom() + + def testPlotColormapGrayLogY(self): + self.plot.getYAxis()._setLogarithmic(True) + self.plot.setGraphTitle('CMap X: Linear Y: Log') + + colormap = Colormap(name='gray', + normalization='linear', + vmin=None, + vmax=None) + self.plot.addImage(DATA_2D, legend="image 1", + origin=(1., 1.), scale=(1., 1.), + resetzoom=False, colormap=colormap) + self.plot.resetZoom() + + def testPlotColormapGrayLogXY(self): + self.plot.getXAxis()._setLogarithmic(True) + self.plot.getYAxis()._setLogarithmic(True) + self.plot.setGraphTitle('CMap X: Log Y: Log') + + colormap = Colormap(name='gray', + normalization='linear', + vmin=None, + vmax=None) + self.plot.addImage(DATA_2D, legend="image 1", + origin=(1., 1.), scale=(1., 1.), + resetzoom=False, colormap=colormap) + self.plot.resetZoom() + + def testPlotRgbRgbaLogXY(self): + self.plot.getXAxis()._setLogarithmic(True) + self.plot.getYAxis()._setLogarithmic(True) + self.plot.setGraphTitle('RGB + RGBA X: Log Y: Log') + + rgb = numpy.array( + (((0, 0, 0), (128, 0, 0), (255, 0, 0)), + ((0, 128, 0), (0, 128, 128), (0, 128, 256))), + dtype=numpy.uint8) + + self.plot.addImage(rgb, legend="rgb", + origin=(1, 1), scale=(10, 10), + resetzoom=False) + + rgba = numpy.array( + (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)), + ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))), + dtype=numpy.float32) + + self.plot.addImage(rgba, legend="rgba", + origin=(5., 5.), scale=(10., 10.), + resetzoom=False) + self.plot.resetZoom() + + +class TestPlotMarkerLog(PlotWidgetTestCase): + """Basic tests for markers on log scales""" + + # Test marker parameters + markers = [ # x, y, color, selectable, draggable + (10., 10., 'blue', False, False), + (20., 20., 'red', False, False), + (40., 100., 'green', True, False), + (40., 500., 'gray', True, True), + (60., 800., 'black', False, True), + ] + + def setUp(self): + super(TestPlotMarkerLog, self).setUp() + + self.plot.getYAxis().setLabel('Rows') + self.plot.getXAxis().setLabel('Columns') + self.plot.getXAxis().setAutoScale(False) + self.plot.getYAxis().setAutoScale(False) + self.plot.setKeepDataAspectRatio(False) + self.plot.setLimits(1., 100., 1., 1000.) + self.plot.getXAxis()._setLogarithmic(True) + self.plot.getYAxis()._setLogarithmic(True) + + def testPlotMarkerXLog(self): + self.plot.setGraphTitle('Markers X, Log axes') + + for x, _, color, select, drag in self.markers: + name = str(x) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addXMarker(x, name, name, color, select, drag) + self.plot.resetZoom() + + def testPlotMarkerYLog(self): + self.plot.setGraphTitle('Markers Y, Log axes') + + for _, y, color, select, drag in self.markers: + name = str(y) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addYMarker(y, name, name, color, select, drag) + self.plot.resetZoom() + + def testPlotMarkerPtLog(self): + self.plot.setGraphTitle('Markers Pt, Log axes') + + for x, y, color, select, drag in self.markers: + name = "{0},{1}".format(x, y) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addMarker(x, y, name, name, color, select, drag) + self.plot.resetZoom() + + +@pytest.mark.usefixtures("test_options_class_attr") +class TestPlotWidgetSwitchBackend(PlotWidgetTestCase): + """Test [get|set]Backend to switch backend""" + + @pytest.mark.usefixtures("test_options") + def testSwitchBackend(self): + """Test switching a plot with a few items""" + backends = {'none': 'BackendBase', 'mpl': 'BackendMatplotlibQt'} + if self.test_options.WITH_GL_TEST: + backends['gl'] = 'BackendOpenGL' + + self.plot.addImage(numpy.arange(100).reshape(10, 10)) + self.plot.addCurve((-3, -2, -1), (1, 2, 3)) + self.plot.resetZoom() + xlimits = self.plot.getXAxis().getLimits() + ylimits = self.plot.getYAxis().getLimits() + items = self.plot.getItems() + self.assertEqual(len(items), 2) + + for backend, className in backends.items(): + with self.subTest(backend=backend): + self.plot.setBackend(backend) + self.plot.replot() + + retrievedBackend = self.plot.getBackend() + self.assertEqual(type(retrievedBackend).__name__, className) + self.assertEqual(self.plot.getXAxis().getLimits(), xlimits) + self.assertEqual(self.plot.getYAxis().getLimits(), ylimits) + self.assertEqual(self.plot.getItems(), items) + + +class TestPlotWidgetSelection(PlotWidgetTestCase): + """Test PlotWidget.selection and active items handling""" + + def _checkSelection(self, selection, current=None, selected=()): + """Check current item and selected items.""" + self.assertIs(selection.getCurrentItem(), current) + self.assertEqual(selection.getSelectedItems(), selected) + + def testSyncWithActiveItems(self): + """Test update of PlotWidgetSelection according to active items""" + listener = SignalListener() + + selection = self.plot.selection() + selection.sigCurrentItemChanged.connect(listener) + self._checkSelection(selection) + + # Active item is current + self.plot.addImage(((0, 1), (2, 3)), legend='image') + image = self.plot.getActiveImage() + self.assertEqual(listener.callCount(), 1) + self._checkSelection(selection, image, (image,)) + + # No active = no current + self.plot.setActiveImage(None) + self.assertEqual(listener.callCount(), 2) + self._checkSelection(selection) + + # Active item is current + self.plot.setActiveImage('image') + self.assertEqual(listener.callCount(), 3) + self._checkSelection(selection, image, (image,)) + + # Mosted recently "actived" item is current + self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter') + scatter = self.plot.getActiveScatter() + self.assertEqual(listener.callCount(), 4) + self._checkSelection(selection, scatter, (scatter, image)) + + # Previously mosted recently "actived" item is current + self.plot.setActiveScatter(None) + self.assertEqual(listener.callCount(), 5) + self._checkSelection(selection, image, (image,)) + + # Mosted recently "actived" item is current + self.plot.setActiveScatter('scatter') + self.assertEqual(listener.callCount(), 6) + self._checkSelection(selection, scatter, (scatter, image)) + + # No active = no current + self.plot.setActiveImage(None) + self.plot.setActiveScatter(None) + self.assertEqual(listener.callCount(), 7) + self._checkSelection(selection) + + # Mosted recently "actived" item is current + self.plot.setActiveScatter('scatter') + self.assertEqual(listener.callCount(), 8) + self.plot.setActiveImage('image') + self.assertEqual(listener.callCount(), 9) + self._checkSelection(selection, image, (image, scatter)) + + # Add a curve which is not active by default + self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve') + curve = self.plot.getCurve('curve') + self.assertEqual(listener.callCount(), 9) + self._checkSelection(selection, image, (image, scatter)) + + # Mosted recently "actived" item is current + self.plot.setActiveCurve('curve') + self.assertEqual(listener.callCount(), 10) + self._checkSelection(selection, curve, (curve, image, scatter)) + + # Add a curve which is not active by default + self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve2') + curve2 = self.plot.getCurve('curve2') + self.assertEqual(listener.callCount(), 10) + self._checkSelection(selection, curve, (curve, image, scatter)) + + # Mosted recently "actived" item is current, previous curve is removed + self.plot.setActiveCurve('curve2') + self.assertEqual(listener.callCount(), 11) + self._checkSelection(selection, curve2, (curve2, image, scatter)) + + # No items = no current + self.plot.clear() + self.assertEqual(listener.callCount(), 12) + self._checkSelection(selection) + + def testPlotWidgetWithItems(self): + """Test init of selection on a plot with items""" + self.plot.addImage(((0, 1), (2, 3)), legend='image') + self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter') + self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve') + self.plot.setActiveCurve('curve') + + selection = self.plot.selection() + self.assertIsNotNone(selection.getCurrentItem()) + selected = selection.getSelectedItems() + self.assertEqual(len(selected), 3) + self.assertIn(self.plot.getActiveCurve(), selected) + self.assertIn(self.plot.getActiveImage(), selected) + self.assertIn(self.plot.getActiveScatter(), selected) + + def testSetCurrentItem(self): + """Test setCurrentItem""" + # Add items to the plot + self.plot.addImage(((0, 1), (2, 3)), legend='image') + image = self.plot.getActiveImage() + self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter') + scatter = self.plot.getActiveScatter() + self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve') + self.plot.setActiveCurve('curve') + curve = self.plot.getActiveCurve() + + selection = self.plot.selection() + self.assertIsNotNone(selection.getCurrentItem()) + self.assertEqual(len(selection.getSelectedItems()), 3) + + # Set current to None reset all active items + selection.setCurrentItem(None) + self._checkSelection(selection) + self.assertIsNone(self.plot.getActiveCurve()) + self.assertIsNone(self.plot.getActiveImage()) + self.assertIsNone(self.plot.getActiveScatter()) + + # Set current to an item makes it active + selection.setCurrentItem(image) + self._checkSelection(selection, image, (image,)) + self.assertIsNone(self.plot.getActiveCurve()) + self.assertIs(self.plot.getActiveImage(), image) + self.assertIsNone(self.plot.getActiveScatter()) + + # Set current to an item makes it active and keeps other active + selection.setCurrentItem(curve) + self._checkSelection(selection, curve, (curve, image)) + self.assertIs(self.plot.getActiveCurve(), curve) + self.assertIs(self.plot.getActiveImage(), image) + self.assertIsNone(self.plot.getActiveScatter()) + + # Set current to an item makes it active and keeps other active + selection.setCurrentItem(scatter) + self._checkSelection(selection, scatter, (scatter, curve, image)) + self.assertIs(self.plot.getActiveCurve(), curve) + self.assertIs(self.plot.getActiveImage(), image) + self.assertIs(self.plot.getActiveScatter(), scatter) + + +@pytest.mark.usefixtures("use_opengl") +class TestPlotWidget_Gl(TestPlotWidget): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotImage_Gl(TestPlotImage): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotCurve_Gl(TestPlotCurve): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotHistogram_Gl(TestPlotHistogram): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotScatter_Gl(TestPlotScatter): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotMarker_Gl(TestPlotMarker): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotItem_Gl(TestPlotItem): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotAxes_Gl(TestPlotAxes): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotActiveCurveImage_Gl(TestPlotActiveCurveImage): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotEmptyLog_Gl(TestPlotEmptyLog): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotCurveLog_Gl(TestPlotCurveLog): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotImageLog_Gl(TestPlotImageLog): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotMarkerLog_Gl(TestPlotMarkerLog): + backend="gl" + +@pytest.mark.usefixtures("use_opengl") +class TestPlotWidgetSelection_Gl(TestPlotWidgetSelection): + backend="gl" + +class TestSpecial_ExplicitMplBackend(TestSpecialBackend): + backend="mpl" diff --git a/src/silx/gui/plot/test/testPlotWidgetNoBackend.py b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py new file mode 100644 index 0000000..4914929 --- /dev/null +++ b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py @@ -0,0 +1,618 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWidget with 'none' backend""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "17/01/2018" + + +import unittest +from functools import reduce +from silx.utils.testutils import ParametricTestCase + +import numpy + +from silx.gui.plot.PlotWidget import PlotWidget +from silx.gui.plot.items.histogram import _getHistogramCurve, _computeEdges + + +class TestPlot(unittest.TestCase): + """Basic tests of Plot without backend""" + + def testPlotTitleLabels(self): + """Create a Plot and set the labels""" + + plot = PlotWidget(backend='none') + + title, xlabel, ylabel = 'the title', 'x label', 'y label' + plot.setGraphTitle(title) + plot.getXAxis().setLabel(xlabel) + plot.getYAxis().setLabel(ylabel) + + self.assertEqual(plot.getGraphTitle(), title) + self.assertEqual(plot.getXAxis().getLabel(), xlabel) + self.assertEqual(plot.getYAxis().getLabel(), ylabel) + + def testAddNoRemove(self): + """add objects to the Plot""" + + plot = PlotWidget(backend='none') + plot.addCurve(x=(1, 2, 3), y=(3, 2, 1)) + plot.addImage(numpy.arange(100.).reshape(10, -1)) + plot.addShape(numpy.array((1., 10.)), + numpy.array((10., 10.)), + shape="rectangle") + plot.addXMarker(10.) + + +class TestPlotRanges(ParametricTestCase): + """Basic tests of Plot data ranges without backend""" + + _getValidValues = {True: lambda ar: ar > 0, + False: lambda ar: numpy.ones(shape=ar.shape, + dtype=bool)} + + @staticmethod + def _getRanges(arrays, are_logs): + gen = (TestPlotRanges._getValidValues[is_log](ar) + for (ar, is_log) in zip(arrays, are_logs)) + indices = numpy.where(reduce(numpy.logical_and, gen))[0] + if len(indices) > 0: + ranges = [(ar[indices[0]], ar[indices[-1]]) for ar in arrays] + else: + ranges = [None] * len(arrays) + + return ranges + + @staticmethod + def _getRangesMinmax(ranges): + # TODO : error if None in ranges. + rangeMin = numpy.min([rng[0] for rng in ranges]) + rangeMax = numpy.max([rng[1] for rng in ranges]) + return rangeMin, rangeMax + + def testDataRangeNoPlot(self): + """empty plot data range""" + + plot = PlotWidget(backend='none') + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.getXAxis()._setLogarithmic(logX) + plot.getYAxis()._setLogarithmic(logY) + dataRange = plot.getDataRange() + self.assertIsNone(dataRange.x) + self.assertIsNone(dataRange.y) + self.assertIsNone(dataRange.yright) + + def testDataRangeLeft(self): + """left axis range""" + + plot = PlotWidget(backend='none') + + xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1 + yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1 + + plot.addCurve(x=xData, + y=yData, + legend='plot_0', + yaxis='left') + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.getXAxis()._setLogarithmic(logX) + plot.getYAxis()._setLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = self._getRanges([xData, yData], + [logX, logY]) + self.assertSequenceEqual(dataRange.x, xRange) + self.assertSequenceEqual(dataRange.y, yRange) + self.assertIsNone(dataRange.yright) + + def testDataRangeRight(self): + """right axis range""" + + plot = PlotWidget(backend='none') + xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1 + yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1 + plot.addCurve(x=xData, + y=yData, + legend='plot_0', + yaxis='right') + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.getXAxis()._setLogarithmic(logX) + plot.getYAxis()._setLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = self._getRanges([xData, yData], + [logX, logY]) + self.assertSequenceEqual(dataRange.x, xRange) + self.assertIsNone(dataRange.y) + self.assertSequenceEqual(dataRange.yright, yRange) + + def testDataRangeImage(self): + """image data range""" + + origin = (-10, 25) + scale = (3., 8.) + image = numpy.arange(100.).reshape(20, 5) + + plot = PlotWidget(backend='none') + plot.addImage(image, + origin=origin, scale=scale) + + xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0] + yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1] + + ranges = {(False, False): (xRange, yRange), + (True, False): (None, None), + (True, True): (None, None), + (False, True): (None, None)} + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.getXAxis()._setLogarithmic(logX) + plot.getYAxis()._setLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = ranges[logX, logY] + self.assertTrue(numpy.array_equal(dataRange.x, xRange), + msg='{0} != {1}'.format(dataRange.x, xRange)) + self.assertTrue(numpy.array_equal(dataRange.y, yRange), + msg='{0} != {1}'.format(dataRange.y, yRange)) + self.assertIsNone(dataRange.yright) + + def testDataRangeLeftRight(self): + """right+left axis range""" + + plot = PlotWidget(backend='none') + + xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1 + yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1 + plot.addCurve(x=xData_l, + y=yData_l, + legend='plot_l', + yaxis='left') + + xData_r = numpy.arange(10) - 4.9 # range : -4.9 , 4.1 + yData_r = numpy.arange(10) - 6.9 # range : -6.9 , 2.1 + plot.addCurve(x=xData_r, + y=yData_r, + legend='plot_r', + yaxis='right') + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.getXAxis()._setLogarithmic(logX) + plot.getYAxis()._setLogarithmic(logY) + dataRange = plot.getDataRange() + xRangeL, yRangeL = self._getRanges([xData_l, yData_l], + [logX, logY]) + xRangeR, yRangeR = self._getRanges([xData_r, yData_r], + [logX, logY]) + xRangeLR = self._getRangesMinmax([xRangeL, xRangeR]) + self.assertSequenceEqual(dataRange.x, xRangeLR) + self.assertSequenceEqual(dataRange.y, yRangeL) + self.assertSequenceEqual(dataRange.yright, yRangeR) + + def testDataRangeCurveImage(self): + """right+left+image axis range""" + + # overlapping ranges : + # image sets x min and y max + # plot_left sets y min + # plot_right sets x max (and yright) + plot = PlotWidget(backend='none') + + origin = (-10, 5) + scale = (3., 8.) + image = numpy.arange(100.).reshape(20, 5) + + plot.addImage(image, + origin=origin, scale=scale, legend='image') + + xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1 + yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1 + plot.addCurve(x=xData_l, + y=yData_l, + legend='plot_l', + yaxis='left') + + xData_r = numpy.arange(10) + 4.1 # range : 4.1 , 13.1 + yData_r = numpy.arange(10) - 0.9 # range : -0.9 , 8.1 + plot.addCurve(x=xData_r, + y=yData_r, + legend='plot_r', + yaxis='right') + + imgXRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0] + imgYRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1] + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.getXAxis()._setLogarithmic(logX) + plot.getYAxis()._setLogarithmic(logY) + dataRange = plot.getDataRange() + xRangeL, yRangeL = self._getRanges([xData_l, yData_l], + [logX, logY]) + xRangeR, yRangeR = self._getRanges([xData_r, yData_r], + [logX, logY]) + if logX or logY: + xRangeLR = self._getRangesMinmax([xRangeL, xRangeR]) + else: + xRangeLR = self._getRangesMinmax([xRangeL, + xRangeR, + imgXRange]) + yRangeL = self._getRangesMinmax([yRangeL, imgYRange]) + self.assertSequenceEqual(dataRange.x, xRangeLR) + self.assertSequenceEqual(dataRange.y, yRangeL) + self.assertSequenceEqual(dataRange.yright, yRangeR) + + def testDataRangeImageNegativeScaleX(self): + """image data range, negative scale""" + + origin = (-10, 25) + scale = (-3., 8.) + image = numpy.arange(100.).reshape(20, 5) + + plot = PlotWidget(backend='none') + plot.addImage(image, + origin=origin, scale=scale) + + xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0] + xRange.sort() # negative scale! + yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1] + + ranges = {(False, False): (xRange, yRange), + (True, False): (None, None), + (True, True): (None, None), + (False, True): (None, None)} + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.getXAxis()._setLogarithmic(logX) + plot.getYAxis()._setLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = ranges[logX, logY] + self.assertTrue(numpy.array_equal(dataRange.x, xRange), + msg='{0} != {1}'.format(dataRange.x, xRange)) + self.assertTrue(numpy.array_equal(dataRange.y, yRange), + msg='{0} != {1}'.format(dataRange.y, yRange)) + self.assertIsNone(dataRange.yright) + + def testDataRangeImageNegativeScaleY(self): + """image data range, negative scale""" + + origin = (-10, 25) + scale = (3., -8.) + image = numpy.arange(100.).reshape(20, 5) + + plot = PlotWidget(backend='none') + plot.addImage(image, + origin=origin, scale=scale) + + xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0] + yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1] + yRange.sort() # negative scale! + + ranges = {(False, False): (xRange, yRange), + (True, False): (None, None), + (True, True): (None, None), + (False, True): (None, None)} + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.getXAxis()._setLogarithmic(logX) + plot.getYAxis()._setLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = ranges[logX, logY] + self.assertTrue(numpy.array_equal(dataRange.x, xRange), + msg='{0} != {1}'.format(dataRange.x, xRange)) + self.assertTrue(numpy.array_equal(dataRange.y, yRange), + msg='{0} != {1}'.format(dataRange.y, yRange)) + self.assertIsNone(dataRange.yright) + + def testDataRangeHiddenCurve(self): + """curves with a hidden curve""" + plot = PlotWidget(backend='none') + plot.addCurve((0, 1), (0, 1), legend='shown') + plot.addCurve((0, 1, 2), (5, 5, 5), legend='hidden') + range1 = plot.getDataRange() + self.assertEqual(range1.x, (0, 2)) + self.assertEqual(range1.y, (0, 5)) + plot.hideCurve('hidden') + range2 = plot.getDataRange() + self.assertEqual(range2.x, (0, 1)) + self.assertEqual(range2.y, (0, 1)) + + +class TestPlotGetCurveImage(unittest.TestCase): + """Test of plot getCurve and getImage methods""" + + def testGetCurve(self): + """PlotWidget.getCurve and Plot.getActiveCurve tests""" + + plot = PlotWidget(backend='none') + + # No curve + curve = plot.getCurve() + self.assertIsNone(curve) # No curve + + plot.setActiveCurveHandling(True) + plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 0') + plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 1') + plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 2') + plot.setActiveCurve('curve 0') + + # Active curve + active = plot.getActiveCurve() + self.assertEqual(active.getName(), 'curve 0') + curve = plot.getCurve() + self.assertEqual(curve.getName(), 'curve 0') + + # No active curve and curves + plot.setActiveCurveHandling(False) + active = plot.getActiveCurve() + self.assertIsNone(active) # No active curve + curve = plot.getCurve() + self.assertEqual(curve.getName(), 'curve 2') # Last added curve + + # Last curve hidden + plot.hideCurve('curve 2', True) + curve = plot.getCurve() + self.assertEqual(curve.getName(), 'curve 1') # Last added curve + + # All curves hidden + plot.hideCurve('curve 1', True) + plot.hideCurve('curve 0', True) + curve = plot.getCurve() + self.assertIsNone(curve) + + def testGetCurveOldApi(self): + """old API PlotWidget.getCurve and Plot.getActiveCurve tests""" + + plot = PlotWidget(backend='none') + + # No curve + curve = plot.getCurve() + self.assertIsNone(curve) # No curve + + plot.setActiveCurveHandling(True) + x = numpy.arange(10.).astype(numpy.float32) + y = x * x + plot.addCurve(x=x, y=y, legend='curve 0', info=["whatever"]) + plot.addCurve(x=x, y=2*x, legend='curve 1', info="anything") + plot.setActiveCurve('curve 0') + + # Active curve (4 elements) + xOut, yOut, legend, info = plot.getActiveCurve()[:4] + self.assertEqual(legend, 'curve 0') + self.assertTrue(numpy.allclose(xOut, x), 'curve 0 wrong x data') + self.assertTrue(numpy.allclose(yOut, y), 'curve 0 wrong y data') + + # Active curve (5 elements) + xOut, yOut, legend, info, params = plot.getCurve("curve 1") + self.assertEqual(legend, 'curve 1') + self.assertEqual(info, 'anything') + self.assertTrue(numpy.allclose(xOut, x), 'curve 1 wrong x data') + self.assertTrue(numpy.allclose(yOut, 2 * x), 'curve 1 wrong y data') + + def testGetImage(self): + """PlotWidget.getImage and PlotWidget.getActiveImage tests""" + + plot = PlotWidget(backend='none') + + # No image + image = plot.getImage() + self.assertIsNone(image) + + plot.addImage(((0, 1), (2, 3)), legend='image 0') + plot.addImage(((0, 1), (2, 3)), legend='image 1') + + # Active image + active = plot.getActiveImage() + self.assertEqual(active.getName(), 'image 0') + image = plot.getImage() + self.assertEqual(image.getName(), 'image 0') + + # No active image + plot.addImage(((0, 1), (2, 3)), legend='image 2') + plot.setActiveImage(None) + active = plot.getActiveImage() + self.assertIsNone(active) + image = plot.getImage() + self.assertEqual(image.getName(), 'image 2') + + # Active image + plot.setActiveImage('image 1') + active = plot.getActiveImage() + self.assertEqual(active.getName(), 'image 1') + image = plot.getImage() + self.assertEqual(image.getName(), 'image 1') + + def testGetImageOldApi(self): + """PlotWidget.getImage and PlotWidget.getActiveImage old API tests""" + + plot = PlotWidget(backend='none') + + # No image + image = plot.getImage() + self.assertIsNone(image) + + image = numpy.arange(10).astype(numpy.float32) + image.shape = 5, 2 + + plot.addImage(image, legend='image 0', info=["Hi!"]) + + # Active image + data, legend, info, something, params = plot.getActiveImage() + self.assertEqual(legend, 'image 0') + self.assertEqual(info, ["Hi!"]) + self.assertTrue(numpy.allclose(data, image), "image 0 data not correct") + + def testGetAllImages(self): + """PlotWidget.getAllImages test""" + + plot = PlotWidget(backend='none') + + # No image + images = plot.getAllImages() + self.assertEqual(len(images), 0) + + # 2 images + data = numpy.arange(100).reshape(10, 10) + plot.addImage(data, legend='1') + plot.addImage(data, origin=(10, 10), legend='2') + images = plot.getAllImages(just_legend=True) + self.assertEqual(list(images), ['1', '2']) + images = plot.getAllImages(just_legend=False) + self.assertEqual(len(images), 2) + self.assertEqual(images[0].getName(), '1') + self.assertEqual(images[1].getName(), '2') + + +class TestPlotAddScatter(unittest.TestCase): + """Test of plot addScatter""" + + def testAddGetScatter(self): + + plot = PlotWidget(backend='none') + + # No curve + scatter = plot._getItem(kind="scatter") + self.assertIsNone(scatter) # No curve + + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0') + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1') + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2') + plot._setActiveItem('scatter', 'scatter 0') + + # Active scatter + active = plot._getActiveItem(kind='scatter') + self.assertEqual(active.getName(), 'scatter 0') + + # check default values + self.assertAlmostEqual(active.getSymbolSize(), active._DEFAULT_SYMBOL_SIZE) + self.assertEqual(active.getSymbol(), "o") + self.assertAlmostEqual(active.getAlpha(), 1.0) + + # modify parameters + active.setSymbolSize(20.5) + active.setSymbol("d") + active.setAlpha(0.777) + + s0 = plot.getScatter("scatter 0") + + self.assertAlmostEqual(s0.getSymbolSize(), 20.5) + self.assertEqual(s0.getSymbol(), "d") + self.assertAlmostEqual(s0.getAlpha(), 0.777) + + scatter1 = plot._getItem(kind='scatter', legend='scatter 1') + self.assertEqual(scatter1.getName(), 'scatter 1') + + def testGetAllScatters(self): + """PlotWidget.getAllImages test""" + + plot = PlotWidget(backend='none') + + items = plot.getItems() + self.assertEqual(len(items), 0) + + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0') + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1') + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2') + + items = plot.getItems() + self.assertEqual(len(items), 3) + self.assertEqual(items[0].getName(), 'scatter 0') + self.assertEqual(items[1].getName(), 'scatter 1') + self.assertEqual(items[2].getName(), 'scatter 2') + + +class TestPlotHistogram(unittest.TestCase): + """Basic tests for histogram.""" + + def testEdges(self): + x = numpy.array([0, 1, 2]) + edgesRight = numpy.array([0, 1, 2, 3]) + edgesLeft = numpy.array([-1, 0, 1, 2]) + edgesCenter = numpy.array([-0.5, 0.5, 1.5, 2.5]) + + # testing x values for right + edges = _computeEdges(x, 'right') + numpy.testing.assert_array_equal(edges, edgesRight) + + edges = _computeEdges(x, 'center') + numpy.testing.assert_array_equal(edges, edgesCenter) + + edges = _computeEdges(x, 'left') + numpy.testing.assert_array_equal(edges, edgesLeft) + + def testHistogramCurve(self): + y = numpy.array([3, 2, 5]) + edges = numpy.array([0, 1, 2, 3]) + + xHisto, yHisto = _getHistogramCurve(y, edges) + numpy.testing.assert_array_equal( + yHisto, numpy.array([3, 3, 2, 2, 5, 5])) + + y = numpy.array([-3, 2, 5, 0]) + edges = numpy.array([-2, -1, 0, 1, 2]) + xHisto, yHisto = _getHistogramCurve(y, edges) + numpy.testing.assert_array_equal( + yHisto, numpy.array([-3, -3, 2, 2, 5, 5, 0, 0])) diff --git a/src/silx/gui/plot/test/testPlotWindow.py b/src/silx/gui/plot/test/testPlotWindow.py new file mode 100644 index 0000000..9e1497f --- /dev/null +++ b/src/silx/gui/plot/test/testPlotWindow.py @@ -0,0 +1,174 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWindow""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "27/06/2017" + + +import unittest +import numpy +import pytest + +from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction + +from silx.gui import qt +from silx.gui.plot import PlotWindow +from silx.gui.colors import Colormap + + +class TestPlotWindow(TestCaseQt): + """Base class for tests of PlotWindow.""" + + def setUp(self): + super(TestPlotWindow, self).setUp() + self.plot = PlotWindow() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestPlotWindow, self).tearDown() + + def testActions(self): + """Test the actions QToolButtons""" + self.plot.setLimits(1, 100, 1, 100) + + checkList = [ # QAction, Plot state getter + (self.plot.xAxisAutoScaleAction, self.plot.getXAxis().isAutoScale), + (self.plot.yAxisAutoScaleAction, self.plot.getYAxis().isAutoScale), + (self.plot.xAxisLogarithmicAction, self.plot.getXAxis()._isLogarithmic), + (self.plot.yAxisLogarithmicAction, self.plot.getYAxis()._isLogarithmic), + (self.plot.gridAction, self.plot.getGraphGrid), + ] + + for action, getter in checkList: + self.mouseMove(self.plot) + initialState = getter() + toolButton = getQToolButtonFromAction(action) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self.assertNotEqual(getter(), initialState, + msg='"%s" state not changed' % action.text()) + + self.mouseClick(toolButton, qt.Qt.LeftButton) + self.assertEqual(getter(), initialState, + msg='"%s" state not changed' % action.text()) + + # Trigger a zoom reset + self.mouseMove(self.plot) + resetZoomAction = self.plot.resetZoomAction + toolButton = getQToolButtonFromAction(resetZoomAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + def testDockWidgets(self): + """Test add/remove dock widgets""" + dock1 = qt.QDockWidget('Test 1') + dock1.setWidget(qt.QLabel('Test 1')) + + self.plot.addTabbedDockWidget(dock1) + self.qapp.processEvents() + + self.plot.removeDockWidget(dock1) + self.qapp.processEvents() + + dock2 = qt.QDockWidget('Test 2') + dock2.setWidget(qt.QLabel('Test 2')) + + self.plot.addTabbedDockWidget(dock2) + self.qapp.processEvents() + + if qt.BINDING != 'PySide2': + # Weird bug with PySide2 later upon gc.collect() when getting the layout + self.assertNotEqual(self.plot.layout().indexOf(dock2), + -1, + "dock2 not properly displayed") + + def testToolAspectRatio(self): + self.plot.toolBar() + self.plot.keepDataAspectRatioButton.keepDataAspectRatio() + self.assertTrue(self.plot.isKeepDataAspectRatio()) + self.plot.keepDataAspectRatioButton.dontKeepDataAspectRatio() + self.assertFalse(self.plot.isKeepDataAspectRatio()) + + def testToolYAxisOrigin(self): + self.plot.toolBar() + self.plot.yAxisInvertedButton.setYAxisUpward() + self.assertFalse(self.plot.getYAxis().isInverted()) + self.plot.yAxisInvertedButton.setYAxisDownward() + self.assertTrue(self.plot.getYAxis().isInverted()) + + def testColormapAutoscaleCache(self): + # Test that the min/max cache is not computed twice + + old = Colormap._computeAutoscaleRange + self._count = 0 + def _computeAutoscaleRange(colormap, data): + self._count = self._count + 1 + return 10, 20 + Colormap._computeAutoscaleRange = _computeAutoscaleRange + try: + colormap = Colormap(name='red') + self.plot.setVisible(True) + + # Add an image + data = numpy.arange(8**2).reshape(8, 8) + self.plot.addImage(data, legend="foo", colormap=colormap) + self.plot.setActiveImage("foo") + + # Use the colorbar + self.plot.getColorBarWidget().setVisible(True) + self.qWait(50) + + # Remove and add again the same item + image = self.plot.getImage("foo") + self.plot.removeImage("foo") + self.plot.addItem(image) + self.qWait(50) + finally: + Colormap._computeAutoscaleRange = old + self.assertEqual(self._count, 1) + del self._count + + @pytest.mark.usefixtures("use_opengl") + def testSwitchBackend(self): + """Test switching an empty plot""" + self.plot.resetZoom() + xlimits = self.plot.getXAxis().getLimits() + ylimits = self.plot.getYAxis().getLimits() + isKeepAspectRatio = self.plot.isKeepDataAspectRatio() + + for backend in ('gl', 'mpl'): + with self.subTest(): + self.plot.setBackend(backend) + self.plot.replot() + self.assertEqual(self.plot.getXAxis().getLimits(), xlimits) + self.assertEqual(self.plot.getYAxis().getLimits(), ylimits) + self.assertEqual( + self.plot.isKeepDataAspectRatio(), isKeepAspectRatio) diff --git a/src/silx/gui/plot/test/testRoiStatsWidget.py b/src/silx/gui/plot/test/testRoiStatsWidget.py new file mode 100644 index 0000000..eb29267 --- /dev/null +++ b/src/silx/gui/plot/test/testRoiStatsWidget.py @@ -0,0 +1,277 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests for ROIStatsWidget""" + + +from silx.gui.utils.testutils import TestCaseQt +from silx.gui import qt +from silx.gui.plot import PlotWindow +from silx.gui.plot.stats.stats import Stats +from silx.gui.plot.ROIStatsWidget import ROIStatsWidget +from silx.gui.plot.CurvesROIWidget import ROI +from silx.gui.plot.items.roi import RectangleROI, PolygonROI +from silx.gui.plot.StatsWidget import UpdateMode +import unittest +import numpy + + + +class _TestRoiStatsBase(TestCaseQt): + """Base class for several unittest relative to ROIStatsWidget""" + def setUp(self): + TestCaseQt.setUp(self) + # define plot + self.plot = PlotWindow() + self.plot.addImage(numpy.arange(10000).reshape(100, 100), + legend='img1') + self.img_item = self.plot.getImage('img1') + self.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.arange(56), + legend='curve1') + self.curve_item = self.plot.getCurve('curve1') + self.plot.addHistogram(edges=numpy.linspace(0, 10, 56), + histogram=numpy.arange(56), legend='histo1') + self.histogram_item = self.plot.getHistogram(legend='histo1') + self.plot.addScatter(x=numpy.linspace(0, 10, 56), + y=numpy.linspace(0, 10, 56), + value=numpy.arange(56), + legend='scatter1') + self.scatter_item = self.plot.getScatter(legend='scatter1') + + # stats widget + self.statsWidget = ROIStatsWidget(plot=self.plot) + + # define stats + stats = [ + ('sum', numpy.sum), + ('mean', numpy.mean), + ] + self.statsWidget.setStats(stats=stats) + + # define rois + self.roi1D = ROI(name='range1', fromdata=0, todata=4, type_='energy') + self.rectangle_roi = RectangleROI() + self.rectangle_roi.setGeometry(origin=(0, 0), size=(20, 20)) + self.rectangle_roi.setName('Initial ROI') + self.polygon_roi = PolygonROI() + points = numpy.array([[0, 5], [5, 0], [10, 5], [5, 10]]) + self.polygon_roi.setPoints(points) + + def statsTable(self): + return self.statsWidget._statsROITable + + def tearDown(self): + Stats._getContext.cache_clear() + self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose, True) + self.statsWidget.close() + self.statsWidget = None + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose, True) + self.plot.close() + self.plot = None + TestCaseQt.tearDown(self) + + +class TestRoiStatsCouple(_TestRoiStatsBase): + """ + Test different possible couple (roi, plotItem). + Check that: + + * computation is correct if couple is valid + * raise an error if couple is invalid + """ + def testROICurve(self): + """ + Test that the couple (ROI, curveItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.curve_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '253') + self.assertEqual(tableItems['mean'].text(), '11.0') + + def testRectangleImage(self): + """ + Test that the couple (RectangleROI, imageItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + assert item is not None + self.plot.addImage(numpy.ones(10000).reshape(100, 100), + legend='img1') + self.qapp.processEvents() + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), str(float(21*21))) + self.assertEqual(tableItems['mean'].text(), '1.0') + + def testPolygonImage(self): + """ + Test that the couple (PolygonROI, imageItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.polygon_roi, + plotItem=self.img_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '22750') + self.assertEqual(tableItems['mean'].text(), '455.0') + + def testROIImage(self): + """ + Test that the couple (ROI, imageItem) is raising an error + """ + with self.assertRaises(TypeError): + self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.img_item) + + def testRectangleCurve(self): + """ + Test that the couple (rectangleROI, curveItem) is raising an error + """ + with self.assertRaises(TypeError): + self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.curve_item) + + def testROIHistogram(self): + """ + Test that the couple (PolygonROI, imageItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.histogram_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '253') + self.assertEqual(tableItems['mean'].text(), '11.0') + + def testROIScatter(self): + """ + Test that the couple (PolygonROI, imageItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.scatter_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '253') + self.assertEqual(tableItems['mean'].text(), '11.0') + + +class TestRoiStatsAddRemoveItem(_TestRoiStatsBase): + """Test adding and removing (roi, plotItem) items""" + def testAddRemoveItems(self): + item1 = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.scatter_item) + self.assertTrue(item1 is not None) + self.assertEqual(self.statsTable().rowCount(), 1) + item2 = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.histogram_item) + self.assertTrue(item2 is not None) + self.assertEqual(self.statsTable().rowCount(), 2) + # try to add twice the same item + item3 = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.histogram_item) + self.assertTrue(item3 is None) + self.assertEqual(self.statsTable().rowCount(), 2) + item4 = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.curve_item) + self.assertTrue(item4 is not None) + self.assertEqual(self.statsTable().rowCount(), 3) + + self.statsWidget.removeItem(plotItem=item4._plot_item, + roi=item4._roi) + self.assertEqual(self.statsTable().rowCount(), 2) + # try to remove twice the same item + self.statsWidget.removeItem(plotItem=item4._plot_item, + roi=item4._roi) + self.assertEqual(self.statsTable().rowCount(), 2) + self.statsWidget.removeItem(plotItem=item2._plot_item, + roi=item2._roi) + self.statsWidget.removeItem(plotItem=item1._plot_item, + roi=item1._roi) + self.assertEqual(self.statsTable().rowCount(), 0) + + +class TestRoiStatsRoiUpdate(_TestRoiStatsBase): + """Test that the stats will be updated if the roi is updated""" + def testChangeRoi(self): + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '445410') + self.assertEqual(tableItems['mean'].text(), '1010.0') + + # update roi + self.rectangle_roi.setOrigin(position=(10, 10)) + self.assertNotEqual(tableItems['sum'].text(), '445410') + self.assertNotEqual(tableItems['mean'].text(), '1010.0') + + def testUpdateModeScenario(self): + """Test update according to a simple scenario""" + self.statsWidget._setUpdateMode(UpdateMode.AUTO) + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '445410') + self.assertEqual(tableItems['mean'].text(), '1010.0') + self.statsWidget._setUpdateMode(UpdateMode.MANUAL) + self.rectangle_roi.setOrigin(position=(10, 10)) + self.qapp.processEvents() + self.assertNotEqual(tableItems['sum'].text(), '445410') + self.assertNotEqual(tableItems['mean'].text(), '1010.0') + self.statsWidget._updateAllStats(is_request=True) + self.assertNotEqual(tableItems['sum'].text(), '445410') + self.assertNotEqual(tableItems['mean'].text(), '1010.0') + + +class TestRoiStatsPlotItemUpdate(_TestRoiStatsBase): + """Test that the stats will be updated if the plot item is updated""" + def testChangeImage(self): + self.statsWidget._setUpdateMode(UpdateMode.AUTO) + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['mean'].text(), '1010.0') + + # update plot + self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100), + legend='img1') + self.assertNotEqual(tableItems['mean'].text(), '1059.5') + + def testUpdateModeScenario(self): + """Test update according to a simple scenario""" + self.statsWidget._setUpdateMode(UpdateMode.MANUAL) + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['mean'].text(), '1010.0') + self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100), + legend='img1') + self.assertEqual(tableItems['mean'].text(), '1010.0') + self.statsWidget._updateAllStats(is_request=True) + self.assertEqual(tableItems['mean'].text(), '1110.0') diff --git a/src/silx/gui/plot/test/testSaveAction.py b/src/silx/gui/plot/test/testSaveAction.py new file mode 100644 index 0000000..9280fb6 --- /dev/null +++ b/src/silx/gui/plot/test/testSaveAction.py @@ -0,0 +1,132 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test the plot's save action (consistency of output)""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "28/11/2017" + + +import unittest +import tempfile +import os + +from silx.gui.plot.test.utils import PlotWidgetTestCase + +from silx.gui.plot import PlotWidget +from silx.gui.plot.actions.io import SaveAction + + +class TestSaveActionSaveCurvesAsSpec(unittest.TestCase): + + def setUp(self): + self.plot = PlotWidget(backend='none') + self.saveAction = SaveAction(plot=self.plot) + + self.tempdir = tempfile.mkdtemp() + self.out_fname = os.path.join(self.tempdir, "out.dat") + + def tearDown(self): + os.unlink(self.out_fname) + os.rmdir(self.tempdir) + + def testSaveMultipleCurvesAsSpec(self): + """Test that labels are properly used.""" + self.plot.setGraphXLabel("graph x label") + self.plot.setGraphYLabel("graph y label") + + self.plot.addCurve([0, 1], [1, 2], "curve with labels", + xlabel="curve0 X", ylabel="curve0 Y") + self.plot.addCurve([-1, 3], [-6, 2], "curve with X label", + xlabel="curve1 X") + self.plot.addCurve([-2, 0], [8, 12], "curve with Y label", + ylabel="curve2 Y") + self.plot.addCurve([3, 1], [7, 6], "curve with no labels") + + self.saveAction._saveCurves(self.plot, + self.out_fname, + SaveAction.DEFAULT_ALL_CURVES_FILTERS[0]) # "All curves as SpecFile (*.dat)" + + with open(self.out_fname, "rb") as f: + file_content = f.read() + if hasattr(file_content, "decode"): + file_content = file_content.decode() + + # case with all curve labels specified + self.assertIn("#S 1 curve0 Y", file_content) + self.assertIn("#L curve0 X curve0 Y", file_content) + + # graph X&Y labels are used when no curve label is specified + self.assertIn("#S 2 graph y label", file_content) + self.assertIn("#L curve1 X graph y label", file_content) + + self.assertIn("#S 3 curve2 Y", file_content) + self.assertIn("#L graph x label curve2 Y", file_content) + + self.assertIn("#S 4 graph y label", file_content) + self.assertIn("#L graph x label graph y label", file_content) + + +class TestSaveActionExtension(PlotWidgetTestCase): + """Test SaveAction file filter API""" + + def _dummySaveFunction(self, plot, filename, nameFilter): + pass + + def testFileFilterAPI(self): + """Test addition/update of a file filter""" + saveAction = SaveAction(plot=self.plot, parent=self.plot) + + # Add a new file filter + nameFilter = 'Dummy file (*.dummy)' + saveAction.setFileFilter('all', nameFilter, self._dummySaveFunction) + self.assertTrue(nameFilter in saveAction.getFileFilters('all')) + self.assertEqual(saveAction.getFileFilters('all')[nameFilter], + self._dummySaveFunction) + + # Add a new file filter at a particular position + nameFilter = 'Dummy file2 (*.dummy)' + saveAction.setFileFilter('all', nameFilter, + self._dummySaveFunction, index=3) + self.assertTrue(nameFilter in saveAction.getFileFilters('all')) + filters = saveAction.getFileFilters('all') + self.assertEqual(filters[nameFilter], self._dummySaveFunction) + self.assertEqual(list(filters.keys()).index(nameFilter),3) + + # Update an existing file filter + nameFilter = SaveAction.IMAGE_FILTER_EDF + saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction) + self.assertEqual(saveAction.getFileFilters('image')[nameFilter], + self._dummySaveFunction) + + # Change the position of an existing file filter + nameFilter = 'Dummy file2 (*.dummy)' + oldIndex = list(saveAction.getFileFilters('all')).index(nameFilter) + newIndex = oldIndex - 1 + saveAction.setFileFilter('all', nameFilter, + self._dummySaveFunction, index=newIndex) + filters = saveAction.getFileFilters('all') + self.assertEqual(filters[nameFilter], self._dummySaveFunction) + self.assertEqual(list(filters.keys()).index(nameFilter), newIndex) diff --git a/src/silx/gui/plot/test/testScatterMaskToolsWidget.py b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py new file mode 100644 index 0000000..447ee58 --- /dev/null +++ b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py @@ -0,0 +1,306 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for MaskToolsWidget""" + +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "17/01/2018" + + +import logging +import os.path +import unittest + +import numpy + +from silx.gui import qt +from silx.test.utils import temp_dir +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import getQToolButtonFromAction +from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget +from .utils import PlotWidgetTestCase + +import fabio + + +_logger = logging.getLogger(__name__) + + +class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase): + """Basic test for MaskToolsWidget""" + + def _createPlot(self): + return PlotWindow() + + def setUp(self): + super(TestScatterMaskToolsWidget, self).setUp() + self.widget = ScatterMaskToolsWidget.ScatterMaskToolsDockWidget( + plot=self.plot, name='TEST') + self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget) + + self.maskWidget = self.widget.widget() + + def tearDown(self): + del self.maskWidget + del self.widget + super(TestScatterMaskToolsWidget, self).tearDown() + + def testEmptyPlot(self): + """Empty plot, display MaskToolsDockWidget, toggle multiple masks""" + self.maskWidget.setMultipleMasks('single') + self.qapp.processEvents() + + self.maskWidget.setMultipleMasks('exclusive') + self.qapp.processEvents() + + def _drag(self): + """Drag from plot center to offset position""" + plot = self.plot.getWidgetHandle() + xCenter, yCenter = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + pos0 = xCenter, yCenter + pos1 = xCenter + offset, yCenter + offset + + self.mouseMove(plot, pos=(0, 0)) + self.mouseMove(plot, pos=pos0) + self.qapp.processEvents() + self.mousePress(plot, qt.Qt.LeftButton, pos=pos0) + self.qapp.processEvents() + + self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2)) + self.mouseMove(plot, pos=pos1) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1) + self.qapp.processEvents() + self.mouseMove(plot, pos=(0, 0)) + + def _drawPolygon(self): + """Draw a star polygon in the plot""" + plot = self.plot.getWidgetHandle() + x, y = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + star = [(x, y + offset), + (x - offset, y - offset), + (x + offset, y), + (x - offset, y), + (x + offset, y - offset), + (x, y + offset)] # Close polygon + + self.mouseMove(plot, pos=[0, 0]) + for pos in star: + self.mouseMove(plot, pos=pos) + self.qapp.processEvents() + self.mousePress(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + + def _drawPencil(self): + """Draw a star polygon in the plot""" + plot = self.plot.getWidgetHandle() + x, y = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + star = [(x, y + offset), + (x - offset, y - offset), + (x + offset, y), + (x - offset, y), + (x + offset, y - offset)] + + self.mouseMove(plot, pos=[0, 0]) + self.mouseMove(plot, pos=star[0]) + self.mousePress(plot, qt.Qt.LeftButton, pos=star[0]) + for pos in star[1:]: + self.mouseMove(plot, pos=pos) + self.mouseRelease( + plot, qt.Qt.LeftButton, pos=star[-1]) + + def testWithAScatter(self): + """Plot with a Scatter: test MaskToolsWidget interactions""" + + # Add and remove a scatter (this should enable/disable GUI + change mask) + self.plot.addScatter( + x=numpy.arange(256), + y=numpy.arange(256), + value=numpy.random.random(256), + legend='test') + self.plot._setActiveItem(kind="scatter", legend="test") + self.qapp.processEvents() + + self.plot.remove('test', kind='scatter') + self.qapp.processEvents() + + self.plot.addScatter( + x=numpy.arange(1000), + y=1000 * (numpy.arange(1000) % 20), + value=numpy.random.random(1000), + legend='test') + self.plot._setActiveItem(kind="scatter", legend="test") + self.plot.resetZoom() + self.qapp.processEvents() + + # Test draw rectangle # + toolButton = getQToolButtonFromAction(self.maskWidget.rectAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drag() + + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drag() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # Test draw polygon # + toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drawPolygon() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drawPolygon() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # Test draw pencil # + toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + self.maskWidget.pencilSpinBox.setValue(30) + self.qapp.processEvents() + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drawPencil() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drawPencil() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # Test no draw tool # + toolButton = getQToolButtonFromAction(self.maskWidget.browseAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + self.plot.clear() + + def __loadSave(self, file_format): + self.plot.addScatter( + x=numpy.arange(256), + y=25 * (numpy.arange(256) % 10), + value=numpy.random.random(256), + legend='test') + self.plot._setActiveItem(kind="scatter", legend="test") + self.plot.resetZoom() + self.qapp.processEvents() + + # Draw a polygon mask + toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self._drawPolygon() + + ref_mask = self.maskWidget.getSelectionMask() + self.assertFalse(numpy.all(numpy.equal(ref_mask, 0))) + + with temp_dir() as tmp: + mask_filename = os.path.join(tmp, 'mask.' + file_format) + self.maskWidget.save(mask_filename, file_format) + + self.maskWidget.resetSelectionMask() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + self.maskWidget.load(mask_filename) + self.assertTrue(numpy.all(numpy.equal( + self.maskWidget.getSelectionMask(), ref_mask))) + + def testLoadSaveNpy(self): + self.__loadSave("npy") + + def testLoadSaveCsv(self): + self.__loadSave("csv") + + def testSigMaskChangedEmitted(self): + self.qapp.processEvents() + self.plot.addScatter( + x=numpy.arange(1000), + y=1000 * (numpy.arange(1000) % 20), + value=numpy.ones((1000,)), + legend='test') + self.plot._setActiveItem(kind="scatter", legend="test") + self.plot.resetZoom() + self.qapp.processEvents() + + self.plot.remove('test', kind='scatter') + self.qapp.processEvents() + + self.plot.addScatter( + x=numpy.arange(1000), + y=1000 * (numpy.arange(1000) % 20), + value=numpy.random.random(1000), + legend='test') + + l = [] + + def slot(): + l.append(1) + + self.maskWidget.sigMaskChanged.connect(slot) + + # rectangle mask + toolButton = getQToolButtonFromAction(self.maskWidget.rectAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drag() + + self.assertGreater(len(l), 0) diff --git a/src/silx/gui/plot/test/testScatterView.py b/src/silx/gui/plot/test/testScatterView.py new file mode 100644 index 0000000..d11d4d8 --- /dev/null +++ b/src/silx/gui/plot/test/testScatterView.py @@ -0,0 +1,123 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for ScatterView""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2018" + + +import unittest + +import numpy + +from silx.gui.plot.items import Axis, Scatter +from silx.gui.plot import ScatterView +from silx.gui.plot.test.utils import PlotWidgetTestCase + + +class TestScatterView(PlotWidgetTestCase): + """Test of ScatterView widget""" + + def _createPlot(self): + return ScatterView() + + def test(self): + """Simple tests""" + x = numpy.arange(100) + y = numpy.arange(100) + value = numpy.arange(100) + self.plot.setData(x, y, value) + self.qapp.processEvents() + + data = self.plot.getData() + self.assertEqual(len(data), 5) + self.assertTrue(numpy.all(numpy.equal(x, data[0]))) + self.assertTrue(numpy.all(numpy.equal(y, data[1]))) + self.assertTrue(numpy.all(numpy.equal(value, data[2]))) + self.assertIsNone(data[3]) # xerror + self.assertIsNone(data[4]) # yerror + + # Test access to scatter item + self.assertIsInstance(self.plot.getScatterItem(), Scatter) + + # Test toolbar actions + + action = self.plot.getScatterToolBar().getXAxisLogarithmicAction() + action.trigger() + self.qapp.processEvents() + + maskAction = self.plot.getScatterToolBar().actions()[-1] + maskAction.trigger() + self.qapp.processEvents() + + # Test proxy API + + self.plot.resetZoom() + self.qapp.processEvents() + + scale = self.plot.getXAxis().getScale() + self.assertEqual(scale, Axis.LOGARITHMIC) + + scale = self.plot.getYAxis().getScale() + self.assertEqual(scale, Axis.LINEAR) + + title = 'Test ScatterView' + self.plot.setGraphTitle(title) + self.assertEqual(self.plot.getGraphTitle(), title) + + self.qapp.processEvents() + + # Reset scatter data + + self.plot.setData(None, None, None) + self.qapp.processEvents() + + data = self.plot.getData() + self.assertEqual(len(data), 5) + self.assertEqual(len(data[0]), 0) # x + self.assertEqual(len(data[1]), 0) # y + self.assertEqual(len(data[2]), 0) # value + self.assertIsNone(data[3]) # xerror + self.assertIsNone(data[4]) # yerror + + def testAlpha(self): + """Test alpha transparency in setData""" + _pts = 100 + _levels = 100 + _fwhm = 50 + x = numpy.random.rand(_pts)*_levels + y = numpy.random.rand(_pts)*_levels + value = numpy.random.rand(_pts)*_levels + x0 = x[int(_pts/2)] + y0 = x[int(_pts/2)] + #2D Gaussian kernel + alpha = numpy.exp(-4*numpy.log(2) * ((x-x0)**2 + (y-y0)**2) / _fwhm**2) + + self.plot.setData(x, y, value, alpha=alpha) + self.qapp.processEvents() + + alphaData = self.plot.getScatterItem().getAlphaData() + self.assertTrue(numpy.all(numpy.equal(alpha, alphaData))) diff --git a/src/silx/gui/plot/test/testStackView.py b/src/silx/gui/plot/test/testStackView.py new file mode 100644 index 0000000..0d18113 --- /dev/null +++ b/src/silx/gui/plot/test/testStackView.py @@ -0,0 +1,248 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for StackView""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "20/03/2017" + + +import unittest +import numpy + +from silx.gui.utils.testutils import TestCaseQt, SignalListener + +from silx.gui import qt +from silx.gui.plot import StackView +from silx.gui.plot.StackView import StackViewMainWindow + +from silx.utils.array_like import ListOfImages + + +class TestStackView(TestCaseQt): + """Base class for tests of StackView.""" + + def setUp(self): + super(TestStackView, self).setUp() + self.stackview = StackView() + self.stackview.show() + self.qWaitForWindowExposed(self.stackview) + self.mystack = numpy.fromfunction( + lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.), + (10, 20, 30) + ) + + def tearDown(self): + self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose) + self.stackview.close() + del self.stackview + super(TestStackView, self).tearDown() + + def testScaleColormapRangeToStack(self): + """Test scaleColormapRangeToStack""" + self.stackview.setStack(self.mystack) + self.stackview.setColormap("viridis") + colormap = self.stackview.getColormap() + + # Colormap autoscale to image + self.assertEqual(colormap.getVRange(), (None, None)) + self.stackview.scaleColormapRangeToStack() + + # Colormap range set according to stack range + self.assertEqual(colormap.getVRange(), (self.mystack.min(), self.mystack.max())) + + def testSetStack(self): + self.stackview.setStack(self.mystack) + self.stackview.setColormap("viridis", autoscale=True) + my_trans_stack, params = self.stackview.getStack() + self.assertEqual(my_trans_stack.shape, self.mystack.shape) + self.assertTrue(numpy.array_equal(self.mystack, + my_trans_stack)) + self.assertEqual(params["colormap"]["name"], + "viridis") + + def testSetStackPerspective(self): + self.stackview.setStack(self.mystack, perspective=1) + # my_orig_stack, params = self.stackview.getStack() + my_trans_stack, params = self.stackview.getCurrentView() + + # get stack returns the transposed data, depending on the perspective + self.assertEqual(my_trans_stack.shape, + (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2])) + self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)), + my_trans_stack)) + + def testSetStackListOfImages(self): + loi = [self.mystack[i] for i in range(self.mystack.shape[0])] + + self.stackview.setStack(loi) + my_orig_stack, params = self.stackview.getStack(returnNumpyArray=True) + my_trans_stack, params = self.stackview.getStack(returnNumpyArray=True) + self.assertEqual(my_trans_stack.shape, self.mystack.shape) + self.assertTrue(numpy.array_equal(self.mystack, + my_trans_stack)) + self.assertTrue(numpy.array_equal(self.mystack, + my_orig_stack)) + self.assertIsInstance(my_trans_stack, numpy.ndarray) + + self.stackview.setStack(loi, perspective=2) + my_orig_stack, params = self.stackview.getStack(copy=False) + my_trans_stack, params = self.stackview.getCurrentView(copy=False) + # getStack(copy=False) must return the object set in setStack + self.assertIs(my_orig_stack, loi) + # getCurrentView(copy=False) returns a ListOfImages whose .images + # attr is the original data + self.assertEqual(my_trans_stack.shape, + (self.mystack.shape[2], self.mystack.shape[0], self.mystack.shape[1])) + self.assertTrue(numpy.array_equal(numpy.array(my_trans_stack), + numpy.transpose(self.mystack, axes=(2, 0, 1)))) + self.assertIsInstance(my_trans_stack, + ListOfImages) # returnNumpyArray=False by default in getStack + self.assertIs(my_trans_stack.images, loi) + + def testPerspective(self): + self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4))) + self.assertEqual(self.stackview._perspective, 0, + "Default perspective is not 0 (dim1-dim2).") + + self.stackview._StackView__planeSelection.setPerspective(1) + self.assertEqual(self.stackview._perspective, 1, + "Plane selection combobox not updating perspective") + + self.stackview.setStack(numpy.arange(6).reshape((1, 2, 3))) + self.assertEqual(self.stackview._perspective, 1, + "Perspective not preserved when calling setStack " + "without specifying the perspective parameter.") + + self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)), perspective=2) + self.assertEqual(self.stackview._perspective, 2, + "Perspective not set in setStack(..., perspective=2).") + + def testDefaultTitle(self): + """Test that the plot title contains the proper Z information""" + self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)), + calibrations=[(0, 1), (-10, 10), (3.14, 3.14)]) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=0") + self.stackview.setFrameNumber(2) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=2") + + self.stackview._StackView__planeSelection.setPerspective(1) + self.stackview.setFrameNumber(0) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=-10") + self.stackview.setFrameNumber(2) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=10") + + self.stackview._StackView__planeSelection.setPerspective(2) + self.stackview.setFrameNumber(0) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=3.14") + self.stackview.setFrameNumber(1) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=6.28") + + def testCustomTitle(self): + """Test setting the plot title with a user defined callback""" + self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)), + calibrations=[(0, 1), (-10, 10), (3.14, 3.14)]) + + def title_callback(frame_idx): + return "Cubed index title %d" % (frame_idx**3) + + self.stackview.setTitleCallback(title_callback) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Cubed index title 0") + self.stackview.setFrameNumber(2) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Cubed index title 8") + + # perspective should not matter, only frame index + self.stackview._StackView__planeSelection.setPerspective(1) + self.stackview.setFrameNumber(0) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Cubed index title 0") + self.stackview.setFrameNumber(2) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Cubed index title 8") + + with self.assertRaises(TypeError): + # setTitleCallback should not accept non-callable objects like strings + self.stackview.setTitleCallback( + "Là , vous faites sirop de vingt-et-un et vous dites : " + "beau sirop, mi-sirop, siroté, gagne-sirop, sirop-grelot," + " passe-montagne, sirop au bon goût.") + + def testStackFrameNumber(self): + self.stackview.setStack(self.mystack) + self.assertEqual(self.stackview.getFrameNumber(), 0) + + listener = SignalListener() + self.stackview.sigFrameChanged.connect(listener) + + self.stackview.setFrameNumber(1) + self.assertEqual(self.stackview.getFrameNumber(), 1) + self.assertEqual(listener.arguments(), [(1,)]) + + +class TestStackViewMainWindow(TestCaseQt): + """Base class for tests of StackView.""" + + def setUp(self): + super(TestStackViewMainWindow, self).setUp() + self.stackview = StackViewMainWindow() + self.stackview.show() + self.qWaitForWindowExposed(self.stackview) + self.mystack = numpy.fromfunction( + lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.), + (10, 20, 30) + ) + + def tearDown(self): + self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose) + self.stackview.close() + del self.stackview + super(TestStackViewMainWindow, self).tearDown() + + def testSetStack(self): + self.stackview.setStack(self.mystack) + self.stackview.setColormap("viridis", autoscale=True) + my_trans_stack, params = self.stackview.getStack() + self.assertEqual(my_trans_stack.shape, self.mystack.shape) + self.assertTrue(numpy.array_equal(self.mystack, + my_trans_stack)) + self.assertEqual(params["colormap"]["name"], + "viridis") + + def testSetStackPerspective(self): + self.stackview.setStack(self.mystack, perspective=1) + my_trans_stack, params = self.stackview.getCurrentView() + # get stack returns the transposed data, depending on the perspective + self.assertEqual(my_trans_stack.shape, + (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2])) + self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)), + my_trans_stack)) diff --git a/src/silx/gui/plot/test/testStats.py b/src/silx/gui/plot/test/testStats.py new file mode 100644 index 0000000..0a792a4 --- /dev/null +++ b/src/silx/gui/plot/test/testStats.py @@ -0,0 +1,1047 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for CurvesROIWidget""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "07/03/2018" + + +from silx.gui import qt +from silx.gui.plot.stats import stats +from silx.gui.plot import StatsWidget +from silx.gui.plot.stats import statshandler +from silx.gui.utils.testutils import TestCaseQt, SignalListener +from silx.gui.plot import Plot1D, Plot2D +from silx.gui.plot3d.SceneWidget import SceneWidget +from silx.gui.plot.items.roi import RectangleROI, PolygonROI +from silx.gui.plot.tools.roi import RegionOfInterestManager +from silx.gui.plot.stats.stats import Stats +from silx.gui.plot.CurvesROIWidget import ROI +from silx.utils.testutils import ParametricTestCase +import unittest +import logging +import numpy + +_logger = logging.getLogger(__name__) + + +class TestStatsBase(object): + """Base class for stats TestCase""" + def setUp(self): + self.createCurveContext() + self.createImageContext() + self.createScatterContext() + + def tearDown(self): + self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot1d.close() + del self.plot1d + self.plot2d.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot2d.close() + del self.plot2d + self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.scatterPlot.close() + del self.scatterPlot + + def createCurveContext(self): + self.plot1d = Plot1D() + x = range(20) + y = range(20) + self.plot1d.addCurve(x, y, legend='curve0') + + self.curveContext = stats._CurveContext( + item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False, + roi=None) + + def createScatterContext(self): + self.scatterPlot = Plot2D() + lgd = 'scatter plot' + self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36]) + self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18]) + self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5]) + self.scatterPlot.addScatter(self.xScatterData, self.yScatterData, + self.valuesScatterData, legend=lgd) + self.scatterContext = stats._ScatterContext( + item=self.scatterPlot.getScatter(lgd), + plot=self.scatterPlot, + onlimits=False, + roi=None + ) + + def createImageContext(self): + self.plot2d = Plot2D() + self._imgLgd = 'test image' + self.imageData = numpy.arange(32*128).reshape(32, 128) + self.plot2d.addImage(data=self.imageData, + legend=self._imgLgd, replace=False) + self.imageContext = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False, + roi=None + ) + + def getBasicStats(self): + return { + 'min': stats.StatMin(), + 'minCoords': stats.StatCoordMin(), + 'max': stats.StatMax(), + 'maxCoords': stats.StatCoordMax(), + 'std': stats.Stat(name='std', fct=numpy.std), + 'mean': stats.Stat(name='mean', fct=numpy.mean), + 'com': stats.StatCOM() + } + + +class TestStats(TestStatsBase, TestCaseQt): + """ + Test :class:`BaseClass` class and inheriting classes + """ + def setUp(self): + TestCaseQt.setUp(self) + TestStatsBase.setUp(self) + + def tearDown(self): + TestStatsBase.tearDown(self) + TestCaseQt.tearDown(self) + + def testBasicStatsCurve(self): + """Test result for simple stats on a curve""" + _stats = self.getBasicStats() + xData = yData = numpy.array(range(20)) + self.assertEqual(_stats['min'].calculate(self.curveContext), 0) + self.assertEqual(_stats['max'].calculate(self.curveContext), 19) + self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,)) + self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,)) + self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData)) + self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData)) + com = numpy.sum(xData * yData) / numpy.sum(yData) + self.assertEqual(_stats['com'].calculate(self.curveContext), com) + + def testBasicStatsImage(self): + """Test result for simple stats on an image""" + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.imageContext), 0) + self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1) + self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0)) + self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31)) + self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData)) + self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData)) + + yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1) + xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0) + dataXRange = range(self.imageData.shape[1]) + dataYRange = range(self.imageData.shape[0]) + + ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData) + xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData) + + self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom)) + + def testStatsImageAdv(self): + """Test that scale and origin are taking into account for images""" + + image2Data = numpy.arange(32 * 128).reshape(32, 128) + self.plot2d.addImage(data=image2Data, legend=self._imgLgd, + replace=True, origin=(100, 10), scale=(2, 0.5)) + image2Context = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False, + roi=None, + ) + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(image2Context), 0) + self.assertEqual( + _stats['max'].calculate(image2Context), 128 * 32 - 1) + self.assertEqual( + _stats['minCoords'].calculate(image2Context), (100, 10)) + self.assertEqual( + _stats['maxCoords'].calculate(image2Context), (127*2. + 100, + 31 * 0.5 + 10)) + self.assertEqual(_stats['std'].calculate(image2Context), + numpy.std(self.imageData)) + self.assertEqual(_stats['mean'].calculate(image2Context), + numpy.mean(self.imageData)) + + yData = numpy.sum(self.imageData, axis=1) + xData = numpy.sum(self.imageData, axis=0) + dataXRange = numpy.arange(self.imageData.shape[1], dtype=numpy.float64) + dataYRange = numpy.arange(self.imageData.shape[0], dtype=numpy.float64) + + ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData) + ycom = (ycom * 0.5) + 10 + xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData) + xcom = (xcom * 2.) + 100 + self.assertTrue(numpy.allclose( + _stats['com'].calculate(image2Context), (xcom, ycom))) + + def testBasicStatsScatter(self): + """Test result for simple stats on a scatter""" + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.scatterContext), 5) + self.assertEqual(_stats['max'].calculate(self.scatterContext), 90) + self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2)) + self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69)) + self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData)) + self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData)) + + data = self.valuesScatterData.astype(numpy.float64) + comx = numpy.sum(self.xScatterData * data) / numpy.sum(data) + comy = numpy.sum(self.yScatterData * data) / numpy.sum(data) + self.assertEqual(_stats['com'].calculate(self.scatterContext), + (comx, comy)) + + def testKindNotManagedByStat(self): + """Make sure an exception is raised if we try to execute calculate + of the base class""" + b = stats.StatBase(name='toto', compatibleKinds='curve') + with self.assertRaises(NotImplementedError): + b.calculate(self.imageContext) + + def testKindNotManagedByContext(self): + """ + Make sure an error is raised if we try to calculate a statistic with + a context not managed + """ + myStat = stats.Stat(name='toto', fct=numpy.std, kinds=('curve')) + myStat.calculate(self.curveContext) + with self.assertRaises(ValueError): + myStat.calculate(self.scatterContext) + with self.assertRaises(ValueError): + myStat.calculate(self.imageContext) + + def testOnLimits(self): + stat = stats.StatMin() + + self.plot1d.getXAxis().setLimitsConstraints(minPos=2, maxPos=5) + curveContextOnLimits = stats._CurveContext( + item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=True, + roi=None) + self.assertEqual(stat.calculate(curveContextOnLimits), 2) + + self.plot2d.getXAxis().setLimitsConstraints(minPos=32) + imageContextOnLimits = stats._ImageContext( + item=self.plot2d.getImage('test image'), + plot=self.plot2d, + onlimits=True, + roi=None) + self.assertEqual(stat.calculate(imageContextOnLimits), 32) + + self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40) + scatterContextOnLimits = stats._ScatterContext( + item=self.scatterPlot.getScatter('scatter plot'), + plot=self.scatterPlot, + onlimits=True, + roi=None) + self.assertEqual(stat.calculate(scatterContextOnLimits), 20) + + +class TestStatsFormatter(TestCaseQt): + """Simple test to check usage of the :class:`StatsFormatter`""" + def setUp(self): + TestCaseQt.setUp(self) + self.plot1d = Plot1D() + x = range(20) + y = range(20) + self.plot1d.addCurve(x, y, legend='curve0') + + self.curveContext = stats._CurveContext( + item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False, + roi=None) + + self.stat = stats.StatMin() + + def tearDown(self): + self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot1d.close() + del self.plot1d + TestCaseQt.tearDown(self) + + def testEmptyFormatter(self): + """Make sure a formatter with no formatter definition will return a + simple cast to str""" + emptyFormatter = statshandler.StatFormatter() + self.assertEqual( + emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000') + + def testSettedFormatter(self): + """Make sure a formatter with no formatter definition will return a + simple cast to str""" + formatter= statshandler.StatFormatter(formatter='{0:.3f}') + self.assertEqual( + formatter.format(self.stat.calculate(self.curveContext)), '0.000') + + +class TestStatsHandler(TestCaseQt): + """Make sure the StatHandler is correctly making the link between + :class:`StatBase` and :class:`StatFormatter` and checking the API is valid + """ + def setUp(self): + TestCaseQt.setUp(self) + self.plot1d = Plot1D() + x = range(20) + y = range(20) + self.plot1d.addCurve(x, y, legend='curve0') + self.curveItem = self.plot1d.getCurve('curve0') + + self.stat = stats.StatMin() + + def tearDown(self): + Stats._getContext.cache_clear() + self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot1d.close() + self.plot1d = None + TestCaseQt.tearDown(self) + + def testConstructor(self): + """Make sure the constructor can deal will all possible arguments: + + * tuple of :class:`StatBase` derivated classes + * tuple of tuples (:class:`StatBase`, :class:`StatFormatter`) + * tuple of tuples (str, pointer to function, kind) + """ + handler0 = statshandler.StatsHandler( + (stats.StatMin(), stats.StatMax()) + ) + + res = handler0.calculate(item=self.curveItem, plot=self.plot1d, + onlimits=False) + self.assertTrue('min' in res) + self.assertEqual(res['min'], '0') + self.assertTrue('max' in res) + self.assertEqual(res['max'], '19') + + handler1 = statshandler.StatsHandler( + ( + (stats.StatMin(), statshandler.StatFormatter(formatter=None)), + (stats.StatMax(), statshandler.StatFormatter()) + ) + ) + + res = handler1.calculate(item=self.curveItem, plot=self.plot1d, + onlimits=False) + self.assertTrue('min' in res) + self.assertEqual(res['min'], '0') + self.assertTrue('max' in res) + self.assertEqual(res['max'], '19.000') + + handler2 = statshandler.StatsHandler( + ( + (stats.StatMin(), None), + (stats.StatMax(), statshandler.StatFormatter()) + )) + + res = handler2.calculate(item=self.curveItem, plot=self.plot1d, + onlimits=False) + self.assertTrue('min' in res) + self.assertEqual(res['min'], '0') + self.assertTrue('max' in res) + self.assertEqual(res['max'], '19.000') + + handler3 = statshandler.StatsHandler(( + (('amin', numpy.argmin), statshandler.StatFormatter()), + ('amax', numpy.argmax) + )) + + res = handler3.calculate(item=self.curveItem, plot=self.plot1d, + onlimits=False) + self.assertTrue('amin' in res) + self.assertEqual(res['amin'], '0.000') + self.assertTrue('amax' in res) + self.assertEqual(res['amax'], '19') + + with self.assertRaises(ValueError): + statshandler.StatsHandler(('name')) + + +class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): + """Basic test for StatsWidget with curves""" + def setUp(self): + TestCaseQt.setUp(self) + self.plot = Plot1D() + self.plot.show() + x = range(20) + y = range(20) + self.plot.addCurve(x, y, legend='curve0') + y = range(12, 32) + self.plot.addCurve(x, y, legend='curve1') + y = range(-2, 18) + self.plot.addCurve(x, y, legend='curve2') + self.widget = StatsWidget.StatsWidget(plot=self.plot) + self.statsTable = self.widget._statsTable + + mystats = statshandler.StatsHandler(( + stats.StatMin(), + (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + stats.StatMax(), + (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + stats.StatDelta(), + ('std', numpy.std), + ('mean', numpy.mean), + stats.StatCOM() + )) + + self.statsTable.setStats(mystats) + + def tearDown(self): + Stats._getContext.cache_clear() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.statsTable = None + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + self.plot = None + TestCaseQt.tearDown(self) + + def testDisplayActiveItemsSyncOptions(self): + """ + Test that the several option of the sync options are well + synchronized between the different object""" + widget = StatsWidget.StatsWidget(plot=self.plot) + table = StatsWidget.StatsTable(plot=self.plot) + + def check_display_only_active_item(only_active): + # check internal value + self.assertIs(widget._statsTable._displayOnlyActItem, only_active) + # self.assertTrue(table._displayOnlyActItem is only_active) + # check gui display + self.assertEqual(widget._options.isActiveItemMode(), only_active) + + for displayOnlyActiveItems in (True, False): + with self.subTest(displayOnlyActiveItems=displayOnlyActiveItems): + widget.setDisplayOnlyActiveItem(displayOnlyActiveItems) + # table.setDisplayOnlyActiveItem(displayOnlyActiveItems) + check_display_only_active_item(displayOnlyActiveItems) + + check_display_only_active_item(only_active=False) + widget.setAttribute(qt.Qt.WA_DeleteOnClose) + table.setAttribute(qt.Qt.WA_DeleteOnClose) + widget.close() + table.close() + + def testInit(self): + """Make sure all the curves are registred on initialization""" + self.assertEqual(self.statsTable.rowCount(), 3) + + def testRemoveCurve(self): + """Make sure the Curves stats take into account the curve removal from + plot""" + self.plot.removeCurve('curve2') + self.assertEqual(self.statsTable.rowCount(), 2) + for iRow in range(2): + self.assertTrue(self.statsTable.item(iRow, 0).text() in ('curve0', 'curve1')) + + self.plot.removeCurve('curve0') + self.assertEqual(self.statsTable.rowCount(), 1) + self.plot.removeCurve('curve1') + self.assertEqual(self.statsTable.rowCount(), 0) + + def testAddCurve(self): + """Make sure the Curves stats take into account the add curve action""" + self.plot.addCurve(legend='curve3', x=range(10), y=range(10)) + self.assertEqual(self.statsTable.rowCount(), 4) + + def testUpdateCurveFromAddCurve(self): + """Make sure the stats of the cuve will be removed after updating a + curve""" + self.plot.addCurve(legend='curve0', x=range(10), y=range(10)) + self.qapp.processEvents() + self.assertEqual(self.statsTable.rowCount(), 3) + curve = self.plot._getItem(kind='curve', legend='curve0') + tableItems = self.statsTable._itemToTableItems(curve) + self.assertEqual(tableItems['max'].text(), '9') + + def testUpdateCurveFromCurveObj(self): + self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) + self.qapp.processEvents() + self.assertEqual(self.statsTable.rowCount(), 3) + curve = self.plot._getItem(kind='curve', legend='curve0') + tableItems = self.statsTable._itemToTableItems(curve) + self.assertEqual(tableItems['max'].text(), '3') + + def testSetAnotherPlot(self): + plot2 = Plot1D() + plot2.addCurve(x=range(26), y=range(26), legend='new curve') + self.statsTable.setPlot(plot2) + self.assertEqual(self.statsTable.rowCount(), 1) + self.qapp.processEvents() + plot2.setAttribute(qt.Qt.WA_DeleteOnClose) + plot2.close() + plot2 = None + + def testUpdateMode(self): + """Make sure the update modes are well take into account""" + self.plot.setActiveCurve('curve0') + for display_only_active in (True, False): + with self.subTest(display_only_active=display_only_active): + self.widget.setDisplayOnlyActiveItem(display_only_active) + self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + update_stats_action = self.widget._options.getUpdateStatsAction() + # test from api + self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO) + self.widget.show() + # check stats change in auto mode + self.plot.getCurve('curve0').setData(x=range(4), y=range(-1, 3)) + self.qapp.processEvents() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == -1.) + + self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5)) + self.qapp.processEvents() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == 1.) + + # check stats change in manual mode only if requested + self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) + self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL) + + self.plot.getCurve('curve0').setData(x=range(4), y=range(2, 6)) + self.qapp.processEvents() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == 1.) + + update_stats_action.trigger() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == 2.) + + def testItemHidden(self): + """Test if an item is hide, then the associated stats item is also + hide""" + curve0 = self.plot.getCurve('curve0') + curve1 = self.plot.getCurve('curve1') + curve2 = self.plot.getCurve('curve2') + + self.plot.show() + self.widget.show() + self.qWaitForWindowExposed(self.widget) + self.assertFalse(self.statsTable.isRowHidden(0)) + self.assertFalse(self.statsTable.isRowHidden(1)) + self.assertFalse(self.statsTable.isRowHidden(2)) + + curve0.setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.statsTable.isRowHidden(0)) + curve0.setVisible(True) + self.qapp.processEvents() + self.assertFalse(self.statsTable.isRowHidden(0)) + curve1.setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.statsTable.isRowHidden(1)) + tableItems = self.statsTable._itemToTableItems(curve2) + curve2_min = tableItems['min'].text() + self.assertTrue(float(curve2_min) == -2.) + + curve0.setVisible(False) + curve1.setVisible(False) + curve2.setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.statsTable.isRowHidden(0)) + self.assertTrue(self.statsTable.isRowHidden(1)) + self.assertTrue(self.statsTable.isRowHidden(2)) + + +class TestStatsWidgetWithImages(TestCaseQt): + """Basic test for StatsWidget with images""" + + IMAGE_LEGEND = 'test image' + + def setUp(self): + TestCaseQt.setUp(self) + self.plot = Plot2D() + + self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128), + legend=self.IMAGE_LEGEND, replace=False) + + self.widget = StatsWidget.StatsTable(plot=self.plot) + + mystats = statshandler.StatsHandler(( + (stats.StatMin(), statshandler.StatFormatter()), + (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + (stats.StatMax(), statshandler.StatFormatter()), + (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + (stats.StatDelta(), statshandler.StatFormatter()), + ('std', numpy.std), + ('mean', numpy.mean), + (stats.StatCOM(), statshandler.StatFormatter(None)) + )) + + self.widget.setStats(mystats) + + def tearDown(self): + Stats._getContext.cache_clear() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + self.plot = None + TestCaseQt.tearDown(self) + + def test(self): + image = self.plot._getItem( + kind='image', legend=self.IMAGE_LEGEND) + tableItems = self.widget._itemToTableItems(image) + + maxText = '{0:.3f}'.format((128 * 128) - 1) + self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND) + self.assertEqual(tableItems['min'].text(), '0.000') + self.assertEqual(tableItems['max'].text(), maxText) + self.assertEqual(tableItems['delta'].text(), maxText) + self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0') + self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0') + + def testItemHidden(self): + """Test if an item is hide, then the associated stats item is also + hide""" + self.widget.show() + self.plot.show() + self.qWaitForWindowExposed(self.widget) + self.assertFalse(self.widget.isRowHidden(0)) + self.plot.getImage(self.IMAGE_LEGEND).setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.widget.isRowHidden(0)) + + +class TestStatsWidgetWithScatters(TestCaseQt): + + SCATTER_LEGEND = 'scatter plot' + + def setUp(self): + TestCaseQt.setUp(self) + self.scatterPlot = Plot2D() + self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60], + [2, 3, 4, 26, 69, 6], + [5, 6, 7, 10, 90, 20], + legend=self.SCATTER_LEGEND) + self.widget = StatsWidget.StatsTable(plot=self.scatterPlot) + + mystats = statshandler.StatsHandler(( + stats.StatMin(), + (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + stats.StatMax(), + (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + stats.StatDelta(), + ('std', numpy.std), + ('mean', numpy.mean), + stats.StatCOM() + )) + + self.widget.setStats(mystats) + + def tearDown(self): + Stats._getContext.cache_clear() + self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.scatterPlot.close() + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + self.scatterPlot = None + TestCaseQt.tearDown(self) + + def testStats(self): + scatter = self.scatterPlot._getItem( + kind='scatter', legend=self.SCATTER_LEGEND) + tableItems = self.widget._itemToTableItems(scatter) + self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND) + self.assertEqual(tableItems['min'].text(), '5') + self.assertEqual(tableItems['coords min'].text(), '0, 2') + self.assertEqual(tableItems['max'].text(), '90') + self.assertEqual(tableItems['coords max'].text(), '50, 69') + self.assertEqual(tableItems['delta'].text(), '85') + + +class TestEmptyStatsWidget(TestCaseQt): + def test(self): + widget = StatsWidget.StatsWidget() + widget.show() + self.qWaitForWindowExposed(widget) + + +class TestLineWidget(TestCaseQt): + """Some test for the StatsLineWidget.""" + def setUp(self): + TestCaseQt.setUp(self) + + mystats = statshandler.StatsHandler(( + (stats.StatMin(), statshandler.StatFormatter()), + )) + + self.plot = Plot1D() + self.plot.show() + self.x = range(20) + self.y0 = range(20) + self.curve0 = self.plot.addCurve(self.x, self.y0, legend='curve0') + self.y1 = range(12, 32) + self.plot.addCurve(self.x, self.y1, legend='curve1') + self.y2 = range(-2, 18) + self.plot.addCurve(self.x, self.y2, legend='curve2') + self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot, + kind='curve', + stats=mystats) + + def tearDown(self): + Stats._getContext.cache_clear() + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.widget.setPlot(None) + self.widget._lineStatsWidget._statQlineEdit.clear() + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + self.plot = None + TestCaseQt.tearDown(self) + + def testProcessing(self): + self.widget._lineStatsWidget.setStatsOnVisibleData(False) + self.qapp.processEvents() + self.plot.setActiveCurve(legend='curve0') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.000') + self.plot.setActiveCurve(legend='curve1') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '12.000') + self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5) + self.widget.setStatsOnVisibleData(True) + self.qapp.processEvents() + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000') + self.plot.setActiveCurve(None) + self.assertIsNone(self.plot.getActiveCurve()) + self.widget.setStatsOnVisibleData(False) + self.qapp.processEvents() + self.assertFalse(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000') + self.widget.setKind('image') + self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312) + self.qapp.processEvents() + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.312') + + def testUpdateMode(self): + """Make sure the update modes are well take into account""" + self.plot.setActiveCurve(self.curve0) + _autoRB = self.widget._options._autoRB + _manualRB = self.widget._options._manualRB + # test from api + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + self.assertTrue(_autoRB.isChecked()) + self.assertFalse(_manualRB.isChecked()) + + # check stats change in auto mode + curve0_min = self.widget._lineStatsWidget._statQlineEdit['min'].text() + new_y = numpy.array(self.y0) - 2.56 + self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0) + curve0_min2 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.assertTrue(curve0_min != curve0_min2) + + # check stats change in manual mode only if requested + self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) + self.assertFalse(_autoRB.isChecked()) + self.assertTrue(_manualRB.isChecked()) + + new_y = numpy.array(self.y0) - 1.2 + self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0) + curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.assertTrue(curve0_min3 == curve0_min2) + self.widget._options._updateRequested() + curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.assertTrue(curve0_min3 != curve0_min2) + + # test from gui + self.widget.showRadioButtons(True) + self.widget._options._autoRB.toggle() + self.assertTrue(_autoRB.isChecked()) + self.assertFalse(_manualRB.isChecked()) + + self.widget._options._manualRB.toggle() + self.assertFalse(_autoRB.isChecked()) + self.assertTrue(_manualRB.isChecked()) + + +class TestUpdateModeWidget(TestCaseQt): + """Test UpdateModeWidget""" + def setUp(self): + TestCaseQt.setUp(self) + self.widget = StatsWidget.UpdateModeWidget(parent=None) + + def tearDown(self): + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + TestCaseQt.tearDown(self) + + def testSignals(self): + """Test the signal emission of the widget""" + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + modeChangedListener = SignalListener() + manualUpdateListener = SignalListener() + self.widget.sigUpdateModeChanged.connect(modeChangedListener) + self.widget.sigUpdateRequested.connect(manualUpdateListener) + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO) + self.assertEqual(modeChangedListener.callCount(), 0) + self.qapp.processEvents() + + self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) + self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL) + self.qapp.processEvents() + self.assertEqual(modeChangedListener.callCount(), 1) + self.assertEqual(manualUpdateListener.callCount(), 0) + self.widget._updatePB.click() + self.widget._updatePB.click() + self.assertEqual(manualUpdateListener.callCount(), 2) + + self.widget._autoRB.setChecked(True) + self.assertEqual(modeChangedListener.callCount(), 2) + self.widget._updatePB.click() + self.assertEqual(manualUpdateListener.callCount(), 2) + + +class TestStatsROI(TestStatsBase, TestCaseQt): + """ + Test stats based on ROI + """ + def setUp(self): + TestCaseQt.setUp(self) + self.createRois() + TestStatsBase.setUp(self) + self.createHistogramContext() + + self.roiManager = RegionOfInterestManager(self.plot2d) + self.roiManager.addRoi(self._2Droi_rect) + self.roiManager.addRoi(self._2Droi_poly) + + def tearDown(self): + self.roiManager.clear() + self.roiManager = None + self._1Droi = None + self._2Droi_rect = None + self._2Droi_poly = None + self.plotHisto.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plotHisto.close() + self.plotHisto = None + TestStatsBase.tearDown(self) + TestCaseQt.tearDown(self) + + def createRois(self): + self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0) + self._2Droi_rect = RectangleROI() + self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0)) + self._2Droi_poly = PolygonROI() + points = numpy.array(((0, 20), (0, 0), (10, 0))) + self._2Droi_poly.setPoints(points=points) + + def createCurveContext(self): + TestStatsBase.createCurveContext(self) + self.curveContext = stats._CurveContext( + item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False, + roi=self._1Droi) + + def createHistogramContext(self): + self.plotHisto = Plot1D() + x = range(20) + y = range(20) + self.plotHisto.addHistogram(x, y, legend='histo0') + + self.histoContext = stats._HistogramContext( + item=self.plotHisto.getHistogram('histo0'), + plot=self.plotHisto, + onlimits=False, + roi=self._1Droi) + + def createScatterContext(self): + TestStatsBase.createScatterContext(self) + self.scatterContext = stats._ScatterContext( + item=self.scatterPlot.getScatter('scatter plot'), + plot=self.scatterPlot, + onlimits=False, + roi=self._1Droi + ) + + def createImageContext(self): + TestStatsBase.createImageContext(self) + + self.imageContext = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False, + roi=self._2Droi_rect + ) + + self.imageContext_2 = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False, + roi=self._2Droi_poly + ) + + def testErrors(self): + # test if onlimits is True and give also a roi + with self.assertRaises(ValueError): + stats._CurveContext(item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=True, + roi=self._1Droi) + + # test if is a curve context and give an invalid 2D roi + with self.assertRaises(TypeError): + stats._CurveContext(item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False, + roi=self._2Droi_rect) + + def testBasicStatsCurve(self): + """Test result for simple stats on a curve""" + _stats = self.getBasicStats() + xData = yData = numpy.array(range(0, 10)) + self.assertEqual(_stats['min'].calculate(self.curveContext), 2) + self.assertEqual(_stats['max'].calculate(self.curveContext), 5) + self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,)) + self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,)) + self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6])) + self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6])) + com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6]) + self.assertEqual(_stats['com'].calculate(self.curveContext), com) + + def testBasicStatsImageRectRoi(self): + """Test result for simple stats on an image""" + self.assertEqual(self.imageContext.values.compressed().size, 121) + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.imageContext), 10) + self.assertEqual(_stats['max'].calculate(self.imageContext), 1300) + self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0)) + self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0)) + self.assertAlmostEqual(_stats['std'].calculate(self.imageContext), + numpy.std(self.imageData[0:11, 10:21])) + self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext), + numpy.mean(self.imageData[0:11, 10:21])) + + compressed_values = self.imageContext.values.compressed() + compressed_values = compressed_values.reshape(11, 11) + yData = numpy.sum(compressed_values.astype(numpy.float64), axis=1) + xData = numpy.sum(compressed_values.astype(numpy.float64), axis=0) + + dataYRange = range(11) + dataXRange = range(10, 21) + + ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData) + xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData) + self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom)) + + def testBasicStatsImagePolyRoi(self): + """Test a simple rectangle ROI""" + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0) + self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432) + self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0)) + # not 0.0, 19.0 because not fully in. Should all pixel have a weight, + # on to manage them in stats. For now 0 if the center is not in, else 1 + self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0)) + + def testBasicStatsScatter(self): + self.assertEqual(self.scatterContext.values.compressed().size, 2) + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.scatterContext), 6) + self.assertEqual(_stats['max'].calculate(self.scatterContext), 7) + self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3)) + self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4)) + self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7])) + self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7])) + + def testBasicHistogram(self): + _stats = self.getBasicStats() + xData = yData = numpy.array(range(2, 6)) + self.assertEqual(_stats['min'].calculate(self.histoContext), 2) + self.assertEqual(_stats['max'].calculate(self.histoContext), 5) + self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,)) + self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,)) + self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData)) + self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData)) + com = numpy.sum(xData * yData) / numpy.sum(yData) + self.assertEqual(_stats['com'].calculate(self.histoContext), com) + + +class TestAdvancedROIImageContext(TestCaseQt): + """Test stats result on an image context with different scale and + origins""" + + def setUp(self): + TestCaseQt.setUp(self) + self.data_dims = (100, 100) + self.data = numpy.random.rand(*self.data_dims) + self.plot = Plot2D() + + def tearDown(self): + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.plot = None + TestCaseQt.tearDown(self) + + def test(self): + """Test stats result on an image context with different scale and + origins""" + roi_origins = [(0, 0), (2, 10), (14, 20)] + img_origins = [(0, 0), (14, 20), (2, 10)] + img_scales = [1.0, 0.5, 2.0] + _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), } + for roi_origin in roi_origins: + for img_origin in img_origins: + for img_scale in img_scales: + with self.subTest(roi_origin=roi_origin, + img_origin=img_origin, + img_scale=img_scale): + self.plot.addImage(self.data, legend='img', + origin=img_origin, + scale=img_scale) + roi = RectangleROI() + roi.setGeometry(origin=roi_origin, size=(20, 20)) + context = stats._ImageContext( + item=self.plot.getImage('img'), + plot=self.plot, + onlimits=False, + roi=roi) + x_start = int((roi_origin[0] - img_origin[0]) / img_scale) + x_end = int(x_start + (20 / img_scale)) + 1 + y_start = int((roi_origin[1] - img_origin[1])/ img_scale) + y_end = int(y_start + (20 / img_scale)) + 1 + x_start = max(x_start, 0) + x_end = min(max(x_end, 0), self.data_dims[1]) + y_start = max(y_start, 0) + y_end = min(max(y_end, 0), self.data_dims[0]) + th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end]) + self.assertAlmostEqual(_stats['sum'].calculate(context), + th_sum) diff --git a/src/silx/gui/plot/test/testUtilsAxis.py b/src/silx/gui/plot/test/testUtilsAxis.py new file mode 100644 index 0000000..dd4a689 --- /dev/null +++ b/src/silx/gui/plot/test/testUtilsAxis.py @@ -0,0 +1,203 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWidget""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "20/11/2018" + + +import unittest +from silx.gui.plot import PlotWidget +from silx.gui.utils.testutils import TestCaseQt +from silx.gui.plot.utils.axis import SyncAxes + + +class TestAxisSync(TestCaseQt): + """Tests AxisSync class""" + + def setUp(self): + TestCaseQt.setUp(self) + self.plot1 = PlotWidget() + self.plot2 = PlotWidget() + self.plot3 = PlotWidget() + + def tearDown(self): + self.plot1 = None + self.plot2 = None + self.plot3 = None + TestCaseQt.tearDown(self) + + def testMoveFirstAxis(self): + """Test synchronization after construction""" + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testMoveSecondAxis(self): + """Test synchronization after construction""" + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + + self.plot2.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testMoveTwoAxes(self): + """Test synchronization after construction""" + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + + self.plot1.getXAxis().setLimits(1, 50) + self.plot2.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testDestruction(self): + """Test synchronization when sync object is destroyed""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + del sync + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testAxisDestruction(self): + """Test synchronization when an axis disappear""" + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + + # Destroy the plot is possible + import weakref + plot = weakref.ref(self.plot2) + self.plot2 = None + result = self.qWaitForDestroy(plot) + if not result: + # We can't test + self.skipTest("Object not destroyed") + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testStop(self): + """Test synchronization after calling stop""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + sync.stop() + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testStopMovingStart(self): + """Test synchronization after calling stop, moving an axis, then start again""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + sync.stop() + self.plot1.getXAxis().setLimits(10, 500) + self.plot2.getXAxis().setLimits(1, 50) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + sync.start() + + # The first axis is the reference + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testDoubleStop(self): + """Test double stop""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + sync.stop() + self.assertRaises(RuntimeError, sync.stop) + + def testDoubleStart(self): + """Test double stop""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + self.assertRaises(RuntimeError, sync.start) + + def testScale(self): + """Test scale change""" + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + self.plot1.getXAxis().setScale(self.plot1.getXAxis().LOGARITHMIC) + self.assertEqual(self.plot1.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC) + self.assertEqual(self.plot2.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC) + self.assertEqual(self.plot3.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC) + + def testDirection(self): + """Test direction change""" + _sync = SyncAxes([self.plot1.getYAxis(), self.plot2.getYAxis(), self.plot3.getYAxis()]) + self.plot1.getYAxis().setInverted(True) + self.assertEqual(self.plot1.getYAxis().isInverted(), True) + self.assertEqual(self.plot2.getYAxis().isInverted(), True) + self.assertEqual(self.plot3.getYAxis().isInverted(), True) + + def testSyncCenter(self): + """Test direction change""" + # Not the same scale + self.plot1.getXAxis().setLimits(0, 200) + self.plot2.getXAxis().setLimits(0, 20) + self.plot3.getXAxis().setLimits(0, 2) + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()], + syncLimits=False, syncCenter=True) + + self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (100 - 1, 100 + 1)) + + def testSyncCenterAndZoom(self): + """Test direction change""" + # Not the same scale + self.plot1.getXAxis().setLimits(0, 200) + self.plot2.getXAxis().setLimits(0, 20) + self.plot3.getXAxis().setLimits(0, 2) + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()], + syncLimits=False, syncCenter=True, syncZoom=True) + + # Supposing all the plots use the same size + self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (0, 200)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (0, 200)) + + def testAddAxis(self): + """Test synchronization after construction""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis()]) + sync.addAxis(self.plot3.getXAxis()) + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testRemoveAxis(self): + """Test synchronization after construction""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + sync.removeAxis(self.plot3.getXAxis()) + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500)) diff --git a/src/silx/gui/plot/test/utils.py b/src/silx/gui/plot/test/utils.py new file mode 100644 index 0000000..64fca56 --- /dev/null +++ b/src/silx/gui/plot/test/utils.py @@ -0,0 +1,93 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "26/01/2018" + + +import logging +import pytest +import unittest + +from silx.gui.utils.testutils import TestCaseQt + +from silx.gui import qt +from silx.gui.plot import PlotWidget + + +logger = logging.getLogger(__name__) + + +@pytest.mark.usefixtures("test_options_class_attr") +class PlotWidgetTestCase(TestCaseQt): + """Base class for tests of PlotWidget, not a TestCase in itself. + + plot attribute is the PlotWidget created for the test. + """ + __screenshot_already_taken = False + backend = None + + def _createPlot(self): + return PlotWidget(backend=self.backend) + + def setUp(self): + super(PlotWidgetTestCase, self).setUp() + self.plot = self._createPlot() + self.plot.show() + self.plotAlive = True + self.qWaitForWindowExposed(self.plot) + TestCaseQt.mouseClick(self, self.plot, button=qt.Qt.LeftButton, pos=(0, 0)) + + def __onPlotDestroyed(self): + self.plotAlive = False + + def _waitForPlotClosed(self): + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.destroyed.connect(self.__onPlotDestroyed) + self.plot.close() + del self.plot + for _ in range(100): + if not self.plotAlive: + break + self.qWait(10) + else: + logger.error("Plot is still alive") + + def tearDown(self): + if not self._currentTestSucceeded(): + # MPL is the only widget which uses the real system mouse. + # In case of a the windows is outside of the screen, minimzed, + # overlapped by a system popup, the MPL widget will not receive the + # mouse event. + # Taking a screenshot help debuging this cases in the continuous + # integration environement. + if not PlotWidgetTestCase.__screenshot_already_taken: + PlotWidgetTestCase.__screenshot_already_taken = True + self.logScreenShot() + self.qapp.processEvents() + self._waitForPlotClosed() + super(PlotWidgetTestCase, self).tearDown() diff --git a/src/silx/gui/plot/tools/CurveLegendsWidget.py b/src/silx/gui/plot/tools/CurveLegendsWidget.py new file mode 100644 index 0000000..4a517dd --- /dev/null +++ b/src/silx/gui/plot/tools/CurveLegendsWidget.py @@ -0,0 +1,247 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a widget to display :class:`PlotWidget` curve legends. +""" + +from __future__ import division + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "20/07/2018" + + +import logging +import weakref + + +from ... import qt +from ...widgets.FlowLayout import FlowLayout as _FlowLayout +from ..LegendSelector import LegendIcon as _LegendIcon +from .. import items + + +_logger = logging.getLogger(__name__) + + +class _LegendWidget(qt.QWidget): + """Widget displaying curve style and its legend + + :param QWidget parent: See :class:`QWidget` + :param ~silx.gui.plot.items.Curve curve: Associated curve + """ + + def __init__(self, parent, curve): + super(_LegendWidget, self).__init__(parent) + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(10, 0, 10, 0) + + curve.sigItemChanged.connect(self._curveChanged) + + icon = _LegendIcon(curve=curve) + layout.addWidget(icon) + + label = qt.QLabel(curve.getName()) + label.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter) + layout.addWidget(label) + + self._update() + + def getCurve(self): + """Returns curve associated to this widget + + :rtype: Union[~silx.gui.plot.items.Curve,None] + """ + icon = self.findChild(_LegendIcon) + return icon.getCurve() + + def _update(self): + """Update widget according to current curve state. + """ + curve = self.getCurve() + if curve is None: + _logger.error('Curve no more exists') + self.setVisible(False) + return + + self.setEnabled(curve.isVisible()) + + label = self.findChild(qt.QLabel) + if curve.isHighlighted(): + label.setStyleSheet("border: 1px solid black") + else: + label.setStyleSheet("") + + def _curveChanged(self, event): + """Handle update of curve item + + :param event: Kind of change + """ + if event in (items.ItemChangedType.VISIBLE, + items.ItemChangedType.HIGHLIGHTED, + items.ItemChangedType.HIGHLIGHTED_STYLE): + self._update() + + +class CurveLegendsWidget(qt.QWidget): + """Widget displaying curves legends in a plot + + :param QWidget parent: See :class:`QWidget` + """ + + sigCurveClicked = qt.Signal(object) + """Signal emitted when the legend of a curve is clicked + + It provides the corresponding curve. + """ + + def __init__(self, parent=None): + super(CurveLegendsWidget, self).__init__(parent) + self._clicked = None + self._legends = {} + self._plotRef = None + + def layout(self): + layout = super(CurveLegendsWidget, self).layout() + if layout is None: + # Lazy layout initialization to allow overloading + layout = _FlowLayout() + layout.setHorizontalSpacing(0) + self.setLayout(layout) + return layout + + def getPlotWidget(self): + """Returns the associated :class:`PlotWidget` + + :rtype: Union[~silx.gui.plot.PlotWidget,None] + """ + return None if self._plotRef is None else self._plotRef() + + def setPlotWidget(self, plot): + """Set the associated :class:`PlotWidget` + + :param ~silx.gui.plot.PlotWidget plot: Plot widget to attach + """ + previousPlot = self.getPlotWidget() + if previousPlot is not None: + previousPlot.sigItemAdded.disconnect( self._itemAdded) + previousPlot.sigItemAboutToBeRemoved.disconnect(self._itemRemoved) + for legend in list(self._legends.keys()): + self._removeLegend(legend) + + self._plotRef = None if plot is None else weakref.ref(plot) + + if plot is not None: + plot.sigItemAdded.connect(self._itemAdded) + plot.sigItemAboutToBeRemoved.connect(self._itemRemoved) + + for legend in plot.getAllCurves(just_legend=True): + self._addLegend(legend) + + def curveAt(self, *args): + """Returns the curve object represented at the given position + + Either takes a QPoint or x and y as input in widget coordinates. + + :rtype: Union[~silx.gui.plot.items.Curve,None] + """ + if len(args) == 1: + point = args[0] + elif len(args) == 2: + point = qt.QPoint(*args) + else: + raise ValueError('Unsupported arguments') + assert isinstance(point, qt.QPoint) + + widget = self.childAt(point) + while widget not in (self, None): + if isinstance(widget, _LegendWidget): + return widget.getCurve() + widget = widget.parent() + return None # No widget or not in _LegendWidget + + def _itemAdded(self, item): + """Handle item added to the plot content""" + if isinstance(item, items.Curve): + self._addLegend(item.getName()) + + def _itemRemoved(self, item): + """Handle item removed from the plot content""" + if isinstance(item, items.Curve): + self._removeLegend(item.getName()) + + def _addLegend(self, legend): + """Add a curve to the legends + + :param str legend: Curve's legend + """ + if legend in self._legends: + return # Can happen when changing curve's y axis + + plot = self.getPlotWidget() + if plot is None: + return None + + curve = plot.getCurve(legend) + if curve is None: + _logger.error('Curve not found: %s' % legend) + return + + widget = _LegendWidget(parent=self, curve=curve) + self.layout().addWidget(widget) + self._legends[legend] = widget + + def _removeLegend(self, legend): + """Remove a curve from the legends if it exists + + :param str legend: The curve's legend + """ + widget = self._legends.pop(legend, None) + if widget is None: + _logger.warning('Unknown legend: %s' % legend) + else: + self.layout().removeWidget(widget) + widget.setParent(None) + + def mousePressEvent(self, event): + if event.button() == qt.Qt.LeftButton: + self._clicked = event.pos() + + _CLICK_THRESHOLD = 5 + """Threshold for clicks""" + + def mouseMoveEvent(self, event): + if self._clicked is not None: + dx = abs(self._clicked.x() - event.pos().x()) + dy = abs(self._clicked.y() - event.pos().y()) + if dx > self._CLICK_THRESHOLD or dy > self._CLICK_THRESHOLD: + self._clicked = None # Click is cancelled + + def mouseReleaseEvent(self, event): + if event.button() == qt.Qt.LeftButton and self._clicked is not None: + curve = self.curveAt(event.pos()) + if curve is not None: + self.sigCurveClicked.emit(curve) + + self._clicked = None diff --git a/src/silx/gui/plot/tools/LimitsToolBar.py b/src/silx/gui/plot/tools/LimitsToolBar.py new file mode 100644 index 0000000..fc192a6 --- /dev/null +++ b/src/silx/gui/plot/tools/LimitsToolBar.py @@ -0,0 +1,131 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A toolbar to display and edit limits of a PlotWidget +""" + + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "16/10/2017" + + +from ... import qt +from ...widgets.FloatEdit import FloatEdit + + +class LimitsToolBar(qt.QToolBar): + """QToolBar displaying and controlling the limits of a :class:`PlotWidget`. + + To run the following sample code, a QApplication must be initialized. + First, create a PlotWindow: + + >>> from silx.gui.plot import PlotWindow + >>> plot = PlotWindow() # Create a PlotWindow to add the toolbar to + + Then, create the LimitsToolBar and add it to the PlotWindow. + + >>> from silx.gui import qt + >>> from silx.gui.plot.tools import LimitsToolBar + + >>> toolbar = LimitsToolBar(plot=plot) # Create the toolbar + >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolbar) # Add it to the plot + >>> plot.show() # To display the PlotWindow with the limits toolbar + + :param parent: See :class:`QToolBar`. + :param plot: :class:`PlotWidget` instance on which to operate. + :param str title: See :class:`QToolBar`. + """ + + def __init__(self, parent=None, plot=None, title='Limits'): + super(LimitsToolBar, self).__init__(title, parent) + assert plot is not None + self._plot = plot + self._plot.sigPlotSignal.connect(self._plotWidgetSlot) + + self._initWidgets() + + @property + def plot(self): + """The :class:`PlotWidget` the toolbar is attached to.""" + return self._plot + + def _initWidgets(self): + """Create and init Toolbar widgets.""" + xMin, xMax = self.plot.getXAxis().getLimits() + yMin, yMax = self.plot.getYAxis().getLimits() + + self.addWidget(qt.QLabel('Limits: ')) + self.addWidget(qt.QLabel(' X: ')) + self._xMinFloatEdit = FloatEdit(self, xMin) + self._xMinFloatEdit.editingFinished[()].connect( + self._xFloatEditChanged) + self.addWidget(self._xMinFloatEdit) + + self._xMaxFloatEdit = FloatEdit(self, xMax) + self._xMaxFloatEdit.editingFinished[()].connect( + self._xFloatEditChanged) + self.addWidget(self._xMaxFloatEdit) + + self.addWidget(qt.QLabel(' Y: ')) + self._yMinFloatEdit = FloatEdit(self, yMin) + self._yMinFloatEdit.editingFinished[()].connect( + self._yFloatEditChanged) + self.addWidget(self._yMinFloatEdit) + + self._yMaxFloatEdit = FloatEdit(self, yMax) + self._yMaxFloatEdit.editingFinished[()].connect( + self._yFloatEditChanged) + self.addWidget(self._yMaxFloatEdit) + + def _plotWidgetSlot(self, event): + """Listen to :class:`PlotWidget` events.""" + if event['event'] not in ('limitsChanged',): + return + + xMin, xMax = self.plot.getXAxis().getLimits() + yMin, yMax = self.plot.getYAxis().getLimits() + + self._xMinFloatEdit.setValue(xMin) + self._xMaxFloatEdit.setValue(xMax) + self._yMinFloatEdit.setValue(yMin) + self._yMaxFloatEdit.setValue(yMax) + + def _xFloatEditChanged(self): + """Handle X limits changed from the GUI.""" + xMin, xMax = self._xMinFloatEdit.value(), self._xMaxFloatEdit.value() + if xMax < xMin: + xMin, xMax = xMax, xMin + + self.plot.getXAxis().setLimits(xMin, xMax) + + def _yFloatEditChanged(self): + """Handle Y limits changed from the GUI.""" + yMin, yMax = self._yMinFloatEdit.value(), self._yMaxFloatEdit.value() + if yMax < yMin: + yMin, yMax = yMax, yMin + + self.plot.getYAxis().setLimits(yMin, yMax) diff --git a/src/silx/gui/plot/tools/PositionInfo.py b/src/silx/gui/plot/tools/PositionInfo.py new file mode 100644 index 0000000..8b95fbc --- /dev/null +++ b/src/silx/gui/plot/tools/PositionInfo.py @@ -0,0 +1,373 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a widget displaying mouse coordinates in a PlotWidget. + +It can be configured to provide more information. +""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "16/10/2017" + + +import logging +import numbers +import traceback +import weakref + +import numpy + +from ....utils.deprecation import deprecated +from ... import qt +from .. import items +from ...widgets.ElidedLabel import ElidedLabel + + +_logger = logging.getLogger(__name__) + + +class _PositionInfoLabel(ElidedLabel): + """QLabel with a default size larger than what is displayed.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse) + + def sizeHint(self): + hint = super().sizeHint() + width = self.fontMetrics().boundingRect('##############').width() + return qt.QSize(max(hint.width(), width), hint.height()) + + +# PositionInfo ################################################################ + +class PositionInfo(qt.QWidget): + """QWidget displaying coords converted from data coords of the mouse. + + Provide this widget with a list of couple: + + - A name to display before the data + - A function that takes (x, y) as arguments and returns something that + gets converted to a string. + If the result is a float it is converted with '%.7g' format. + + To run the following sample code, a QApplication must be initialized. + First, create a PlotWindow and add a QToolBar where to place the + PositionInfo widget. + + >>> from silx.gui.plot import PlotWindow + >>> from silx.gui import qt + + >>> plot = PlotWindow() # Create a PlotWindow to add the widget to + >>> toolBar = qt.QToolBar() # Create a toolbar to place the widget in + >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) # Add it to plot + + Then, create the PositionInfo widget and add it to the toolbar. + The PositionInfo widget is created with a list of converters, here + to display polar coordinates of the mouse position. + + >>> import numpy + >>> from silx.gui.plot.tools import PositionInfo + + >>> position = PositionInfo(plot=plot, converters=[ + ... ('Radius', lambda x, y: numpy.sqrt(x*x + y*y)), + ... ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))]) + >>> toolBar.addWidget(position) # Add the widget to the toolbar + <...> + >>> plot.show() # To display the PlotWindow with the position widget + + :param plot: The PlotWidget this widget is displaying data coords from. + :param converters: + List of 2-tuple: name to display and conversion function from (x, y) + in data coords to displayed value. + If None, the default, it displays X and Y. + :param parent: Parent widget + """ + + SNAP_THRESHOLD_DIST = 5 + + def __init__(self, parent=None, plot=None, converters=None): + assert plot is not None + self._plotRef = weakref.ref(plot) + self._snappingMode = self.SNAPPING_DISABLED + + super(PositionInfo, self).__init__(parent) + + if converters is None: + converters = (('X', lambda x, y: x), ('Y', lambda x, y: y)) + + self._fields = [] # To store (QLineEdit, name, function (x, y)->v) + + # Create a new layout with new widgets + layout = qt.QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + # layout.setSpacing(0) + + # Create all QLabel and store them with the corresponding converter + for name, func in converters: + layout.addWidget(qt.QLabel('<b>' + name + ':</b>')) + + contentWidget = _PositionInfoLabel(self) + contentWidget.setText('------') + layout.addWidget(contentWidget) + self._fields.append((contentWidget, name, func)) + + layout.addStretch(1) + self.setLayout(layout) + + # Connect to Plot events + plot.sigPlotSignal.connect(self._plotEvent) + + def getPlotWidget(self): + """Returns the PlotWidget this widget is attached to or None. + + :rtype: Union[~silx.gui.plot.PlotWidget,None] + """ + return self._plotRef() + + @property + @deprecated(replacement='getPlotWidget', since_version='0.8.0') + def plot(self): + return self.getPlotWidget() + + def getConverters(self): + """Return the list of converters as 2-tuple (name, function).""" + return [(name, func) for _label, name, func in self._fields] + + def _plotEvent(self, event): + """Handle events from the Plot. + + :param dict event: Plot event + """ + if event['event'] == 'mouseMoved': + x, y = event['x'], event['y'] + xPixel, yPixel = event['xpixel'], event['ypixel'] + self._updateStatusBar(x, y, xPixel, yPixel) + + def updateInfo(self): + """Update displayed information""" + plot = self.getPlotWidget() + if plot is None: + _logger.error("Trying to update PositionInfo " + "while PlotWidget no longer exists") + return + + widget = plot.getWidgetHandle() + position = widget.mapFromGlobal(qt.QCursor.pos()) + xPixel, yPixel = position.x(), position.y() + dataPos = plot.pixelToData(xPixel, yPixel, check=True) + if dataPos is not None: # Inside plot area + x, y = dataPos + self._updateStatusBar(x, y, xPixel, yPixel) + + def _updateStatusBar(self, x, y, xPixel, yPixel): + """Update information from the status bar using the definitions. + + :param float x: Position-x in data + :param float y: Position-y in data + :param float xPixel: Position-x in pixels + :param float yPixel: Position-y in pixels + """ + plot = self.getPlotWidget() + if plot is None: + return + + styleSheet = "color: rgb(0, 0, 0);" # Default style + xData, yData = x, y + + snappingMode = self.getSnappingMode() + + # Snapping when crosshair either not requested or active + if (snappingMode & (self.SNAPPING_CURVE | self.SNAPPING_SCATTER) and + (not (snappingMode & self.SNAPPING_CROSSHAIR) or + plot.getGraphCursor())): + styleSheet = "color: rgb(255, 0, 0);" # Style far from item + + if snappingMode & self.SNAPPING_ACTIVE_ONLY: + selectedItems = [] + + if snappingMode & self.SNAPPING_CURVE: + activeCurve = plot.getActiveCurve() + if activeCurve: + selectedItems.append(activeCurve) + + if snappingMode & self.SNAPPING_SCATTER: + activeScatter = plot._getActiveItem(kind='scatter') + if activeScatter: + selectedItems.append(activeScatter) + + else: + kinds = [] + if snappingMode & self.SNAPPING_CURVE: + kinds.append(items.Curve) + kinds.append(items.Histogram) + if snappingMode & self.SNAPPING_SCATTER: + kinds.append(items.Scatter) + selectedItems = [item for item in plot.getItems() + if isinstance(item, tuple(kinds)) and item.isVisible()] + + # Compute distance threshold + window = plot.window() + windowHandle = window.windowHandle() + if windowHandle is not None: + ratio = windowHandle.devicePixelRatio() + else: + ratio = qt.QGuiApplication.primaryScreen().devicePixelRatio() + + # Baseline squared distance threshold + distInPixels = (self.SNAP_THRESHOLD_DIST * ratio)**2 + + for item in selectedItems: + if (snappingMode & self.SNAPPING_SYMBOLS_ONLY and ( + not isinstance(item, items.SymbolMixIn) or + not item.getSymbol())): + # Only handled if item symbols are visible + continue + + if isinstance(item, items.Histogram): + result = item.pick(xPixel, yPixel) + if result is not None: # Histogram picked + index = result.getIndices()[0] + edges = item.getBinEdgesData(copy=False) + + # Snap to bin center and value + xData = 0.5 * (edges[index] + edges[index + 1]) + yData = item.getValueData(copy=False)[index] + + # Update label style sheet + styleSheet = "color: rgb(0, 0, 0);" + break + + else: # Curve, Scatter + xArray = item.getXData(copy=False) + yArray = item.getYData(copy=False) + closestIndex = numpy.argmin( + pow(xArray - x, 2) + pow(yArray - y, 2)) + + xClosest = xArray[closestIndex] + yClosest = yArray[closestIndex] + + if isinstance(item, items.YAxisMixIn): + axis = item.getYAxis() + else: + axis = 'left' + + closestInPixels = plot.dataToPixel( + xClosest, yClosest, axis=axis) + if closestInPixels is not None: + curveDistInPixels = ( + (closestInPixels[0] - xPixel)**2 + + (closestInPixels[1] - yPixel)**2) + + if curveDistInPixels <= distInPixels: + # Update label style sheet + styleSheet = "color: rgb(0, 0, 0);" + + # if close enough, snap to data point coord + xData, yData = xClosest, yClosest + distInPixels = curveDistInPixels + + for label, name, func in self._fields: + label.setStyleSheet(styleSheet) + + try: + value = func(xData, yData) + text = self.valueToString(value) + label.setText(text) + except: + label.setText('Error') + _logger.error( + "Error while converting coordinates (%f, %f)" + "with converter '%s'" % (xPixel, yPixel, name)) + _logger.error(traceback.format_exc()) + + def valueToString(self, value): + if isinstance(value, (tuple, list)): + value = [self.valueToString(v) for v in value] + return ", ".join(value) + elif isinstance(value, numbers.Real): + # Use this for floats and int + return '%.7g' % value + else: + # Fallback for other types + return str(value) + + # Snapping mode + + SNAPPING_DISABLED = 0 + """No snapping occurs""" + + SNAPPING_CROSSHAIR = 1 << 0 + """Snapping only enabled when crosshair cursor is enabled""" + + SNAPPING_ACTIVE_ONLY = 1 << 1 + """Snapping only enabled for active item""" + + SNAPPING_SYMBOLS_ONLY = 1 << 2 + """Snapping only when symbols are visible""" + + SNAPPING_CURVE = 1 << 3 + """Snapping on curves""" + + SNAPPING_SCATTER = 1 << 4 + """Snapping on scatter""" + + def setSnappingMode(self, mode): + """Set the snapping mode. + + The mode is a mask. + + :param int mode: The mode to use + """ + if mode != self._snappingMode: + self._snappingMode = mode + self.updateInfo() + + def getSnappingMode(self): + """Returns the snapping mode as a mask + + :rtype: int + """ + return self._snappingMode + + _SNAPPING_LEGACY = (SNAPPING_CROSSHAIR | + SNAPPING_ACTIVE_ONLY | + SNAPPING_SYMBOLS_ONLY | + SNAPPING_CURVE | + SNAPPING_SCATTER) + """Legacy snapping mode""" + + @property + @deprecated(replacement="getSnappingMode", since_version="0.8") + def autoSnapToActiveCurve(self): + return self.getSnappingMode() == self._SNAPPING_LEGACY + + @autoSnapToActiveCurve.setter + @deprecated(replacement="setSnappingMode", since_version="0.8") + def autoSnapToActiveCurve(self, flag): + self.setSnappingMode( + self._SNAPPING_LEGACY if flag else self.SNAPPING_DISABLED) diff --git a/src/silx/gui/plot/tools/RadarView.py b/src/silx/gui/plot/tools/RadarView.py new file mode 100644 index 0000000..7076835 --- /dev/null +++ b/src/silx/gui/plot/tools/RadarView.py @@ -0,0 +1,361 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""QWidget displaying an overview of a 2D plot. + +This shows the available range of the data, and the current location of the +plot view. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "22/02/2021" + +import logging +import weakref +from ... import qt +from ...utils import LockReentrant + +_logger = logging.getLogger(__name__) + + +class _DraggableRectItem(qt.QGraphicsRectItem): + """RectItem which signals its change through visibleRectDragged.""" + def __init__(self, *args, **kwargs): + super(_DraggableRectItem, self).__init__( + *args, **kwargs) + + self._previousCursor = None + self.setFlag(qt.QGraphicsItem.ItemIsMovable) + self.setFlag(qt.QGraphicsItem.ItemSendsGeometryChanges) + self.setAcceptHoverEvents(True) + self._ignoreChange = False + self._constraint = 0, 0, 0, 0 + + def setConstraintRect(self, left, top, width, height): + """Set the constraint rectangle for dragging. + + The coordinates are in the _DraggableRectItem coordinate system. + + This constraint only applies to modification through interaction + (i.e., this constraint is not applied to change through API). + + If the _DraggableRectItem is smaller than the constraint rectangle, + the _DraggableRectItem remains within the constraint rectangle. + If the _DraggableRectItem is wider than the constraint rectangle, + the constraint rectangle remains within the _DraggableRectItem. + """ + self._constraint = left, left + width, top, top + height + + def setPos(self, *args, **kwargs): + """Overridden to ignore changes from API in itemChange.""" + self._ignoreChange = True + super(_DraggableRectItem, self).setPos(*args, **kwargs) + self._ignoreChange = False + + def moveBy(self, *args, **kwargs): + """Overridden to ignore changes from API in itemChange.""" + self._ignoreChange = True + super(_DraggableRectItem, self).moveBy(*args, **kwargs) + self._ignoreChange = False + + def itemChange(self, change, value): + """Callback called before applying changes to the item.""" + if (change == qt.QGraphicsItem.ItemPositionChange and + not self._ignoreChange): + # Makes sure that the visible area is in the data + # or that data is in the visible area if area is too wide + x, y = value.x(), value.y() + xMin, xMax, yMin, yMax = self._constraint + + if self.rect().width() <= (xMax - xMin): + if x < xMin: + value.setX(xMin) + elif x > xMax - self.rect().width(): + value.setX(xMax - self.rect().width()) + else: + if x > xMin: + value.setX(xMin) + elif x < xMax - self.rect().width(): + value.setX(xMax - self.rect().width()) + + if self.rect().height() <= (yMax - yMin): + if y < yMin: + value.setY(yMin) + elif y > yMax - self.rect().height(): + value.setY(yMax - self.rect().height()) + else: + if y > yMin: + value.setY(yMin) + elif y < yMax - self.rect().height(): + value.setY(yMax - self.rect().height()) + + if self.pos() != value: + # Notify change through signal + views = self.scene().views() + assert len(views) == 1 + views[0].visibleRectDragged.emit( + value.x() + self.rect().left(), + value.y() + self.rect().top(), + self.rect().width(), + self.rect().height()) + + return value + + return super(_DraggableRectItem, self).itemChange( + change, value) + + def hoverEnterEvent(self, event): + """Called when the mouse enters the rectangle area""" + self._previousCursor = self.cursor() + self.setCursor(qt.Qt.OpenHandCursor) + + def hoverLeaveEvent(self, event): + """Called when the mouse leaves the rectangle area""" + if self._previousCursor is not None: + self.setCursor(self._previousCursor) + self._previousCursor = None + + +class RadarView(qt.QGraphicsView): + """Widget presenting a synthetic view of a 2D area and + the current visible area. + + Coordinates are as in QGraphicsView: + x goes from left to right and y goes from top to bottom. + This widget preserves the aspect ratio of the areas. + + The 2D area and the visible area can be set with :meth:`setDataRect` + and :meth:`setVisibleRect`. + When the visible area has been dragged by the user, its new position + is signaled by the *visibleRectDragged* signal. + + It is possible to invert the direction of the axes by using the + :meth:`scale` method of QGraphicsView. + """ + + visibleRectDragged = qt.Signal(float, float, float, float) + """Signals that the visible rectangle has been dragged. + + It provides: left, top, width, height in data coordinates. + """ + + _DATA_PEN = qt.QPen(qt.QColor('white')) + _DATA_BRUSH = qt.QBrush(qt.QColor('light gray')) + _ACTIVEDATA_PEN = qt.QPen(qt.QColor('black')) + _ACTIVEDATA_BRUSH = qt.QBrush(qt.QColor('transparent')) + _ACTIVEDATA_PEN.setWidth(2) + _ACTIVEDATA_PEN.setCosmetic(True) + _VISIBLE_PEN = qt.QPen(qt.QColor('blue')) + _VISIBLE_PEN.setWidth(2) + _VISIBLE_PEN.setCosmetic(True) + _VISIBLE_BRUSH = qt.QBrush(qt.QColor(0, 0, 0, 0)) + _TOOLTIP = 'Radar View:\nRed contour: Visible area\nGray area: The image' + + _PIXMAP_SIZE = 256 + + def __init__(self, parent=None): + self.__plotRef = None + self._scene = qt.QGraphicsScene() + self._dataRect = self._scene.addRect(0, 0, 1, 1, + self._DATA_PEN, + self._DATA_BRUSH) + self._imageRect = self._scene.addRect(0, 0, 1, 1, + self._ACTIVEDATA_PEN, + self._ACTIVEDATA_BRUSH) + self._imageRect.setVisible(False) + self._scatterRect = self._scene.addRect(0, 0, 1, 1, + self._ACTIVEDATA_PEN, + self._ACTIVEDATA_BRUSH) + self._scatterRect.setVisible(False) + self._curveRect = self._scene.addRect(0, 0, 1, 1, + self._ACTIVEDATA_PEN, + self._ACTIVEDATA_BRUSH) + self._curveRect.setVisible(False) + + self._visibleRect = _DraggableRectItem(0, 0, 1, 1) + self._visibleRect.setPen(self._VISIBLE_PEN) + self._visibleRect.setBrush(self._VISIBLE_BRUSH) + self._scene.addItem(self._visibleRect) + + super(RadarView, self).__init__(self._scene, parent) + self.setHorizontalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff) + self.setVerticalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff) + self.setFocusPolicy(qt.Qt.NoFocus) + self.setStyleSheet('border: 0px') + self.setToolTip(self._TOOLTIP) + + self.__reentrant = LockReentrant() + self.visibleRectDragged.connect(self._viewRectDragged) + + self.__timer = qt.QTimer(self) + self.__timer.timeout.connect(self._updateDataContent) + + def sizeHint(self): + # """Overridden to avoid sizeHint to depend on content size.""" + return self.minimumSizeHint() + + def wheelEvent(self, event): + # """Overridden to disable vertical scrolling with wheel.""" + event.ignore() + + def resizeEvent(self, event): + # """Overridden to fit current content to new size.""" + self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio) + super(RadarView, self).resizeEvent(event) + + def setDataRect(self, left, top, width, height): + """Set the bounds of the data rectangular area. + + This sets the coordinate system. + """ + self._dataRect.setRect(left, top, width, height) + self._visibleRect.setConstraintRect(left, top, width, height) + self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio) + + def setVisibleRect(self, left, top, width, height): + """Set the visible rectangular area. + + The coordinates are relative to the data rect. + """ + self.__visibleRect = left, top, width, height + self._visibleRect.setRect(0, 0, width, height) + self._visibleRect.setPos(left, top) + self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio) + + def __setVisibleRectFromPlot(self, plot): + """Update radar view visible area. + + Takes care of y coordinate conversion. + """ + xMin, xMax = plot.getXAxis().getLimits() + yMin, yMax = plot.getYAxis().getLimits() + self.setVisibleRect(xMin, yMin, xMax - xMin, yMax - yMin) + + def getPlotWidget(self): + """Returns the connected plot + + :rtype: Union[None,PlotWidget] + """ + if self.__plotRef is None: + return None + plot = self.__plotRef() + if plot is None: + self.__plotRef = None + return plot + + def setPlotWidget(self, plot): + """Set the PlotWidget this radar view connects to. + + As result `setDataRect` and `setVisibleRect` will be called + automatically. + + :param Union[None,PlotWidget] plot: + """ + previousPlot = self.getPlotWidget() + if previousPlot is not None: # Disconnect previous plot + plot.getXAxis().sigLimitsChanged.disconnect(self._xLimitChanged) + plot.getYAxis().sigLimitsChanged.disconnect(self._yLimitChanged) + plot.getYAxis().sigInvertedChanged.disconnect(self._updateYAxisInverted) + + # Reset plot and timer + # FIXME: It would be good to clean up the display here + self.__plotRef = None + self.__timer.stop() + + if plot is not None: # Connect new plot + self.__plotRef = weakref.ref(plot) + plot.getXAxis().sigLimitsChanged.connect(self._xLimitChanged) + plot.getYAxis().sigLimitsChanged.connect(self._yLimitChanged) + plot.getYAxis().sigInvertedChanged.connect(self._updateYAxisInverted) + self.__setVisibleRectFromPlot(plot) + self._updateYAxisInverted() + self.__timer.start(500) + + def _xLimitChanged(self, vmin, vmax): + plot = self.getPlotWidget() + self.__setVisibleRectFromPlot(plot) + + def _yLimitChanged(self, vmin, vmax): + plot = self.getPlotWidget() + self.__setVisibleRectFromPlot(plot) + + def _updateYAxisInverted(self, inverted=None): + """Sync radar view axis orientation.""" + plot = self.getPlotWidget() + if inverted is None: + # Do not perform this when called from plot signal + inverted = plot.getYAxis().isInverted() + # Use scale to invert radarView + # RadarView default Y direction is from top to bottom + # As opposed to Plot. So invert RadarView when Plot is NOT inverted. + self.resetTransform() + if not inverted: + self.scale(1., -1.) + self.update() + + def _viewRectDragged(self, left, top, width, height): + """Slot for radar view visible rectangle changes.""" + plot = self.getPlotWidget() + if plot is None: + return + + if self.__reentrant.locked(): + return + + with self.__reentrant: + plot.setLimits(left, left + width, top, top + height) + + def _updateDataContent(self): + """Update the content to the current data content""" + plot = self.getPlotWidget() + if plot is None: + return + ranges = plot.getDataRange() + xmin, xmax = ranges.x if ranges.x is not None else (0, 0) + ymin, ymax = ranges.y if ranges.y is not None else (0, 0) + self.setDataRect(xmin, ymin, xmax - xmin, ymax - ymin) + + self.__updateItem(self._imageRect, plot.getActiveImage()) + self.__updateItem(self._scatterRect, plot.getActiveScatter()) + self.__updateItem(self._curveRect, plot.getActiveCurve()) + + def __updateItem(self, rect, item): + """Sync rect with item bounds + + :param QGraphicsRectItem rect: + :param Item item: + """ + if item is None: + rect.setVisible(False) + return + ranges = item._getBounds() + if ranges is None: + rect.setVisible(False) + return + xmin, xmax, ymin, ymax = ranges + width = xmax - xmin + height = ymax - ymin + rect.setRect(xmin, ymin, width, height) + rect.setVisible(True) diff --git a/src/silx/gui/plot/tools/__init__.py b/src/silx/gui/plot/tools/__init__.py new file mode 100644 index 0000000..09f468c --- /dev/null +++ b/src/silx/gui/plot/tools/__init__.py @@ -0,0 +1,50 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a set of widgets working with :class:`PlotWidget`. + +It provides some QToolBar and QWidget: + +- :class:`InteractiveModeToolBar` +- :class:`OutputToolBar` +- :class:`ImageToolBar` +- :class:`CurveToolBar` +- :class:`LimitsToolBar` +- :class:`PositionInfo` + +It also provides a :mod:`~silx.gui.plot.tools.roi` module to handle +interactive region of interest on a :class:`~silx.gui.plot.PlotWidget`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "01/03/2018" + + +from .toolbars import InteractiveModeToolBar # noqa +from .toolbars import OutputToolBar # noqa +from .toolbars import ImageToolBar, CurveToolBar, ScatterToolBar # noqa + +from .LimitsToolBar import LimitsToolBar # noqa +from .PositionInfo import PositionInfo # noqa diff --git a/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py new file mode 100644 index 0000000..44187ef --- /dev/null +++ b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py @@ -0,0 +1,54 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module profile tools for scatter plots. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +from silx.utils import deprecation +from . import toolbar + + +class ScatterProfileToolBar(toolbar.ProfileToolBar): + """QToolBar providing scatter plot profiling tools + + :param parent: See :class:`QToolBar`. + :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate. + :param str title: See :class:`QToolBar`. + """ + + def __init__(self, parent=None, plot=None, title=None): + super(ScatterProfileToolBar, self).__init__(parent, plot) + if title is not None: + deprecation.deprecated_warning("Attribute", + name="title", + reason="removed", + since_version="0.13.0", + only_once=True, + skip_backtrace_count=1) + self.setScheme("scatter") diff --git a/src/silx/gui/plot/tools/profile/__init__.py b/src/silx/gui/plot/tools/profile/__init__.py new file mode 100644 index 0000000..d91191e --- /dev/null +++ b/src/silx/gui/plot/tools/profile/__init__.py @@ -0,0 +1,38 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides tools to get profiles on plot data. + +It provides: + +- :class:`ScatterProfileToolBar`: a QToolBar to handle profile on scatter data + +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "07/06/2018" + + +from .ScatterProfileToolBar import ScatterProfileToolBar # noqa diff --git a/src/silx/gui/plot/tools/profile/core.py b/src/silx/gui/plot/tools/profile/core.py new file mode 100644 index 0000000..200f5cf --- /dev/null +++ b/src/silx/gui/plot/tools/profile/core.py @@ -0,0 +1,525 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module define core objects for profile tools. +""" + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno", "V. Valls"] +__license__ = "MIT" +__date__ = "17/04/2020" + +import collections +import numpy +import weakref + +from silx.image.bilinear import BilinearImage +from silx.gui import qt + + +CurveProfileData = collections.namedtuple( + 'CurveProfileData', [ + "coords", + "profile", + "title", + "xLabel", + "yLabel", + ]) + +RgbaProfileData = collections.namedtuple( + 'RgbaProfileData', [ + "coords", + "profile", + "profile_r", + "profile_g", + "profile_b", + "profile_a", + "title", + "xLabel", + "yLabel", + ]) + +ImageProfileData = collections.namedtuple( + 'ImageProfileData', [ + 'coords', + 'profile', + 'title', + 'xLabel', + 'yLabel', + 'colormap', + ]) + + +class ProfileRoiMixIn: + """Base mix-in for ROI which can be used to select a profile. + + This mix-in have to be applied to a :class:`~silx.gui.plot.items.roi.RegionOfInterest` + in order to be usable by a :class:`~silx.gui.plot.tools.profile.manager.ProfileManager`. + """ + + ITEM_KIND = None + """Define the plot item which can be used with this profile ROI""" + + sigProfilePropertyChanged = qt.Signal() + """Emitted when a property of this profile have changed""" + + sigPlotItemChanged = qt.Signal() + """Emitted when the plot item linked to this profile have changed""" + + def __init__(self, parent=None): + self.__profileWindow = None + self.__profileManager = None + self.__plotItem = None + self.setName("Profile") + self.setEditable(True) + self.setSelectable(True) + + def invalidateProfile(self): + """Must be called by the implementation when the profile have to be + recomputed.""" + profileManager = self.getProfileManager() + if profileManager is not None: + profileManager.requestUpdateProfile(self) + + def invalidateProperties(self): + """Must be called when a property of the profile have changed.""" + self.sigProfilePropertyChanged.emit() + + def _setPlotItem(self, plotItem): + """Specify the plot item to use with this profile + + :param `~silx.gui.plot.items.item.Item` plotItem: A plot item + """ + previousPlotItem = self.getPlotItem() + if previousPlotItem is plotItem: + return + self.__plotItem = weakref.ref(plotItem) + self.sigPlotItemChanged.emit() + + def getPlotItem(self): + """Returns the plot item used by this profile + + :rtype: `~silx.gui.plot.items.item.Item` + """ + if self.__plotItem is None: + return None + plotItem = self.__plotItem() + if plotItem is None: + self.__plotItem = None + return plotItem + + def _setProfileManager(self, profileManager): + self.__profileManager = profileManager + + def getProfileManager(self): + """ + Returns the profile manager connected to this ROI. + + :rtype: ~silx.gui.plot.tools.profile.manager.ProfileManager + """ + return self.__profileManager + + def getProfileWindow(self): + """ + Returns the windows associated to this ROI, else None. + + :rtype: ProfileWindow + """ + return self.__profileWindow + + def setProfileWindow(self, profileWindow): + """ + Associate a window to this ROI. Can be None. + + :param ProfileWindow profileWindow: A main window + to display the profile. + """ + if profileWindow is self.__profileWindow: + return + if self.__profileWindow is not None: + self.__profileWindow.sigClose.disconnect(self.__profileWindowAboutToClose) + self.__profileWindow.setRoiProfile(None) + self.__profileWindow = profileWindow + if self.__profileWindow is not None: + self.__profileWindow.sigClose.connect(self.__profileWindowAboutToClose) + self.__profileWindow.setRoiProfile(self) + + def __profileWindowAboutToClose(self): + profileManager = self.getProfileManager() + roiManager = profileManager.getRoiManager() + try: + roiManager.removeRoi(self) + except ValueError: + pass + + def computeProfile(self, item): + """ + Compute the profile which will be displayed. + + This method is not called from the main Qt thread, but from a thread + pool. + + :param ~silx.gui.plot.items.Item item: A plot item + :rtype: Union[CurveProfileData,ImageProfileData] + """ + raise NotImplementedError() + + +def _alignedFullProfile(data, origin, scale, position, roiWidth, axis, method): + """Get a profile along one axis on a stack of images + + :param numpy.ndarray data: 3D volume (stack of 2D images) + The first dimension is the image index. + :param origin: Origin of image in plot (ox, oy) + :param scale: Scale of image in plot (sx, sy) + :param float position: Position of profile line in plot coords + on the axis orthogonal to the profile direction. + :param int roiWidth: Width of the profile in image pixels. + :param int axis: 0 for horizontal profile, 1 for vertical. + :param str method: method to compute the profile. Can be 'mean' or 'sum' or + 'none' + :return: profile image + effective ROI area corners in plot coords + """ + assert axis in (0, 1) + assert len(data.shape) == 3 + assert method in ('mean', 'sum', 'none') + + # Convert from plot to image coords + imgPos = int((position - origin[1 - axis]) / scale[1 - axis]) + + if axis == 1: # Vertical profile + # Transpose image to always do a horizontal profile + data = numpy.transpose(data, (0, 2, 1)) + + nimages, height, width = data.shape + + roiWidth = min(height, roiWidth) # Clip roi width to image size + + # Get [start, end[ coords of the roi in the data + start = int(int(imgPos) + 0.5 - roiWidth / 2.) + start = min(max(0, start), height - roiWidth) + end = start + roiWidth + + if method == 'none': + profile = None + else: + if start < height and end > 0: + if method == 'mean': + fct = numpy.mean + elif method == 'sum': + fct = numpy.sum + else: + raise ValueError('method not managed') + profile = fct(data[:, max(0, start):min(end, height), :], axis=1).astype(numpy.float32) + else: + profile = numpy.zeros((nimages, width), dtype=numpy.float32) + + # Compute effective ROI in plot coords + profileBounds = numpy.array( + (0, width, width, 0), + dtype=numpy.float32) * scale[axis] + origin[axis] + roiBounds = numpy.array( + (start, start, end, end), + dtype=numpy.float32) * scale[1 - axis] + origin[1 - axis] + + if axis == 0: # Horizontal profile + area = profileBounds, roiBounds + else: # vertical profile + area = roiBounds, profileBounds + + return profile, area + + +def _alignedPartialProfile(data, rowRange, colRange, axis, method): + """Mean of a rectangular region (ROI) of a stack of images + along a given axis. + + Returned values and all parameters are in image coordinates. + + :param numpy.ndarray data: 3D volume (stack of 2D images) + The first dimension is the image index. + :param rowRange: [min, max[ of ROI rows (upper bound excluded). + :type rowRange: 2-tuple of int (min, max) with min < max + :param colRange: [min, max[ of ROI columns (upper bound excluded). + :type colRange: 2-tuple of int (min, max) with min < max + :param int axis: The axis along which to take the profile of the ROI. + 0: Sum rows along columns. + 1: Sum columns along rows. + :param str method: method to compute the profile. Can be 'mean' or 'sum' + :return: Profile image along the ROI as the mean of the intersection + of the ROI and the image. + """ + assert axis in (0, 1) + assert len(data.shape) == 3 + assert rowRange[0] < rowRange[1] + assert colRange[0] < colRange[1] + assert method in ('mean', 'sum') + + nimages, height, width = data.shape + + # Range aligned with the integration direction + profileRange = colRange if axis == 0 else rowRange + + profileLength = abs(profileRange[1] - profileRange[0]) + + # Subset of the image to use as intersection of ROI and image + rowStart = min(max(0, rowRange[0]), height) + rowEnd = min(max(0, rowRange[1]), height) + colStart = min(max(0, colRange[0]), width) + colEnd = min(max(0, colRange[1]), width) + + if method == 'mean': + _fct = numpy.mean + elif method == 'sum': + _fct = numpy.sum + else: + raise ValueError('method not managed') + + imgProfile = _fct(data[:, rowStart:rowEnd, colStart:colEnd], axis=axis + 1, + dtype=numpy.float32) + + # Profile including out of bound area + profile = numpy.zeros((nimages, profileLength), dtype=numpy.float32) + + # Place imgProfile in full profile + offset = - min(0, profileRange[0]) + profile[:, offset:offset + imgProfile.shape[1]] = imgProfile + + return profile + + +def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): + """Create the profile line for the the given image. + + :param roiInfo: information about the ROI: start point, end point and + type ("X", "Y", "D") + :param numpy.ndarray currentData: the 2D image or the 3D stack of images + on which we compute the profile. + :param origin: (ox, oy) the offset from origin + :type origin: 2-tuple of float + :param scale: (sx, sy) the scale to use + :type scale: 2-tuple of float + :param int lineWidth: width of the profile line + :param str method: method to compute the profile. Can be 'mean' or 'sum' + or 'none': to compute everything except the profile + :return: `coords, profile, area, profileName, xLabel`, where: + - coords is the X coordinate to use to display the profile + - profile is a 2D array of the profiles of the stack of images. + For a single image, the profile is a curve, so this parameter + has a shape *(1, len(curve))* + - area is a tuple of two 1D arrays with 4 values each. They represent + the effective ROI area corners in plot coords. + - profileName is a string describing the ROI, meant to be used as + title of the profile plot + - xLabel the label for X in the profile window + + :rtype: tuple(ndarray,ndarray,(ndarray,ndarray),str) + """ + if currentData is None or roiInfo is None or lineWidth is None: + raise ValueError("createProfile called with invalide arguments") + + # force 3D data (stack of images) + if len(currentData.shape) == 2: + currentData3D = currentData.reshape((1,) + currentData.shape) + elif len(currentData.shape) == 3: + currentData3D = currentData + + roiWidth = max(1, lineWidth) + roiStart, roiEnd, lineProjectionMode = roiInfo + + if lineProjectionMode == 'X': # Horizontal profile on the whole image + profile, area = _alignedFullProfile(currentData3D, + origin, scale, + roiStart[1], roiWidth, + axis=0, + method=method) + + if method == 'none': + coords = None + else: + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[0] + origin[0] + + yMin, yMax = min(area[1]), max(area[1]) - 1 + if roiWidth <= 1: + profileName = '{ylabel} = %g' % yMin + else: + profileName = '{ylabel} = [%g, %g]' % (yMin, yMax) + xLabel = '{xlabel}' + + elif lineProjectionMode == 'Y': # Vertical profile on the whole image + profile, area = _alignedFullProfile(currentData3D, + origin, scale, + roiStart[0], roiWidth, + axis=1, + method=method) + + if method == 'none': + coords = None + else: + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[1] + origin[1] + + xMin, xMax = min(area[0]), max(area[0]) - 1 + if roiWidth <= 1: + profileName = '{xlabel} = %g' % xMin + else: + profileName = '{xlabel} = [%g, %g]' % (xMin, xMax) + xLabel = '{ylabel}' + + else: # Free line profile + + # Convert start and end points in image coords as (row, col) + startPt = ((roiStart[1] - origin[1]) / scale[1], + (roiStart[0] - origin[0]) / scale[0]) + endPt = ((roiEnd[1] - origin[1]) / scale[1], + (roiEnd[0] - origin[0]) / scale[0]) + + if (int(startPt[0]) == int(endPt[0]) or + int(startPt[1]) == int(endPt[1])): + # Profile is aligned with one of the axes + + # Convert to int + startPt = int(startPt[0]), int(startPt[1]) + endPt = int(endPt[0]), int(endPt[1]) + + # Ensure startPt <= endPt + if startPt[0] > endPt[0] or startPt[1] > endPt[1]: + startPt, endPt = endPt, startPt + + if startPt[0] == endPt[0]: # Row aligned + rowRange = (int(startPt[0] + 0.5 - 0.5 * roiWidth), + int(startPt[0] + 0.5 + 0.5 * roiWidth)) + colRange = startPt[1], endPt[1] + 1 + if method == 'none': + profile = None + else: + profile = _alignedPartialProfile(currentData3D, + rowRange, colRange, + axis=0, + method=method) + + else: # Column aligned + rowRange = startPt[0], endPt[0] + 1 + colRange = (int(startPt[1] + 0.5 - 0.5 * roiWidth), + int(startPt[1] + 0.5 + 0.5 * roiWidth)) + if method == 'none': + profile = None + else: + profile = _alignedPartialProfile(currentData3D, + rowRange, colRange, + axis=1, + method=method) + # Convert ranges to plot coords to draw ROI area + area = ( + numpy.array( + (colRange[0], colRange[1], colRange[1], colRange[0]), + dtype=numpy.float32) * scale[0] + origin[0], + numpy.array( + (rowRange[0], rowRange[0], rowRange[1], rowRange[1]), + dtype=numpy.float32) * scale[1] + origin[1]) + + else: # General case: use bilinear interpolation + + # Ensure startPt <= endPt + if (startPt[1] > endPt[1] or ( + startPt[1] == endPt[1] and startPt[0] > endPt[0])): + startPt, endPt = endPt, startPt + + if method == 'none': + profile = None + else: + profile = [] + for slice_idx in range(currentData3D.shape[0]): + bilinear = BilinearImage(currentData3D[slice_idx, :, :]) + + profile.append(bilinear.profile_line( + (startPt[0] - 0.5, startPt[1] - 0.5), + (endPt[0] - 0.5, endPt[1] - 0.5), + roiWidth, + method=method)) + profile = numpy.array(profile) + + # Extend ROI with half a pixel on each end, and + # Convert back to plot coords (x, y) + length = numpy.sqrt((endPt[0] - startPt[0]) ** 2 + + (endPt[1] - startPt[1]) ** 2) + dRow = (endPt[0] - startPt[0]) / length + dCol = (endPt[1] - startPt[1]) / length + + # Extend ROI with half a pixel on each end + roiStartPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol + roiEndPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol + + # Rotate deltas by 90 degrees to apply line width + dRow, dCol = dCol, -dRow + + area = ( + numpy.array((roiStartPt[1] - 0.5 * roiWidth * dCol, + roiStartPt[1] + 0.5 * roiWidth * dCol, + roiEndPt[1] + 0.5 * roiWidth * dCol, + roiEndPt[1] - 0.5 * roiWidth * dCol), + dtype=numpy.float32) * scale[0] + origin[0], + numpy.array((roiStartPt[0] - 0.5 * roiWidth * dRow, + roiStartPt[0] + 0.5 * roiWidth * dRow, + roiEndPt[0] + 0.5 * roiWidth * dRow, + roiEndPt[0] - 0.5 * roiWidth * dRow), + dtype=numpy.float32) * scale[1] + origin[1]) + + # Convert start and end points back to plot coords + y0 = startPt[0] * scale[1] + origin[1] + x0 = startPt[1] * scale[0] + origin[0] + y1 = endPt[0] * scale[1] + origin[1] + x1 = endPt[1] * scale[0] + origin[0] + + if startPt[1] == endPt[1]: + profileName = '{xlabel} = %g; {ylabel} = [%g, %g]' % (x0, y0, y1) + if method == 'none': + coords = None + else: + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[1] + y0 + xLabel = '{ylabel}' + + elif startPt[0] == endPt[0]: + profileName = '{ylabel} = %g; {xlabel} = [%g, %g]' % (y0, x0, x1) + if method == 'none': + coords = None + else: + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[0] + x0 + xLabel = '{xlabel}' + + else: + m = (y1 - y0) / (x1 - x0) + b = y0 - m * x0 + profileName = '{ylabel} = %g * {xlabel} %+g' % (m, b) + if method == 'none': + coords = None + else: + coords = numpy.linspace(x0, x1, len(profile[0]), + endpoint=True, + dtype=numpy.float32) + xLabel = '{xlabel}' + + return coords, profile, area, profileName, xLabel diff --git a/src/silx/gui/plot/tools/profile/editors.py b/src/silx/gui/plot/tools/profile/editors.py new file mode 100644 index 0000000..80e0452 --- /dev/null +++ b/src/silx/gui/plot/tools/profile/editors.py @@ -0,0 +1,307 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides editors which are used to custom profile ROI properties. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "28/06/2018" + +import logging + +from silx.gui import qt + +from silx.gui.utils import blockSignals +from silx.gui.plot.PlotToolButtons import ProfileOptionToolButton +from silx.gui.plot.PlotToolButtons import ProfileToolButton +from . import rois +from . import core + + +_logger = logging.getLogger(__name__) + + +class _NoProfileRoiEditor(qt.QWidget): + + sigDataCommited = qt.Signal() + + def setEditorData(self, roi): + pass + + def setRoiData(self, roi): + pass + + +class _DefaultImageProfileRoiEditor(qt.QWidget): + + sigDataCommited = qt.Signal() + + def __init__(self, parent=None): + qt.QWidget.__init__(self, parent=parent) + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self._initLayout(layout) + + def _initLayout(self, layout): + self._lineWidth = qt.QSpinBox(self) + self._lineWidth.setRange(1, 1000) + self._lineWidth.setValue(1) + self._lineWidth.valueChanged[int].connect(self._widgetChanged) + + self._methodsButton = ProfileOptionToolButton(parent=self, plot=None) + self._methodsButton.sigMethodChanged.connect(self._widgetChanged) + + label = qt.QLabel('W:') + label.setToolTip("Line width in pixels") + layout.addWidget(label) + layout.addWidget(self._lineWidth) + layout.addWidget(self._methodsButton) + + def _widgetChanged(self, value=None): + self.commitData() + + def commitData(self): + self.sigDataCommited.emit() + + def setEditorData(self, roi): + with blockSignals(self._lineWidth): + self._lineWidth.setValue(roi.getProfileLineWidth()) + with blockSignals(self._methodsButton): + method = roi.getProfileMethod() + self._methodsButton.setMethod(method) + + def setRoiData(self, roi): + lineWidth = self._lineWidth.value() + roi.setProfileLineWidth(lineWidth) + method = self._methodsButton.getMethod() + roi.setProfileMethod(method) + + +class _DefaultImageStackProfileRoiEditor(_DefaultImageProfileRoiEditor): + + def _initLayout(self, layout): + super(_DefaultImageStackProfileRoiEditor, self)._initLayout(layout) + self._profileDim = ProfileToolButton(parent=self, plot=None) + self._profileDim.sigDimensionChanged.connect(self._widgetChanged) + layout.addWidget(self._profileDim) + + def setEditorData(self, roi): + super(_DefaultImageStackProfileRoiEditor, self).setEditorData(roi) + with blockSignals(self._profileDim): + kind = roi.getProfileType() + dim = {"1D": 1, "2D": 2}[kind] + self._profileDim.setDimension(dim) + + def setRoiData(self, roi): + super(_DefaultImageStackProfileRoiEditor, self).setRoiData(roi) + dim = self._profileDim.getDimension() + kind = {1: "1D", 2: "2D"}[dim] + roi.setProfileType(kind) + + +class _DefaultScatterProfileRoiEditor(qt.QWidget): + + sigDataCommited = qt.Signal() + + def __init__(self, parent=None): + qt.QWidget.__init__(self, parent=parent) + + self._nPoints = qt.QSpinBox(self) + self._nPoints.setRange(1, 9999) + self._nPoints.setValue(1024) + self._nPoints.valueChanged[int].connect(self.__widgetChanged) + + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + label = qt.QLabel('Samples:') + label.setToolTip("Number of sample points of the profile") + layout.addWidget(label) + layout.addWidget(self._nPoints) + + def __widgetChanged(self, value=None): + self.commitData() + + def commitData(self): + self.sigDataCommited.emit() + + def setEditorData(self, roi): + with blockSignals(self._nPoints): + self._nPoints.setValue(roi.getNPoints()) + + def setRoiData(self, roi): + nPoints = self._nPoints.value() + roi.setNPoints(nPoints) + + +class ProfileRoiEditorAction(qt.QWidgetAction): + """ + Action displaying GUI to edit the selected ROI. + + :param qt.QWidget parent: Parent widget + """ + def __init__(self, parent=None): + super(ProfileRoiEditorAction, self).__init__(parent) + self.__roiManager = None + self.__roi = None + self.__inhibiteReentance = None + + def createWidget(self, parent): + """Inherit the method to create a new editor""" + widget = qt.QWidget(parent) + layout = qt.QHBoxLayout(widget) + if isinstance(parent, qt.QMenu): + margins = layout.contentsMargins() + layout.setContentsMargins(margins.left(), 0, margins.right(), 0) + else: + layout.setContentsMargins(0, 0, 0, 0) + + editorClass = self.getEditorClass(self.__roi) + editor = editorClass(parent) + editor.setEditorData(self.__roi) + self.__setEditor(widget, editor) + return widget + + def deleteWidget(self, widget): + """Inherit the method to delete an editor""" + self.__setEditor(widget, None) + return qt.QWidgetAction.deleteWidget(self, widget) + + def _getEditor(self, widget): + """Returns the editor contained in the widget holder""" + layout = widget.layout() + if layout.count() == 0: + return None + return layout.itemAt(0).widget() + + def setRoiManager(self, roiManager): + """ + Connect this action to a ROI manager. + + :param RegionOfInterestManager roiManager: A ROI manager + """ + if self.__roiManager is roiManager: + return + if self.__roiManager is not None: + self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged) + self.__roiManager = roiManager + if self.__roiManager is not None: + self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged) + self.__currentRoiChanged(roiManager.getCurrentRoi()) + + def __currentRoiChanged(self, roi): + """Handle changes of the selected ROI""" + if roi is not None and not isinstance(roi, core.ProfileRoiMixIn): + return + self.setProfileRoi(roi) + + def setProfileRoi(self, roi): + """Set a profile ROI to edit. + + :param ProfileRoiMixIn roi: A profile ROI + """ + if self.__roi is roi: + return + if self.__roi is not None: + self.__roi.sigProfilePropertyChanged.disconnect(self.__roiPropertyChanged) + self.__roi = roi + if self.__roi is not None: + self.__roi.sigProfilePropertyChanged.connect(self.__roiPropertyChanged) + self._updateWidgets() + + def getRoiProfile(self): + """Returns the edited profile ROI. + + :rtype: ProfileRoiMixIn + """ + return self.__roi + + def __roiPropertyChanged(self): + """Handle changes on the property defining the ROI. + """ + self._updateWidgetValues() + + def __setEditor(self, widget, editor): + """Set the editor to display. + + :param qt.QWidget editor: The editor to display + """ + previousEditor = self._getEditor(widget) + if previousEditor is editor: + return + layout = widget.layout() + if previousEditor is not None: + previousEditor.sigDataCommited.disconnect(self._editorDataCommited) + layout.removeWidget(previousEditor) + previousEditor.deleteLater() + if editor is not None: + editor.sigDataCommited.connect(self._editorDataCommited) + layout.addWidget(editor) + + def getEditorClass(self, roi): + """Returns the editor class to use according to the ROI.""" + if roi is None: + editorClass = _NoProfileRoiEditor + elif isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn, + rois.ProfileImageStackCrossROI)): + # Must be done before the default image ROI + # Cause ImageStack ROIs inherit from Image ROIs + editorClass = _DefaultImageStackProfileRoiEditor + elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn, + rois.ProfileImageCrossROI)): + editorClass = _DefaultImageProfileRoiEditor + elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn, + rois.ProfileScatterCrossROI)): + editorClass = _DefaultScatterProfileRoiEditor + else: + # Unsupported + editorClass = _NoProfileRoiEditor + return editorClass + + def _updateWidgets(self): + """Update the kind of editor to display, according to the selected + profile ROI.""" + parent = self.parent() + editorClass = self.getEditorClass(self.__roi) + for widget in self.createdWidgets(): + editor = editorClass(parent) + editor.setEditorData(self.__roi) + self.__setEditor(widget, editor) + + def _updateWidgetValues(self): + """Update the content of the displayed editor, according to the + selected profile ROI.""" + for widget in self.createdWidgets(): + editor = self._getEditor(widget) + if self.__inhibiteReentance is editor: + continue + editor.setEditorData(self.__roi) + + def _editorDataCommited(self): + """Handle changes from the editor.""" + editor = self.sender() + if self.__roi is not None: + self.__inhibiteReentance = editor + editor.setRoiData(self.__roi) + self.__inhibiteReentance = None diff --git a/src/silx/gui/plot/tools/profile/manager.py b/src/silx/gui/plot/tools/profile/manager.py new file mode 100644 index 0000000..4a22bc0 --- /dev/null +++ b/src/silx/gui/plot/tools/profile/manager.py @@ -0,0 +1,1079 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a manager to compute and display profiles. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "28/06/2018" + +import logging +import weakref + +from silx.gui import qt +from silx.gui import colors +from silx.gui import utils + +from silx.utils.weakref import WeakMethodProxy +from silx.gui import icons +from silx.gui.plot import PlotWidget +from silx.gui.plot.tools.roi import RegionOfInterestManager +from silx.gui.plot.tools.roi import CreateRoiModeAction +from silx.gui.plot import items +from silx.gui.qt import silxGlobalThreadPool +from silx.gui.qt import inspect +from . import rois +from . import core +from . import editors + + +_logger = logging.getLogger(__name__) + + +class _RunnableComputeProfile(qt.QRunnable): + """Runner to process profiles + + :param qt.QThreadPool threadPool: The thread which will be used to + execute this runner. It is used to update the used signals + :param ~silx.gui.plot.items.Item item: Item in which the profile is + computed + :param ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn roi: ROI + defining the profile shape and other characteristics + """ + + class _Signals(qt.QObject): + """Signal holder""" + resultReady = qt.Signal(object, object) + runnerFinished = qt.Signal(object) + + def __init__(self, threadPool, item, roi): + """Constructor + """ + super(_RunnableComputeProfile, self).__init__() + self._signals = self._Signals() + self._signals.moveToThread(threadPool.thread()) + self._item = item + self._roi = roi + self._cancelled = False + + def _lazyCancel(self): + """Cancel the runner if it is not yet started. + + The threadpool will still execute the runner, but this will process + nothing. + + This is only used with Qt<5.9 where QThreadPool.tryTake is not available. + """ + self._cancelled = True + + def autoDelete(self): + return False + + def getRoi(self): + """Returns the ROI in which the runner will compute a profile. + + :rtype: ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn + """ + return self._roi + + @property + def resultReady(self): + """Signal emitted when the result of the computation is available. + + This signal provides 2 values: The ROI, and the computation result. + """ + return self._signals.resultReady + + @property + def runnerFinished(self): + """Signal emitted when runner have finished. + + This signal provides a single value: the runner itself. + """ + return self._signals.runnerFinished + + def run(self): + """Process the profile computation. + """ + if not self._cancelled: + try: + profileData = self._roi.computeProfile(self._item) + except Exception: + _logger.error("Error while computing profile", exc_info=True) + else: + self.resultReady.emit(self._roi, profileData) + self.runnerFinished.emit(self) + + +class ProfileWindow(qt.QMainWindow): + """ + Display a computed profile. + + The content can be described using :meth:`setRoiProfile` if the source of + the profile is a profile ROI, and :meth:`setProfile` for the data content. + """ + + sigClose = qt.Signal() + """Emitted by :meth:`closeEvent` (e.g. when the window is closed + through the window manager's close icon).""" + + def __init__(self, parent=None, backend=None): + qt.QMainWindow.__init__(self, parent=parent, flags=qt.Qt.Dialog) + + self.setWindowTitle('Profile window') + self._plot1D = None + self._plot2D = None + self._backend = backend + self._data = None + + widget = qt.QWidget() + self._layout = qt.QStackedLayout(widget) + self._layout.setContentsMargins(0, 0, 0, 0) + self.setCentralWidget(widget) + + def prepareWidget(self, roi): + """Called before the show to prepare the window to use with + a specific ROI.""" + if isinstance(roi, rois._DefaultImageStackProfileRoiMixIn): + profileType = roi.getProfileType() + else: + profileType = "1D" + if profileType == "1D": + self.getPlot1D() + elif profileType == "2D": + self.getPlot2D() + + def createPlot1D(self, parent, backend): + """Inherit this function to create your own plot to render 1D + profiles. The default value is a `Plot1D`. + + :param parent: The parent of this widget or None. + :param backend: The backend to use for the plot. + See :class:`PlotWidget` for the list of supported backend. + :rtype: PlotWidget + """ + # import here to avoid circular import + from ...PlotWindow import Plot1D + plot = Plot1D(parent=parent, backend=backend) + plot.setDataMargins(yMinMargin=0.1, yMaxMargin=0.1) + plot.setGraphYLabel('Profile') + plot.setGraphXLabel('') + return plot + + def createPlot2D(self, parent, backend): + """Inherit this function to create your own plot to render 2D + profiles. The default value is a `Plot2D`. + + :param parent: The parent of this widget or None. + :param backend: The backend to use for the plot. + See :class:`PlotWidget` for the list of supported backend. + :rtype: PlotWidget + """ + # import here to avoid circular import + from ...PlotWindow import Plot2D + return Plot2D(parent=parent, backend=backend) + + def getPlot1D(self, init=True): + """Return the current plot used to display curves and create it if it + does not yet exists and `init` is True. Else returns None.""" + if not init: + return self._plot1D + if self._plot1D is None: + self._plot1D = self.createPlot1D(self, self._backend) + self._layout.addWidget(self._plot1D) + return self._plot1D + + def _showPlot1D(self): + plot = self.getPlot1D() + self._layout.setCurrentWidget(plot) + + def getPlot2D(self, init=True): + """Return the current plot used to display image and create it if it + does not yet exists and `init` is True. Else returns None.""" + if not init: + return self._plot2D + if self._plot2D is None: + self._plot2D = self.createPlot2D(parent=self, backend=self._backend) + self._layout.addWidget(self._plot2D) + return self._plot2D + + def _showPlot2D(self): + plot = self.getPlot2D() + self._layout.setCurrentWidget(plot) + + def getCurrentPlotWidget(self): + return self._layout.currentWidget() + + def closeEvent(self, qCloseEvent): + self.sigClose.emit() + qCloseEvent.accept() + + def setRoiProfile(self, roi): + """Set the profile ROI which it the source of the following data + to display. + + :param ProfileRoiMixIn roi: The profile ROI data source + """ + if roi is None: + return + self.__color = colors.rgba(roi.getColor()) + + def _setImageProfile(self, data): + """ + Setup the window to display a new profile data which is represented + by an image. + + :param core.ImageProfileData data: Computed data profile + """ + plot = self.getPlot2D() + + plot.clear() + plot.setGraphTitle(data.title) + plot.getXAxis().setLabel(data.xLabel) + + + coords = data.coords + colormap = data.colormap + profileScale = (coords[-1] - coords[0]) / data.profile.shape[1], 1 + plot.addImage(data.profile, + legend="profile", + colormap=colormap, + origin=(coords[0], 0), + scale=profileScale) + plot.getYAxis().setLabel("Frame index (depth)") + + self._showPlot2D() + + def _setCurveProfile(self, data): + """ + Setup the window to display a new profile data which is represented + by a curve. + + :param core.CurveProfileData data: Computed data profile + """ + plot = self.getPlot1D() + + plot.clear() + plot.setGraphTitle(data.title) + plot.getXAxis().setLabel(data.xLabel) + plot.getYAxis().setLabel(data.yLabel) + + plot.addCurve(data.coords, + data.profile, + legend="level", + color=self.__color) + + self._showPlot1D() + + def _setRgbaProfile(self, data): + """ + Setup the window to display a new profile data which is represented + by a curve. + + :param core.RgbaProfileData data: Computed data profile + """ + plot = self.getPlot1D() + + plot.clear() + plot.setGraphTitle(data.title) + plot.getXAxis().setLabel(data.xLabel) + plot.getYAxis().setLabel(data.yLabel) + + self._showPlot1D() + + plot.addCurve(data.coords, data.profile, + legend="level", color="black") + plot.addCurve(data.coords, data.profile_r, + legend="red", color="red") + plot.addCurve(data.coords, data.profile_g, + legend="green", color="green") + plot.addCurve(data.coords, data.profile_b, + legend="blue", color="blue") + if data.profile_a is not None: + plot.addCurve(data.coords, data.profile_a, legend="alpha", color="gray") + + def clear(self): + """Clear the window profile""" + plot = self.getPlot1D(init=False) + if plot is not None: + plot.clear() + plot = self.getPlot2D(init=False) + if plot is not None: + plot.clear() + + def getProfile(self): + """Returns the profile data which is displayed""" + return self.__data + + def setProfile(self, data): + """ + Setup the window to display a new profile data. + + This method dispatch the result to a specific method according to the + data type. + + :param data: Computed data profile + """ + self.__data = data + if data is None: + self.clear() + elif isinstance(data, core.ImageProfileData): + self._setImageProfile(data) + elif isinstance(data, core.RgbaProfileData): + self._setRgbaProfile(data) + elif isinstance(data, core.CurveProfileData): + self._setCurveProfile(data) + else: + raise TypeError("Unsupported type %s" % type(data)) + + +class _ClearAction(qt.QAction): + """Action to clear the profile manager + + The action is only enabled if something can be cleaned up. + """ + + def __init__(self, parent, profileManager): + super(_ClearAction, self).__init__(parent) + self.__profileManager = weakref.ref(profileManager) + icon = icons.getQIcon('profile-clear') + self.setIcon(icon) + self.setText('Clear profile') + self.setToolTip('Clear the profiles') + self.setCheckable(False) + self.setEnabled(False) + self.triggered.connect(profileManager.clearProfile) + plot = profileManager.getPlotWidget() + roiManager = profileManager.getRoiManager() + plot.sigInteractiveModeChanged.connect(self.__modeUpdated) + roiManager.sigRoiChanged.connect(self.__roiListUpdated) + + def getProfileManager(self): + return self.__profileManager() + + def __roiListUpdated(self): + self.__update() + + def __modeUpdated(self, source): + self.__update() + + def __update(self): + profileManager = self.getProfileManager() + if profileManager is None: + return + roiManager = profileManager.getRoiManager() + if roiManager is None: + return + enabled = roiManager.isStarted() or len(roiManager.getRois()) > 0 + self.setEnabled(enabled) + + +class _StoreLastParamBehavior(qt.QObject): + """This object allow to store and restore the properties of the ROI + profiles""" + + def __init__(self, parent): + assert isinstance(parent, ProfileManager) + super(_StoreLastParamBehavior, self).__init__(parent=parent) + self.__properties = {} + self.__profileRoi = None + self.__filter = utils.LockReentrant() + + def _roi(self): + """Return the spied ROI""" + if self.__profileRoi is None: + return None + roi = self.__profileRoi() + if roi is None: + self.__profileRoi = None + return roi + + def setProfileRoi(self, roi): + """Set a profile ROI to spy. + + :param ProfileRoiMixIn roi: A profile ROI + """ + previousRoi = self._roi() + if previousRoi is roi: + return + if previousRoi is not None: + previousRoi.sigProfilePropertyChanged.disconnect(self._profilePropertyChanged) + self.__profileRoi = None if roi is None else weakref.ref(roi) + if roi is not None: + roi.sigProfilePropertyChanged.connect(self._profilePropertyChanged) + + def _profilePropertyChanged(self): + """Handle changes on the properties defining the profile ROI. + """ + if self.__filter.locked(): + return + roi = self.sender() + self.storeProperties(roi) + + def storeProperties(self, roi): + if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn, + rois.ProfileImageStackCrossROI)): + self.__properties["method"] = roi.getProfileMethod() + self.__properties["line-width"] = roi.getProfileLineWidth() + self.__properties["type"] = roi.getProfileType() + elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn, + rois.ProfileImageCrossROI)): + self.__properties["method"] = roi.getProfileMethod() + self.__properties["line-width"] = roi.getProfileLineWidth() + elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn, + rois.ProfileScatterCrossROI)): + self.__properties["npoints"] = roi.getNPoints() + + def restoreProperties(self, roi): + with self.__filter: + if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn, + rois.ProfileImageStackCrossROI)): + value = self.__properties.get("method", None) + if value is not None: + roi.setProfileMethod(value) + value = self.__properties.get("line-width", None) + if value is not None: + roi.setProfileLineWidth(value) + value = self.__properties.get("type", None) + if value is not None: + roi.setProfileType(value) + elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn, + rois.ProfileImageCrossROI)): + value = self.__properties.get("method", None) + if value is not None: + roi.setProfileMethod(value) + value = self.__properties.get("line-width", None) + if value is not None: + roi.setProfileLineWidth(value) + elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn, + rois.ProfileScatterCrossROI)): + value = self.__properties.get("npoints", None) + if value is not None: + roi.setNPoints(value) + + +class ProfileManager(qt.QObject): + """Base class for profile management tools + + :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate. + :param plot: :class:`~silx.gui.plot.tools.roi.RegionOfInterestManager` + on which to operate. + """ + def __init__(self, parent=None, plot=None, roiManager=None): + super(ProfileManager, self).__init__(parent) + + assert isinstance(plot, PlotWidget) + self._plotRef = weakref.ref( + plot, WeakMethodProxy(self.__plotDestroyed)) + + # Set-up interaction manager + if roiManager is None: + roiManager = RegionOfInterestManager(plot) + + self._roiManagerRef = weakref.ref(roiManager) + self._rois = [] + self._pendingRunners = [] + """List of ROIs which have to be updated""" + + self.__reentrantResults = {} + """Store reentrant result to avoid to skip some of them + cause the implementation uses a QEventLoop.""" + + self._profileWindowClass = ProfileWindow + """Class used to display the profile results""" + + self._computedProfiles = 0 + """Statistics for tests""" + + self.__itemTypes = [] + """Kind of items to use""" + + self.__tracking = False + """Is the plot active items are tracked""" + + self.__useColorFromCursor = True + """If true, force the ROI color with the colormap marker color""" + + self._item = None + """The selected item""" + + self.__singleProfileAtATime = True + """When it's true, only a single profile is displayed at a time.""" + + self._previousWindowGeometry = [] + + self._storeProperties = _StoreLastParamBehavior(self) + """If defined the profile properties of the last ROI are reused to the + new created ones""" + + # Listen to plot limits changed + plot.getXAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile) + plot.getYAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile) + + roiManager.sigInteractiveModeFinished.connect(self.__interactionFinished) + roiManager.sigInteractiveRoiCreated.connect(self.__roiCreated) + roiManager.sigRoiAdded.connect(self.__roiAdded) + roiManager.sigRoiAboutToBeRemoved.connect(self.__roiRemoved) + + def setSingleProfile(self, enable): + """ + Enable or disable the single profile mode. + + In single mode, the manager enforce a single ROI at the same + time. A new one will remove the previous one. + + If this mode is not enabled, many ROIs can be created, and many + profile windows will be displayed. + """ + self.__singleProfileAtATime = enable + + def isSingleProfile(self): + """ + Returns true if the manager is in a single profile mode. + + :rtype: bool + """ + return self.__singleProfileAtATime + + def __interactionFinished(self): + """Handle end of interactive mode""" + pass + + def __roiAdded(self, roi): + """Handle new ROI""" + # Filter out non profile ROIs + if not isinstance(roi, core.ProfileRoiMixIn): + return + self.__addProfile(roi) + + def __roiRemoved(self, roi): + """Handle removed ROI""" + # Filter out non profile ROIs + if not isinstance(roi, core.ProfileRoiMixIn): + return + self.__removeProfile(roi) + + def createProfileAction(self, profileRoiClass, parent=None): + """Create an action from a class of ProfileRoi + + :param core.ProfileRoiMixIn profileRoiClass: A class of a profile ROI + :param qt.QObject parent: The parent of the created action. + :rtype: qt.QAction + """ + if not issubclass(profileRoiClass, core.ProfileRoiMixIn): + raise TypeError("Type %s not expected" % type(profileRoiClass)) + roiManager = self.getRoiManager() + action = CreateRoiModeAction(parent, roiManager, profileRoiClass) + if hasattr(profileRoiClass, "ICON"): + action.setIcon(icons.getQIcon(profileRoiClass.ICON)) + if hasattr(profileRoiClass, "NAME"): + def articulify(word): + """Add an an/a article in the front of the word""" + first = word[1] if word[0] == 'h' else word[0] + if first in "aeiou": + return "an " + word + return "a " + word + action.setText('Define %s' % articulify(profileRoiClass.NAME)) + action.setToolTip('Enables %s selection mode' % profileRoiClass.NAME) + action.setSingleShot(True) + return action + + def createClearAction(self, parent): + """Create an action to clean up the plot from the profile ROIs. + + :param qt.QObject parent: The parent of the created action. + :rtype: qt.QAction + """ + action = _ClearAction(parent, self) + return action + + def createImageActions(self, parent): + """Create actions designed for image items. This actions created + new ROIs. + + :param qt.QObject parent: The parent of the created action. + :rtype: List[qt.QAction] + """ + profileClasses = [ + rois.ProfileImageHorizontalLineROI, + rois.ProfileImageVerticalLineROI, + rois.ProfileImageLineROI, + rois.ProfileImageDirectedLineROI, + rois.ProfileImageCrossROI, + ] + return [self.createProfileAction(pc, parent=parent) for pc in profileClasses] + + def createScatterActions(self, parent): + """Create actions designed for scatter items. This actions created + new ROIs. + + :param qt.QObject parent: The parent of the created action. + :rtype: List[qt.QAction] + """ + profileClasses = [ + rois.ProfileScatterHorizontalLineROI, + rois.ProfileScatterVerticalLineROI, + rois.ProfileScatterLineROI, + rois.ProfileScatterCrossROI, + ] + return [self.createProfileAction(pc, parent=parent) for pc in profileClasses] + + def createScatterSliceActions(self, parent): + """Create actions designed for regular scatter items. This actions + created new ROIs. + + This ROIs was designed to use the input data without interpolation, + like you could do with an image. + + :param qt.QObject parent: The parent of the created action. + :rtype: List[qt.QAction] + """ + profileClasses = [ + rois.ProfileScatterHorizontalSliceROI, + rois.ProfileScatterVerticalSliceROI, + rois.ProfileScatterCrossSliceROI, + ] + return [self.createProfileAction(pc, parent=parent) for pc in profileClasses] + + def createImageStackActions(self, parent): + """Create actions designed for stack image items. This actions + created new ROIs. + + This ROIs was designed to create both profile on the displayed image + and profile on the full stack (2D result). + + :param qt.QObject parent: The parent of the created action. + :rtype: List[qt.QAction] + """ + profileClasses = [ + rois.ProfileImageStackHorizontalLineROI, + rois.ProfileImageStackVerticalLineROI, + rois.ProfileImageStackLineROI, + rois.ProfileImageStackCrossROI, + ] + return [self.createProfileAction(pc, parent=parent) for pc in profileClasses] + + def createEditorAction(self, parent): + """Create an action containing GUI to edit the selected profile ROI. + + :param qt.QObject parent: The parent of the created action. + :rtype: qt.QAction + """ + action = editors.ProfileRoiEditorAction(parent) + action.setRoiManager(self.getRoiManager()) + return action + + def setItemType(self, image=False, scatter=False): + """Set the item type to use and select the active one. + + :param bool image: Image item are allowed + :param bool scatter: Scatter item are allowed + """ + self.__itemTypes = [] + plot = self.getPlotWidget() + item = None + if image: + self.__itemTypes.append("image") + item = plot.getActiveImage() + if scatter: + self.__itemTypes.append("scatter") + if item is None: + item = plot.getActiveScatter() + self.setPlotItem(item) + + def setProfileWindowClass(self, profileWindowClass): + """Set the class which will be instantiated to display profile result. + """ + self._profileWindowClass = profileWindowClass + + def setActiveItemTracking(self, tracking): + """Enable/disable the tracking of the active item of the plot. + + :param bool tracking: Tracking mode + """ + if self.__tracking == tracking: + return + plot = self.getPlotWidget() + if self.__tracking: + plot.sigActiveImageChanged.disconnect(self._activeImageChanged) + plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged) + self.__tracking = tracking + if self.__tracking: + plot.sigActiveImageChanged.connect(self.__activeImageChanged) + plot.sigActiveScatterChanged.connect(self.__activeScatterChanged) + + def setDefaultColorFromCursorColor(self, enabled): + """Enabled/disable the use of the colormap cursor color to display the + ROIs. + + If set, the manager will update the color of the profile ROIs using the + current colormap cursor color from the selected item. + """ + self.__useColorFromCursor = enabled + + def __activeImageChanged(self, previous, legend): + """Handle plot item selection""" + if "image" in self.__itemTypes: + plot = self.getPlotWidget() + item = plot.getImage(legend) + self.setPlotItem(item) + + def __activeScatterChanged(self, previous, legend): + """Handle plot item selection""" + if "scatter" in self.__itemTypes: + plot = self.getPlotWidget() + item = plot.getScatter(legend) + self.setPlotItem(item) + + def __roiCreated(self, roi): + """Handle ROI creation""" + # Filter out non profile ROIs + if isinstance(roi, core.ProfileRoiMixIn): + if self._storeProperties is not None: + # Initialize the properties with the previous ones + self._storeProperties.restoreProperties(roi) + + def __addProfile(self, profileRoi): + """Add a new ROI to the manager.""" + if profileRoi.getFocusProxy() is None: + if self._storeProperties is not None: + # Follow changes on properties + self._storeProperties.setProfileRoi(profileRoi) + if self.__singleProfileAtATime: + # FIXME: It would be good to reuse the windows to avoid blinking + self.clearProfile() + + profileRoi._setProfileManager(self) + self._updateRoiColor(profileRoi) + self._rois.append(profileRoi) + self.requestUpdateProfile(profileRoi) + + def __removeProfile(self, profileRoi): + """Remove a ROI from the manager.""" + window = self._disconnectProfileWindow(profileRoi) + if window is not None: + geometry = window.geometry() + if not geometry.isEmpty(): + self._previousWindowGeometry.append(geometry) + self.clearProfileWindow(window) + if profileRoi in self._rois: + self._rois.remove(profileRoi) + + def _disconnectProfileWindow(self, profileRoi): + """Handle profile window close.""" + window = profileRoi.getProfileWindow() + profileRoi.setProfileWindow(None) + return window + + def clearProfile(self): + """Clear the associated ROI profile""" + roiManager = self.getRoiManager() + for roi in list(self._rois): + if roi.getFocusProxy() is not None: + # Skip sub ROIs, it will be removed by their parents + continue + roiManager.removeRoi(roi) + + if not roiManager.isDrawing(): + # Clean the selected mode + roiManager.stop() + + def hasPendingOperations(self): + """Returns true if a thread is still computing or displaying a profile. + + :rtype: bool + """ + return len(self.__reentrantResults) > 0 or len(self._pendingRunners) > 0 + + def requestUpdateAllProfile(self): + """Request to update the profile of all the managed ROIs. + """ + for roi in self._rois: + self.requestUpdateProfile(roi) + + def requestUpdateProfile(self, profileRoi): + """Request to update a specific profile ROI. + + :param ~core.ProfileRoiMixIn profileRoi: + """ + if profileRoi.computeProfile is None: + return + threadPool = silxGlobalThreadPool() + + # Clean up deprecated runners + for runner in list(self._pendingRunners): + if not inspect.isValid(runner): + self._pendingRunners.remove(runner) + continue + if runner.getRoi() is profileRoi: + if hasattr(threadPool, "tryTake"): + if threadPool.tryTake(runner): + self._pendingRunners.remove(runner) + else: # Support Qt<5.9 + runner._lazyCancel() + + item = self.getPlotItem() + if item is None or not isinstance(item, profileRoi.ITEM_KIND): + # This item is not compatible with this profile + profileRoi._setPlotItem(None) + profileWindow = profileRoi.getProfileWindow() + if profileWindow is not None: + profileWindow.setProfile(None) + return + + profileRoi._setPlotItem(item) + runner = _RunnableComputeProfile(threadPool, item, profileRoi) + runner.runnerFinished.connect(self.__cleanUpRunner) + runner.resultReady.connect(self.__displayResult) + self._pendingRunners.append(runner) + threadPool.start(runner) + + def __cleanUpRunner(self, runner): + """Remove a thread pool runner from the list of hold tasks. + + Called at the termination of the runner. + """ + if runner in self._pendingRunners: + self._pendingRunners.remove(runner) + + def __displayResult(self, roi, profileData): + """Display the result of a ROI. + + :param ~core.ProfileRoiMixIn profileRoi: A managed ROI + :param ~core.CurveProfileData profileData: Computed data profile + """ + if roi in self.__reentrantResults: + # Store the data to process it in the main loop + # And not a sub loop created by initProfileWindow + # This also remove the duplicated requested + self.__reentrantResults[roi] = profileData + return + + self.__reentrantResults[roi] = profileData + self._computedProfiles = self._computedProfiles + 1 + window = roi.getProfileWindow() + if window is None: + plot = self.getPlotWidget() + window = self.createProfileWindow(plot, roi) + # roi.profileWindow have to be set before initializing the window + # Cause the initialization is using QEventLoop + roi.setProfileWindow(window) + self.initProfileWindow(window, roi) + window.show() + + lastData = self.__reentrantResults.pop(roi) + window.setProfile(lastData) + + def __plotDestroyed(self, ref): + """Handle finalization of PlotWidget + + :param ref: weakref to the plot + """ + self._plotRef = None + self._roiManagerRef = None + self._pendingRunners = [] + + def setPlotItem(self, item): + """Set the plot item focused by the profile manager. + + :param ~silx.gui.plot.items.Item item: A plot item + """ + previous = self.getPlotItem() + if previous is item: + return + if item is None: + self._item = None + else: + item.sigItemChanged.connect(self.__itemChanged) + self._item = weakref.ref(item) + self._updateRoiColors() + self.requestUpdateAllProfile() + + def getDefaultColor(self, item): + """Returns the default ROI color to use according to the given item. + + :param ~silx.gui.plot.items.item.Item item: AN item + :rtype: qt.QColor + """ + color = 'pink' + if isinstance(item, items.ColormapMixIn): + colormap = item.getColormap() + name = colormap.getName() + if name is not None: + color = colors.cursorColorForColormap(name) + color = colors.asQColor(color) + return color + + def _updateRoiColors(self): + """Update ROI color according to the item selection""" + if not self.__useColorFromCursor: + return + item = self.getPlotItem() + color = self.getDefaultColor(item) + for roi in self._rois: + roi.setColor(color) + + def _updateRoiColor(self, roi): + """Update a specific ROI according to the current selected item. + + :param RegionOfInterest roi: The ROI to update + """ + if not self.__useColorFromCursor: + return + item = self.getPlotItem() + color = self.getDefaultColor(item) + roi.setColor(color) + + def __itemChanged(self, changeType): + """Handle item changes. + """ + if changeType in (items.ItemChangedType.DATA, + items.ItemChangedType.MASK, + items.ItemChangedType.POSITION, + items.ItemChangedType.SCALE): + self.requestUpdateAllProfile() + elif changeType == (items.ItemChangedType.COLORMAP): + self._updateRoiColors() + + def getPlotItem(self): + """Returns the item focused by the profile manager. + + :rtype: ~silx.gui.plot.items.Item + """ + if self._item is None: + return None + item = self._item() + if item is None: + self._item = None + return item + + def getPlotWidget(self): + """The plot associated to the profile manager. + + :rtype: ~silx.gui.plot.PlotWidget + """ + if self._plotRef is None: + return None + plot = self._plotRef() + if plot is None: + self._plotRef = None + return plot + + def getCurrentRoi(self): + """Returns the currently selected ROI, else None. + + :rtype: core.ProfileRoiMixIn + """ + roiManager = self.getRoiManager() + if roiManager is None: + return None + roi = roiManager.getCurrentRoi() + if not isinstance(roi, core.ProfileRoiMixIn): + return None + return roi + + def getRoiManager(self): + """Returns the used ROI manager + + :rtype: RegionOfInterestManager + """ + return self._roiManagerRef() + + def createProfileWindow(self, plot, roi): + """Create a new profile window. + + :param ~core.ProfileRoiMixIn roi: The plot containing the raw data + :param ~core.ProfileRoiMixIn roi: A managed ROI + :rtype: ~ProfileWindow + """ + return self._profileWindowClass(plot) + + def initProfileWindow(self, profileWindow, roi): + """This function is called just after the profile window creation in + order to initialize the window location. + + :param ~ProfileWindow profileWindow: + The profile window to initialize. + """ + # Enforce the use of one of the widgets + # To have the correct window size + profileWindow.prepareWidget(roi) + profileWindow.adjustSize() + + # Trick to avoid blinking while retrieving the right window size + # Display the window, hide it and wait for some event loops + profileWindow.show() + profileWindow.hide() + eventLoop = qt.QEventLoop(self) + for _ in range(10): + if not eventLoop.processEvents(): + break + + profileWindow.show() + if len(self._previousWindowGeometry) > 0: + geometry = self._previousWindowGeometry.pop() + profileWindow.setGeometry(geometry) + return + + window = self.getPlotWidget().window() + winGeom = window.frameGeometry() + if qt.BINDING in ("PySide2", "PyQt5"): + qapp = qt.QApplication.instance() + desktop = qapp.desktop() + screenGeom = desktop.availableGeometry(window) + else: # Qt6 (and also Qt>=5.14) + screenGeom = window.screen().availableGeometry() + spaceOnLeftSide = winGeom.left() + spaceOnRightSide = screenGeom.width() - winGeom.right() + + profileGeom = profileWindow.frameGeometry() + profileWidth = profileGeom.width() + + # Align vertically to the center of the window + top = winGeom.top() + (winGeom.height() - profileGeom.height()) // 2 + + margin = 5 + if profileWidth < spaceOnRightSide: + # Place profile on the right + left = winGeom.right() + margin + elif profileWidth < spaceOnLeftSide: + # Place profile on the left + left = max(0, winGeom.left() - profileWidth - margin) + else: + # Move it as much as possible where there is more space + if spaceOnLeftSide > spaceOnRightSide: + left = 0 + else: + left = screenGeom.width() - profileGeom.width() + profileWindow.move(left, top) + + + def clearProfileWindow(self, profileWindow): + """Called when a profile window is not anymore needed. + + By default the window will be closed. But it can be + inherited to change this behavior. + """ + profileWindow.deleteLater() diff --git a/src/silx/gui/plot/tools/profile/rois.py b/src/silx/gui/plot/tools/profile/rois.py new file mode 100644 index 0000000..9eef622 --- /dev/null +++ b/src/silx/gui/plot/tools/profile/rois.py @@ -0,0 +1,1156 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module define ROIs for profile tools. + +.. inheritance-diagram:: + silx.gui.plot.tools.profile.rois + :top-classes: silx.gui.plot.tools.profile.core.ProfileRoiMixIn, silx.gui.plot.items.roi.RegionOfInterest + :parts: 1 + :private-bases: +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "01/12/2020" + +import numpy +import weakref +from concurrent.futures import CancelledError + +from silx.gui import colors + +from silx.gui.plot import items +from silx.gui.plot.items import roi as roi_items +from . import core +from silx.gui import utils +from .....utils.proxy import docstring + + +def _relabelAxes(plot, text): + """Relabel {xlabel} and {ylabel} from this text using the corresponding + plot axis label. If the axis label is empty, label it with "X" and "Y". + + :rtype: str + """ + xLabel = plot.getXAxis().getLabel() + if not xLabel: + xLabel = "X" + yLabel = plot.getYAxis().getLabel() + if not yLabel: + yLabel = "Y" + return text.format(xlabel=xLabel, ylabel=yLabel) + + +def _lineProfileTitle(x0, y0, x1, y1): + """Compute corresponding plot title + + This can be overridden to change title behavior. + + :param float x0: Profile start point X coord + :param float y0: Profile start point Y coord + :param float x1: Profile end point X coord + :param float y1: Profile end point Y coord + :return: Title to use + :rtype: str + """ + if x0 == x1: + title = '{xlabel} = %g; {ylabel} = [%g, %g]' % (x0, y0, y1) + elif y0 == y1: + title = '{ylabel} = %g; {xlabel} = [%g, %g]' % (y0, x0, x1) + else: + m = (y1 - y0) / (x1 - x0) + b = y0 - m * x0 + title = '{ylabel} = %g * {xlabel} %+g' % (m, b) + + return title + + +class _ImageProfileArea(items.Shape): + """This shape displays the location of pixels used to compute the + profile.""" + + def __init__(self, parentRoi): + items.Shape.__init__(self, "polygon") + color = colors.rgba(parentRoi.getColor()) + self.setColor(color) + self.setFill(True) + self.setOverlay(True) + self.setPoints([[0, 0], [0, 0]]) # Else it segfault + + self.__parentRoi = weakref.ref(parentRoi) + parentRoi.sigItemChanged.connect(self._updateAreaProperty) + parentRoi.sigRegionChanged.connect(self._updateArea) + parentRoi.sigProfilePropertyChanged.connect(self._updateArea) + parentRoi.sigPlotItemChanged.connect(self._updateArea) + + def getParentRoi(self): + if self.__parentRoi is None: + return None + parentRoi = self.__parentRoi() + if parentRoi is None: + self.__parentRoi = None + return parentRoi + + def _updateAreaProperty(self, event=None, checkVisibility=True): + parentRoi = self.sender() + if event == items.ItemChangedType.COLOR: + parentRoi._updateItemProperty(event, parentRoi, self) + elif event == items.ItemChangedType.VISIBLE: + if self.getPlotItem() is not None: + parentRoi._updateItemProperty(event, parentRoi, self) + + def _updateArea(self): + roi = self.getParentRoi() + item = roi.getPlotItem() + if item is None: + self.setVisible(False) + return + polygon = self._computePolygon(item) + self.setVisible(True) + polygon = numpy.array(polygon).T + self.setLineStyle("--") + self.setPoints(polygon, copy=False) + + def _computePolygon(self, item): + if not isinstance(item, items.ImageBase): + raise TypeError("Unexpected class %s" % type(item)) + + currentData = item.getValueData(copy=False) + + roi = self.getParentRoi() + origin = item.getOrigin() + scale = item.getScale() + _coords, _profile, area, _profileName, _xLabel = core.createProfile( + roiInfo=roi._getRoiInfo(), + currentData=currentData, + origin=origin, + scale=scale, + lineWidth=roi.getProfileLineWidth(), + method="none") + return area + + +class _SliceProfileArea(items.Shape): + """This shape displays the location a profile in a scatter. + + Each point used to compute the slice are linked together. + """ + + def __init__(self, parentRoi): + items.Shape.__init__(self, "polygon") + color = colors.rgba(parentRoi.getColor()) + self.setColor(color) + self.setFill(True) + self.setOverlay(True) + self.setPoints([[0, 0], [0, 0]]) # Else it segfault + + self.__parentRoi = weakref.ref(parentRoi) + parentRoi.sigItemChanged.connect(self._updateAreaProperty) + parentRoi.sigRegionChanged.connect(self._updateArea) + parentRoi.sigProfilePropertyChanged.connect(self._updateArea) + parentRoi.sigPlotItemChanged.connect(self._updateArea) + + def getParentRoi(self): + if self.__parentRoi is None: + return None + parentRoi = self.__parentRoi() + if parentRoi is None: + self.__parentRoi = None + return parentRoi + + def _updateAreaProperty(self, event=None, checkVisibility=True): + parentRoi = self.sender() + if event == items.ItemChangedType.COLOR: + parentRoi._updateItemProperty(event, parentRoi, self) + elif event == items.ItemChangedType.VISIBLE: + if self.getPlotItem() is not None: + parentRoi._updateItemProperty(event, parentRoi, self) + + def _updateArea(self): + roi = self.getParentRoi() + item = roi.getPlotItem() + if item is None: + self.setVisible(False) + return + polylines = self._computePolylines(roi, item) + if polylines is None: + self.setVisible(False) + return + self.setVisible(True) + self.setLineStyle("--") + self.setPoints(polylines, copy=False) + + def _computePolylines(self, roi, item): + slicing = roi._getSlice(item) + if slicing is None: + return None + xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False) + xx, yy = xx[slicing], yy[slicing] + polylines = numpy.array((xx, yy)).T + if len(polylines) == 0: + return None + return polylines + + +class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn): + """Provide common behavior for silx default image profile ROI. + """ + + ITEM_KIND = items.ImageBase + + def __init__(self, parent=None): + core.ProfileRoiMixIn.__init__(self, parent=parent) + self.__method = "mean" + self.__width = 1 + self.sigRegionChanged.connect(self.__regionChanged) + self.sigPlotItemChanged.connect(self.__updateArea) + self.__area = _ImageProfileArea(self) + self.addItem(self.__area) + + def __regionChanged(self): + self.invalidateProfile() + self.__updateArea() + + def setProfileMethod(self, method): + """ + :param str method: method to compute the profile. Can be 'mean' or 'sum' + """ + if self.__method == method: + return + self.__method = method + self.invalidateProperties() + self.invalidateProfile() + + def getProfileMethod(self): + return self.__method + + def setProfileLineWidth(self, width): + if self.__width == width: + return + self.__width = width + self.__updateArea() + self.invalidateProperties() + self.invalidateProfile() + + def getProfileLineWidth(self): + return self.__width + + def __updateArea(self): + plotItem = self.getPlotItem() + if plotItem is None: + self.setLineStyle("-") + else: + self.setLineStyle("--") + + def _getRoiInfo(self): + """Wrapper to allow to reuse the previous Profile code. + + It would be good to remove it at one point. + """ + if isinstance(self, roi_items.HorizontalLineROI): + lineProjectionMode = 'X' + y = self.getPosition() + roiStart = (0, y) + roiEnd = (1, y) + elif isinstance(self, roi_items.VerticalLineROI): + lineProjectionMode = 'Y' + x = self.getPosition() + roiStart = (x, 0) + roiEnd = (x, 1) + elif isinstance(self, roi_items.LineROI): + lineProjectionMode = 'D' + roiStart, roiEnd = self.getEndPoints() + else: + assert False + + return roiStart, roiEnd, lineProjectionMode + + def computeProfile(self, item): + if not isinstance(item, items.ImageBase): + raise TypeError("Unexpected class %s" % type(item)) + + origin = item.getOrigin() + scale = item.getScale() + method = self.getProfileMethod() + lineWidth = self.getProfileLineWidth() + + def createProfile2(currentData): + coords, profile, _area, profileName, xLabel = core.createProfile( + roiInfo=self._getRoiInfo(), + currentData=currentData, + origin=origin, + scale=scale, + lineWidth=lineWidth, + method=method) + return coords, profile, profileName, xLabel + + currentData = item.getValueData(copy=False) + + yLabel = "%s" % str(method).capitalize() + coords, profile, title, xLabel = createProfile2(currentData) + title = title + "; width = %d" % lineWidth + + # Use the axis names from the original plot + profileManager = self.getProfileManager() + plot = profileManager.getPlotWidget() + title = _relabelAxes(plot, title) + xLabel = _relabelAxes(plot, xLabel) + + if isinstance(item, items.ImageRgba): + rgba = item.getData(copy=False) + _coords, r, _profileName, _xLabel = createProfile2(rgba[..., 0]) + _coords, g, _profileName, _xLabel = createProfile2(rgba[..., 1]) + _coords, b, _profileName, _xLabel = createProfile2(rgba[..., 2]) + if rgba.shape[-1] == 4: + _coords, a, _profileName, _xLabel = createProfile2(rgba[..., 3]) + else: + a = [None] + data = core.RgbaProfileData( + coords=coords, + profile=profile[0], + profile_r=r[0], + profile_g=g[0], + profile_b=b[0], + profile_a=a[0], + title=title, + xLabel=xLabel, + yLabel=yLabel, + ) + else: + data = core.CurveProfileData( + coords=coords, + profile=profile[0], + title=title, + xLabel=xLabel, + yLabel=yLabel, + ) + return data + + +class ProfileImageHorizontalLineROI(roi_items.HorizontalLineROI, + _DefaultImageProfileRoiMixIn): + """ROI for an horizontal profile at a location of an image""" + + ICON = 'shape-horizontal' + NAME = 'horizontal line profile' + + def __init__(self, parent=None): + roi_items.HorizontalLineROI.__init__(self, parent=parent) + _DefaultImageProfileRoiMixIn.__init__(self, parent=parent) + + +class ProfileImageVerticalLineROI(roi_items.VerticalLineROI, + _DefaultImageProfileRoiMixIn): + """ROI for a vertical profile at a location of an image""" + + ICON = 'shape-vertical' + NAME = 'vertical line profile' + + def __init__(self, parent=None): + roi_items.VerticalLineROI.__init__(self, parent=parent) + _DefaultImageProfileRoiMixIn.__init__(self, parent=parent) + + +class ProfileImageLineROI(roi_items.LineROI, + _DefaultImageProfileRoiMixIn): + """ROI for an image profile between 2 points. + + The X profile of this ROI is the projecting into one of the x/y axes, + using its scale and its orientation. + """ + + ICON = 'shape-diagonal' + NAME = 'line profile' + + def __init__(self, parent=None): + roi_items.LineROI.__init__(self, parent=parent) + _DefaultImageProfileRoiMixIn.__init__(self, parent=parent) + + +class ProfileImageDirectedLineROI(roi_items.LineROI, + _DefaultImageProfileRoiMixIn): + """ROI for an image profile between 2 points. + + The X profile of the line is displayed projected into the line itself, + using its scale and its orientation. It's the distance from the origin. + """ + + ICON = 'shape-diagonal-directed' + NAME = 'directed line profile' + + def __init__(self, parent=None): + roi_items.LineROI.__init__(self, parent=parent) + _DefaultImageProfileRoiMixIn.__init__(self, parent=parent) + self._handleStart.setSymbol('o') + + def computeProfile(self, item): + if not isinstance(item, items.ImageBase): + raise TypeError("Unexpected class %s" % type(item)) + + from silx.image.bilinear import BilinearImage + + origin = item.getOrigin() + scale = item.getScale() + method = self.getProfileMethod() + lineWidth = self.getProfileLineWidth() + currentData = item.getValueData(copy=False) + + roiInfo = self._getRoiInfo() + roiStart, roiEnd, _lineProjectionMode = roiInfo + + startPt = ((roiStart[1] - origin[1]) / scale[1], + (roiStart[0] - origin[0]) / scale[0]) + endPt = ((roiEnd[1] - origin[1]) / scale[1], + (roiEnd[0] - origin[0]) / scale[0]) + + if numpy.array_equal(startPt, endPt): + return None + + bilinear = BilinearImage(currentData) + profile = bilinear.profile_line( + (startPt[0] - 0.5, startPt[1] - 0.5), + (endPt[0] - 0.5, endPt[1] - 0.5), + lineWidth, + method=method) + + # Compute the line size + lineSize = numpy.sqrt((roiEnd[1] - roiStart[1]) ** 2 + + (roiEnd[0] - roiStart[0]) ** 2) + coords = numpy.linspace(0, lineSize, len(profile), + endpoint=True, + dtype=numpy.float32) + + title = _lineProfileTitle(*roiStart, *roiEnd) + title = title + "; width = %d" % lineWidth + xLabel = "√({xlabel}²+{ylabel}²)" + yLabel = str(method).capitalize() + + # Use the axis names from the original plot + profileManager = self.getProfileManager() + plot = profileManager.getPlotWidget() + xLabel = _relabelAxes(plot, xLabel) + title = _relabelAxes(plot, title) + + data = core.CurveProfileData( + coords=coords, + profile=profile, + title=title, + xLabel=xLabel, + yLabel=yLabel, + ) + return data + + +class _ProfileCrossROI(roi_items.HandleBasedROI, core.ProfileRoiMixIn): + + """ROI to manage a cross of profiles + + It is managed using 2 sub ROIs for vertical and horizontal. + """ + + _kind = "Cross" + """Label for this kind of ROI""" + + _plotShape = "point" + """Plot shape which is used for the first interaction""" + + def __init__(self, parent=None): + roi_items.HandleBasedROI.__init__(self, parent=parent) + core.ProfileRoiMixIn.__init__(self, parent=parent) + self.sigRegionChanged.connect(self.__regionChanged) + self.sigAboutToBeRemoved.connect(self.__aboutToBeRemoved) + self.__position = 0, 0 + self.__vline = None + self.__hline = None + self.__handle = self.addHandle() + self.__handleLabel = self.addLabelHandle() + self.__handleLabel.setText(self.getName()) + self.__inhibitReentance = utils.LockReentrant() + self.computeProfile = None + self.sigItemChanged.connect(self.__updateLineProperty) + + # Make sure the marker is over the ROIs + self.__handle.setZValue(1) + # Create the vline and the hline + self._createSubRois() + + @docstring(roi_items.HandleBasedROI) + def contains(self, position): + roiPos = self.getPosition() + return position[0] == roiPos[0] or position[1] == roiPos[1] + + def setFirstShapePoints(self, points): + pos = points[0] + self.setPosition(pos) + + def getPosition(self): + """Returns the position of this ROI + + :rtype: numpy.ndarray + """ + return self.__position + + def setPosition(self, pos): + """Set the position of this ROI + + :param numpy.ndarray pos: 2d-coordinate of this point + """ + self.__position = pos + with utils.blockSignals(self.__handle): + self.__handle.setPosition(*pos) + with utils.blockSignals(self.__handleLabel): + self.__handleLabel.setPosition(*pos) + self.sigRegionChanged.emit() + + def handleDragUpdated(self, handle, origin, previous, current): + if handle is self.__handle: + self.setPosition(current) + + def __updateLineProperty(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.NAME: + self.__handleLabel.setText(self.getName()) + elif event in [items.ItemChangedType.COLOR, + items.ItemChangedType.VISIBLE]: + lines = [] + if self.__vline: + lines.append(self.__vline) + if self.__hline: + lines.append(self.__hline) + self._updateItemProperty(event, self, lines) + + def _createLines(self, parent): + """Inherit this function to return 2 ROI objects for respectivly + the horizontal, and the vertical lines.""" + raise NotImplementedError() + + def _setProfileManager(self, profileManager): + core.ProfileRoiMixIn._setProfileManager(self, profileManager) + # Connecting the vline and the hline + roiManager = profileManager.getRoiManager() + roiManager.addRoi(self.__vline) + roiManager.addRoi(self.__hline) + + def _createSubRois(self): + hline, vline = self._createLines(parent=None) + for i, line in enumerate([vline, hline]): + line.setPosition(self.__position[i]) + line.setEditable(True) + line.setSelectable(True) + line.setFocusProxy(self) + line.setName("") + self.__vline = vline + self.__hline = hline + vline.sigAboutToBeRemoved.connect(self.__vlineRemoved) + vline.sigRegionChanged.connect(self.__vlineRegionChanged) + hline.sigAboutToBeRemoved.connect(self.__hlineRemoved) + hline.sigRegionChanged.connect(self.__hlineRegionChanged) + + def _getLines(self): + return self.__hline, self.__vline + + def __regionChanged(self): + if self.__inhibitReentance.locked(): + return + x, y = self.getPosition() + hline, vline = self._getLines() + if hline is None: + return + with self.__inhibitReentance: + hline.setPosition(y) + vline.setPosition(x) + + def __vlineRegionChanged(self): + if self.__inhibitReentance.locked(): + return + pos = self.getPosition() + vline = self.__vline + pos = vline.getPosition(), pos[1] + with self.__inhibitReentance: + self.setPosition(pos) + + def __hlineRegionChanged(self): + if self.__inhibitReentance.locked(): + return + pos = self.getPosition() + hline = self.__hline + pos = pos[0], hline.getPosition() + with self.__inhibitReentance: + self.setPosition(pos) + + def __aboutToBeRemoved(self): + vline = self.__vline + hline = self.__hline + # Avoid side remove signals + if hline is not None: + hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved) + hline.sigRegionChanged.disconnect(self.__hlineRegionChanged) + if vline is not None: + vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved) + vline.sigRegionChanged.disconnect(self.__vlineRegionChanged) + # Clean up the child + profileManager = self.getProfileManager() + roiManager = profileManager.getRoiManager() + if hline is not None: + roiManager.removeRoi(hline) + self.__hline = None + if vline is not None: + roiManager.removeRoi(vline) + self.__vline = None + + def __hlineRemoved(self): + self.__lineRemoved(isHline=True) + + def __vlineRemoved(self): + self.__lineRemoved(isHline=False) + + def __lineRemoved(self, isHline): + """If any of the lines is removed: disconnect this objects, and let the + other one persist""" + hline, vline = self._getLines() + + hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved) + vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved) + hline.sigRegionChanged.disconnect(self.__hlineRegionChanged) + vline.sigRegionChanged.disconnect(self.__vlineRegionChanged) + + self.__hline = None + self.__vline = None + profileManager = self.getProfileManager() + roiManager = profileManager.getRoiManager() + if isHline: + self.__releaseLine(vline) + else: + self.__releaseLine(hline) + roiManager.removeRoi(self) + + def __releaseLine(self, line): + """Release the line in order to make it independent""" + line.setFocusProxy(None) + line.setName(self.getName()) + line.setEditable(self.isEditable()) + line.setSelectable(self.isSelectable()) + + +class ProfileImageCrossROI(_ProfileCrossROI): + """ROI to manage a cross of profiles + + It is managed using 2 sub ROIs for vertical and horizontal. + """ + + ICON = 'shape-cross' + NAME = 'cross profile' + ITEM_KIND = items.ImageBase + + def _createLines(self, parent): + vline = ProfileImageVerticalLineROI(parent=parent) + hline = ProfileImageHorizontalLineROI(parent=parent) + return hline, vline + + def setProfileMethod(self, method): + """ + :param str method: method to compute the profile. Can be 'mean' or 'sum' + """ + hline, vline = self._getLines() + hline.setProfileMethod(method) + vline.setProfileMethod(method) + self.invalidateProperties() + + def getProfileMethod(self): + hline, _vline = self._getLines() + return hline.getProfileMethod() + + def setProfileLineWidth(self, width): + hline, vline = self._getLines() + hline.setProfileLineWidth(width) + vline.setProfileLineWidth(width) + self.invalidateProperties() + + def getProfileLineWidth(self): + hline, _vline = self._getLines() + return hline.getProfileLineWidth() + + +class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn): + """Provide common behavior for silx default scatter profile ROI. + """ + + ITEM_KIND = items.Scatter + + def __init__(self, parent=None): + core.ProfileRoiMixIn.__init__(self, parent=parent) + self.__nPoints = 1024 + self.sigRegionChanged.connect(self.__regionChanged) + + def __regionChanged(self): + self.invalidateProfile() + + # Number of points + + def getNPoints(self): + """Returns the number of points of the profiles + + :rtype: int + """ + return self.__nPoints + + def setNPoints(self, npoints): + """Set the number of points of the profiles + + :param int npoints: + """ + npoints = int(npoints) + if npoints < 1: + raise ValueError("Unsupported number of points: %d" % npoints) + elif npoints != self.__nPoints: + self.__nPoints = npoints + self.invalidateProperties() + self.invalidateProfile() + + def _computeProfile(self, scatter, x0, y0, x1, y1): + """Compute corresponding profile + + :param float x0: Profile start point X coord + :param float y0: Profile start point Y coord + :param float x1: Profile end point X coord + :param float y1: Profile end point Y coord + :return: (points, values) profile data or None + """ + future = scatter._getInterpolator() + try: + interpolator = future.result() + except CancelledError: + return None + if interpolator is None: + return None # Cannot init an interpolator + + nPoints = self.getNPoints() + points = numpy.transpose(( + numpy.linspace(x0, x1, nPoints, endpoint=True), + numpy.linspace(y0, y1, nPoints, endpoint=True))) + + values = interpolator(points) + + if not numpy.any(numpy.isfinite(values)): + return None # Profile outside convex hull + + return points, values + + def computeProfile(self, item): + """Update profile according to current ROI""" + if not isinstance(item, items.Scatter): + raise TypeError("Unexpected class %s" % type(item)) + + # Get end points + if isinstance(self, roi_items.LineROI): + points = self.getEndPoints() + x0, y0 = points[0] + x1, y1 = points[1] + elif isinstance(self, (roi_items.VerticalLineROI, roi_items.HorizontalLineROI)): + profileManager = self.getProfileManager() + plot = profileManager.getPlotWidget() + + if isinstance(self, roi_items.HorizontalLineROI): + x0, x1 = plot.getXAxis().getLimits() + y0 = y1 = self.getPosition() + + elif isinstance(self, roi_items.VerticalLineROI): + x0 = x1 = self.getPosition() + y0, y1 = plot.getYAxis().getLimits() + else: + raise RuntimeError('Unsupported ROI for profile: {}'.format(self.__class__)) + + if x1 < x0 or (x1 == x0 and y1 < y0): + # Invert points + x0, y0, x1, y1 = x1, y1, x0, y0 + + profile = self._computeProfile(item, x0, y0, x1, y1) + if profile is None: + return None + + title = _lineProfileTitle(x0, y0, x1, y1) + points = profile[0] + values = profile[1] + + if (numpy.abs(points[-1, 0] - points[0, 0]) > + numpy.abs(points[-1, 1] - points[0, 1])): + xProfile = points[:, 0] + xLabel = '{xlabel}' + else: + xProfile = points[:, 1] + xLabel = '{ylabel}' + + # Use the axis names from the original + profileManager = self.getProfileManager() + plot = profileManager.getPlotWidget() + title = _relabelAxes(plot, title) + xLabel = _relabelAxes(plot, xLabel) + + data = core.CurveProfileData( + coords=xProfile, + profile=values, + title=title, + xLabel=xLabel, + yLabel='Profile', + ) + return data + + +class ProfileScatterHorizontalLineROI(roi_items.HorizontalLineROI, + _DefaultScatterProfileRoiMixIn): + """ROI for an horizontal profile at a location of a scatter""" + + ICON = 'shape-horizontal' + NAME = 'horizontal line profile' + + def __init__(self, parent=None): + roi_items.HorizontalLineROI.__init__(self, parent=parent) + _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent) + + +class ProfileScatterVerticalLineROI(roi_items.VerticalLineROI, + _DefaultScatterProfileRoiMixIn): + """ROI for an horizontal profile at a location of a scatter""" + + ICON = 'shape-vertical' + NAME = 'vertical line profile' + + def __init__(self, parent=None): + roi_items.VerticalLineROI.__init__(self, parent=parent) + _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent) + + +class ProfileScatterLineROI(roi_items.LineROI, + _DefaultScatterProfileRoiMixIn): + """ROI for an horizontal profile at a location of a scatter""" + + ICON = 'shape-diagonal' + NAME = 'line profile' + + def __init__(self, parent=None): + roi_items.LineROI.__init__(self, parent=parent) + _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent) + + +class ProfileScatterCrossROI(_ProfileCrossROI): + """ROI to manage a cross of profiles for scatters. + """ + + ICON = 'shape-cross' + NAME = 'cross profile' + ITEM_KIND = items.Scatter + + def _createLines(self, parent): + vline = ProfileScatterVerticalLineROI(parent=parent) + hline = ProfileScatterHorizontalLineROI(parent=parent) + return hline, vline + + def getNPoints(self): + """Returns the number of points of the profiles + + :rtype: int + """ + hline, _vline = self._getLines() + return hline.getNPoints() + + def setNPoints(self, npoints): + """Set the number of points of the profiles + + :param int npoints: + """ + hline, vline = self._getLines() + hline.setNPoints(npoints) + vline.setNPoints(npoints) + self.invalidateProperties() + + +class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn): + """Default ROI to allow to slice in the scatter data.""" + + ITEM_KIND = items.Scatter + + def __init__(self, parent=None): + core.ProfileRoiMixIn.__init__(self, parent=parent) + self.__area = _SliceProfileArea(self) + self.addItem(self.__area) + self.sigRegionChanged.connect(self._regionChanged) + self.sigPlotItemChanged.connect(self._updateArea) + + def _regionChanged(self): + self.invalidateProfile() + self._updateArea() + + def _updateArea(self): + plotItem = self.getPlotItem() + if plotItem is None: + self.setLineStyle("-") + else: + self.setLineStyle("--") + + def _getSlice(self, item): + position = self.getPosition() + bounds = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_BOUNDS) + if isinstance(self, roi_items.HorizontalLineROI): + axis = 1 + elif isinstance(self, roi_items.VerticalLineROI): + axis = 0 + else: + assert False + if bounds is None or position < bounds[0][axis] or position > bounds[1][axis]: + # ROI outside of the scatter bound + return None + + major_order = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_MAJOR_ORDER) + assert major_order == 'row' + max_grid_yy, max_grid_xx = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_SHAPE) + + xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False) + if isinstance(self, roi_items.HorizontalLineROI): + axis = yy + max_grid_first = max_grid_yy + max_grid_second = max_grid_xx + major_axis = major_order == 'column' + elif isinstance(self, roi_items.VerticalLineROI): + axis = xx + max_grid_first = max_grid_xx + max_grid_second = max_grid_yy + major_axis = major_order == 'row' + else: + assert False + + def argnearest(array, value): + array = numpy.abs(array - value) + return numpy.argmin(array) + + if major_axis: + # slice in the middle of the scatter + start = max_grid_second // 2 * max_grid_first + vslice = axis[start:start + max_grid_second] + index = argnearest(vslice, position) + slicing = slice(index, None, max_grid_first) + else: + # slice in the middle of the scatter + vslice = axis[max_grid_second // 2::max_grid_second] + index = argnearest(vslice, position) + start = index * max_grid_second + slicing = slice(start, start + max_grid_second) + + return slicing + + def computeProfile(self, item): + if not isinstance(item, items.Scatter): + raise TypeError("Unsupported %s item" % type(item)) + + slicing = self._getSlice(item) + if slicing is None: + # ROI out of bounds + return None + + _xx, _yy, values, _xx_error, _yy_error = item.getData(copy=False) + profile = values[slicing] + + if isinstance(self, roi_items.HorizontalLineROI): + title = "Horizontal slice" + xLabel = "{xlabel} index" + elif isinstance(self, roi_items.VerticalLineROI): + title = "Vertical slice" + xLabel = "{ylabel} index" + else: + assert False + + # Use the axis names from the original plot + profileManager = self.getProfileManager() + plot = profileManager.getPlotWidget() + xLabel = _relabelAxes(plot, xLabel) + + data = core.CurveProfileData( + coords=numpy.arange(len(profile)), + profile=profile, + title=title, + xLabel=xLabel, + yLabel="Profile", + ) + return data + + +class ProfileScatterHorizontalSliceROI(roi_items.HorizontalLineROI, + _DefaultScatterProfileSliceRoiMixIn): + """ROI for an horizontal profile at a location of a scatter + using data slicing. + """ + + ICON = 'slice-horizontal' + NAME = 'horizontal data slice profile' + + def __init__(self, parent=None): + roi_items.HorizontalLineROI.__init__(self, parent=parent) + _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent) + + +class ProfileScatterVerticalSliceROI(roi_items.VerticalLineROI, + _DefaultScatterProfileSliceRoiMixIn): + """ROI for a vertical profile at a location of a scatter + using data slicing. + """ + + ICON = 'slice-vertical' + NAME = 'vertical data slice profile' + + def __init__(self, parent=None): + roi_items.VerticalLineROI.__init__(self, parent=parent) + _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent) + + +class ProfileScatterCrossSliceROI(_ProfileCrossROI): + """ROI to manage a cross of slicing profiles on scatters. + """ + + ICON = 'slice-cross' + NAME = 'cross data slice profile' + ITEM_KIND = items.Scatter + + def _createLines(self, parent): + vline = ProfileScatterVerticalSliceROI(parent=parent) + hline = ProfileScatterHorizontalSliceROI(parent=parent) + return hline, vline + + +class _DefaultImageStackProfileRoiMixIn(_DefaultImageProfileRoiMixIn): + + ITEM_KIND = items.ImageStack + + def __init__(self, parent=None): + super(_DefaultImageStackProfileRoiMixIn, self).__init__(parent=parent) + self.__profileType = "1D" + """Kind of profile""" + + def getProfileType(self): + return self.__profileType + + def setProfileType(self, kind): + assert kind in ["1D", "2D"] + if self.__profileType == kind: + return + self.__profileType = kind + self.invalidateProperties() + self.invalidateProfile() + + def computeProfile(self, item): + if not isinstance(item, items.ImageStack): + raise TypeError("Unexpected class %s" % type(item)) + + kind = self.getProfileType() + if kind == "1D": + result = _DefaultImageProfileRoiMixIn.computeProfile(self, item) + # z = item.getStackPosition() + return result + + assert kind == "2D" + + def createProfile2(currentData): + coords, profile, _area, profileName, xLabel = core.createProfile( + roiInfo=self._getRoiInfo(), + currentData=currentData, + origin=origin, + scale=scale, + lineWidth=self.getProfileLineWidth(), + method=method) + return coords, profile, profileName, xLabel + + currentData = numpy.array(item.getStackData(copy=False)) + origin = item.getOrigin() + scale = item.getScale() + colormap = item.getColormap() + method = self.getProfileMethod() + + coords, profile, profileName, xLabel = createProfile2(currentData) + + data = core.ImageProfileData( + coords=coords, + profile=profile, + title=profileName, + xLabel=xLabel, + yLabel="Profile", + colormap=colormap, + ) + return data + + +class ProfileImageStackHorizontalLineROI(roi_items.HorizontalLineROI, + _DefaultImageStackProfileRoiMixIn): + """ROI for an horizontal profile at a location of a stack of images""" + + ICON = 'shape-horizontal' + NAME = 'horizontal line profile' + + def __init__(self, parent=None): + roi_items.HorizontalLineROI.__init__(self, parent=parent) + _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent) + + +class ProfileImageStackVerticalLineROI(roi_items.VerticalLineROI, + _DefaultImageStackProfileRoiMixIn): + """ROI for an vertical profile at a location of a stack of images""" + + ICON = 'shape-vertical' + NAME = 'vertical line profile' + + def __init__(self, parent=None): + roi_items.VerticalLineROI.__init__(self, parent=parent) + _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent) + + +class ProfileImageStackLineROI(roi_items.LineROI, + _DefaultImageStackProfileRoiMixIn): + """ROI for an vertical profile at a location of a stack of images""" + + ICON = 'shape-diagonal' + NAME = 'line profile' + + def __init__(self, parent=None): + roi_items.LineROI.__init__(self, parent=parent) + _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent) + + +class ProfileImageStackCrossROI(ProfileImageCrossROI): + """ROI for an vertical profile at a location of a stack of images""" + + ICON = 'shape-cross' + NAME = 'cross profile' + ITEM_KIND = items.ImageStack + + def _createLines(self, parent): + vline = ProfileImageStackVerticalLineROI(parent=parent) + hline = ProfileImageStackHorizontalLineROI(parent=parent) + return hline, vline + + def getProfileType(self): + hline, _vline = self._getLines() + return hline.getProfileType() + + def setProfileType(self, kind): + hline, vline = self._getLines() + hline.setProfileType(kind) + vline.setProfileType(kind) + self.invalidateProperties() diff --git a/src/silx/gui/plot/tools/profile/toolbar.py b/src/silx/gui/plot/tools/profile/toolbar.py new file mode 100644 index 0000000..4a9a195 --- /dev/null +++ b/src/silx/gui/plot/tools/profile/toolbar.py @@ -0,0 +1,172 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides tool bar helper. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import logging +import weakref + +from silx.gui import qt +from silx.gui.widgets.MultiModeAction import MultiModeAction +from . import manager +from .. import roi as roi_mdl +from silx.gui.plot import items + + +_logger = logging.getLogger(__name__) + + +class ProfileToolBar(qt.QToolBar): + """Tool bar to provide profile for a plot. + + It is an helper class. For a dedicated application it would be better to + use an own tool bar in order in order have more flexibility. + """ + def __init__(self, parent=None, plot=None): + super(ProfileToolBar, self).__init__(parent=parent) + self.__scheme = None + self.__manager = None + self.__plot = weakref.ref(plot) + self.__multiAction = None + + def getPlotWidget(self): + """The :class:`~silx.gui.plot.PlotWidget` associated to the toolbar. + + :rtype: Union[~silx.gui.plot.PlotWidget,None] + """ + if self.__plot is None: + return None + plot = self.__plot() + if self.__plot is None: + self.__plot = None + return plot + + def setScheme(self, scheme): + """Initialize the tool bar using a configuration scheme. + + It have to be done once and only once. + + :param str scheme: One of "scatter", "image", "imagestack" + """ + assert self.__scheme is None + self.__scheme = scheme + + plot = self.getPlotWidget() + self.__manager = manager.ProfileManager(self, plot) + + if scheme == "image": + self.__manager.setItemType(image=True) + self.__manager.setActiveItemTracking(True) + + multiAction = MultiModeAction(self) + self.addAction(multiAction) + for action in self.__manager.createImageActions(self): + multiAction.addAction(action) + self.__multiAction = multiAction + + cleanAction = self.__manager.createClearAction(self) + self.addAction(cleanAction) + editorAction = self.__manager.createEditorAction(self) + self.addAction(editorAction) + + plot.sigActiveImageChanged.connect(self._activeImageChanged) + self._activeImageChanged() + + elif scheme == "scatter": + self.__manager.setItemType(scatter=True) + self.__manager.setActiveItemTracking(True) + + multiAction = MultiModeAction(self) + self.addAction(multiAction) + for action in self.__manager.createScatterActions(self): + multiAction.addAction(action) + for action in self.__manager.createScatterSliceActions(self): + multiAction.addAction(action) + self.__multiAction = multiAction + + cleanAction = self.__manager.createClearAction(self) + self.addAction(cleanAction) + editorAction = self.__manager.createEditorAction(self) + self.addAction(editorAction) + + plot.sigActiveScatterChanged.connect(self._activeScatterChanged) + self._activeScatterChanged() + + elif scheme == "imagestack": + self.__manager.setItemType(image=True) + self.__manager.setActiveItemTracking(True) + + multiAction = MultiModeAction(self) + self.addAction(multiAction) + for action in self.__manager.createImageStackActions(self): + multiAction.addAction(action) + self.__multiAction = multiAction + + cleanAction = self.__manager.createClearAction(self) + self.addAction(cleanAction) + editorAction = self.__manager.createEditorAction(self) + self.addAction(editorAction) + + plot.sigActiveImageChanged.connect(self._activeImageChanged) + self._activeImageChanged() + + else: + raise ValueError("Toolbar scheme %s unsupported" % scheme) + + def _setRoiActionEnabled(self, itemKind, enabled): + for action in self.__multiAction.getMenu().actions(): + if not isinstance(action, roi_mdl.CreateRoiModeAction): + continue + roiClass = action.getRoiClass() + if issubclass(itemKind, roiClass.ITEM_KIND): + action.setEnabled(enabled) + + def _activeImageChanged(self, previous=None, legend=None): + """Handle active image change to toggle actions""" + if legend is None: + self._setRoiActionEnabled(items.ImageStack, False) + self._setRoiActionEnabled(items.ImageBase, False) + else: + plot = self.getPlotWidget() + image = plot.getActiveImage() + # Disable for empty image + enabled = image.getData(copy=False).size > 0 + self._setRoiActionEnabled(type(image), enabled) + + def _activeScatterChanged(self, previous=None, legend=None): + """Handle active scatter change to toggle actions""" + if legend is None: + self._setRoiActionEnabled(items.Scatter, False) + else: + plot = self.getPlotWidget() + scatter = plot.getActiveScatter() + # Disable for empty image + enabled = scatter.getValueData(copy=False).size > 0 + self._setRoiActionEnabled(type(scatter), enabled) diff --git a/src/silx/gui/plot/tools/roi.py b/src/silx/gui/plot/tools/roi.py new file mode 100644 index 0000000..e4be6a7 --- /dev/null +++ b/src/silx/gui/plot/tools/roi.py @@ -0,0 +1,1417 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides ROI interaction for :class:`~silx.gui.plot.PlotWidget`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import enum +import logging +import time +import weakref +import functools + +import numpy + +from ... import qt, icons +from ...utils import blockSignals +from ...utils import LockReentrant +from .. import PlotWidget +from ..items import roi as roi_items + +from ...colors import rgba + + +logger = logging.getLogger(__name__) + + +class CreateRoiModeAction(qt.QAction): + """ + This action is a plot mode which allows to create new ROIs using a ROI + manager. + + A ROI is created using a specific `roiClass`. `initRoi` and `finalizeRoi` + can be inherited to custom the ROI initialization. + + :param class roiClass: The ROI class which will be created by this action. + :param qt.QObject parent: The action parent + :param RegionOfInterestManager roiManager: The ROI manager + """ + + def __init__(self, parent, roiManager, roiClass): + assert roiManager is not None + assert roiClass is not None + qt.QAction.__init__(self, parent=parent) + self._roiManager = weakref.ref(roiManager) + self._roiClass = roiClass + self._singleShot = False + self._initAction() + self.triggered[bool].connect(self._actionTriggered) + + def _initAction(self): + """Default initialization of the action""" + roiClass = self._roiClass + + name = None + iconName = None + if hasattr(roiClass, "NAME"): + name = roiClass.NAME + if hasattr(roiClass, "ICON"): + iconName = roiClass.ICON + + if iconName is None: + iconName = "add-shape-unknown" + if name is None: + name = roiClass.__name__ + text = 'Add %s' % name + self.setIcon(icons.getQIcon(iconName)) + self.setText(text) + self.setCheckable(True) + self.setToolTip(text) + + def getRoiClass(self): + """Return the ROI class used by this action to create ROIs""" + return self._roiClass + + def getRoiManager(self): + return self._roiManager() + + def setSingleShot(self, singleShot): + """Set it to True to deactivate the action after the first creation + of a ROI. + + :param bool singleShot: New single short state + """ + self._singleShot = singleShot + + def getSingleShot(self): + """If True, after the first creation of a ROI with this mode, + the mode is deactivated. + + :rtype: bool + """ + return self._singleShot + + def _actionTriggered(self, checked): + """Handle mode actions being checked by the user + + :param bool checked: + :param str kind: Corresponding shape kind + """ + roiManager = self.getRoiManager() + if roiManager is None: + return + + if checked: + roiManager.start(self._roiClass, self) + self.__interactiveModeStarted(roiManager) + else: + source = roiManager.getInteractionSource() + if source is self: + roiManager.stop() + + def __interactiveModeStarted(self, roiManager): + roiManager.sigInteractiveRoiCreated.connect(self.initRoi) + roiManager.sigInteractiveRoiFinalized.connect(self.__finalizeRoi) + roiManager.sigInteractiveModeFinished.connect(self.__interactiveModeFinished) + + def __interactiveModeFinished(self): + roiManager = self.getRoiManager() + if roiManager is not None: + roiManager.sigInteractiveRoiCreated.disconnect(self.initRoi) + roiManager.sigInteractiveRoiFinalized.disconnect(self.__finalizeRoi) + roiManager.sigInteractiveModeFinished.disconnect(self.__interactiveModeFinished) + self.setChecked(False) + + def initRoi(self, roi): + """Inherit it to custom the new ROI at it's creation during the + interaction.""" + pass + + def __finalizeRoi(self, roi): + self.finalizeRoi(roi) + if self._singleShot: + roiManager = self.getRoiManager() + if roiManager is not None: + roiManager.stop() + + def finalizeRoi(self, roi): + """Inherit it to custom the new ROI after it's creation when the + interaction is finalized.""" + pass + + +class RoiModeSelector(qt.QWidget): + def __init__(self, parent=None): + super(RoiModeSelector, self).__init__(parent=parent) + self.__roi = None + self.__reentrant = LockReentrant() + + layout = qt.QHBoxLayout(self) + if isinstance(parent, qt.QMenu): + margins = layout.contentsMargins() + layout.setContentsMargins(margins.left(), 0, margins.right(), 0) + else: + layout.setContentsMargins(0, 0, 0, 0) + + self._label = qt.QLabel(self) + self._label.setText("Mode:") + self._label.setToolTip("Select a specific interaction to edit the ROI") + self._combo = qt.QComboBox(self) + self._combo.currentIndexChanged.connect(self._modeSelected) + layout.addWidget(self._label) + layout.addWidget(self._combo) + self._updateAvailableModes() + + def getRoi(self): + """Returns the edited ROI. + + :rtype: roi_items.RegionOfInterest + """ + return self.__roi + + def setRoi(self, roi): + """Returns the edited ROI. + + :rtype: roi_items.RegionOfInterest + """ + if self.__roi is roi: + return + if not isinstance(roi, roi_items.InteractionModeMixIn): + self.__roi = None + self._updateAvailableModes() + return + + if self.__roi is not None: + self.__roi.sigInteractionModeChanged.disconnect(self._modeChanged) + self.__roi = roi + if self.__roi is not None: + self.__roi.sigInteractionModeChanged.connect(self._modeChanged) + self._updateAvailableModes() + + def isEmpty(self): + return not self._label.isVisibleTo(self) + + def _updateAvailableModes(self): + roi = self.getRoi() + if isinstance(roi, roi_items.InteractionModeMixIn): + modes = roi.availableInteractionModes() + else: + modes = [] + if len(modes) <= 1: + self._label.setVisible(False) + self._combo.setVisible(False) + else: + self._label.setVisible(True) + self._combo.setVisible(True) + with blockSignals(self._combo): + self._combo.clear() + for im, m in enumerate(modes): + self._combo.addItem(m.label, m) + self._combo.setItemData(im, m.description, qt.Qt.ToolTipRole) + mode = roi.getInteractionMode() + self._modeChanged(mode) + index = modes.index(mode) + self._combo.setCurrentIndex(index) + + def _modeChanged(self, mode): + """Triggered when the ROI interaction mode was changed externally""" + if self.__reentrant.locked(): + # This event was initialised by the widget + return + roi = self.__roi + modes = roi.availableInteractionModes() + index = modes.index(mode) + with blockSignals(self._combo): + self._combo.setCurrentIndex(index) + + def _modeSelected(self): + """Triggered when the ROI interaction mode was selected in the widget""" + index = self._combo.currentIndex() + if index == -1: + return + roi = self.getRoi() + if roi is not None: + mode = self._combo.itemData(index, qt.Qt.UserRole) + with self.__reentrant: + roi.setInteractionMode(mode) + + +class RoiModeSelectorAction(qt.QWidgetAction): + """Display the selected mode of a ROI and allow to change it""" + + def __init__(self, parent=None): + super(RoiModeSelectorAction, self).__init__(parent) + self.__roiManager = None + + def createWidget(self, parent): + """Inherit the method to create a new widget""" + widget = RoiModeSelector(parent) + manager = self.__roiManager + if manager is not None: + roi = manager.getCurrentRoi() + widget.setRoi(roi) + self.setVisible(not widget.isEmpty()) + return widget + + def deleteWidget(self, widget): + """Inherit the method to delete a widget""" + widget.setRoi(None) + return qt.QWidgetAction.deleteWidget(self, widget) + + def setRoiManager(self, roiManager): + """ + Connect this action to a ROI manager. + + :param RegionOfInterestManager roiManager: A ROI manager + """ + if self.__roiManager is roiManager: + return + if self.__roiManager is not None: + self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged) + self.__roiManager = roiManager + if self.__roiManager is not None: + self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged) + self.__currentRoiChanged(roiManager.getCurrentRoi()) + + def __currentRoiChanged(self, roi): + """Handle changes of the selected ROI""" + self.setRoi(roi) + + def setRoi(self, roi): + """Set a profile ROI to edit. + + :param ProfileRoiMixIn roi: A profile ROI + """ + widget = None + for widget in self.createdWidgets(): + widget.setRoi(roi) + if widget is not None: + self.setVisible(not widget.isEmpty()) + + +class RegionOfInterestManager(qt.QObject): + """Class handling ROI interaction on a PlotWidget. + + It supports the multiple ROIs: points, rectangles, polygons, + lines, horizontal and vertical lines. + + See ``plotInteractiveImageROI.py`` sample code (:ref:`sample-code`). + + :param silx.gui.plot.PlotWidget parent: + The plot widget in which to control the ROIs. + """ + + sigRoiAdded = qt.Signal(roi_items.RegionOfInterest) + """Signal emitted when a new ROI has been added. + + It provides the newly add :class:`RegionOfInterest` object. + """ + + sigRoiAboutToBeRemoved = qt.Signal(roi_items.RegionOfInterest) + """Signal emitted just before a ROI is removed. + + It provides the :class:`RegionOfInterest` object that is about to be removed. + """ + + sigRoiChanged = qt.Signal() + """Signal emitted whenever the ROIs have changed.""" + + sigCurrentRoiChanged = qt.Signal(object) + """Signal emitted whenever a ROI is selected.""" + + sigInteractiveModeStarted = qt.Signal(object) + """Signal emitted when switching to ROI drawing interactive mode. + + It provides the class of the ROI which will be created by the interactive + mode. + """ + + sigInteractiveRoiCreated = qt.Signal(object) + """Signal emitted when a ROI is created during the interaction. + The interaction is still incomplete and can be aborted. + + It provides the ROI object which was just been created. + """ + + sigInteractiveRoiFinalized = qt.Signal(object) + """Signal emitted when a ROI creation is complet. + + It provides the ROI object which was just been created. + """ + + sigInteractiveModeFinished = qt.Signal() + """Signal emitted when leaving interactive ROI drawing mode. + """ + + ROI_CLASSES = ( + roi_items.PointROI, + roi_items.CrossROI, + roi_items.RectangleROI, + roi_items.CircleROI, + roi_items.EllipseROI, + roi_items.PolygonROI, + roi_items.LineROI, + roi_items.HorizontalLineROI, + roi_items.VerticalLineROI, + roi_items.ArcROI, + roi_items.HorizontalRangeROI, + ) + + def __init__(self, parent): + assert isinstance(parent, PlotWidget) + super(RegionOfInterestManager, self).__init__(parent) + self._rois = [] # List of ROIs + self._drawnROI = None # New ROI being currently drawn + + self._roiClass = None + self._source = None + self._color = rgba('red') + + self._label = "__RegionOfInterestManager__%d" % id(self) + + self._currentRoi = None + """Hold currently selected ROI""" + + self._eventLoop = None + + self._modeActions = {} + + parent.sigPlotSignal.connect(self._plotSignals) + + parent.sigInteractiveModeChanged.connect( + self._plotInteractiveModeChanged) + + parent.sigItemRemoved.connect(self._itemRemoved) + + parent._sigDefaultContextMenu.connect(self._feedContextMenu) + + @classmethod + def getSupportedRoiClasses(cls): + """Returns the default available ROI classes + + :rtype: List[class] + """ + return tuple(cls.ROI_CLASSES) + + # Associated QActions + + def getInteractionModeAction(self, roiClass): + """Returns the QAction corresponding to a kind of ROI + + The QAction allows to enable the corresponding drawing + interactive mode. + + :param class roiClass: The ROI class which will be created by this action. + :rtype: QAction + :raise ValueError: If kind is not supported + """ + if not issubclass(roiClass, roi_items.RegionOfInterest): + raise ValueError('Unsupported ROI class %s' % roiClass) + + action = self._modeActions.get(roiClass, None) + if action is None: # Lazy-loading + action = CreateRoiModeAction(self, self, roiClass) + self._modeActions[roiClass] = action + return action + + # PlotWidget eventFilter and listeners + + def _plotInteractiveModeChanged(self, source): + """Handle change of interactive mode in the plot""" + if source is not self: + self.__roiInteractiveModeEnded() + + def _getRoiFromItem(self, item): + """Returns the ROI which own this item, else None + if this manager do not have knowledge of this ROI.""" + for roi in self._rois: + if isinstance(roi, roi_items.RegionOfInterest): + for child in roi.getItems(): + if child is item: + return roi + return None + + def _itemRemoved(self, item): + """Called after an item was removed from the plot.""" + if not hasattr(item, "_roiGroup"): + # Early break to avoid to use _getRoiFromItem + # And to avoid reentrant signal when the ROI remove the item itself + return + roi = self._getRoiFromItem(item) + if roi is not None: + self.removeRoi(roi) + + # Handle ROI interaction + + def _handleInteraction(self, event): + """Handle mouse interaction for ROI addition""" + roiClass = self.getCurrentInteractionModeRoiClass() + if roiClass is None: + return # Should not happen + + kind = roiClass.getFirstInteractionShape() + if kind == 'point': + if event['event'] == 'mouseClicked' and event['button'] == 'left': + points = numpy.array([(event['x'], event['y'])], + dtype=numpy.float64) + # Not an interactive creation + roi = self._createInteractiveRoi(roiClass, points=points) + roi.creationFinalized() + self.sigInteractiveRoiFinalized.emit(roi) + else: # other shapes + if (event['event'] in ('drawingProgress', 'drawingFinished') and + event['parameters']['label'] == self._label): + points = numpy.array((event['xdata'], event['ydata']), + dtype=numpy.float64).T + + if self._drawnROI is None: # Create new ROI + # NOTE: Set something before createRoi, so isDrawing is True + self._drawnROI = object() + self._drawnROI = self._createInteractiveRoi(roiClass, points=points) + else: + self._drawnROI.setFirstShapePoints(points) + + if event['event'] == 'drawingFinished': + if kind == 'polygon' and len(points) > 1: + self._drawnROI.setFirstShapePoints(points[:-1]) + roi = self._drawnROI + self._drawnROI = None # Stop drawing + roi.creationFinalized() + self.sigInteractiveRoiFinalized.emit(roi) + + # RegionOfInterest selection + + def __getRoiFromMarker(self, marker): + """Returns a ROI from a marker, else None""" + # This should be speed up + for roi in self._rois: + if isinstance(roi, roi_items.HandleBasedROI): + for m in roi.getHandles(): + if m is marker: + return roi + else: + for m in roi.getItems(): + if m is marker: + return roi + return None + + def setCurrentRoi(self, roi): + """Set the currently selected ROI, and emit a signal. + + :param Union[RegionOfInterest,None] roi: The ROI to select + """ + if self._currentRoi is roi: + return + if roi is not None: + # Note: Fixed range to avoid infinite loops + for _ in range(10): + target = roi.getFocusProxy() + if target is None: + break + roi = target + else: + raise RuntimeError("Max selection proxy depth (10) reached.") + + if self._currentRoi is not None: + self._currentRoi.setHighlighted(False) + self._currentRoi = roi + if self._currentRoi is not None: + self._currentRoi.setHighlighted(True) + self.sigCurrentRoiChanged.emit(roi) + + def getCurrentRoi(self): + """Returns the currently selected ROI, else None. + + :rtype: Union[RegionOfInterest,None] + """ + return self._currentRoi + + def _plotSignals(self, event): + """Handle mouse interaction for ROI addition""" + clicked = False + roi = None + if event["event"] in ("markerClicked", "markerMoving"): + plot = self.parent() + legend = event["label"] + marker = plot._getMarker(legend=legend) + roi = self.__getRoiFromMarker(marker) + elif event["event"] == "mouseClicked" and event["button"] == "left": + # Marker click is only for dnd + # This also can click on a marker + clicked = True + plot = self.parent() + marker = plot._getMarkerAt(event["xpixel"], event["ypixel"]) + roi = self.__getRoiFromMarker(marker) + else: + return + + if roi not in self._rois: + # The ROI is not own by this manager + return + + if roi is not None: + currentRoi = self.getCurrentRoi() + if currentRoi is roi: + if clicked: + self.__updateMode(roi) + elif roi.isSelectable(): + self.setCurrentRoi(roi) + else: + self.setCurrentRoi(None) + + def __updateMode(self, roi): + if isinstance(roi, roi_items.InteractionModeMixIn): + available = roi.availableInteractionModes() + mode = roi.getInteractionMode() + imode = available.index(mode) + mode = available[(imode + 1) % len(available)] + roi.setInteractionMode(mode) + + def _feedContextMenu(self, menu): + """Called when the default plot context menu is about to be displayed""" + roi = self.getCurrentRoi() + if roi is not None: + if roi.isEditable(): + # Filter by data position + # FIXME: It would be better to use GUI coords for it + plot = self.parent() + pos = plot.getWidgetHandle().mapFromGlobal(qt.QCursor.pos()) + data = plot.pixelToData(pos.x(), pos.y()) + if roi.contains(data): + if isinstance(roi, roi_items.InteractionModeMixIn): + self._contextMenuForInteractionMode(menu, roi) + + removeAction = qt.QAction(menu) + removeAction.setText("Remove %s" % roi.getName()) + callback = functools.partial(self.removeRoi, roi) + removeAction.triggered.connect(callback) + menu.addAction(removeAction) + + def _contextMenuForInteractionMode(self, menu, roi): + availableModes = roi.availableInteractionModes() + currentMode = roi.getInteractionMode() + submenu = qt.QMenu(menu) + modeGroup = qt.QActionGroup(menu) + modeGroup.setExclusive(True) + for mode in availableModes: + action = qt.QAction(menu) + action.setText(mode.label) + action.setToolTip(mode.description) + action.setCheckable(True) + if mode is currentMode: + action.setChecked(True) + else: + callback = functools.partial(roi.setInteractionMode, mode) + action.triggered.connect(callback) + modeGroup.addAction(action) + submenu.addAction(action) + submenu.setTitle("%s interaction mode" % roi.getName()) + menu.addMenu(submenu) + + # RegionOfInterest API + + def getRois(self): + """Returns the list of ROIs. + + It returns an empty tuple if there is currently no ROI. + + :return: Tuple of arrays of objects describing the ROIs + :rtype: List[RegionOfInterest] + """ + return tuple(self._rois) + + def clear(self): + """Reset current ROIs + + :return: True if ROIs were reset. + :rtype: bool + """ + if self.getRois(): # Something to reset + for roi in self._rois: + roi.sigRegionChanged.disconnect( + self._regionOfInterestChanged) + roi.setParent(None) + self._rois = [] + self._roisUpdated() + return True + + else: + return False + + def _regionOfInterestChanged(self, event=None): + """Handle ROI object changed""" + self.sigRoiChanged.emit() + + def _createInteractiveRoi(self, roiClass, points, label=None, index=None): + """Create a new ROI with interactive creation. + + :param class roiClass: The class of the ROI to create + :param numpy.ndarray points: The first shape used to create the ROI + :param str label: The label to display along with the ROI. + :param int index: The position where to insert the ROI. + By default it is appended to the end of the list. + :return: The created ROI object + :rtype: roi_items.RegionOfInterest + :raise RuntimeError: When ROI cannot be added because the maximum + number of ROIs has been reached. + """ + roi = roiClass(parent=None) + if label is not None: + roi.setName(str(label)) + roi.creationStarted() + roi.setFirstShapePoints(points) + + self.addRoi(roi, index) + if roi.isSelectable(): + self.setCurrentRoi(roi) + self.sigInteractiveRoiCreated.emit(roi) + return roi + + def containsRoi(self, roi): + """Returns true if the ROI is part of this manager. + + :param roi_items.RegionOfInterest roi: The ROI to add + :rtype: bool + """ + return roi in self._rois + + def addRoi(self, roi, index=None, useManagerColor=True): + """Add the ROI to the list of ROIs. + + :param roi_items.RegionOfInterest roi: The ROI to add + :param int index: The position where to insert the ROI, + By default it is appended to the end of the list of ROIs + :param bool useManagerColor: + Whether to set the ROI color to the default one of the manager or not. + (Default: True). + :raise RuntimeError: When ROI cannot be added because the maximum + number of ROIs has been reached. + """ + plot = self.parent() + if plot is None: + raise RuntimeError( + 'Cannot add ROI: PlotWidget no more available') + + roi.setParent(self) + + if useManagerColor: + roi.setColor(self.getColor()) + + roi.sigRegionChanged.connect(self._regionOfInterestChanged) + roi.sigItemChanged.connect(self._regionOfInterestChanged) + + if index is None: + self._rois.append(roi) + else: + self._rois.insert(index, roi) + self.sigRoiAdded.emit(roi) + self._roisUpdated() + + def removeRoi(self, roi): + """Remove a ROI from the list of ROIs. + + :param roi_items.RegionOfInterest roi: The ROI to remove + :raise ValueError: When ROI does not belong to this object + """ + if not (isinstance(roi, roi_items.RegionOfInterest) and + roi.parent() is self and + roi in self._rois): + raise ValueError( + 'RegionOfInterest does not belong to this instance') + + roi.sigAboutToBeRemoved.emit() + self.sigRoiAboutToBeRemoved.emit(roi) + + if roi is self._currentRoi: + self.setCurrentRoi(None) + + mustRestart = False + if roi is self._drawnROI: + self._drawnROI = None + mustRestart = True + self._rois.remove(roi) + roi.sigRegionChanged.disconnect(self._regionOfInterestChanged) + roi.sigItemChanged.disconnect(self._regionOfInterestChanged) + roi.setParent(None) + self._roisUpdated() + + if mustRestart: + self._restart() + + def _roisUpdated(self): + """Handle update of the ROI list""" + self.sigRoiChanged.emit() + + # RegionOfInterest parameters + + def getColor(self): + """Return the default color of created ROIs + + :rtype: QColor + """ + return qt.QColor.fromRgbF(*self._color) + + def setColor(self, color): + """Set the default color to use when creating ROIs. + + Existing ROIs are not affected. + + :param color: The color to use for displaying ROIs as + either a color name, a QColor, a list of uint8 or float in [0, 1]. + """ + self._color = rgba(color) + + # Control ROI + + def getCurrentInteractionModeRoiClass(self): + """Returns the current ROI class used by the interactive drawing mode. + + Returns None if the ROI manager is not in an interactive mode. + + :rtype: Union[class,None] + """ + return self._roiClass + + def getInteractionSource(self): + """Returns the object which have requested the ROI creation. + + Returns None if the ROI manager is not in an interactive mode. + + :rtype: Union[object,None] + """ + return self._source + + def isStarted(self): + """Returns True if an interactive ROI drawing mode is active. + + :rtype: bool + """ + return self._roiClass is not None + + def isDrawing(self): + """Returns True if an interactive ROI is drawing. + + :rtype: bool + """ + return self._drawnROI is not None + + def start(self, roiClass, source=None): + """Start an interactive ROI drawing mode. + + :param class roiClass: The ROI class to create. It have to inherite from + `roi_items.RegionOfInterest`. + :param object source: SOurce of the ROI interaction. + :return: True if interactive ROI drawing was started, False otherwise + :rtype: bool + :raise ValueError: If roiClass is not supported + """ + self.stop() + + if not issubclass(roiClass, roi_items.RegionOfInterest): + raise ValueError('Unsupported ROI class %s' % roiClass) + + plot = self.parent() + if plot is None: + return False + + self._roiClass = roiClass + self._source = source + + self._restart() + + plot.sigPlotSignal.connect(self._handleInteraction) + + self.sigInteractiveModeStarted.emit(roiClass) + + return True + + def _restart(self): + """Restart the plot interaction without changing the + source or the ROI class. + """ + roiClass = self._roiClass + plot = self.parent() + firstInteractionShapeKind = roiClass.getFirstInteractionShape() + + if firstInteractionShapeKind == 'point': + plot.setInteractiveMode(mode='select', source=self) + else: + if roiClass.showFirstInteractionShape(): + color = rgba(self.getColor()) + else: + color = None + plot.setInteractiveMode(mode='select-draw', + source=self, + shape=firstInteractionShapeKind, + color=color, + label=self._label) + + def __roiInteractiveModeEnded(self): + """Handle end of ROI draw interactive mode""" + if self.isStarted(): + self._roiClass = None + self._source = None + + if self._drawnROI is not None: + # Cancel ROI create + roi = self._drawnROI + self._drawnROI = None + self.removeRoi(roi) + + plot = self.parent() + if plot is not None: + plot.sigPlotSignal.disconnect(self._handleInteraction) + + self.sigInteractiveModeFinished.emit() + + def stop(self): + """Stop interactive ROI drawing mode. + + :return: True if an interactive ROI drawing mode was actually stopped + :rtype: bool + """ + if not self.isStarted(): + return False + + plot = self.parent() + if plot is not None: + # This leads to call __roiInteractiveModeEnded through + # interactive mode changed signal + plot.resetInteractiveMode() + else: # Fallback + self.__roiInteractiveModeEnded() + + return True + + def exec(self, roiClass): + """Block until :meth:`quit` is called. + + :param class kind: The class of the ROI which have to be created. + See `silx.gui.plot.items.roi`. + :return: The list of ROIs + :rtype: tuple + """ + self.start(roiClass) + + plot = self.parent() + plot.show() + plot.raise_() + + self._eventLoop = qt.QEventLoop() + self._eventLoop.exec() + self._eventLoop = None + + self.stop() + + rois = self.getRois() + self.clear() + return rois + + def exec_(self, roiClass): # Qt5-like compatibility + return self.exec(roiClass) + + def quit(self): + """Stop a blocking :meth:`exec` and call :meth:`stop`""" + if self._eventLoop is not None: + self._eventLoop.quit() + self._eventLoop = None + self.stop() + + +class InteractiveRegionOfInterestManager(RegionOfInterestManager): + """RegionOfInterestManager with features for use from interpreter. + + It is meant to be used through the :meth:`exec`. + It provides some messages to display in a status bar and + different modes to end blocking calls to :meth:`exec`. + + :param parent: See QObject + """ + + sigMessageChanged = qt.Signal(str) + """Signal emitted when a new message should be displayed to the user + + It provides the message as a str. + """ + + def __init__(self, parent): + super(InteractiveRegionOfInterestManager, self).__init__(parent) + self._maxROI = None + self.__timeoutEndTime = None + self.__message = '' + self.__validationMode = self.ValidationMode.ENTER + self.__execClass = None + + self.sigRoiAdded.connect(self.__added) + self.sigRoiAboutToBeRemoved.connect(self.__aboutToBeRemoved) + self.sigInteractiveModeStarted.connect(self.__started) + self.sigInteractiveModeFinished.connect(self.__finished) + + # Max ROI + + def getMaxRois(self): + """Returns the maximum number of ROIs or None if no limit. + + :rtype: Union[int,None] + """ + return self._maxROI + + def setMaxRois(self, max_): + """Set the maximum number of ROIs. + + :param Union[int,None] max_: The max limit or None for no limit. + :raise ValueError: If there is more ROIs than max value + """ + if max_ is not None: + max_ = int(max_) + if max_ <= 0: + raise ValueError('Max limit must be strictly positive') + + if len(self.getRois()) > max_: + raise ValueError( + 'Cannot set max limit: Already too many ROIs') + + self._maxROI = max_ + + def isMaxRois(self): + """Returns True if the maximum number of ROIs is reached. + + :rtype: bool + """ + max_ = self.getMaxRois() + return max_ is not None and len(self.getRois()) >= max_ + + # Validation mode + + @enum.unique + class ValidationMode(enum.Enum): + """Mode of validation to leave blocking :meth:`exec`""" + + AUTO = 'auto' + """Automatically ends the interactive mode once + the user terminates the last ROI shape.""" + + ENTER = 'enter' + """Ends the interactive mode when the *Enter* key is pressed.""" + + AUTO_ENTER = 'auto_enter' + """Ends the interactive mode when reaching max ROIs or + when the *Enter* key is pressed. + """ + + NONE = 'none' + """Do not provide the user a way to end the interactive mode. + + The end of :meth:`exec` is done through :meth:`quit` or timeout. + """ + + def getValidationMode(self): + """Returns the interactive mode validation in use. + + :rtype: ValidationMode + """ + return self.__validationMode + + def setValidationMode(self, mode): + """Set the way to perform interactive mode validation. + + See :class:`ValidationMode` enumeration for the supported + validation modes. + + :param ValidationMode mode: The interactive mode validation to use. + """ + assert isinstance(mode, self.ValidationMode) + if mode != self.__validationMode: + self.__validationMode = mode + + if self.isExec(): + if (self.isMaxRois() and self.getValidationMode() in + (self.ValidationMode.AUTO, + self.ValidationMode.AUTO_ENTER)): + self.quit() + + self.__updateMessage() + + def eventFilter(self, obj, event): + if event.type() == qt.QEvent.Hide: + self.quit() + + if event.type() == qt.QEvent.KeyPress: + key = event.key() + if (key in (qt.Qt.Key_Return, qt.Qt.Key_Enter) and + self.getValidationMode() in ( + self.ValidationMode.ENTER, + self.ValidationMode.AUTO_ENTER)): + # Stop on return key pressed + self.quit() + return True # Stop further handling of this keys + + if (key in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or ( + key == qt.Qt.Key_Z and + event.modifiers() & qt.Qt.ControlModifier)): + rois = self.getRois() + if rois: # Something to undo + self.removeRoi(rois[-1]) + # Stop further handling of keys if something was undone + return True + + return super(InteractiveRegionOfInterestManager, self).eventFilter(obj, event) + + # Message API + + def getMessage(self): + """Returns the current status message. + + This message is meant to be displayed in a status bar. + + :rtype: str + """ + if self.__timeoutEndTime is None: + return self.__message + else: + remaining = self.__timeoutEndTime - time.time() + return self.__message + (' - %d seconds remaining' % + max(1, int(remaining))) + + # Listen to ROI updates + + def __added(self, *args, **kwargs): + """Handle new ROI added""" + max_ = self.getMaxRois() + if max_ is not None: + # When reaching max number of ROIs, redo last one + while len(self.getRois()) > max_: + self.removeRoi(self.getRois()[-2]) + + self.__updateMessage() + if (self.isMaxRois() and + self.getValidationMode() in (self.ValidationMode.AUTO, + self.ValidationMode.AUTO_ENTER)): + self.quit() + + def __aboutToBeRemoved(self, *args, **kwargs): + """Handle removal of a ROI""" + # RegionOfInterest not removed yet + self.__updateMessage(nbrois=len(self.getRois()) - 1) + + def __started(self, roiKind): + """Handle interactive mode started""" + self.__updateMessage() + + def __finished(self): + """Handle interactive mode finished""" + self.__updateMessage() + + def __updateMessage(self, nbrois=None): + """Update message""" + if not self.isExec(): + message = 'Done' + + elif not self.isStarted(): + message = 'Use %s ROI edition mode' % self.__execClass + + else: + if nbrois is None: + nbrois = len(self.getRois()) + + name = self.__execClass._getShortName() + + max_ = self.getMaxRois() + if max_ is None: + message = 'Select %ss (%d selected)' % (name, nbrois) + + elif max_ <= 1: + message = 'Select a %s' % name + else: + message = 'Select %d/%d %ss' % (nbrois, max_, name) + + if (self.getValidationMode() == self.ValidationMode.ENTER and + self.isMaxRois()): + message += ' - Press Enter to confirm' + + if message != self.__message: + self.__message = message + # Use getMessage to add timeout message + self.sigMessageChanged.emit(self.getMessage()) + + # Handle blocking call + + def __timeoutUpdate(self): + """Handle update of timeout""" + if (self.__timeoutEndTime is not None and + (self.__timeoutEndTime - time.time()) > 0): + self.sigMessageChanged.emit(self.getMessage()) + else: # Stop interactive mode and message timer + timer = self.sender() + if timer is not None: + timer.stop() + self.__timeoutEndTime = None + self.quit() + + def isExec(self): + """Returns True if :meth:`exec` is currently running. + + :rtype: bool""" + return self.__execClass is not None + + def exec(self, roiClass, timeout=0): + """Block until ROI selection is done or timeout is elapsed. + + :meth:`quit` also ends this blocking call. + + :param class roiClass: The class of the ROI which have to be created. + See `silx.gui.plot.items.roi`. + :param int timeout: Maximum duration in seconds to block. + Default: No timeout + :return: The list of ROIs + :rtype: List[RegionOfInterest] + """ + plot = self.parent() + if plot is None: + return + + self.__execClass = roiClass + + plot.installEventFilter(self) + + if timeout > 0: + self.__timeoutEndTime = time.time() + timeout + timer = qt.QTimer(self) + timer.timeout.connect(self.__timeoutUpdate) + timer.start(1000) + + rois = super(InteractiveRegionOfInterestManager, self).exec(roiClass) + + timer.stop() + self.__timeoutEndTime = None + + else: + rois = super(InteractiveRegionOfInterestManager, self).exec(roiClass) + + plot.removeEventFilter(self) + + self.__execClass = None + self.__updateMessage() + + return rois + + def exec_(self, roiClass, timeout=0): # Qt5-like compatibility + return self.exec(roiClass, timeout) + + +class _DeleteRegionOfInterestToolButton(qt.QToolButton): + """Tool button deleting a ROI object + + :param parent: See QWidget + :param RegionOfInterest roi: The ROI to delete + """ + + def __init__(self, parent, roi): + super(_DeleteRegionOfInterestToolButton, self).__init__(parent) + self.setIcon(icons.getQIcon('remove')) + self.setToolTip("Remove this ROI") + self.__roiRef = roi if roi is None else weakref.ref(roi) + self.clicked.connect(self.__clicked) + + def __clicked(self, checked): + """Handle button clicked""" + roi = None if self.__roiRef is None else self.__roiRef() + if roi is not None: + manager = roi.parent() + if manager is not None: + manager.removeRoi(roi) + self.__roiRef = None + + +class RegionOfInterestTableWidget(qt.QTableWidget): + """Widget displaying the ROIs of a :class:`RegionOfInterestManager`""" + + def __init__(self, parent=None): + super(RegionOfInterestTableWidget, self).__init__(parent) + self._roiManagerRef = None + + headers = ['Label', 'Edit', 'Kind', 'Coordinates', ''] + self.setColumnCount(len(headers)) + self.setHorizontalHeaderLabels(headers) + + horizontalHeader = self.horizontalHeader() + horizontalHeader.setDefaultAlignment(qt.Qt.AlignLeft) + + horizontalHeader.setSectionResizeMode(0, qt.QHeaderView.Interactive) + horizontalHeader.setSectionResizeMode(1, qt.QHeaderView.ResizeToContents) + horizontalHeader.setSectionResizeMode(2, qt.QHeaderView.ResizeToContents) + horizontalHeader.setSectionResizeMode(3, qt.QHeaderView.Stretch) + horizontalHeader.setSectionResizeMode(4, qt.QHeaderView.ResizeToContents) + + verticalHeader = self.verticalHeader() + verticalHeader.setVisible(False) + + self.setSelectionMode(qt.QAbstractItemView.NoSelection) + self.setFocusPolicy(qt.Qt.NoFocus) + + self.itemChanged.connect(self.__itemChanged) + + def __itemChanged(self, item): + """Handle item updates""" + column = item.column() + index = item.data(qt.Qt.UserRole) + + if index is not None: + manager = self.getRegionOfInterestManager() + roi = manager.getRois()[index] + else: + return + + if column == 0: + # First collect information from item, then update ROI + # Otherwise, this causes issues issues + checked = item.checkState() == qt.Qt.Checked + text= item.text() + roi.setVisible(checked) + roi.setName(text) + elif column == 1: + roi.setEditable(item.checkState() == qt.Qt.Checked) + elif column in (2, 3, 4): + pass # TODO + else: + logger.error('Unhandled column %d', column) + + def setRegionOfInterestManager(self, manager): + """Set the :class:`RegionOfInterestManager` object to sync with + + :param RegionOfInterestManager manager: + """ + assert manager is None or isinstance(manager, RegionOfInterestManager) + + previousManager = self.getRegionOfInterestManager() + + if previousManager is not None: + previousManager.sigRoiChanged.disconnect(self._sync) + self.setRowCount(0) + + self._roiManagerRef = weakref.ref(manager) + + self._sync() + + if manager is not None: + manager.sigRoiChanged.connect(self._sync) + + def _getReadableRoiDescription(self, roi): + """Returns modelisation of a ROI as a readable sequence of values. + + :rtype: str + """ + text = str(roi) + try: + # Extract the params from syntax "CLASSNAME(PARAMS)" + elements = text.split("(", 1) + if len(elements) != 2: + return text + result = elements[1] + result = result.strip() + if not result.endswith(")"): + return text + result = result[0:-1] + # Capitalize each words + result = result.title() + return result + except Exception: + logger.debug("Backtrace", exc_info=True) + return text + + def _sync(self): + """Update widget content according to ROI manger""" + manager = self.getRegionOfInterestManager() + + if manager is None: + self.setRowCount(0) + return + + rois = manager.getRois() + + self.setRowCount(len(rois)) + for index, roi in enumerate(rois): + baseFlags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled + + # Label and visible + label = roi.getName() + item = qt.QTableWidgetItem(label) + item.setFlags(baseFlags | qt.Qt.ItemIsEditable | qt.Qt.ItemIsUserCheckable) + item.setData(qt.Qt.UserRole, index) + item.setCheckState( + qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked) + self.setItem(index, 0, item) + + # Editable + item = qt.QTableWidgetItem() + item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable) + item.setData(qt.Qt.UserRole, index) + item.setCheckState( + qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked) + self.setItem(index, 1, item) + item.setTextAlignment(qt.Qt.AlignCenter) + item.setText(None) + + # Kind + label = roi._getShortName() + if label is None: + # Default value if kind is not overrided + label = roi.__class__.__name__ + item = qt.QTableWidgetItem(label.capitalize()) + item.setFlags(baseFlags) + self.setItem(index, 2, item) + + item = qt.QTableWidgetItem() + item.setFlags(baseFlags) + + # Coordinates + text = self._getReadableRoiDescription(roi) + item.setText(text) + self.setItem(index, 3, item) + + # Delete + delBtn = _DeleteRegionOfInterestToolButton(None, roi) + widget = qt.QWidget(self) + layout = qt.QHBoxLayout() + layout.setContentsMargins(2, 2, 2, 2) + layout.setSpacing(0) + widget.setLayout(layout) + layout.addStretch(1) + layout.addWidget(delBtn) + layout.addStretch(1) + self.setCellWidget(index, 4, widget) + + def getRegionOfInterestManager(self): + """Returns the :class:`RegionOfInterestManager` this widget supervise. + + It returns None if not sync with an :class:`RegionOfInterestManager`. + + :rtype: RegionOfInterestManager + """ + return None if self._roiManagerRef is None else self._roiManagerRef() diff --git a/src/silx/gui/plot/tools/test/__init__.py b/src/silx/gui/plot/tools/test/__init__.py new file mode 100644 index 0000000..aa4a601 --- /dev/null +++ b/src/silx/gui/plot/tools/test/__init__.py @@ -0,0 +1,24 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ diff --git a/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py new file mode 100644 index 0000000..37af10e --- /dev/null +++ b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py @@ -0,0 +1,113 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/08/2018" + + +import unittest + +from silx.gui import qt +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import TestCaseQt +from silx.gui.plot import PlotWindow +from silx.gui.plot.tools import CurveLegendsWidget + + +class TestCurveLegendsWidget(TestCaseQt, ParametricTestCase): + """Tests for CurveLegendsWidget class""" + + def setUp(self): + super(TestCurveLegendsWidget, self).setUp() + self.plot = PlotWindow() + + self.legends = CurveLegendsWidget.CurveLegendsWidget() + self.legends.setPlotWidget(self.plot) + + dock = qt.QDockWidget() + dock.setWindowTitle('Curve Legends') + dock.setWidget(self.legends) + self.plot.addTabbedDockWidget(dock) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + del self.legends + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestCurveLegendsWidget, self).tearDown() + + def _assertNbLegends(self, count): + """Check the number of legends in the CurveLegendsWidget""" + children = self.legends.findChildren(CurveLegendsWidget._LegendWidget) + self.assertEqual(len(children), count) + + def testAddRemoveCurves(self): + """Test CurveLegendsWidget while adding/removing curves""" + self.plot.addCurve((0, 1), (1, 2), legend='a') + self._assertNbLegends(1) + self.plot.addCurve((0, 1), (2, 3), legend='b') + self._assertNbLegends(2) + + # Detached/attach + self.legends.setPlotWidget(None) + self._assertNbLegends(0) + + self.legends.setPlotWidget(self.plot) + self._assertNbLegends(2) + + self.plot.clear() + self._assertNbLegends(0) + + def testUpdateCurves(self): + """Test CurveLegendsWidget while updating curves """ + self.plot.addCurve((0, 1), (1, 2), legend='a') + self._assertNbLegends(1) + self.plot.addCurve((0, 1), (2, 3), legend='b') + self._assertNbLegends(2) + + # Activate curve + self.plot.setActiveCurve('a') + self.qapp.processEvents() + self.plot.setActiveCurve('b') + self.qapp.processEvents() + + # Change curve style + curve = self.plot.getCurve('a') + curve.setLineWidth(2) + for linestyle in (':', '', '--', '-'): + with self.subTest(linestyle=linestyle): + curve.setLineStyle(linestyle) + self.qapp.processEvents() + self.qWait(1000) + + for symbol in ('o', 'd', '', 's'): + with self.subTest(symbol=symbol): + curve.setSymbol(symbol) + self.qapp.processEvents() + self.qWait(1000) diff --git a/src/silx/gui/plot/tools/test/testProfile.py b/src/silx/gui/plot/tools/test/testProfile.py new file mode 100644 index 0000000..829f49e --- /dev/null +++ b/src/silx/gui/plot/tools/test/testProfile.py @@ -0,0 +1,654 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import unittest +import contextlib +import numpy +import logging + +from silx.gui import qt +from silx.utils import deprecation +from silx.utils import testutils + +from silx.gui.utils.testutils import TestCaseQt +from silx.utils.testutils import ParametricTestCase +from silx.gui.plot import PlotWindow, Plot1D, Plot2D, Profile +from silx.gui.plot.StackView import StackView +from silx.gui.plot.tools.profile import rois +from silx.gui.plot.tools.profile import editors +from silx.gui.plot.items import roi as roi_items +from silx.gui.plot.tools.profile import manager +from silx.gui import plot as silx_plot + +_logger = logging.getLogger(__name__) + + +class TestRois(TestCaseQt): + + def test_init(self): + """Check that the constructor is not called twice""" + roi = rois.ProfileImageVerticalLineROI() + if qt.BINDING == "PyQt5": + # the profile ROI + the shape + self.assertEqual(roi.receivers(roi.sigRegionChanged), 2) + + +class TestInteractions(TestCaseQt): + + @contextlib.contextmanager + def defaultPlot(self): + try: + widget = silx_plot.PlotWidget() + widget.show() + self.qWaitForWindowExposed(widget) + yield widget + finally: + widget.close() + widget = None + self.qWait() + + @contextlib.contextmanager + def imagePlot(self): + try: + widget = silx_plot.Plot2D() + image = numpy.arange(10 * 10).reshape(10, -1) + widget.addImage(image) + widget.show() + self.qWaitForWindowExposed(widget) + yield widget + finally: + widget.close() + widget = None + self.qWait() + + @contextlib.contextmanager + def scatterPlot(self): + try: + widget = silx_plot.ScatterView() + + nbX, nbY = 7, 5 + yy = numpy.atleast_2d(numpy.ones(nbY)).T + xx = numpy.atleast_2d(numpy.ones(nbX)) + positionX = numpy.linspace(10, 50, nbX) * yy + positionX = positionX.reshape(nbX * nbY) + positionY = numpy.atleast_2d(numpy.linspace(20, 60, nbY)).T * xx + positionY = positionY.reshape(nbX * nbY) + values = numpy.arange(nbX * nbY) + + widget.setData(positionX, positionY, values) + widget.resetZoom() + widget.show() + self.qWaitForWindowExposed(widget) + yield widget.getPlotWidget() + finally: + widget.close() + widget = None + self.qWait() + + @contextlib.contextmanager + def stackPlot(self): + try: + widget = silx_plot.StackView() + image = numpy.arange(10 * 10).reshape(10, -1) + cube = numpy.array([image, image, image]) + widget.setStack(cube) + widget.resetZoom() + widget.show() + self.qWaitForWindowExposed(widget) + yield widget.getPlotWidget() + finally: + widget.close() + widget = None + self.qWait() + + def waitPendingOperations(self, proflie): + for _ in range(10): + if not proflie.hasPendingOperations(): + return + self.qWait(100) + _logger.error("The profile manager still have pending operations") + + def genericRoiTest(self, plot, roiClass): + profileManager = manager.ProfileManager(plot, plot) + profileManager.setItemType(image=True, scatter=True) + + try: + action = profileManager.createProfileAction(roiClass, plot) + action.triggered[bool].emit(True) + widget = plot.getWidgetHandle() + + # Do the mouse interaction + pos1 = widget.width() * 0.4, widget.height() * 0.4 + self.mouseMove(widget, pos=pos1) + self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1) + + if issubclass(roiClass, roi_items.LineROI): + pos2 = widget.width() * 0.6, widget.height() * 0.6 + self.mouseMove(widget, pos=pos2) + self.mouseClick(widget, qt.Qt.LeftButton, pos=pos2) + + self.waitPendingOperations(profileManager) + + # Test that something was computed + if issubclass(roiClass, rois._ProfileCrossROI): + self.assertEqual(profileManager._computedProfiles, 2) + elif issubclass(roiClass, roi_items.LineROI): + self.assertGreaterEqual(profileManager._computedProfiles, 1) + else: + self.assertEqual(profileManager._computedProfiles, 1) + + # Test the created ROIs + profileRois = profileManager.getRoiManager().getRois() + if issubclass(roiClass, rois._ProfileCrossROI): + self.assertEqual(len(profileRois), 3) + else: + self.assertEqual(len(profileRois), 1) + # The first one should be the expected one + roi = profileRois[0] + + # Test that something was displayed + if issubclass(roiClass, rois._ProfileCrossROI): + profiles = roi._getLines() + window = profiles[0].getProfileWindow() + self.assertIsNotNone(window) + window = profiles[1].getProfileWindow() + self.assertIsNotNone(window) + else: + window = roi.getProfileWindow() + self.assertIsNotNone(window) + finally: + profileManager.clearProfile() + + def testImageActions(self): + roiClasses = [ + rois.ProfileImageHorizontalLineROI, + rois.ProfileImageVerticalLineROI, + rois.ProfileImageLineROI, + rois.ProfileImageCrossROI, + ] + with self.imagePlot() as plot: + for roiClass in roiClasses: + with self.subTest(roiClass=roiClass): + self.genericRoiTest(plot, roiClass) + + def testScatterActions(self): + roiClasses = [ + rois.ProfileScatterHorizontalLineROI, + rois.ProfileScatterVerticalLineROI, + rois.ProfileScatterLineROI, + rois.ProfileScatterCrossROI, + rois.ProfileScatterHorizontalSliceROI, + rois.ProfileScatterVerticalSliceROI, + rois.ProfileScatterCrossSliceROI, + ] + with self.scatterPlot() as plot: + for roiClass in roiClasses: + with self.subTest(roiClass=roiClass): + self.genericRoiTest(plot, roiClass) + + def testStackActions(self): + roiClasses = [ + rois.ProfileImageStackHorizontalLineROI, + rois.ProfileImageStackVerticalLineROI, + rois.ProfileImageStackLineROI, + rois.ProfileImageStackCrossROI, + ] + with self.stackPlot() as plot: + for roiClass in roiClasses: + with self.subTest(roiClass=roiClass): + self.genericRoiTest(plot, roiClass) + + def genericEditorTest(self, plot, roi, editor): + if isinstance(editor, editors._NoProfileRoiEditor): + pass + elif isinstance(editor, editors._DefaultImageStackProfileRoiEditor): + # GUI to ROI + editor._lineWidth.setValue(2) + self.assertEqual(roi.getProfileLineWidth(), 2) + editor._methodsButton.setMethod("sum") + self.assertEqual(roi.getProfileMethod(), "sum") + editor._profileDim.setDimension(1) + self.assertEqual(roi.getProfileType(), "1D") + # ROI to GUI + roi.setProfileLineWidth(3) + self.assertEqual(editor._lineWidth.value(), 3) + roi.setProfileMethod("mean") + self.assertEqual(editor._methodsButton.getMethod(), "mean") + roi.setProfileType("2D") + self.assertEqual(editor._profileDim.getDimension(), 2) + elif isinstance(editor, editors._DefaultImageProfileRoiEditor): + # GUI to ROI + editor._lineWidth.setValue(2) + self.assertEqual(roi.getProfileLineWidth(), 2) + editor._methodsButton.setMethod("sum") + self.assertEqual(roi.getProfileMethod(), "sum") + # ROI to GUI + roi.setProfileLineWidth(3) + self.assertEqual(editor._lineWidth.value(), 3) + roi.setProfileMethod("mean") + self.assertEqual(editor._methodsButton.getMethod(), "mean") + elif isinstance(editor, editors._DefaultScatterProfileRoiEditor): + # GUI to ROI + editor._nPoints.setValue(100) + self.assertEqual(roi.getNPoints(), 100) + # ROI to GUI + roi.setNPoints(200) + self.assertEqual(editor._nPoints.value(), 200) + else: + assert False + + def testEditors(self): + roiClasses = [ + (rois.ProfileImageHorizontalLineROI, editors._DefaultImageProfileRoiEditor), + (rois.ProfileImageVerticalLineROI, editors._DefaultImageProfileRoiEditor), + (rois.ProfileImageLineROI, editors._DefaultImageProfileRoiEditor), + (rois.ProfileImageCrossROI, editors._DefaultImageProfileRoiEditor), + (rois.ProfileScatterHorizontalLineROI, editors._DefaultScatterProfileRoiEditor), + (rois.ProfileScatterVerticalLineROI, editors._DefaultScatterProfileRoiEditor), + (rois.ProfileScatterLineROI, editors._DefaultScatterProfileRoiEditor), + (rois.ProfileScatterCrossROI, editors._DefaultScatterProfileRoiEditor), + (rois.ProfileScatterHorizontalSliceROI, editors._NoProfileRoiEditor), + (rois.ProfileScatterVerticalSliceROI, editors._NoProfileRoiEditor), + (rois.ProfileScatterCrossSliceROI, editors._NoProfileRoiEditor), + (rois.ProfileImageStackHorizontalLineROI, editors._DefaultImageStackProfileRoiEditor), + (rois.ProfileImageStackVerticalLineROI, editors._DefaultImageStackProfileRoiEditor), + (rois.ProfileImageStackLineROI, editors._DefaultImageStackProfileRoiEditor), + (rois.ProfileImageStackCrossROI, editors._DefaultImageStackProfileRoiEditor), + ] + with self.defaultPlot() as plot: + profileManager = manager.ProfileManager(plot, plot) + editorAction = profileManager.createEditorAction(parent=plot) + for roiClass, editorClass in roiClasses: + with self.subTest(roiClass=roiClass): + roi = roiClass() + roi._setProfileManager(profileManager) + try: + # Force widget creation + menu = qt.QMenu(plot) + menu.addAction(editorAction) + widgets = editorAction.createdWidgets() + self.assertGreater(len(widgets), 0) + + editorAction.setProfileRoi(roi) + editorWidget = editorAction._getEditor(widgets[0]) + self.assertIsInstance(editorWidget, editorClass) + self.genericEditorTest(plot, roi, editorWidget) + finally: + editorAction.setProfileRoi(None) + menu.deleteLater() + menu = None + self.qapp.processEvents() + + +class TestProfileToolBar(TestCaseQt, ParametricTestCase): + """Tests for ProfileToolBar widget.""" + + def setUp(self): + super(TestProfileToolBar, self).setUp() + self.plot = PlotWindow() + self.toolBar = Profile.ProfileToolBar(plot=self.plot) + self.plot.addToolBar(self.toolBar) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.mouseMove(self.plot) # Move to center + self.qapp.processEvents() + deprecation.FORCE = True + + def tearDown(self): + deprecation.FORCE = False + self.qapp.processEvents() + profileManager = self.toolBar.getProfileManager() + profileManager.clearProfile() + profileManager = None + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + del self.toolBar + + super(TestProfileToolBar, self).tearDown() + + def testAlignedProfile(self): + """Test horizontal and vertical profile, without and with image""" + # Use Plot backend widget to submit mouse events + widget = self.plot.getWidgetHandle() + for method in ('sum', 'mean'): + with self.subTest(method=method): + # 2 positions to use for mouse events + pos1 = widget.width() * 0.4, widget.height() * 0.4 + pos2 = widget.width() * 0.6, widget.height() * 0.6 + + for action in (self.toolBar.hLineAction, self.toolBar.vLineAction): + with self.subTest(mode=action.text()): + # Trigger tool button for mode + action.trigger() + # Without image + self.mouseMove(widget, pos=pos1) + self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1) + + # with image + self.plot.addImage( + numpy.arange(100 * 100).reshape(100, -1)) + self.mousePress(widget, qt.Qt.LeftButton, pos=pos1) + self.mouseMove(widget, pos=pos2) + self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2) + + self.mouseMove(widget) + self.mouseClick(widget, qt.Qt.LeftButton) + + manager = self.toolBar.getProfileManager() + for _ in range(20): + self.qWait(200) + if not manager.hasPendingOperations(): + break + + @testutils.validate_logging(deprecation.depreclog.name, warning=4) + def testDiagonalProfile(self): + """Test diagonal profile, without and with image""" + # Use Plot backend widget to submit mouse events + widget = self.plot.getWidgetHandle() + + self.plot.addImage( + numpy.arange(100 * 100).reshape(100, -1)) + + for method in ('sum', 'mean'): + with self.subTest(method=method): + # 2 positions to use for mouse events + pos1 = widget.width() * 0.4, widget.height() * 0.4 + pos2 = widget.width() * 0.6, widget.height() * 0.6 + + # Trigger tool button for diagonal profile mode + self.toolBar.lineAction.trigger() + + # draw profile line + widget.setFocus(qt.Qt.OtherFocusReason) + self.mouseMove(widget, pos=pos1) + self.qWait(100) + self.mousePress(widget, qt.Qt.LeftButton, pos=pos1) + self.qWait(100) + self.mouseMove(widget, pos=pos2) + self.qWait(100) + self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2) + self.qWait(100) + + manager = self.toolBar.getProfileManager() + + for _ in range(20): + self.qWait(200) + if not manager.hasPendingOperations(): + break + + roi = manager.getCurrentRoi() + self.assertIsNotNone(roi) + roi.setProfileLineWidth(3) + roi.setProfileMethod(method) + + for _ in range(20): + self.qWait(200) + if not manager.hasPendingOperations(): + break + + curveItem = self.toolBar.getProfilePlot().getAllCurves()[0] + if method == 'sum': + self.assertTrue(curveItem.getData()[1].max() > 10000) + elif method == 'mean': + self.assertTrue(curveItem.getData()[1].max() < 10000) + + # Remove the ROI so the profile window is also removed + roiManager = manager.getRoiManager() + roiManager.removeRoi(roi) + self.qWait(100) + + +class TestDeprecatedProfileToolBar(TestCaseQt): + """Tests old features of the ProfileToolBar widget.""" + + def setUp(self): + self.plot = None + super(TestDeprecatedProfileToolBar, self).setUp() + + def tearDown(self): + if self.plot is not None: + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.plot = None + self.qWait() + + super(TestDeprecatedProfileToolBar, self).tearDown() + + @testutils.validate_logging(deprecation.depreclog.name, warning=2) + def testCustomProfileWindow(self): + from silx.gui.plot import ProfileMainWindow + + self.plot = PlotWindow() + profileWindow = ProfileMainWindow.ProfileMainWindow(self.plot) + toolBar = Profile.ProfileToolBar(parent=self.plot, + plot=self.plot, + profileWindow=profileWindow) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + profileWindow.show() + self.qWaitForWindowExposed(profileWindow) + self.qapp.processEvents() + + self.plot.addImage(numpy.arange(10 * 10).reshape(10, -1)) + profile = rois.ProfileImageHorizontalLineROI() + profile.setPosition(5) + toolBar.getProfileManager().getRoiManager().addRoi(profile) + toolBar.getProfileManager().getRoiManager().setCurrentRoi(profile) + + for _ in range(20): + self.qWait(200) + if not toolBar.getProfileManager().hasPendingOperations(): + break + + # There is a displayed profile + self.assertIsNotNone(profileWindow.getProfile()) + self.assertIs(toolBar.getProfileMainWindow(), profileWindow) + + # There is nothing anymore but the window is still there + toolBar.getProfileManager().clearProfile() + self.qapp.processEvents() + self.assertIsNone(profileWindow.getProfile()) + + +class TestProfile3DToolBar(TestCaseQt): + """Tests for Profile3DToolBar widget. + """ + def setUp(self): + super(TestProfile3DToolBar, self).setUp() + self.plot = StackView() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.plot.setStack(numpy.array([ + [[0, 1, 2], [3, 4, 5]], + [[6, 7, 8], [9, 10, 11]], + [[12, 13, 14], [15, 16, 17]] + ])) + deprecation.FORCE = True + + def tearDown(self): + deprecation.FORCE = False + profileManager = self.plot.getProfileToolbar().getProfileManager() + profileManager.clearProfile() + profileManager = None + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.plot = None + + super(TestProfile3DToolBar, self).tearDown() + + @testutils.validate_logging(deprecation.depreclog.name, warning=2) + def testMethodProfile2D(self): + """Test that the profile can have a different method if we want to + compute then in 1D or in 2D""" + + toolBar = self.plot.getProfileToolbar() + + toolBar.vLineAction.trigger() + plot2D = self.plot.getPlotWidget().getWidgetHandle() + pos1 = plot2D.width() * 0.5, plot2D.height() * 0.5 + self.mouseClick(plot2D, qt.Qt.LeftButton, pos=pos1) + + manager = toolBar.getProfileManager() + roi = manager.getCurrentRoi() + roi.setProfileMethod("mean") + roi.setProfileType("2D") + roi.setProfileLineWidth(3) + + for _ in range(20): + self.qWait(200) + if not manager.hasPendingOperations(): + break + + # check 2D 'mean' profile + profilePlot = toolBar.getProfilePlot() + data = profilePlot.getAllImages()[0].getData() + expected = numpy.array([[1, 4], [7, 10], [13, 16]]) + numpy.testing.assert_almost_equal(data, expected) + + @testutils.validate_logging(deprecation.depreclog.name, warning=2) + def testMethodSumLine(self): + """Simple interaction test to make sure the sum is correctly computed + """ + toolBar = self.plot.getProfileToolbar() + + toolBar.lineAction.trigger() + plot2D = self.plot.getPlotWidget().getWidgetHandle() + pos1 = plot2D.width() * 0.5, plot2D.height() * 0.2 + pos2 = plot2D.width() * 0.5, plot2D.height() * 0.8 + + self.mouseMove(plot2D, pos=pos1) + self.mousePress(plot2D, qt.Qt.LeftButton, pos=pos1) + self.mouseMove(plot2D, pos=pos2) + self.mouseRelease(plot2D, qt.Qt.LeftButton, pos=pos2) + + manager = toolBar.getProfileManager() + roi = manager.getCurrentRoi() + roi.setProfileMethod("sum") + roi.setProfileType("2D") + roi.setProfileLineWidth(3) + + for _ in range(20): + self.qWait(200) + if not manager.hasPendingOperations(): + break + + # check 2D 'sum' profile + profilePlot = toolBar.getProfilePlot() + data = profilePlot.getAllImages()[0].getData() + expected = numpy.array([[3, 12], [21, 30], [39, 48]]) + numpy.testing.assert_almost_equal(data, expected) + + +class TestGetProfilePlot(TestCaseQt): + + def setUp(self): + self.plot = None + super(TestGetProfilePlot, self).setUp() + + def tearDown(self): + if self.plot is not None: + manager = self.plot.getProfileToolbar().getProfileManager() + manager.clearProfile() + manager = None + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.plot = None + + super(TestGetProfilePlot, self).tearDown() + + def testProfile1D(self): + self.plot = Plot2D() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + self.plot.addImage([[0, 1], [2, 3]]) + + toolBar = self.plot.getProfileToolbar() + + manager = toolBar.getProfileManager() + roiManager = manager.getRoiManager() + + roi = rois.ProfileImageHorizontalLineROI() + roi.setPosition(0.5) + roiManager.addRoi(roi) + roiManager.setCurrentRoi(roi) + + for _ in range(20): + self.qWait(200) + if not manager.hasPendingOperations(): + break + + profileWindow = roi.getProfileWindow() + self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow) + self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D) + + def testProfile2D(self): + """Test that the profile plot associated to a stack view is either a + Plot1D or a plot 2D instance.""" + self.plot = StackView() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.plot.setStack(numpy.array([[[0, 1], [2, 3]], + [[4, 5], [6, 7]]])) + + toolBar = self.plot.getProfileToolbar() + + manager = toolBar.getProfileManager() + roiManager = manager.getRoiManager() + + roi = rois.ProfileImageStackHorizontalLineROI() + roi.setPosition(0.5) + roi.setProfileType("2D") + roiManager.addRoi(roi) + roiManager.setCurrentRoi(roi) + + for _ in range(20): + self.qWait(200) + if not manager.hasPendingOperations(): + break + + profileWindow = roi.getProfileWindow() + self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow) + self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot2D) + + roi.setProfileType("1D") + + for _ in range(20): + self.qWait(200) + if not manager.hasPendingOperations(): + break + + profileWindow = roi.getProfileWindow() + self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow) + self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D) diff --git a/src/silx/gui/plot/tools/test/testROI.py b/src/silx/gui/plot/tools/test/testROI.py new file mode 100644 index 0000000..21697d1 --- /dev/null +++ b/src/silx/gui/plot/tools/test/testROI.py @@ -0,0 +1,682 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import unittest +import numpy.testing + +from silx.gui import qt +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import TestCaseQt, SignalListener +from silx.gui.plot import PlotWindow +import silx.gui.plot.items.roi as roi_items +from silx.gui.plot.tools import roi + + +class TestRoiItems(TestCaseQt): + + def testLine_geometry(self): + item = roi_items.LineROI() + startPoint = numpy.array([1, 2]) + endPoint = numpy.array([3, 4]) + item.setEndPoints(startPoint, endPoint) + numpy.testing.assert_allclose(item.getEndPoints()[0], startPoint) + numpy.testing.assert_allclose(item.getEndPoints()[1], endPoint) + + def testHLine_geometry(self): + item = roi_items.HorizontalLineROI() + item.setPosition(15) + self.assertEqual(item.getPosition(), 15) + + def testVLine_geometry(self): + item = roi_items.VerticalLineROI() + item.setPosition(15) + self.assertEqual(item.getPosition(), 15) + + def testPoint_geometry(self): + point = numpy.array([1, 2]) + item = roi_items.PointROI() + item.setPosition(point) + numpy.testing.assert_allclose(item.getPosition(), point) + + def testRectangle_originGeometry(self): + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + center = numpy.array([5, 10]) + item = roi_items.RectangleROI() + item.setGeometry(origin=origin, size=size) + numpy.testing.assert_allclose(item.getOrigin(), origin) + numpy.testing.assert_allclose(item.getSize(), size) + numpy.testing.assert_allclose(item.getCenter(), center) + + def testRectangle_centerGeometry(self): + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + center = numpy.array([5, 10]) + item = roi_items.RectangleROI() + item.setGeometry(center=center, size=size) + numpy.testing.assert_allclose(item.getOrigin(), origin) + numpy.testing.assert_allclose(item.getSize(), size) + numpy.testing.assert_allclose(item.getCenter(), center) + + def testRectangle_setCenterGeometry(self): + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + item = roi_items.RectangleROI() + item.setGeometry(origin=origin, size=size) + newCenter = numpy.array([0, 0]) + item.setCenter(newCenter) + expectedOrigin = numpy.array([-5, -10]) + numpy.testing.assert_allclose(item.getOrigin(), expectedOrigin) + numpy.testing.assert_allclose(item.getCenter(), newCenter) + numpy.testing.assert_allclose(item.getSize(), size) + + def testRectangle_setOriginGeometry(self): + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + item = roi_items.RectangleROI() + item.setGeometry(origin=origin, size=size) + newOrigin = numpy.array([10, 10]) + item.setOrigin(newOrigin) + expectedCenter = numpy.array([15, 20]) + numpy.testing.assert_allclose(item.getOrigin(), newOrigin) + numpy.testing.assert_allclose(item.getCenter(), expectedCenter) + numpy.testing.assert_allclose(item.getSize(), size) + + def testCircle_geometry(self): + center = numpy.array([0, 0]) + radius = 10. + item = roi_items.CircleROI() + item.setGeometry(center=center, radius=radius) + numpy.testing.assert_allclose(item.getCenter(), center) + numpy.testing.assert_allclose(item.getRadius(), radius) + + def testCircle_setCenter(self): + center = numpy.array([0, 0]) + radius = 10. + item = roi_items.CircleROI() + item.setGeometry(center=center, radius=radius) + newCenter = numpy.array([-10, 0]) + item.setCenter(newCenter) + numpy.testing.assert_allclose(item.getCenter(), newCenter) + numpy.testing.assert_allclose(item.getRadius(), radius) + + def testCircle_setRadius(self): + center = numpy.array([0, 0]) + radius = 10. + item = roi_items.CircleROI() + item.setGeometry(center=center, radius=radius) + newRadius = 5.1 + item.setRadius(newRadius) + numpy.testing.assert_allclose(item.getCenter(), center) + numpy.testing.assert_allclose(item.getRadius(), newRadius) + + def testCircle_contains(self): + center = numpy.array([2, -1]) + radius = 1. + item = roi_items.CircleROI() + item.setGeometry(center=center, radius=radius) + self.assertTrue(item.contains([1, -1])) + self.assertFalse(item.contains([0, 0])) + self.assertTrue(item.contains([2, 0])) + self.assertFalse(item.contains([3.01, -1])) + + def testEllipse_contains(self): + center = numpy.array([-2, 0]) + item = roi_items.EllipseROI() + item.setCenter(center) + item.setOrientation(numpy.pi / 4.0) + item.setMajorRadius(2) + item.setMinorRadius(1) + print(item.getMinorRadius(), item.getMajorRadius()) + self.assertFalse(item.contains([0, 0])) + self.assertTrue(item.contains([-1, 1])) + self.assertTrue(item.contains([-3, 0])) + self.assertTrue(item.contains([-2, 0])) + self.assertTrue(item.contains([-2, 1])) + self.assertFalse(item.contains([-4, 1])) + + def testRectangle_isIn(self): + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + item = roi_items.RectangleROI() + item.setGeometry(origin=origin, size=size) + self.assertTrue(item.contains(position=(0, 0))) + self.assertTrue(item.contains(position=(2, 14))) + self.assertFalse(item.contains(position=(14, 12))) + + def testPolygon_emptyGeometry(self): + points = numpy.empty((0, 2)) + item = roi_items.PolygonROI() + item.setPoints(points) + numpy.testing.assert_allclose(item.getPoints(), points) + + def testPolygon_geometry(self): + points = numpy.array([[10, 10], [12, 10], [50, 1]]) + item = roi_items.PolygonROI() + item.setPoints(points) + numpy.testing.assert_allclose(item.getPoints(), points) + + def testPolygon_isIn(self): + points = numpy.array([[0, 0], [0, 10], [5, 10]]) + item = roi_items.PolygonROI() + item.setPoints(points) + self.assertTrue(item.contains((0, 0))) + self.assertFalse(item.contains((6, 2))) + self.assertFalse(item.contains((-2, 5))) + self.assertFalse(item.contains((2, -1))) + self.assertFalse(item.contains((8, 1))) + self.assertTrue(item.contains((1, 8))) + + def testArc_getToSetGeometry(self): + """Test that we can use getGeometry as input to setGeometry""" + item = roi_items.ArcROI() + item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]])) + item.setGeometry(*item.getGeometry()) + + def testArc_degenerated_point(self): + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0 + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + + def testArc_degenerated_line(self): + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + + def testArc_special_circle(self): + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, 3 * numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + numpy.testing.assert_allclose(item.getCenter(), center) + self.assertAlmostEqual(item.getInnerRadius(), innerRadius) + self.assertAlmostEqual(item.getOuterRadius(), outerRadius) + self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0) + self.assertTrue(item.isClosed()) + + def testArc_special_donut(self): + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + numpy.testing.assert_allclose(item.getCenter(), center) + self.assertAlmostEqual(item.getInnerRadius(), innerRadius) + self.assertAlmostEqual(item.getOuterRadius(), outerRadius) + self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0) + self.assertTrue(item.isClosed()) + + def testArc_clockwiseGeometry(self): + """Test that we can use getGeometry as input to setGeometry""" + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + numpy.testing.assert_allclose(item.getCenter(), center) + self.assertAlmostEqual(item.getInnerRadius(), innerRadius) + self.assertAlmostEqual(item.getOuterRadius(), outerRadius) + self.assertAlmostEqual(item.getStartAngle(), startAngle) + self.assertAlmostEqual(item.getEndAngle(), endAngle) + self.assertAlmostEqual(item.isClosed(), False) + + def testArc_anticlockwiseGeometry(self): + """Test that we can use getGeometry as input to setGeometry""" + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, -numpy.pi * 0.5 + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + numpy.testing.assert_allclose(item.getCenter(), center) + self.assertAlmostEqual(item.getInnerRadius(), innerRadius) + self.assertAlmostEqual(item.getOuterRadius(), outerRadius) + self.assertAlmostEqual(item.getStartAngle(), startAngle) + self.assertAlmostEqual(item.getEndAngle(), endAngle) + self.assertAlmostEqual(item.isClosed(), False) + + def testHRange_geometry(self): + item = roi_items.HorizontalRangeROI() + vmin = 1 + vmax = 3 + item.setRange(vmin, vmax) + self.assertAlmostEqual(item.getMin(), vmin) + self.assertAlmostEqual(item.getMax(), vmax) + self.assertAlmostEqual(item.getCenter(), 2) + + +class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase): + """Tests for RegionOfInterestManager class""" + + def setUp(self): + super(TestRegionOfInterestManager, self).setUp() + self.plot = PlotWindow() + + self.roiTableWidget = roi.RegionOfInterestTableWidget() + dock = qt.QDockWidget() + dock.setWidget(self.roiTableWidget) + self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, dock) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + del self.roiTableWidget + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestRegionOfInterestManager, self).tearDown() + + def test(self): + """Test ROI of different shapes""" + tests = ( # shape, points=[list of (x, y), list of (x, y)] + (roi_items.PointROI, numpy.array(([(10., 15.)], [(20., 25.)]))), + (roi_items.RectangleROI, + numpy.array((((1., 10.), (11., 20.)), + ((2., 3.), (12., 13.))))), + (roi_items.PolygonROI, + numpy.array((((0., 1.), (0., 10.), (10., 0.)), + ((5., 6.), (5., 16.), (15., 6.))))), + (roi_items.LineROI, + numpy.array((((10., 20.), (10., 30.)), + ((30., 40.), (30., 50.))))), + (roi_items.HorizontalLineROI, + numpy.array((((10., 20.), (10., 30.)), + ((30., 40.), (30., 50.))))), + (roi_items.VerticalLineROI, + numpy.array((((10., 20.), (10., 30.)), + ((30., 40.), (30., 50.))))), + (roi_items.HorizontalLineROI, + numpy.array((((10., 20.), (10., 30.)), + ((30., 40.), (30., 50.))))), + ) + + for roiClass, points in tests: + with self.subTest(roiClass=roiClass): + manager = roi.RegionOfInterestManager(self.plot) + self.roiTableWidget.setRegionOfInterestManager(manager) + manager.start(roiClass) + + self.assertEqual(manager.getRois(), ()) + + finishListener = SignalListener() + manager.sigInteractiveModeFinished.connect(finishListener) + + changedListener = SignalListener() + manager.sigRoiChanged.connect(changedListener) + + # Add a point + r = roiClass() + r.setFirstShapePoints(points[0]) + manager.addRoi(r) + self.qapp.processEvents() + self.assertTrue(len(manager.getRois()), 1) + self.assertEqual(changedListener.callCount(), 1) + + # Remove it + manager.removeRoi(manager.getRois()[0]) + self.assertEqual(manager.getRois(), ()) + self.assertEqual(changedListener.callCount(), 2) + + # Add two point + r = roiClass() + r.setFirstShapePoints(points[0]) + manager.addRoi(r) + self.qapp.processEvents() + r = roiClass() + r.setFirstShapePoints(points[1]) + manager.addRoi(r) + self.qapp.processEvents() + self.assertTrue(len(manager.getRois()), 2) + self.assertEqual(changedListener.callCount(), 4) + + # Reset it + result = manager.clear() + self.assertTrue(result) + self.assertEqual(manager.getRois(), ()) + self.assertEqual(changedListener.callCount(), 5) + + changedListener.clear() + + # Add two point + r = roiClass() + r.setFirstShapePoints(points[0]) + manager.addRoi(r) + self.qapp.processEvents() + r = roiClass() + r.setFirstShapePoints(points[1]) + manager.addRoi(r) + self.qapp.processEvents() + self.assertTrue(len(manager.getRois()), 2) + self.assertEqual(changedListener.callCount(), 2) + + # stop + result = manager.stop() + self.assertTrue(result) + self.assertTrue(len(manager.getRois()), 1) + self.qapp.processEvents() + self.assertEqual(finishListener.callCount(), 1) + + manager.clear() + + def testRoiDisplay(self): + rois = [] + + # Line + item = roi_items.LineROI() + startPoint = numpy.array([1, 2]) + endPoint = numpy.array([3, 4]) + item.setEndPoints(startPoint, endPoint) + rois.append(item) + # Horizontal line + item = roi_items.HorizontalLineROI() + item.setPosition(15) + rois.append(item) + # Vertical line + item = roi_items.VerticalLineROI() + item.setPosition(15) + rois.append(item) + # Point + item = roi_items.PointROI() + point = numpy.array([1, 2]) + item.setPosition(point) + rois.append(item) + # Rectangle + item = roi_items.RectangleROI() + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + item.setGeometry(origin=origin, size=size) + rois.append(item) + # Polygon + item = roi_items.PolygonROI() + points = numpy.array([[10, 10], [12, 10], [50, 1]]) + item.setPoints(points) + rois.append(item) + # Degenerated polygon: No points + item = roi_items.PolygonROI() + points = numpy.empty((0, 2)) + item.setPoints(points) + rois.append(item) + # Degenerated polygon: A single point + item = roi_items.PolygonROI() + points = numpy.array([[5, 10]]) + item.setPoints(points) + rois.append(item) + # Degenerated arc: it's a point + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0 + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + rois.append(item) + # Degenerated arc: it's a line + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + rois.append(item) + # Special arc: it's a donut + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + rois.append(item) + # Arc + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + rois.append(item) + # Horizontal Range + item = roi_items.HorizontalRangeROI() + item.setRange(-1, 3) + rois.append(item) + + manager = roi.RegionOfInterestManager(self.plot) + self.roiTableWidget.setRegionOfInterestManager(manager) + for item in rois: + with self.subTest(roi=str(item)): + manager.addRoi(item) + self.qapp.processEvents() + item.setEditable(True) + self.qapp.processEvents() + item.setEditable(False) + self.qapp.processEvents() + manager.removeRoi(item) + self.qapp.processEvents() + + def testSelectionProxy(self): + item1 = roi_items.PointROI() + item1.setSelectable(True) + item2 = roi_items.PointROI() + item2.setSelectable(True) + item1.setFocusProxy(item2) + manager = roi.RegionOfInterestManager(self.plot) + manager.setCurrentRoi(item1) + self.assertIs(manager.getCurrentRoi(), item2) + + def testRemovedSelection(self): + item1 = roi_items.PointROI() + item1.setSelectable(True) + manager = roi.RegionOfInterestManager(self.plot) + manager.addRoi(item1) + manager.setCurrentRoi(item1) + manager.removeRoi(item1) + self.assertIs(manager.getCurrentRoi(), None) + + def testMaxROI(self): + """Test Max ROI""" + origin1 = numpy.array([1., 10.]) + size1 = numpy.array([10., 10.]) + origin2 = numpy.array([2., 3.]) + size2 = numpy.array([10., 10.]) + + manager = roi.InteractiveRegionOfInterestManager(self.plot) + self.roiTableWidget.setRegionOfInterestManager(manager) + self.assertEqual(manager.getRois(), ()) + + changedListener = SignalListener() + manager.sigRoiChanged.connect(changedListener) + + # Add two point + item = roi_items.RectangleROI() + item.setGeometry(origin=origin1, size=size1) + manager.addRoi(item) + item = roi_items.RectangleROI() + item.setGeometry(origin=origin2, size=size2) + manager.addRoi(item) + self.qapp.processEvents() + self.assertEqual(changedListener.callCount(), 2) + self.assertEqual(len(manager.getRois()), 2) + + # Try to set max ROI to 1 while there is 2 ROIs + with self.assertRaises(ValueError): + manager.setMaxRois(1) + + manager.clear() + self.assertEqual(len(manager.getRois()), 0) + self.assertEqual(changedListener.callCount(), 3) + + # Set max limit to 1 + manager.setMaxRois(1) + + # Add a point + item = roi_items.RectangleROI() + item.setGeometry(origin=origin1, size=size1) + manager.addRoi(item) + self.qapp.processEvents() + self.assertEqual(changedListener.callCount(), 4) + + # Add a 2nd point while max ROI is 1 + item = roi_items.RectangleROI() + item.setGeometry(origin=origin1, size=size1) + manager.addRoi(item) + self.qapp.processEvents() + self.assertEqual(changedListener.callCount(), 6) + self.assertEqual(len(manager.getRois()), 1) + + def testChangeInteractionMode(self): + """Test change of interaction mode""" + manager = roi.RegionOfInterestManager(self.plot) + self.roiTableWidget.setRegionOfInterestManager(manager) + manager.start(roi_items.PointROI) + + interactiveModeToolBar = self.plot.getInteractiveModeToolBar() + panAction = interactiveModeToolBar.getPanModeAction() + + for roiClass in manager.getSupportedRoiClasses(): + with self.subTest(roiClass=roiClass): + # Change to pan mode + panAction.trigger() + + # Change to interactive ROI mode + action = manager.getInteractionModeAction(roiClass) + action.trigger() + + self.assertEqual(roiClass, manager.getCurrentInteractionModeRoiClass()) + + manager.clear() + + def testLineInteraction(self): + """This test make sure that a ROI based on handles can be edited with + the mouse.""" + xlimit = self.plot.getXAxis().getLimits() + ylimit = self.plot.getYAxis().getLimits() + points = numpy.array([xlimit, ylimit]).T + center = numpy.mean(points, axis=0) + + # Create the line + manager = roi.RegionOfInterestManager(self.plot) + item = roi_items.LineROI() + item.setEndPoints(points[0], points[1]) + item.setEditable(True) + manager.addRoi(item) + self.qapp.processEvents() + + # Drag the center + widget = self.plot.getWidgetHandle() + mx, my = self.plot.dataToPixel(*center) + self.mouseMove(widget, pos=(mx, my)) + self.mousePress(widget, qt.Qt.LeftButton, pos=(mx, my)) + self.mouseMove(widget, pos=(mx, my+25)) + self.mouseMove(widget, pos=(mx, my+50)) + self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my+50)) + + result = numpy.array(item.getEndPoints()) + # x location is still the same + numpy.testing.assert_allclose(points[:, 0], result[:, 0], atol=0.5) + # size is still the same + numpy.testing.assert_allclose(points[1] - points[0], + result[1] - result[0], atol=0.5) + # But Y is not the same + self.assertNotEqual(points[0, 1], result[0, 1]) + self.assertNotEqual(points[1, 1], result[1, 1]) + item = None + manager.clear() + self.qapp.processEvents() + + def testPlotWhenCleared(self): + """PlotWidget.clear should clean up the available ROIs""" + manager = roi.RegionOfInterestManager(self.plot) + item = roi_items.LineROI() + item.setEndPoints((0, 0), (1, 1)) + item.setEditable(True) + manager.addRoi(item) + self.qWait() + try: + # Make sure the test setup is fine + self.assertNotEqual(len(manager.getRois()), 0) + self.assertNotEqual(len(self.plot.getItems()), 0) + + # Call clear and test the expected state + self.plot.clear() + self.assertEqual(len(manager.getRois()), 0) + self.assertEqual(len(self.plot.getItems()), 0) + finally: + # Clean up + manager.clear() + + def testPlotWhenRoiRemoved(self): + """Make sure there is no remaining items in the plot when a ROI is removed""" + manager = roi.RegionOfInterestManager(self.plot) + item = roi_items.LineROI() + item.setEndPoints((0, 0), (1, 1)) + item.setEditable(True) + manager.addRoi(item) + self.qWait() + try: + # Make sure the test setup is fine + self.assertNotEqual(len(manager.getRois()), 0) + self.assertNotEqual(len(self.plot.getItems()), 0) + + # Call clear and test the expected state + manager.removeRoi(item) + self.assertEqual(len(manager.getRois()), 0) + self.assertEqual(len(self.plot.getItems()), 0) + finally: + # Clean up + manager.clear() + + def testArcRoiSwitchMode(self): + """Make sure we can switch mode by clicking on the ROI""" + xlimit = self.plot.getXAxis().getLimits() + ylimit = self.plot.getYAxis().getLimits() + points = numpy.array([xlimit, ylimit]).T + center = numpy.mean(points, axis=0) + size = numpy.abs(points[1] - points[0]) + + # Create the line + manager = roi.RegionOfInterestManager(self.plot) + item = roi_items.ArcROI() + item.setGeometry(center, size[1] / 10, size[1] / 2, 0, 3) + item.setEditable(True) + item.setSelectable(True) + manager.addRoi(item) + self.qapp.processEvents() + + # Initial state + self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode) + self.qWait(500) + + # Click on the center + widget = self.plot.getWidgetHandle() + mx, my = self.plot.dataToPixel(*center) + + # Select the ROI + self.mouseMove(widget, pos=(mx, my)) + self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my)) + self.qWait(500) + self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode) + + # Change the mode + self.mouseMove(widget, pos=(mx, my)) + self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my)) + self.qWait(500) + self.assertIs(item.getInteractionMode(), roi_items.ArcROI.PolarMode) + + manager.clear() + self.qapp.processEvents() diff --git a/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py new file mode 100644 index 0000000..582a276 --- /dev/null +++ b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py @@ -0,0 +1,184 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import unittest +import numpy + +from silx.gui import qt +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import TestCaseQt +from silx.gui.plot import PlotWindow +from silx.gui.plot.tools.profile import manager +from silx.gui.plot.tools.profile import core +from silx.gui.plot.tools.profile import rois + + +class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase): + """Tests for ScatterProfileToolBar class""" + + def setUp(self): + super(TestScatterProfileToolBar, self).setUp() + self.plot = PlotWindow() + + self.manager = manager.ProfileManager(plot=self.plot) + self.manager.setItemType(scatter=True) + self.manager.setActiveItemTracking(True) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + del self.manager + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestScatterProfileToolBar, self).tearDown() + + def testHorizontalProfile(self): + """Test ScatterProfileToolBar horizontal profile""" + + roiManager = self.manager.getRoiManager() + + # Add a scatter plot + self.plot.addScatter( + x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.)) + self.plot.resetZoom(dataMargins=(.1, .1, .1, .1)) + self.qapp.processEvents() + + # Set a ROI profile + roi = rois.ProfileScatterHorizontalLineROI() + roi.setPosition(0.5) + roi.setNPoints(8) + roiManager.addRoi(roi) + + # Wait for async interpolator init + for _ in range(20): + self.qWait(200) + if not self.manager.hasPendingOperations(): + break + self.qapp.processEvents() + + window = roi.getProfileWindow() + self.assertIsNotNone(window) + data = window.getProfile() + self.assertIsInstance(data, core.CurveProfileData) + self.assertEqual(len(data.coords), 8) + + # Check that profile has same limits than Plot + xLimits = self.plot.getXAxis().getLimits() + self.assertEqual(data.coords[0], xLimits[0]) + self.assertEqual(data.coords[-1], xLimits[1]) + + # Clear the profile + self.manager.clearProfile() + self.qapp.processEvents() + self.assertIsNone(roi.getProfileWindow()) + + def testVerticalProfile(self): + """Test ScatterProfileToolBar vertical profile""" + + roiManager = self.manager.getRoiManager() + + # Add a scatter plot + self.plot.addScatter( + x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.)) + self.plot.resetZoom(dataMargins=(.1, .1, .1, .1)) + self.qapp.processEvents() + + # Set a ROI profile + roi = rois.ProfileScatterVerticalLineROI() + roi.setPosition(0.5) + roi.setNPoints(8) + roiManager.addRoi(roi) + + # Wait for async interpolator init + for _ in range(10): + self.qWait(200) + if not self.manager.hasPendingOperations(): + break + + window = roi.getProfileWindow() + self.assertIsNotNone(window) + data = window.getProfile() + self.assertIsInstance(data, core.CurveProfileData) + self.assertEqual(len(data.coords), 8) + + # Check that profile has same limits than Plot + yLimits = self.plot.getYAxis().getLimits() + self.assertEqual(data.coords[0], yLimits[0]) + self.assertEqual(data.coords[-1], yLimits[1]) + + # Check that profile limits are updated when changing limits + self.plot.getYAxis().setLimits(yLimits[0] + 1, yLimits[1] + 10) + + # Wait for async interpolator init + for _ in range(10): + self.qWait(200) + if not self.manager.hasPendingOperations(): + break + + yLimits = self.plot.getYAxis().getLimits() + data = window.getProfile() + self.assertEqual(data.coords[0], yLimits[0]) + self.assertEqual(data.coords[-1], yLimits[1]) + + # Clear the profile + self.manager.clearProfile() + self.qapp.processEvents() + self.assertIsNone(roi.getProfileWindow()) + + def testLineProfile(self): + """Test ScatterProfileToolBar line profile""" + + roiManager = self.manager.getRoiManager() + + # Add a scatter plot + self.plot.addScatter( + x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.)) + self.plot.resetZoom(dataMargins=(.1, .1, .1, .1)) + self.qapp.processEvents() + + # Set a ROI profile + roi = rois.ProfileScatterLineROI() + roi.setEndPoints(numpy.array([0., 0.]), numpy.array([1., 1.])) + roi.setNPoints(8) + roiManager.addRoi(roi) + + # Wait for async interpolator init + for _ in range(10): + self.qWait(200) + if not self.manager.hasPendingOperations(): + break + + window = roi.getProfileWindow() + self.assertIsNotNone(window) + data = window.getProfile() + self.assertIsInstance(data, core.CurveProfileData) + self.assertEqual(len(data.coords), 8) diff --git a/src/silx/gui/plot/tools/test/testTools.py b/src/silx/gui/plot/tools/test/testTools.py new file mode 100644 index 0000000..846f641 --- /dev/null +++ b/src/silx/gui/plot/tools/test/testTools.py @@ -0,0 +1,135 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for silx.gui.plot.tools package""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/03/2018" + + +import functools +import unittest +import numpy + +from silx.utils.testutils import LoggingValidator +from silx.gui.utils.testutils import qWaitForWindowExposedAndActivate +from silx.gui import qt +from silx.gui.plot import PlotWindow +from silx.gui.plot import tools +from silx.gui.plot.test.utils import PlotWidgetTestCase + + +class TestPositionInfo(PlotWidgetTestCase): + """Tests for PositionInfo widget.""" + + def _createPlot(self): + return PlotWindow() + + def setUp(self): + super(TestPositionInfo, self).setUp() + self.mouseMove(self.plot, pos=(0, 0)) + self.qapp.processEvents() + self.qWait(100) + + def tearDown(self): + super(TestPositionInfo, self).tearDown() + + def _test(self, positionWidget, converterNames, **kwargs): + """General test of PositionInfo. + + - Add it to a toolbar and + - Move mouse around the center of the PlotWindow. + """ + toolBar = qt.QToolBar() + self.plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) + + toolBar.addWidget(positionWidget) + + converters = positionWidget.getConverters() + self.assertEqual(len(converters), len(converterNames)) + for index, name in enumerate(converterNames): + self.assertEqual(converters[index][0], name) + + self.qapp.processEvents() + with LoggingValidator(tools.__name__, **kwargs): + # Move mouse to center + center = self.plot.size() / 2 + self.mouseMove(self.plot, pos=(center.width(), center.height())) + # Move out + self.mouseMove(self.plot, pos=(1, 1)) + + def testDefaultConverters(self): + """Test PositionInfo with default converters""" + positionWidget = tools.PositionInfo(plot=self.plot) + self._test(positionWidget, ('X', 'Y')) + + def testCustomConverters(self): + """Test PositionInfo with custom converters""" + converters = [ + ('Coords', lambda x, y: (int(x), int(y))), + ('Radius', lambda x, y: numpy.sqrt(x * x + y * y)), + ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x))) + ] + positionWidget = tools.PositionInfo(plot=self.plot, + converters=converters) + self._test(positionWidget, ('Coords', 'Radius', 'Angle')) + + def testFailingConverters(self): + """Test PositionInfo with failing custom converters""" + def raiseException(x, y): + raise RuntimeError() + + positionWidget = tools.PositionInfo( + plot=self.plot, + converters=[('Exception', raiseException)]) + self._test(positionWidget, ['Exception'], error=2) + + def testUpdate(self): + """Test :meth:`PositionInfo.updateInfo`""" + calls = [] + + def update(calls, x, y): # Get number of calls + calls.append((x, y)) + return len(calls) + + positionWidget = tools.PositionInfo( + plot=self.plot, + converters=[('Call count', functools.partial(update, calls))]) + + positionWidget.updateInfo() + self.assertEqual(len(calls), 1) + + +class TestPlotToolsToolbars(PlotWidgetTestCase): + """Tests toolbars from silx.gui.plot.tools""" + + def test(self): + """"Add all toolbars""" + for tbClass in (tools.InteractiveModeToolBar, + tools.ImageToolBar, + tools.CurveToolBar, + tools.OutputToolBar): + tb = tbClass(parent=self.plot, plot=self.plot) + self.plot.addToolBar(tb) diff --git a/src/silx/gui/plot/tools/toolbars.py b/src/silx/gui/plot/tools/toolbars.py new file mode 100644 index 0000000..3df7d06 --- /dev/null +++ b/src/silx/gui/plot/tools/toolbars.py @@ -0,0 +1,362 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides toolbars that work with :class:`PlotWidget`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "01/03/2018" + + +from ... import qt +from .. import actions +from ..PlotWidget import PlotWidget +from .. import PlotToolButtons +from ....utils.deprecation import deprecated + + +class InteractiveModeToolBar(qt.QToolBar): + """Toolbar with interactive mode actions + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Plot Interaction'): + super(InteractiveModeToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._zoomModeAction = actions.mode.ZoomModeAction( + parent=self, plot=plot) + self.addAction(self._zoomModeAction) + + self._panModeAction = actions.mode.PanModeAction( + parent=self, plot=plot) + self.addAction(self._panModeAction) + + def getZoomModeAction(self): + """Returns the zoom mode QAction. + + :rtype: PlotAction + """ + return self._zoomModeAction + + def getPanModeAction(self): + """Returns the pan mode QAction + + :rtype: PlotAction + """ + return self._panModeAction + + +class OutputToolBar(qt.QToolBar): + """Toolbar providing icons to copy, save and print a PlotWidget + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Plot Output'): + super(OutputToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._copyAction = actions.io.CopyAction(parent=self, plot=plot) + self.addAction(self._copyAction) + + self._saveAction = actions.io.SaveAction(parent=self, plot=plot) + self.addAction(self._saveAction) + + self._printAction = actions.io.PrintAction(parent=self, plot=plot) + self.addAction(self._printAction) + + def getCopyAction(self): + """Returns the QAction performing copy to clipboard of the PlotWidget + + :rtype: PlotAction + """ + return self._copyAction + + def getSaveAction(self): + """Returns the QAction performing save to file of the PlotWidget + + :rtype: PlotAction + """ + return self._saveAction + + def getPrintAction(self): + """Returns the QAction performing printing of the PlotWidget + + :rtype: PlotAction + """ + return self._printAction + + +class ImageToolBar(qt.QToolBar): + """Toolbar providing PlotAction suited when displaying images + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Image'): + super(ImageToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._resetZoomAction = actions.control.ResetZoomAction( + parent=self, plot=plot) + self.addAction(self._resetZoomAction) + + self._colormapAction = actions.control.ColormapAction( + parent=self, plot=plot) + self.addAction(self._colormapAction) + + self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton( + parent=self, plot=plot) + self.addWidget(self._keepDataAspectRatioButton) + + self._yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton( + parent=self, plot=plot) + self.addWidget(self._yAxisInvertedButton) + + def getResetZoomAction(self): + """Returns the QAction to reset the zoom. + + :rtype: PlotAction + """ + return self._resetZoomAction + + def getColormapAction(self): + """Returns the QAction to control the colormap. + + :rtype: PlotAction + """ + return self._colormapAction + + def getKeepDataAspectRatioButton(self): + """Returns the QToolButton controlling data aspect ratio. + + :rtype: QToolButton + """ + return self._keepDataAspectRatioButton + + def getYAxisInvertedButton(self): + """Returns the QToolButton controlling Y axis orientation. + + :rtype: QToolButton + """ + return self._yAxisInvertedButton + + +class CurveToolBar(qt.QToolBar): + """Toolbar providing PlotAction suited when displaying curves + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Image'): + super(CurveToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._resetZoomAction = actions.control.ResetZoomAction( + parent=self, plot=plot) + self.addAction(self._resetZoomAction) + + self._xAxisAutoScaleAction = actions.control.XAxisAutoScaleAction( + parent=self, plot=plot) + self.addAction(self._xAxisAutoScaleAction) + + self._yAxisAutoScaleAction = actions.control.YAxisAutoScaleAction( + parent=self, plot=plot) + self.addAction(self._yAxisAutoScaleAction) + + self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction( + parent=self, plot=plot) + self.addAction(self._xAxisLogarithmicAction) + + self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction( + parent=self, plot=plot) + self.addAction(self._yAxisLogarithmicAction) + + self._gridAction = actions.control.GridAction( + parent=self, plot=plot) + self.addAction(self._gridAction) + + self._curveStyleAction = actions.control.CurveStyleAction( + parent=self, plot=plot) + self.addAction(self._curveStyleAction) + + def getResetZoomAction(self): + """Returns the QAction to reset the zoom. + + :rtype: PlotAction + """ + return self._resetZoomAction + + def getXAxisAutoScaleAction(self): + """Returns the QAction to toggle X axis autoscale. + + :rtype: PlotAction + """ + return self._xAxisAutoScaleAction + + def getYAxisAutoScaleAction(self): + """Returns the QAction to toggle Y axis autoscale. + + :rtype: PlotAction + """ + return self._yAxisAutoScaleAction + + def getXAxisLogarithmicAction(self): + """Returns the QAction to toggle X axis log/linear scale. + + :rtype: PlotAction + """ + return self._xAxisLogarithmicAction + + def getYAxisLogarithmicAction(self): + """Returns the QAction to toggle Y axis log/linear scale. + + :rtype: PlotAction + """ + return self._yAxisLogarithmicAction + + def getGridAction(self): + """Returns the action to toggle the plot grid. + + :rtype: PlotAction + """ + return self._gridAction + + def getCurveStyleAction(self): + """Returns the QAction to change the style of all curves. + + :rtype: PlotAction + """ + return self._curveStyleAction + + +class ScatterToolBar(qt.QToolBar): + """Toolbar providing PlotAction suited when displaying scatter plot + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Scatter Tools'): + super(ScatterToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._resetZoomAction = actions.control.ResetZoomAction( + parent=self, plot=plot) + self.addAction(self._resetZoomAction) + + self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction( + parent=self, plot=plot) + self.addAction(self._xAxisLogarithmicAction) + + self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction( + parent=self, plot=plot) + self.addAction(self._yAxisLogarithmicAction) + + self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton( + parent=self, plot=plot) + self.addWidget(self._keepDataAspectRatioButton) + + self._gridAction = actions.control.GridAction( + parent=self, plot=plot) + self.addAction(self._gridAction) + + self._colormapAction = actions.control.ColormapAction( + parent=self, plot=plot) + self.addAction(self._colormapAction) + + self._visualizationToolButton = \ + PlotToolButtons.ScatterVisualizationToolButton(parent=self, plot=plot) + self.addWidget(self._visualizationToolButton) + + def getResetZoomAction(self): + """Returns the QAction to reset the zoom. + + :rtype: PlotAction + """ + return self._resetZoomAction + + def getXAxisLogarithmicAction(self): + """Returns the QAction to toggle X axis log/linear scale. + + :rtype: PlotAction + """ + return self._xAxisLogarithmicAction + + def getYAxisLogarithmicAction(self): + """Returns the QAction to toggle Y axis log/linear scale. + + :rtype: PlotAction + """ + return self._yAxisLogarithmicAction + + def getGridAction(self): + """Returns the action to toggle the plot grid. + + :rtype: PlotAction + """ + return self._gridAction + + def getColormapAction(self): + """Returns the QAction to control the colormap. + + :rtype: PlotAction + """ + return self._colormapAction + + def getKeepDataAspectRatioButton(self): + """Returns the QToolButton controlling data aspect ratio. + + :rtype: QToolButton + """ + return self._keepDataAspectRatioButton + + def getScatterVisualizationToolButton(self): + """Returns the QToolButton controlling the visualization mode. + + :rtype: ScatterVisualizationToolButton + """ + return self._visualizationToolButton + + @deprecated(replacement='getScatterVisualizationToolButton', + since_version='0.11.0') + def getSymbolToolButton(self): + return self.getScatterVisualizationToolButton() diff --git a/src/silx/gui/plot/utils/__init__.py b/src/silx/gui/plot/utils/__init__.py new file mode 100644 index 0000000..3187f6b --- /dev/null +++ b/src/silx/gui/plot/utils/__init__.py @@ -0,0 +1,30 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Utils module for plot. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "29/06/2017" diff --git a/src/silx/gui/plot/utils/axis.py b/src/silx/gui/plot/utils/axis.py new file mode 100644 index 0000000..5cf8ad9 --- /dev/null +++ b/src/silx/gui/plot/utils/axis.py @@ -0,0 +1,398 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2021 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module contains utils class for axes management. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "20/11/2018" + +import functools +import logging +from contextlib import contextmanager +import weakref +import silx.utils.weakref as silxWeakref +from silx.gui.plot.items.axis import Axis, XAxis, YAxis +from ...qt.inspect import isValid as _isQObjectValid + + +_logger = logging.getLogger(__name__) + + +class SyncAxes(object): + """Synchronize a set of plot axes together. + + It is created with the expected axes and starts to synchronize them. + + It can be customized to synchronize limits, scale, and direction of axes + together. By default everything is synchronized. + + The API :meth:`start` and :meth:`stop` can be used to enable/disable the + synchronization while this object is still alive. + + If this object is destroyed the synchronization stop. + + .. versionadded:: 0.6 + """ + + def __init__(self, axes, + syncLimits=True, + syncScale=True, + syncDirection=True, + syncCenter=False, + syncZoom=False, + filterHiddenPlots=False + ): + """ + Constructor + + :param list(Axis) axes: A list of axes to synchronize together + :param bool syncLimits: Synchronize axes limits + :param bool syncScale: Synchronize axes scale + :param bool syncDirection: Synchronize axes direction + :param bool syncCenter: Synchronize the center of the axes in the center + of the plots + :param bool syncZoom: Synchronize the zoom of the plot + :param bool filterHiddenPlots: True to avoid updating hidden plots. + Default: False. + """ + object.__init__(self) + + def implies(x, y): return bool(y ** x) + + assert(implies(syncZoom, not syncLimits)) + assert(implies(syncCenter, not syncLimits)) + assert(implies(syncLimits, not syncCenter)) + assert(implies(syncLimits, not syncZoom)) + + self.__filterHiddenPlots = filterHiddenPlots + self.__locked = False + self.__axisRefs = [] + self.__syncLimits = syncLimits + self.__syncScale = syncScale + self.__syncDirection = syncDirection + self.__syncCenter = syncCenter + self.__syncZoom = syncZoom + self.__callbacks = None + self.__lastMainAxis = None + + for axis in axes: + self.addAxis(axis) + + self.start() + + def start(self): + """Start synchronizing axes together. + + The first axis is used as the reference for the first synchronization. + After that, any changes to any axes will be used to synchronize other + axes. + """ + if self.isSynchronizing(): + raise RuntimeError("Axes already synchronized") + self.__callbacks = {} + + axes = self.__getAxes() + + # register callback for further sync + for axis in axes: + self.__connectAxes(axis) + self.synchronize() + + def isSynchronizing(self): + """Returns true if events are connected to the axes to synchronize them + all together + + :rtype: bool + """ + return self.__callbacks is not None + + def __connectAxes(self, axis): + refAxis = weakref.ref(axis) + callbacks = [] + if self.__syncLimits: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigLimitsChanged + sig.connect(callback) + callbacks.append(("sigLimitsChanged", callback)) + elif self.__syncCenter and self.__syncZoom: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisCenterAndZoomChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigLimitsChanged + sig.connect(callback) + callbacks.append(("sigLimitsChanged", callback)) + elif self.__syncZoom: + raise NotImplementedError() + elif self.__syncCenter: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisCenterChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigLimitsChanged + sig.connect(callback) + callbacks.append(("sigLimitsChanged", callback)) + if self.__syncScale: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigScaleChanged + sig.connect(callback) + callbacks.append(("sigScaleChanged", callback)) + if self.__syncDirection: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigInvertedChanged + sig.connect(callback) + callbacks.append(("sigInvertedChanged", callback)) + + if self.__filterHiddenPlots: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisVisibilityChanged) + callback = functools.partial(callback, refAxis) + plot = axis._getPlot() + plot.sigVisibilityChanged.connect(callback) + callbacks.append(("sigVisibilityChanged", callback)) + + self.__callbacks[refAxis] = callbacks + + def __disconnectAxes(self, axis): + if axis is not None and _isQObjectValid(axis): + ref = weakref.ref(axis) + callbacks = self.__callbacks.pop(ref) + for sigName, callback in callbacks: + if sigName == "sigVisibilityChanged": + obj = axis._getPlot() + else: + obj = axis + if obj is not None: + sig = getattr(obj, sigName) + sig.disconnect(callback) + + def addAxis(self, axis): + """Add a new axes to synchronize. + + :param ~silx.gui.plot.items.Axis axis: The axis to synchronize + """ + self.__axisRefs.append(weakref.ref(axis)) + if self.isSynchronizing(): + self.__connectAxes(axis) + # This could be done faster as only this axis have to be fixed + self.synchronize() + + def removeAxis(self, axis): + """Remove an axis from the synchronized axes. + + :param ~silx.gui.plot.items.Axis axis: The axis to remove + """ + ref = weakref.ref(axis) + self.__axisRefs.remove(ref) + if self.isSynchronizing(): + self.__disconnectAxes(axis) + + def synchronize(self, mainAxis=None): + """Synchronize programatically all the axes. + + :param ~silx.gui.plot.items.Axis mainAxis: + The axis to take as reference (Default: the first axis). + """ + # sync the current state + axes = self.__getAxes() + if len(axes) == 0: + return + + if mainAxis is None: + mainAxis = axes[0] + + refMainAxis = weakref.ref(mainAxis) + if self.__syncLimits: + self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits()) + elif self.__syncCenter and self.__syncZoom: + self.__axisCenterAndZoomChanged(refMainAxis, *mainAxis.getLimits()) + elif self.__syncCenter: + self.__axisCenterChanged(refMainAxis, *mainAxis.getLimits()) + if self.__syncScale: + self.__axisScaleChanged(refMainAxis, mainAxis.getScale()) + if self.__syncDirection: + self.__axisInvertedChanged(refMainAxis, mainAxis.isInverted()) + + def stop(self): + """Stop the synchronization of the axes""" + if not self.isSynchronizing(): + raise RuntimeError("Axes not synchronized") + for ref in list(self.__callbacks.keys()): + axis = ref() + self.__disconnectAxes(axis) + self.__callbacks = None + + def __del__(self): + """Destructor""" + # clean up references + if self.__callbacks is not None: + self.stop() + + def __getAxes(self): + """Returns list of existing axes. + + :rtype: List[Axis] + """ + axes = [ref() for ref in self.__axisRefs] + return [axis for axis in axes if axis is not None] + + @contextmanager + def __inhibitSignals(self): + self.__locked = True + yield + self.__locked = False + + def __axesToUpdate(self, changedAxis): + for axis in self.__getAxes(): + if axis is changedAxis: + continue + if self.__filterHiddenPlots: + plot = axis._getPlot() + if not plot.isVisible(): + continue + yield axis + + def __axisVisibilityChanged(self, changedAxis, isVisible): + if not isVisible: + return + if self.__locked: + return + changedAxis = changedAxis() + if self.__lastMainAxis is None: + self.__lastMainAxis = self.__axisRefs[0] + mainAxis = self.__lastMainAxis + mainAxis = mainAxis() + self.synchronize(mainAxis=mainAxis) + # force back the main axis + self.__lastMainAxis = weakref.ref(mainAxis) + + def __getAxesCenter(self, axis, vmin, vmax): + """Returns the value displayed in the center of this axis range. + + :rtype: float + """ + scale = axis.getScale() + if scale == Axis.LINEAR: + center = (vmin + vmax) * 0.5 + else: + raise NotImplementedError("Log scale not implemented") + return center + + def __getRangeInPixel(self, axis): + """Returns the size of the axis in pixel""" + bounds = axis._getPlot().getPlotBoundsInPixels() + # bounds: left, top, width, height + if isinstance(axis, XAxis): + return bounds[2] + elif isinstance(axis, YAxis): + return bounds[3] + else: + assert(False) + + def __getLimitsFromCenter(self, axis, pos, pixelSize=None): + """Returns the limits to apply to this axis to move the `pos` into the + center of this axis. + + :param Axis axis: + :param float pos: Position in the center of the computed limits + :param Union[None,float] pixelSize: Pixel size to apply to compute the + limits. If `None` the current pixel size is applyed. + """ + scale = axis.getScale() + if scale == Axis.LINEAR: + if pixelSize is None: + # Use the current pixel size of the axis + limits = axis.getLimits() + valueRange = limits[0] - limits[1] + a = pos - valueRange * 0.5 + b = pos + valueRange * 0.5 + else: + pixelRange = self.__getRangeInPixel(axis) + a = pos - pixelRange * 0.5 * pixelSize + b = pos + pixelRange * 0.5 * pixelSize + + else: + raise NotImplementedError("Log scale not implemented") + if a > b: + return b, a + return a, b + + def __axisLimitsChanged(self, changedAxis, vmin, vmax): + if self.__locked: + return + self.__lastMainAxis = changedAxis + changedAxis = changedAxis() + with self.__inhibitSignals(): + for axis in self.__axesToUpdate(changedAxis): + axis.setLimits(vmin, vmax) + + def __axisCenterAndZoomChanged(self, changedAxis, vmin, vmax): + if self.__locked: + return + self.__lastMainAxis = changedAxis + changedAxis = changedAxis() + with self.__inhibitSignals(): + center = self.__getAxesCenter(changedAxis, vmin, vmax) + pixelRange = self.__getRangeInPixel(changedAxis) + if pixelRange == 0: + return + pixelSize = (vmax - vmin) / pixelRange + for axis in self.__axesToUpdate(changedAxis): + vmin, vmax = self.__getLimitsFromCenter(axis, center, pixelSize) + axis.setLimits(vmin, vmax) + + def __axisCenterChanged(self, changedAxis, vmin, vmax): + if self.__locked: + return + self.__lastMainAxis = changedAxis + changedAxis = changedAxis() + with self.__inhibitSignals(): + center = self.__getAxesCenter(changedAxis, vmin, vmax) + for axis in self.__axesToUpdate(changedAxis): + vmin, vmax = self.__getLimitsFromCenter(axis, center) + axis.setLimits(vmin, vmax) + + def __axisScaleChanged(self, changedAxis, scale): + if self.__locked: + return + self.__lastMainAxis = changedAxis + changedAxis = changedAxis() + with self.__inhibitSignals(): + for axis in self.__axesToUpdate(changedAxis): + axis.setScale(scale) + + def __axisInvertedChanged(self, changedAxis, isInverted): + if self.__locked: + return + self.__lastMainAxis = changedAxis + changedAxis = changedAxis() + with self.__inhibitSignals(): + for axis in self.__axesToUpdate(changedAxis): + axis.setInverted(isInverted) diff --git a/src/silx/gui/plot/utils/intersections.py b/src/silx/gui/plot/utils/intersections.py new file mode 100644 index 0000000..53f2546 --- /dev/null +++ b/src/silx/gui/plot/utils/intersections.py @@ -0,0 +1,101 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module contains utils class for axes management. +""" + +__authors__ = ["H. Payno", ] +__license__ = "MIT" +__date__ = "18/05/2020" + + +import numpy + + +def lines_intersection(line1_pt1, line1_pt2, line2_pt1, line2_pt2): + """ + line segment intersection using vectors (Computer Graphics by F.S. Hill) + + :param tuple line1_pt1: + :param tuple line1_pt2: + :param tuple line2_pt1: + :param tuple line2_pt2: + :return: Union[None,numpy.array] + """ + dir_line1 = line1_pt2[0] - line1_pt1[0], line1_pt2[1] - line1_pt1[1] + dir_line2 = line2_pt2[0] - line2_pt1[0], line2_pt2[1] - line2_pt1[1] + dp = line1_pt1 - line2_pt1 + + def perp(a): + b = numpy.empty_like(a) + b[0] = -a[1] + b[1] = a[0] + return b + + dap = perp(dir_line1) + denom = numpy.dot(dap, dir_line2) + num = numpy.dot(dap, dp) + if denom == 0: + return None + return ( + (num / denom.astype(float)) * dir_line2[0] + line2_pt1[0], + (num / denom.astype(float)) * dir_line2[1] + line2_pt1[1]) + + +def segments_intersection(seg1_start_pt, seg1_end_pt, seg2_start_pt, + seg2_end_pt): + """ + Compute intersection between two segments + + :param seg1_start_pt: + :param seg1_end_pt: + :param seg2_start_pt: + :param seg2_end_pt: + :return: numpy.array if an intersection exists, else None + :rtype: Union[None,numpy.array] + """ + intersection = lines_intersection(line1_pt1=seg1_start_pt, + line1_pt2=seg1_end_pt, + line2_pt1=seg2_start_pt, + line2_pt2=seg2_end_pt) + if intersection is not None: + max_x_seg1 = max(seg1_start_pt[0], seg1_end_pt[0]) + max_x_seg2 = max(seg2_start_pt[0], seg2_end_pt[0]) + max_y_seg1 = max(seg1_start_pt[1], seg1_end_pt[1]) + max_y_seg2 = max(seg2_start_pt[1], seg2_end_pt[1]) + + min_x_seg1 = min(seg1_start_pt[0], seg1_end_pt[0]) + min_x_seg2 = min(seg2_start_pt[0], seg2_end_pt[0]) + min_y_seg1 = min(seg1_start_pt[1], seg1_end_pt[1]) + min_y_seg2 = min(seg2_start_pt[1], seg2_end_pt[1]) + + min_tmp_x = max(min_x_seg1, min_x_seg2) + max_tmp_x = min(max_x_seg1, max_x_seg2) + min_tmp_y = max(min_y_seg1, min_y_seg2) + max_tmp_y = min(max_y_seg1, max_y_seg2) + if (min_tmp_x <= intersection[0] <= max_tmp_x and + min_tmp_y <= intersection[1] <= max_tmp_y): + return intersection + else: + return None |