From a763e5d1b3921b3194f3d4e94ab9de3fbe08bbdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Picca=20Fr=C3=A9d=C3=A9ric-Emmanuel?= Date: Tue, 28 May 2019 08:16:16 +0200 Subject: New upstream version 0.10.1+dfsg --- silx/gui/colors.py | 454 +++-- silx/gui/data/DataViewer.py | 92 +- silx/gui/data/DataViewerFrame.py | 5 +- silx/gui/data/DataViewerSelector.py | 6 +- silx/gui/data/DataViews.py | 499 +++++- silx/gui/data/Hdf5TableView.py | 13 +- silx/gui/data/HexaTableView.py | 8 +- silx/gui/data/NXdataWidgets.py | 439 ++++- silx/gui/data/TextFormatter.py | 50 +- silx/gui/data/test/test_arraywidget.py | 6 +- silx/gui/data/test/test_dataviewer.py | 14 +- silx/gui/data/test/test_numpyaxesselector.py | 7 +- silx/gui/data/test/test_textformatter.py | 12 +- silx/gui/dialog/AbstractDataFileDialog.py | 56 +- silx/gui/dialog/ColormapDialog.py | 355 ++-- silx/gui/dialog/DataFileDialog.py | 10 +- silx/gui/dialog/FileTypeComboBox.py | 39 +- silx/gui/dialog/ImageFileDialog.py | 5 +- silx/gui/dialog/SafeFileSystemModel.py | 6 +- silx/gui/dialog/test/test_colormapdialog.py | 36 +- silx/gui/dialog/test/test_datafiledialog.py | 115 +- silx/gui/dialog/test/test_imagefiledialog.py | 117 +- silx/gui/dialog/utils.py | 6 +- silx/gui/hdf5/Hdf5Formatter.py | 17 +- silx/gui/hdf5/Hdf5Item.py | 38 +- silx/gui/hdf5/Hdf5TreeModel.py | 41 +- silx/gui/hdf5/NexusSortFilterProxyModel.py | 4 +- silx/gui/hdf5/_utils.py | 74 +- silx/gui/hdf5/test/test_hdf5.py | 59 +- silx/gui/icons.py | 8 +- silx/gui/plot/ColorBar.py | 19 +- silx/gui/plot/CompareImages.py | 2 +- silx/gui/plot/ComplexImageView.py | 11 +- silx/gui/plot/CurvesROIWidget.py | 1834 +++++++++++++------- silx/gui/plot/MaskToolsWidget.py | 108 +- silx/gui/plot/PlotInteraction.py | 148 +- silx/gui/plot/PlotToolButtons.py | 29 +- silx/gui/plot/PlotWidget.py | 492 +++--- silx/gui/plot/PlotWindow.py | 87 +- silx/gui/plot/PrintPreviewToolButton.py | 61 +- silx/gui/plot/Profile.py | 109 +- silx/gui/plot/ScatterMaskToolsWidget.py | 65 +- silx/gui/plot/ScatterView.py | 12 +- silx/gui/plot/StackView.py | 10 +- silx/gui/plot/StatsWidget.py | 1594 ++++++++++++----- silx/gui/plot/_BaseMaskToolsWidget.py | 157 +- silx/gui/plot/_utils/dtime_ticklayout.py | 4 +- silx/gui/plot/actions/control.py | 11 +- silx/gui/plot/actions/io.py | 38 +- silx/gui/plot/backends/BackendBase.py | 37 +- silx/gui/plot/backends/BackendMatplotlib.py | 223 ++- silx/gui/plot/backends/BackendOpenGL.py | 364 ++-- silx/gui/plot/backends/glutils/GLPlotCurve.py | 119 +- silx/gui/plot/backends/glutils/GLPlotFrame.py | 124 +- silx/gui/plot/backends/glutils/GLSupport.py | 63 +- silx/gui/plot/items/__init__.py | 2 +- silx/gui/plot/items/axis.py | 6 +- silx/gui/plot/items/complex.py | 8 +- silx/gui/plot/items/core.py | 37 +- silx/gui/plot/items/curve.py | 5 +- silx/gui/plot/items/histogram.py | 6 +- silx/gui/plot/items/roi.py | 72 +- silx/gui/plot/items/scatter.py | 5 +- silx/gui/plot/items/shape.py | 45 +- silx/gui/plot/matplotlib/Colormap.py | 16 +- silx/gui/plot/stats/stats.py | 400 +++-- silx/gui/plot/stats/statshandler.py | 124 +- silx/gui/plot/test/testCurvesROIWidget.py | 219 ++- silx/gui/plot/test/testMaskToolsWidget.py | 7 +- silx/gui/plot/test/testPlotWidget.py | 61 +- silx/gui/plot/test/testSaveAction.py | 20 +- silx/gui/plot/test/testScatterMaskToolsWidget.py | 5 +- silx/gui/plot/test/testStats.py | 284 +-- silx/gui/plot/test/testUtilsAxis.py | 49 +- silx/gui/plot/tools/roi.py | 20 +- .../plot/tools/test/testScatterProfileToolBar.py | 2 +- silx/gui/plot/utils/axis.py | 288 ++- silx/gui/plot3d/ParamTreeView.py | 2 +- silx/gui/plot3d/ScalarFieldView.py | 21 + silx/gui/plot3d/SceneWidget.py | 30 +- silx/gui/plot3d/_model/items.py | 67 +- silx/gui/plot3d/items/__init__.py | 4 +- silx/gui/plot3d/items/core.py | 4 +- silx/gui/plot3d/items/mesh.py | 281 ++- silx/gui/plot3d/items/mixins.py | 21 +- silx/gui/plot3d/items/scatter.py | 39 +- silx/gui/plot3d/items/volume.py | 12 +- silx/gui/plot3d/scene/primitives.py | 8 +- silx/gui/plot3d/test/__init__.py | 4 +- silx/gui/plot3d/test/testSceneWidgetPicking.py | 53 +- silx/gui/plot3d/test/testStatsWidget.py | 213 +++ silx/gui/qt/_pyside_dynamic.py | 54 +- silx/gui/test/test_colors.py | 110 +- silx/gui/utils/concurrent.py | 4 +- silx/gui/utils/projecturl.py | 77 + silx/gui/utils/test/test_async.py | 4 +- silx/gui/utils/testutils.py | 10 +- silx/gui/widgets/PrintPreview.py | 74 +- silx/gui/widgets/RangeSlider.py | 198 ++- silx/gui/widgets/UrlSelectionTable.py | 164 ++ 100 files changed, 7828 insertions(+), 3619 deletions(-) create mode 100644 silx/gui/plot3d/test/testStatsWidget.py create mode 100644 silx/gui/utils/projecturl.py create mode 100644 silx/gui/widgets/UrlSelectionTable.py (limited to 'silx/gui') diff --git a/silx/gui/colors.py b/silx/gui/colors.py index a51bcdc..f1f34c9 100644 --- a/silx/gui/colors.py +++ b/silx/gui/colors.py @@ -29,18 +29,28 @@ from __future__ import absolute_import __authors__ = ["T. Vincent", "H.Payno"] __license__ = "MIT" -__date__ = "05/10/2018" +__date__ = "29/01/2019" -from silx.gui import qt -import copy as copy_mdl import numpy import logging +import collections +from silx.gui import qt +from silx import config from silx.math.combo import min_max from silx.math.colormap import cmap as _cmap from silx.utils.exceptions import NotEditableError +from silx.utils import deprecation +from silx.resources import resource_filename as _resource_filename + _logger = logging.getLogger(__file__) +try: + from matplotlib import cm as _matplotlib_cm +except ImportError: + _logger.info("matplotlib not available, only embedded colormaps available") + _matplotlib_cm = None + _COLORDICT = {} """Dictionary of common colors.""" @@ -67,12 +77,44 @@ _COLORDICT['darkBrown'] = '#660000' _COLORDICT['darkCyan'] = '#008080' _COLORDICT['darkYellow'] = '#808000' _COLORDICT['darkMagenta'] = '#800080' +_COLORDICT['transparent'] = '#00000000' # FIXME: It could be nice to expose a functional API instead of that attribute COLORDICT = _COLORDICT +_LUT_DESCRIPTION = collections.namedtuple("_LUT_DESCRIPTION", ["source", "cursor_color", "preferred"]) +"""Description of a LUT for internal purpose.""" + + +_AVAILABLE_LUTS = collections.OrderedDict([ + ('gray', _LUT_DESCRIPTION('builtin', 'pink', True)), + ('reversed gray', _LUT_DESCRIPTION('builtin', 'pink', True)), + ('temperature', _LUT_DESCRIPTION('builtin', 'pink', True)), + ('red', _LUT_DESCRIPTION('builtin', 'green', True)), + ('green', _LUT_DESCRIPTION('builtin', 'pink', True)), + ('blue', _LUT_DESCRIPTION('builtin', 'yellow', True)), + ('jet', _LUT_DESCRIPTION('matplotlib', 'pink', True)), + ('viridis', _LUT_DESCRIPTION('resource', 'pink', True)), + ('magma', _LUT_DESCRIPTION('resource', 'green', True)), + ('inferno', _LUT_DESCRIPTION('resource', 'green', True)), + ('plasma', _LUT_DESCRIPTION('resource', 'green', True)), + ('hsv', _LUT_DESCRIPTION('matplotlib', 'black', True)), +]) +"""Description for internal porpose of all the default LUT provided by the library.""" + + +DEFAULT_MIN_LIN = 0 +"""Default min value if in linear normalization""" +DEFAULT_MAX_LIN = 1 +"""Default max value if in linear normalization""" +DEFAULT_MIN_LOG = 1 +"""Default min value if in log normalization""" +DEFAULT_MAX_LOG = 10 +"""Default max value if in log normalization""" + + def rgba(color, colorDict=None): """Convert color code '#RRGGBB' and '#RRGGBBAA' to (R, G, B, A) @@ -121,19 +163,21 @@ def rgba(color, colorDict=None): return r, g, b, a -_COLORMAP_CURSOR_COLORS = { - 'gray': 'pink', - 'reversed gray': 'pink', - 'temperature': 'pink', - 'red': 'green', - 'green': 'pink', - 'blue': 'yellow', - 'jet': 'pink', - 'viridis': 'pink', - 'magma': 'green', - 'inferno': 'green', - 'plasma': 'green', -} +def greyed(color, colorDict=None): + """Convert color code '#RRGGBB' and '#RRGGBBAA' to a grey color + (R, G, B, A). + + It also convert RGB(A) values from uint8 to float in [0, 1] and + accept a QColor as color argument. + + :param str color: The color to convert + :param dict colorDict: A dictionary of color name conversion to color code + :returns: RGBA colors as floats in [0., 1.] + :rtype: tuple + """ + r, g, b, a = rgba(color=color, colorDict=colorDict) + g = 0.21 * r + 0.72 * g + 0.07 * b + return g, g, g, a def cursorColorForColormap(colormapName): @@ -143,26 +187,140 @@ def cursorColorForColormap(colormapName): :return: Name of the color. :rtype: str """ - return _COLORMAP_CURSOR_COLORS.get(colormapName, 'black') + description = _AVAILABLE_LUTS.get(colormapName, None) + if description is not None: + color = description.cursor_color + if color is not None: + return color + return 'black' -DEFAULT_COLORMAPS = ( - 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') -"""Tuple of supported colormap names.""" +# Colormap loader -DEFAULT_MIN_LIN = 0 -"""Default min value if in linear normalization""" -DEFAULT_MAX_LIN = 1 -"""Default max value if in linear normalization""" -DEFAULT_MIN_LOG = 1 -"""Default min value if in log normalization""" -DEFAULT_MAX_LOG = 10 -"""Default max value if in log normalization""" +_COLORMAP_CACHE = {} +"""Cache already used colormaps as name: color LUT""" + + +def _arrayToRgba8888(colors): + """Convert colors from a numpy array using float (0..1) int or uint + (0..255) to uint8 RGBA. + + :param numpy.ndarray colors: Array of float int or uint colors to convert + :return: colors as uint8 + :rtype: numpy.ndarray + """ + assert len(colors.shape) == 2 + assert colors.shape[1] in (3, 4) + + if colors.dtype == numpy.uint8: + pass + elif colors.dtype.kind == 'f': + # Each bin is [N, N+1[ except the last one: [255, 256] + colors = numpy.clip(colors.astype(numpy.float64) * 256, 0., 255.) + colors = colors.astype(numpy.uint8) + elif colors.dtype.kind in 'iu': + colors = numpy.clip(colors, 0, 255) + colors = colors.astype(numpy.uint8) + + if colors.shape[1] == 3: + tmp = numpy.empty((len(colors), 4), dtype=numpy.uint8) + tmp[:, 0:3] = colors + tmp[:, 3] = 255 + colors = tmp + + return colors + + +def _createColormapLut(name): + """Returns the color LUT corresponding to a colormap name + + :param str name: Name of the colormap to load + :returns: Corresponding table of colors + :rtype: numpy.ndarray + :raise ValueError: If no colormap corresponds to name + """ + description = _AVAILABLE_LUTS.get(name) + use_mpl = False + if description is not None: + if description.source == "builtin": + # Build colormap LUT + lut = numpy.zeros((256, 4), dtype=numpy.uint8) + lut[:, 3] = 255 + + if name == 'gray': + lut[:, :3] = numpy.arange(256, dtype=numpy.uint8).reshape(-1, 1) + elif name == 'reversed gray': + lut[:, :3] = numpy.arange(255, -1, -1, dtype=numpy.uint8).reshape(-1, 1) + elif name == 'red': + lut[:, 0] = numpy.arange(256, dtype=numpy.uint8) + elif name == 'green': + lut[:, 1] = numpy.arange(256, dtype=numpy.uint8) + elif name == 'blue': + lut[:, 2] = numpy.arange(256, dtype=numpy.uint8) + elif name == 'temperature': + # Red + lut[128:192, 0] = numpy.arange(2, 255, 4, dtype=numpy.uint8) + lut[192:, 0] = 255 + # Green + lut[:64, 1] = numpy.arange(0, 255, 4, dtype=numpy.uint8) + lut[64:192, 1] = 255 + lut[192:, 1] = numpy.arange(252, -1, -4, dtype=numpy.uint8) + # Blue + lut[:64, 2] = 255 + lut[64:128, 2] = numpy.arange(254, 0, -4, dtype=numpy.uint8) + else: + raise RuntimeError("Built-in colormap not implemented") + return lut + + elif description.source == "resource": + # Load colormap LUT + colors = numpy.load(_resource_filename("gui/colormaps/%s.npy" % name)) + # Convert to uint8 and add alpha channel + lut = _arrayToRgba8888(colors) + return lut + + elif description.source == "matplotlib": + use_mpl = True + + else: + raise RuntimeError("Internal LUT source '%s' unsupported" % description.source) + + # Here it expect a matplotlib LUTs + + if use_mpl: + # matplotlib is mandatory + if _matplotlib_cm is None: + raise ValueError("The colormap '%s' expect matplotlib, but matplotlib is not installed" % name) + + if _matplotlib_cm is not None: # Try to load with matplotlib + colormap = _matplotlib_cm.get_cmap(name) + lut = colormap(numpy.linspace(0, 1, colormap.N, endpoint=True)) + lut = _arrayToRgba8888(lut) + return lut + + raise ValueError("Unknown colormap '%s'" % name) + + +def _getColormap(name): + """Returns the color LUT corresponding to a colormap name + + :param str name: Name of the colormap to load + :returns: Corresponding table of colors + :rtype: numpy.ndarray + :raise ValueError: If no colormap corresponds to name + """ + name = str(name) + if name not in _COLORMAP_CACHE: + lut = _createColormapLut(name) + _COLORMAP_CACHE[name] = lut + return _COLORMAP_CACHE[name] class Colormap(qt.QObject): """Description of a colormap + If no `name` nor `colors` are provided, a default gray LUT is used. + :param str name: Name of the colormap :param tuple colors: optional, custom colormap. Nx3 or Nx4 numpy array of RGB(A) colors, @@ -187,10 +345,11 @@ class Colormap(qt.QObject): sigChanged = qt.Signal() """Signal emitted when the colormap has changed.""" - def __init__(self, name='gray', colors=None, normalization=LINEAR, vmin=None, vmax=None): + def __init__(self, name=None, colors=None, normalization=LINEAR, vmin=None, vmax=None): qt.QObject.__init__(self) + self._editable = True + assert normalization in Colormap.NORMALIZATIONS - assert not (name is None and colors is None) if normalization is Colormap.LOGARITHM: if (vmin is not None and vmin < 0) or (vmax is not None and vmax < 0): m = "Unsuported vmin (%s) and/or vmax (%s) given for a log scale." @@ -200,78 +359,76 @@ class Colormap(qt.QObject): vmin = None vmax = None - self._name = str(name) if name is not None else None - self._setColors(colors) + self._name = None + self._colors = None + + if colors is not None and name is not None: + deprecation.deprecated_warning("Argument", + name="silx.gui.plot.Colors", + reason="name and colors can't be used at the same time", + since_version="0.10.0", + skip_backtrace_count=1) + + colors = None + + if name is not None: + self.setName(name) # And resets colormap LUT + elif colors is not None: + self.setColormapLUT(colors) + else: + # Default colormap is grey + self.setName("gray") + self._normalization = str(normalization) self._vmin = float(vmin) if vmin is not None else None self._vmax = float(vmax) if vmax is not None else None - self._editable = True - - def isAutoscale(self): - """Return True if both min and max are in autoscale mode""" - return self._vmin is None and self._vmax is None - def getName(self): - """Return the name of the colormap - :rtype: str - """ - return self._name - - @staticmethod - def _convertColorsFromFloatToUint8(colors): - """Convert colors from float in [0, 1] to uint8 + def setFromColormap(self, other): + """Set this colormap using information from the `other` colormap. - :param numpy.ndarray colors: Array of float colors to convert - :return: colors as uint8 - :rtype: numpy.ndarray + :param Colormap other: Colormap to use as reference. """ - # Each bin is [N, N+1[ except the last one: [255, 256] - return numpy.clip( - colors.astype(numpy.float64) * 256, 0., 255.).astype(numpy.uint8) - - def _setColors(self, colors): - if colors is None: - self._colors = None + if not self.isEditable(): + raise NotEditableError('Colormap is not editable') + if self == other: + return + old = self.blockSignals(True) + name = other.getName() + if name is not None: + self.setName(name) else: - colors = numpy.array(colors, copy=False) - if colors.shape == (): - raise TypeError("An array is expected for 'colors' argument. '%s' was found." % type(colors)) - colors.shape = -1, colors.shape[-1] - if colors.dtype.kind == 'f': - colors = self._convertColorsFromFloatToUint8(colors) - - # Makes sure it is RGBA8888 - self._colors = numpy.zeros((len(colors), 4), dtype=numpy.uint8) - self._colors[:, 3] = 255 # Alpha channel - self._colors[:, :colors.shape[1]] = colors # Copy colors + self.setColormapLUT(other.getColormapLUT()) + self.setNormalization(other.getNormalization()) + self.setVRange(other.getVMin(), other.getVMax()) + self.blockSignals(old) + self.sigChanged.emit() def getNColors(self, nbColors=None): """Returns N colors computed by sampling the colormap regularly. :param nbColors: The number of colors in the returned array or None for the default value. - The default value is 256 for colormap with a name (see :meth:`setName`) and - it is the size of the LUT for colormap defined with :meth:`setColormapLUT`. + The default value is the size of the colormap LUT. :type nbColors: int or None :return: 2D array of uint8 of shape (nbColors, 4) :rtype: numpy.ndarray """ # Handle default value for nbColors if nbColors is None: - lut = self.getColormapLUT() - if lut is not None: # In this case uses LUT length - nbColors = len(lut) - else: # Default to 256 - nbColors = 256 - - nbColors = int(nbColors) + return numpy.array(self._colors, copy=True) + else: + colormap = self.copy() + colormap.setNormalization(Colormap.LINEAR) + colormap.setVRange(vmin=None, vmax=None) + colors = colormap.applyToData( + numpy.arange(int(nbColors), dtype=numpy.int)) + return colors - colormap = self.copy() - colormap.setNormalization(Colormap.LINEAR) - colormap.setVRange(vmin=None, vmax=None) - colors = colormap.applyToData( - numpy.arange(nbColors, dtype=numpy.int)) - return colors + def getName(self): + """Return the name of the colormap + :rtype: str + """ + return self._name def setName(self, name): """Set the name of the colormap to use. @@ -281,23 +438,31 @@ class Colormap(qt.QObject): 'reversed gray', 'temperature', 'red', 'green', 'blue', 'jet', 'viridis', 'magma', 'inferno', 'plasma'. """ + name = str(name) + if self._name == name: + return if self.isEditable() is False: raise NotEditableError('Colormap is not editable') - assert name in self.getSupportedColormaps() - self._name = str(name) - self._colors = None + if name not in self.getSupportedColormaps(): + raise ValueError("Colormap name '%s' is not supported" % name) + self._name = name + self._colors = _getColormap(self._name) self.sigChanged.emit() - def getColormapLUT(self): - """Return the list of colors for the colormap or None if not set + def getColormapLUT(self, copy=True): + """Return the list of colors for the colormap or None if not set. + This returns None if the colormap was set with :meth:`setName`. + Use :meth:`getNColors` to get the colormap LUT for any colormap. + + :param bool copy: If true a copy of the numpy array is provided :return: the list of colors for the colormap or None if not set :rtype: numpy.ndarray or None """ - if self._colors is None: - return None + if self._name is None: + return numpy.array(self._colors, copy=copy) else: - return numpy.array(self._colors, copy=True) + return None def setColormapLUT(self, colors): """Set the colors of the colormap. @@ -310,10 +475,15 @@ class Colormap(qt.QObject): """ if self.isEditable() is False: raise NotEditableError('Colormap is not editable') - self._setColors(colors) - if len(colors) is 0: - self._colors = None - + assert colors is not None + + colors = numpy.array(colors, copy=False) + if colors.shape == (): + raise TypeError("An array is expected for 'colors' argument. '%s' was found." % type(colors)) + assert len(colors) != 0 + assert colors.ndim >= 2 + colors.shape = -1, colors.shape[-1] + self._colors = _arrayToRgba8888(colors) self._name = None self.sigChanged.emit() @@ -335,6 +505,10 @@ class Colormap(qt.QObject): self._normalization = str(norm) self.sigChanged.emit() + def isAutoscale(self): + """Return True if both min and max are in autoscale mode""" + return self._vmin is None and self._vmax is None + def getVMin(self): """Return the lower bound of the colormap @@ -504,7 +678,7 @@ class Colormap(qt.QObject): """ return { 'name': self._name, - 'colors': copy_mdl.copy(self._colors), + 'colors': self.getColormapLUT(), 'vmin': self._vmin, 'vmax': self._vmax, 'autoscale': self.isAutoscale(), @@ -546,8 +720,10 @@ class Colormap(qt.QObject): if dic.get('autoscale', False): vmin, vmax = None, None - self._name = name - self._colors = colors + if name is not None: + self.setName(name) + else: + self.setColormapLUT(colors) self._vmin = vmin self._vmax = vmax self._autoscale = True if (vmin is None and vmax is None) else False @@ -557,7 +733,7 @@ class Colormap(qt.QObject): @staticmethod def _fromDict(dic): - colormap = Colormap(name="") + colormap = Colormap() colormap._setFromDict(dic) return colormap @@ -567,7 +743,7 @@ class Colormap(qt.QObject): :rtype: silx.gui.colors.Colormap """ return Colormap(name=self._name, - colors=copy_mdl.copy(self._colors), + colors=self.getColormapLUT(), vmin=self._vmin, vmax=self._vmax, normalization=self._normalization) @@ -577,34 +753,30 @@ class Colormap(qt.QObject): :param numpy.ndarray data: The data to convert. """ - name = self.getName() - if name is not None: # Get colormap definition from matplotlib - # FIXME: If possible remove dependency to the plot - from .plot.matplotlib import Colormap as MPLColormap - mplColormap = MPLColormap.getColormap(name) - colors = mplColormap(numpy.linspace(0, 1, 256, endpoint=True)) - colors = self._convertColorsFromFloatToUint8(colors) - - else: # Use user defined LUT - colors = self.getColormapLUT() - vmin, vmax = self.getColormapRange(data) normalization = self.getNormalization() - - return _cmap(data, colors, vmin, vmax, normalization) + return _cmap(data, self._colors, vmin, vmax, normalization) @staticmethod def getSupportedColormaps(): """Get the supported colormap names as a tuple of str. The list should at least contain and start by: - ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') + + ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue', + 'viridis', 'magma', 'inferno', 'plasma') + :rtype: tuple """ - # FIXME: If possible remove dependency to the plot - from .plot.matplotlib import Colormap as MPLColormap - maps = MPLColormap.getSupportedColormaps() - return DEFAULT_COLORMAPS + maps + colormaps = set() + if _matplotlib_cm is not None: + colormaps.update(_matplotlib_cm.cmap_d.keys()) + colormaps.update(_AVAILABLE_LUTS.keys()) + + colormaps = tuple(cmap for cmap in sorted(colormaps) + if cmap not in _AVAILABLE_LUTS.keys()) + + return tuple(_AVAILABLE_LUTS.keys()) + colormaps def __str__(self): return str(self._toDict()) @@ -617,6 +789,10 @@ class Colormap(qt.QObject): def __eq__(self, other): """Compare colormap values and not pointers""" + if other is None: + return False + if not isinstance(other, Colormap): + return False return (self.getName() == other.getName() and self.getNormalization() == other.getNormalization() and self.getVMin() == other.getVMin() and @@ -710,13 +886,10 @@ def preferredColormaps(): """ global _PREFERRED_COLORMAPS if _PREFERRED_COLORMAPS is None: - _PREFERRED_COLORMAPS = DEFAULT_COLORMAPS # Initialize preferred colormaps - setPreferredColormaps(('gray', 'reversed gray', - 'temperature', 'red', 'green', 'blue', 'jet', - 'viridis', 'magma', 'inferno', 'plasma', - 'hsv')) - return _PREFERRED_COLORMAPS + default_preferred = [k for k in _AVAILABLE_LUTS.keys() if _AVAILABLE_LUTS[k].preferred] + setPreferredColormaps(default_preferred) + return tuple(_PREFERRED_COLORMAPS) def setPreferredColormaps(colormaps): @@ -730,10 +903,41 @@ def setPreferredColormaps(colormaps): :raise ValueError: if the list of available preferred colormaps is empty. """ supportedColormaps = Colormap.getSupportedColormaps() - colormaps = tuple( - cmap for cmap in colormaps if cmap in supportedColormaps) + colormaps = [cmap for cmap in colormaps if cmap in supportedColormaps] if len(colormaps) == 0: raise ValueError("Cannot set preferred colormaps to an empty list") global _PREFERRED_COLORMAPS _PREFERRED_COLORMAPS = colormaps + + +def registerLUT(name, colors, cursor_color='black', preferred=True): + """Register a custom LUT to be used with `Colormap` objects. + + It can override existing LUT names. + + :param str name: Name of the LUT as defined to configure colormaps + :param numpy.ndarray colors: The custom LUT to register. + Nx3 or Nx4 numpy array of RGB(A) colors, + either uint8 or float in [0, 1]. + :param bool preferred: If true, this LUT will be displayed as part of the + preferred colormaps in dialogs. + :param str cursor_color: Color used to display overlay over images using + colormap with this LUT. + """ + description = _LUT_DESCRIPTION('user', cursor_color, preferred=preferred) + colors = _arrayToRgba8888(colors) + _AVAILABLE_LUTS[name] = description + + if preferred: + # Invalidate the preferred cache + global _PREFERRED_COLORMAPS + if _PREFERRED_COLORMAPS is not None: + if name not in _PREFERRED_COLORMAPS: + _PREFERRED_COLORMAPS.append(name) + else: + # The cache is not yet loaded, it's fine + pass + + # Register the cache as the LUT was already loaded + _COLORMAP_CACHE[name] = colors diff --git a/silx/gui/data/DataViewer.py b/silx/gui/data/DataViewer.py index 4db2863..b33a931 100644 --- a/silx/gui/data/DataViewer.py +++ b/silx/gui/data/DataViewer.py @@ -32,12 +32,10 @@ from silx.gui.data.DataViews import _normalizeData import logging from silx.gui import qt from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector -from silx.utils import deprecation -from silx.utils.property import classproperty __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "24/04/2018" +__date__ = "12/02/2019" _logger = logging.getLogger(__name__) @@ -70,66 +68,6 @@ class DataViewer(qt.QFrame): viewer.setVisible(True) """ - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.EMPTY_MODE", since_version="0.7", skip_backtrace_count=2) - def EMPTY_MODE(self): - return DataViews.EMPTY_MODE - - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.PLOT1D_MODE", since_version="0.7", skip_backtrace_count=2) - def PLOT1D_MODE(self): - return DataViews.PLOT1D_MODE - - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.PLOT2D_MODE", since_version="0.7", skip_backtrace_count=2) - def PLOT2D_MODE(self): - return DataViews.PLOT2D_MODE - - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.PLOT3D_MODE", since_version="0.7", skip_backtrace_count=2) - def PLOT3D_MODE(self): - return DataViews.PLOT3D_MODE - - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.RAW_MODE", since_version="0.7", skip_backtrace_count=2) - def RAW_MODE(self): - return DataViews.RAW_MODE - - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.RAW_ARRAY_MODE", since_version="0.7", skip_backtrace_count=2) - def RAW_ARRAY_MODE(self): - return DataViews.RAW_ARRAY_MODE - - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.RAW_RECORD_MODE", since_version="0.7", skip_backtrace_count=2) - def RAW_RECORD_MODE(self): - return DataViews.RAW_RECORD_MODE - - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.RAW_SCALAR_MODE", since_version="0.7", skip_backtrace_count=2) - def RAW_SCALAR_MODE(self): - return DataViews.RAW_SCALAR_MODE - - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.STACK_MODE", since_version="0.7", skip_backtrace_count=2) - def STACK_MODE(self): - return DataViews.STACK_MODE - - # TODO: Can be removed for silx 0.8 - @classproperty - @deprecation.deprecated(replacement="DataViews.HDF5_MODE", since_version="0.7", skip_backtrace_count=2) - def HDF5_MODE(self): - return DataViews.HDF5_MODE - displayedViewChanged = qt.Signal(object) """Emitted when the displayed view changes""" @@ -288,6 +226,7 @@ class DataViewer(qt.QFrame): else: self.__displayedData = self.__data + # TODO: would be good to avoid that, it should be synchonous qt.QTimer.singleShot(10, self.__setDataInView) def __setDataInView(self): @@ -405,18 +344,16 @@ class DataViewer(qt.QFrame): data = self.__data info = self._getInfo() # sort available views according to priority - priorities = [v.getDataPriority(data, info) for v in self.__views] - views = zip(priorities, self.__views) + views = [] + for v in self.__views: + views.extend(v.getMatchingViews(data, info)) + views = [(v.getCachedDataPriority(data, info), v) for v in views] views = filter(lambda t: t[0] > DataViews.DataView.UNSUPPORTED, views) views = sorted(views, reverse=True) + views = [v[1] for v in views] # store available views - if len(views) == 0: - self.__setCurrentAvailableViews([]) - available = [] - else: - available = [v[1] for v in views] - self.__setCurrentAvailableViews(available) + self.__setCurrentAvailableViews(views) def __updateView(self): """Display the data using the widget which fit the best""" @@ -447,7 +384,7 @@ class DataViewer(qt.QFrame): priority to lowest. :rtype: DataView """ - hdf5View = self.getViewFromModeId(DataViewer.HDF5_MODE) + hdf5View = self.getViewFromModeId(DataViews.HDF5_MODE) if hdf5View in available: return hdf5View return self.getViewFromModeId(DataViews.EMPTY_MODE) @@ -487,6 +424,17 @@ class DataViewer(qt.QFrame): """ return self.__currentAvailableViews + def getReachableViews(self): + """Returns the list of reachable views from the registred available + views. + + :rtype: List[DataView] + """ + views = [] + for v in self.availableViews(): + views.extend(v.getReachableViews()) + return views + def availableViews(self): """Returns the list of registered views diff --git a/silx/gui/data/DataViewerFrame.py b/silx/gui/data/DataViewerFrame.py index 4e6d2e8..9bfb95b 100644 --- a/silx/gui/data/DataViewerFrame.py +++ b/silx/gui/data/DataViewerFrame.py @@ -27,7 +27,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "24/04/2018" +__date__ = "12/02/2019" from silx.gui import qt from .DataViewer import DataViewer @@ -120,6 +120,9 @@ class DataViewerFrame(qt.QWidget): """ self.__dataViewer.setGlobalHooks(hooks) + def getReachableViews(self): + return self.__dataViewer.getReachableViews() + def availableViews(self): """Returns the list of registered views diff --git a/silx/gui/data/DataViewerSelector.py b/silx/gui/data/DataViewerSelector.py index 35bbe99..a1e9947 100644 --- a/silx/gui/data/DataViewerSelector.py +++ b/silx/gui/data/DataViewerSelector.py @@ -29,7 +29,7 @@ from __future__ import division __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "23/01/2018" +__date__ = "12/02/2019" import weakref import functools @@ -85,7 +85,7 @@ class DataViewerSelector(qt.QWidget): iconSize = qt.QSize(16, 16) - for view in self.__dataViewer.availableViews(): + for view in self.__dataViewer.getReachableViews(): label = view.label() icon = view.icon() button = qt.QPushButton(label) @@ -155,7 +155,7 @@ class DataViewerSelector(qt.QWidget): self.__dataViewer.setDisplayedView(view) def __checkAvailableButtons(self): - views = set(self.__dataViewer.availableViews()) + views = set(self.__dataViewer.getReachableViews()) if views == set(self.__buttons.keys()): return # Recreate all the buttons diff --git a/silx/gui/data/DataViews.py b/silx/gui/data/DataViews.py index 2291e87..6575d0d 100644 --- a/silx/gui/data/DataViews.py +++ b/silx/gui/data/DataViews.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,6 +31,7 @@ import numbers import numpy import silx.io +from silx.utils import deprecation from silx.gui import qt, icons from silx.gui.data.TextFormatter import TextFormatter from silx.io import nxdata @@ -41,7 +42,7 @@ from silx.gui.dialog.ColormapDialog import ColormapDialog __authors__ = ["V. Valls", "P. Knobel"] __license__ = "MIT" -__date__ = "23/05/2018" +__date__ = "19/02/2019" _logger = logging.getLogger(__name__) @@ -67,6 +68,8 @@ NXDATA_CURVE_MODE = 73 NXDATA_XYVSCATTER_MODE = 74 NXDATA_IMAGE_MODE = 75 NXDATA_STACK_MODE = 76 +NXDATA_VOLUME_MODE = 77 +NXDATA_VOLUME_AS_STACK_MODE = 78 def _normalizeData(data): @@ -100,6 +103,7 @@ class DataInfo(object): """Store extracted information from a data""" def __init__(self, data): + self.__priorities = {} data = self.normalizeData(data) self.isArray = False self.interpretation = None @@ -131,9 +135,6 @@ class DataInfo(object): elif nx_class == "NXdata": # group claiming to be NXdata could not be parsed self.isInvalidNXdata = True - elif nx_class == "NXentry" and "default" in data.attrs: - # entry claiming to have a default NXdata could not be parsed - self.isInvalidNXdata = True elif nx_class == "NXroot" or silx.io.is_file(data): # root claiming to have a default entry if "default" in data.attrs: @@ -141,6 +142,9 @@ class DataInfo(object): if def_entry in data and "default" in data[def_entry].attrs: # and entry claims to have default NXdata self.isInvalidNXdata = True + elif "default" in data.attrs: + # group claiming to have a default NXdata could not be parsed + self.isInvalidNXdata = True if isinstance(data, numpy.ndarray): self.isArray = True @@ -201,6 +205,12 @@ class DataInfo(object): Else returns the data.""" return _normalizeData(data) + def cachePriority(self, view, priority): + self.__priorities[view] = priority + + def getPriority(self, view): + return self.__priorities[view] + class DataViewHooks(object): """A set of hooks defined to custom the behaviour of the data views.""" @@ -357,6 +367,35 @@ class DataView(object): """ return [] + def getReachableViews(self): + """Returns the views that can be returned by `getMatchingViews`. + + :param object data: Any object to be displayed + :param DataInfo info: Information cached about this data + :rtype: List[DataView] + """ + return [self] + + def getMatchingViews(self, data, info): + """Returns the views according to data and info from the data. + + :param object data: Any object to be displayed + :param DataInfo info: Information cached about this data + :rtype: List[DataView] + """ + priority = self.getCachedDataPriority(data, info) + if priority == DataView.UNSUPPORTED: + return [] + return [self] + + def getCachedDataPriority(self, data, info): + try: + priority = info.getPriority(self) + except KeyError: + priority = self.getDataPriority(data, info) + info.cachePriority(self, priority) + return priority + def getDataPriority(self, data, info): """ Returns the priority of using this view according to a data. @@ -377,7 +416,53 @@ class DataView(object): return str(self) < str(other) -class CompositeDataView(DataView): +class _CompositeDataView(DataView): + """Contains sub views""" + + def getViews(self): + """Returns the direct sub views registered in this view. + + :rtype: List[DataView] + """ + raise NotImplementedError() + + def getReachableViews(self): + """Returns all views that can be reachable at on point. + + This method return any sub view provided (recursivly). + + :rtype: List[DataView] + """ + raise NotImplementedError() + + def getMatchingViews(self, data, info): + """Returns sub views matching this data and info. + + This method return any sub view provided (recursivly). + + :param object data: Any object to be displayed + :param DataInfo info: Information cached about this data + :rtype: List[DataView] + """ + raise NotImplementedError() + + @deprecation.deprecated(replacement="getReachableViews", since_version="0.10") + def availableViews(self): + return self.getViews() + + def isSupportedData(self, data, info): + """If true, the composite view allow sub views to access to this data. + Else this this data is considered as not supported by any of sub views + (incliding this composite view). + + :param object data: Any object to be displayed + :param DataInfo info: Information cached about this data + :rtype: bool + """ + return True + + +class SelectOneDataView(_CompositeDataView): """Data view which can display a data using different view according to the kind of the data.""" @@ -386,7 +471,7 @@ class CompositeDataView(DataView): :param qt.QWidget parent: Parent of the hold widget """ - super(CompositeDataView, self).__init__(parent, modeId, icon, label) + super(SelectOneDataView, self).__init__(parent, modeId, icon, label) self.__views = OrderedDict() self.__currentView = None @@ -395,7 +480,7 @@ class CompositeDataView(DataView): :param DataViewHooks hooks: The data view hooks to use """ - super(CompositeDataView, self).setHooks(hooks) + super(SelectOneDataView, self).setHooks(hooks) if hooks is not None: for v in self.__views: v.setHooks(hooks) @@ -407,16 +492,40 @@ class CompositeDataView(DataView): dataView.setHooks(hooks) self.__views[dataView] = None - def availableViews(self): + def getReachableViews(self): + views = [] + addSelf = False + for v in self.__views: + if isinstance(v, SelectManyDataView): + views.extend(v.getReachableViews()) + else: + addSelf = True + if addSelf: + # Single views are hidden by this view + views.insert(0, self) + return views + + def getMatchingViews(self, data, info): + if not self.isSupportedData(data, info): + return [] + view = self.__getBestView(data, info) + if isinstance(view, SelectManyDataView): + return view.getMatchingViews(data, info) + else: + return [self] + + def getViews(self): """Returns the list of registered views :rtype: List[DataView] """ return list(self.__views.keys()) - def getBestView(self, data, info): + def __getBestView(self, data, info): """Returns the best view according to priorities.""" - views = [(v.getDataPriority(data, info), v) for v in self.__views.keys()] + if not self.isSupportedData(data, info): + return None + views = [(v.getCachedDataPriority(data, info), v) for v in self.__views.keys()] views = filter(lambda t: t[0] > DataView.UNSUPPORTED, views) views = sorted(views, key=lambda t: t[0], reverse=True) @@ -471,17 +580,17 @@ class CompositeDataView(DataView): self.__currentView.setData(data) def axesNames(self, data, info): - view = self.getBestView(data, info) + view = self.__getBestView(data, info) self.__currentView = view return view.axesNames(data, info) def getDataPriority(self, data, info): - view = self.getBestView(data, info) + view = self.__getBestView(data, info) self.__currentView = view if view is None: return DataView.UNSUPPORTED else: - return view.getDataPriority(data, info) + return view.getCachedDataPriority(data, info) def replaceView(self, modeId, newView): """Replace a data view with a custom view. @@ -502,7 +611,7 @@ class CompositeDataView(DataView): if view.modeId() == modeId: oldView = view break - elif isinstance(view, CompositeDataView): + elif isinstance(view, _CompositeDataView): # recurse hooks = self.getHooks() if hooks is not None: @@ -519,6 +628,135 @@ class CompositeDataView(DataView): return True +# NOTE: SelectOneDataView was introduced with silx 0.10 +CompositeDataView = SelectOneDataView + + +class SelectManyDataView(_CompositeDataView): + """Data view which can select a set of sub views according to + the kind of the data. + + This view itself is abstract and is not exposed. + """ + + def __init__(self, parent, views=None): + """Constructor + + :param qt.QWidget parent: Parent of the hold widget + """ + super(SelectManyDataView, self).__init__(parent, modeId=None, icon=None, label=None) + if views is None: + views = [] + self.__views = views + + def setHooks(self, hooks): + """Set the data context to use with this view. + + :param DataViewHooks hooks: The data view hooks to use + """ + super(SelectManyDataView, self).setHooks(hooks) + if hooks is not None: + for v in self.__views: + v.setHooks(hooks) + + def addView(self, dataView): + """Add a new dataview to the available list.""" + hooks = self.getHooks() + if hooks is not None: + dataView.setHooks(hooks) + self.__views.append(dataView) + + def getViews(self): + """Returns the list of registered views + + :rtype: List[DataView] + """ + return list(self.__views) + + def getReachableViews(self): + views = [] + for v in self.__views: + views.extend(v.getReachableViews()) + return views + + def getMatchingViews(self, data, info): + """Returns the views according to data and info from the data. + + :param object data: Any object to be displayed + :param DataInfo info: Information cached about this data + """ + if not self.isSupportedData(data, info): + return [] + views = [v for v in self.__views if v.getCachedDataPriority(data, info) != DataView.UNSUPPORTED] + return views + + def customAxisNames(self): + raise RuntimeError("Abstract view") + + def setCustomAxisValue(self, name, value): + raise RuntimeError("Abstract view") + + def select(self): + raise RuntimeError("Abstract view") + + def createWidget(self, parent): + raise RuntimeError("Abstract view") + + def clear(self): + for v in self.__views: + v.clear() + + def setData(self, data): + raise RuntimeError("Abstract view") + + def axesNames(self, data, info): + raise RuntimeError("Abstract view") + + def getDataPriority(self, data, info): + if not self.isSupportedData(data, info): + return DataView.UNSUPPORTED + priorities = [v.getCachedDataPriority(data, info) for v in self.__views] + priorities = [v for v in priorities if v != DataView.UNSUPPORTED] + priorities = sorted(priorities) + if len(priorities) == 0: + return DataView.UNSUPPORTED + return priorities[-1] + + def replaceView(self, modeId, newView): + """Replace a data view with a custom view. + Return True in case of success, False in case of failure. + + .. note:: + + This method must be called just after instantiation, before + the viewer is used. + + :param int modeId: Unique mode ID identifying the DataView to + be replaced. + :param DataViews.DataView newView: New data view + :return: True if replacement was successful, else False + """ + oldView = None + for iview, view in enumerate(self.__views): + if view.modeId() == modeId: + oldView = view + break + elif isinstance(view, CompositeDataView): + # recurse + hooks = self.getHooks() + if hooks is not None: + newView.setHooks(hooks) + if view.replaceView(modeId, newView): + return True + + if oldView is None: + return False + + # replace oldView with new view in dict + self.__views[iview] = newView + return True + + class _EmptyView(DataView): """Dummy view to display nothing""" @@ -1096,17 +1334,6 @@ class _InvalidNXdataView(DataView): # invalid: could not even be parsed by NXdata self._msg = "Group has @NX_class = NXdata, but could not be interpreted" self._msg += " as valid NXdata." - elif nx_class == "NXentry": - self._msg = "NXentry group provides a @default attribute," - default_nxdata_name = data.attrs["default"] - if default_nxdata_name not in data: - self._msg += " but no corresponding NXdata group exists." - elif get_attr_as_unicode(data[default_nxdata_name], "NX_class") != "NXdata": - self._msg += " but the corresponding item is not a " - self._msg += "NXdata group." - else: - self._msg += " but the corresponding NXdata seems to be" - self._msg += " malformed." elif nx_class == "NXroot" or silx.io.is_file(data): default_entry = data[data.attrs["default"]] default_nxdata_name = default_entry.attrs["default"] @@ -1122,6 +1349,17 @@ class _InvalidNXdataView(DataView): else: self._msg += " but the corresponding NXdata seems to be" self._msg += " malformed." + else: + self._msg = "Group provides a @default attribute," + default_nxdata_name = data.attrs["default"] + if default_nxdata_name not in data: + self._msg += " but no corresponding NXdata group exists." + elif get_attr_as_unicode(data[default_nxdata_name], "NX_class") != "NXdata": + self._msg += " but the corresponding item is not a " + self._msg += "NXdata group." + else: + self._msg += " but the corresponding NXdata seems to be" + self._msg += " malformed." return 100 @@ -1277,7 +1515,7 @@ class _NXdataXYVScatterView(DataView): class _NXdataImageView(DataView): """DataView using a Plot2D for displaying NXdata images: - 2-D signal or n-D signals with *@interpretation=spectrum*.""" + 2-D signal or n-D signals with *@interpretation=image*.""" def __init__(self, parent): DataView.__init__(self, parent, modeId=NXDATA_IMAGE_MODE) @@ -1323,6 +1561,53 @@ class _NXdataImageView(DataView): return DataView.UNSUPPORTED +class _NXdataComplexImageView(DataView): + """DataView using a ComplexImageView for displaying NXdata complex images: + 2-D signal or n-D signals with *@interpretation=image*.""" + def __init__(self, parent): + DataView.__init__(self, parent, + modeId=NXDATA_IMAGE_MODE) + + def createWidget(self, parent): + from silx.gui.data.NXdataWidgets import ArrayComplexImagePlot + widget = ArrayComplexImagePlot(parent, colormap=self.defaultColormap()) + widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog()) + return widget + + def clear(self): + self.getWidget().clear() + + def setData(self, data): + data = self.normalizeData(data) + nxd = nxdata.get_default(data, validate=False) + + # last two axes are Y & X + img_slicing = slice(-2, None) + y_axis, x_axis = nxd.axes[img_slicing] + y_label, x_label = nxd.axes_names[img_slicing] + + self.getWidget().setImageData( + [nxd.signal] + nxd.auxiliary_signals, + x_axis=x_axis, y_axis=y_axis, + signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names, + xlabel=x_label, ylabel=y_label, + title=nxd.title) + + def axesNames(self, data, info): + # disabled (used by default axis selector widget in Hdf5Viewer) + return None + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + + if info.hasNXdata and not info.isInvalidNXdata: + nxd = nxdata.get_default(data, validate=False) + if nxd.is_image and numpy.iscomplexobj(nxd.signal): + return 100 + + return DataView.UNSUPPORTED + + class _NXdataStackView(DataView): def __init__(self, parent): DataView.__init__(self, parent, @@ -1368,6 +1653,154 @@ class _NXdataStackView(DataView): return DataView.UNSUPPORTED +class _NXdataVolumeView(DataView): + def __init__(self, parent): + DataView.__init__(self, parent, + label="NXdata (3D)", + icon=icons.getQIcon("view-nexus"), + modeId=NXDATA_VOLUME_MODE) + try: + import silx.gui.plot3d # noqa + except ImportError: + _logger.warning("Plot3dView is not available") + _logger.debug("Backtrace", exc_info=True) + raise + + def normalizeData(self, data): + data = DataView.normalizeData(self, data) + data = _normalizeComplex(data) + return data + + def createWidget(self, parent): + from silx.gui.data.NXdataWidgets import ArrayVolumePlot + widget = ArrayVolumePlot(parent) + return widget + + def axesNames(self, data, info): + # disabled (used by default axis selector widget in Hdf5Viewer) + return None + + def clear(self): + self.getWidget().clear() + + def setData(self, data): + data = self.normalizeData(data) + nxd = nxdata.get_default(data, validate=False) + signal_name = nxd.signal_name + z_axis, y_axis, x_axis = nxd.axes[-3:] + z_label, y_label, x_label = nxd.axes_names[-3:] + title = nxd.title or signal_name + + widget = self.getWidget() + widget.setData( + nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis, + signal_name=signal_name, + xlabel=x_label, ylabel=y_label, zlabel=z_label, + title=title) + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + if info.hasNXdata and not info.isInvalidNXdata: + if nxdata.get_default(data, validate=False).is_volume: + return 150 + + return DataView.UNSUPPORTED + + +class _NXdataVolumeAsStackView(DataView): + def __init__(self, parent): + DataView.__init__(self, parent, + label="NXdata (2D)", + icon=icons.getQIcon("view-nexus"), + modeId=NXDATA_VOLUME_AS_STACK_MODE) + + def createWidget(self, parent): + from silx.gui.data.NXdataWidgets import ArrayStackPlot + widget = ArrayStackPlot(parent) + widget.getStackView().setColormap(self.defaultColormap()) + widget.getStackView().getPlot().getColormapAction().setColorDialog(self.defaultColorDialog()) + return widget + + def axesNames(self, data, info): + # disabled (used by default axis selector widget in Hdf5Viewer) + return None + + def clear(self): + self.getWidget().clear() + + def setData(self, data): + data = self.normalizeData(data) + nxd = nxdata.get_default(data, validate=False) + signal_name = nxd.signal_name + z_axis, y_axis, x_axis = nxd.axes[-3:] + z_label, y_label, x_label = nxd.axes_names[-3:] + title = nxd.title or signal_name + + widget = self.getWidget() + widget.setStackData( + nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis, + signal_name=signal_name, + xlabel=x_label, ylabel=y_label, zlabel=z_label, + title=title) + # Override the colormap, while setStack overwrite it + widget.getStackView().setColormap(self.defaultColormap()) + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + if info.isComplex: + return DataView.UNSUPPORTED + if info.hasNXdata and not info.isInvalidNXdata: + if nxdata.get_default(data, validate=False).is_volume: + return 200 + + return DataView.UNSUPPORTED + +class _NXdataComplexVolumeAsStackView(DataView): + def __init__(self, parent): + DataView.__init__(self, parent, + label="NXdata (2D)", + icon=icons.getQIcon("view-nexus"), + modeId=NXDATA_VOLUME_AS_STACK_MODE) + self._is_complex_data = False + + def createWidget(self, parent): + from silx.gui.data.NXdataWidgets import ArrayComplexImagePlot + widget = ArrayComplexImagePlot(parent, colormap=self.defaultColormap()) + widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog()) + return widget + + def axesNames(self, data, info): + # disabled (used by default axis selector widget in Hdf5Viewer) + return None + + def clear(self): + self.getWidget().clear() + + def setData(self, data): + data = self.normalizeData(data) + nxd = nxdata.get_default(data, validate=False) + signal_name = nxd.signal_name + z_axis, y_axis, x_axis = nxd.axes[-3:] + z_label, y_label, x_label = nxd.axes_names[-3:] + title = nxd.title or signal_name + + self.getWidget().setImageData( + [nxd.signal] + nxd.auxiliary_signals, + x_axis=x_axis, y_axis=y_axis, + signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names, + xlabel=x_label, ylabel=y_label, title=nxd.title) + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + if not info.isComplex: + return DataView.UNSUPPORTED + if info.hasNXdata and not info.isInvalidNXdata: + if nxdata.get_default(data, validate=False).is_volume: + return 200 + + return DataView.UNSUPPORTED + + class _NXdataView(CompositeDataView): """Composite view displaying NXdata groups using the most adequate widget depending on the dimensionality.""" @@ -1382,5 +1815,17 @@ class _NXdataView(CompositeDataView): self.addView(_NXdataScalarView(parent)) self.addView(_NXdataCurveView(parent)) self.addView(_NXdataXYVScatterView(parent)) + self.addView(_NXdataComplexImageView(parent)) self.addView(_NXdataImageView(parent)) self.addView(_NXdataStackView(parent)) + + # The 3D view can be displayed using 2 ways + nx3dViews = SelectManyDataView(parent) + nx3dViews.addView(_NXdataVolumeAsStackView(parent)) + nx3dViews.addView(_NXdataComplexVolumeAsStackView(parent)) + try: + nx3dViews.addView(_NXdataVolumeView(parent)) + except Exception: + _logger.warning("NXdataVolumeView is not available") + _logger.debug("Backtrace", exc_info=True) + self.addView(nx3dViews) diff --git a/silx/gui/data/Hdf5TableView.py b/silx/gui/data/Hdf5TableView.py index 9e28fbf..d7c33f3 100644 --- a/silx/gui/data/Hdf5TableView.py +++ b/silx/gui/data/Hdf5TableView.py @@ -30,12 +30,14 @@ from __future__ import division __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "05/07/2018" +__date__ = "12/02/2019" import collections import functools import os.path import logging +import h5py + from silx.gui import qt import silx.io from .TextFormatter import TextFormatter @@ -44,11 +46,6 @@ from silx.gui.widgets import HierarchicalTableView from ..hdf5.Hdf5Formatter import Hdf5Formatter from ..hdf5._utils import htmlFromDict -try: - import h5py -except ImportError: - h5py = None - _logger = logging.getLogger(__name__) @@ -198,11 +195,9 @@ class _CellFilterAvailableData(_CellData): } def __init__(self, filterId): - import h5py.version if h5py.version.hdf5_version_tuple >= (1, 10, 2): # Previous versions only returns True if the filter was first used # to decode a dataset - import h5py.h5z self.__availability = h5py.h5z.filter_avail(filterId) else: self.__availability = "na" @@ -416,7 +411,7 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel): self.__data.addHeaderRow(headerLabel="Data info") - if h5py is not None and hasattr(obj, "id") and hasattr(obj.id, "get_type"): + if hasattr(obj, "id") and hasattr(obj.id, "get_type"): # display the HDF5 type self.__data.addHeaderValueRow("HDF5 type", self.__formatHdf5Type) self.__data.addHeaderValueRow("dtype", self.__formatDType) diff --git a/silx/gui/data/HexaTableView.py b/silx/gui/data/HexaTableView.py index c86c0af..1617f0a 100644 --- a/silx/gui/data/HexaTableView.py +++ b/silx/gui/data/HexaTableView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,11 +28,13 @@ hexadecimal viewer. """ from __future__ import division -import numpy import collections + +import numpy +import six + from silx.gui import qt import silx.io.utils -from silx.third_party import six from silx.gui.widgets.TableWidget import CopySelectedCellsAction __authors__ = ["V. Valls"] diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py index f7c479d..e5a2550 100644 --- a/silx/gui/data/NXdataWidgets.py +++ b/silx/gui/data/NXdataWidgets.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,19 +26,25 @@ """ __authors__ = ["P. Knobel"] __license__ = "MIT" -__date__ = "10/10/2018" +__date__ = "12/11/2018" +import logging +import numbers import numpy from silx.gui import qt from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector from silx.gui.plot import Plot1D, Plot2D, StackView, ScatterView +from silx.gui.plot.ComplexImageView import ComplexImageView from silx.gui.colors import Colormap from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration +_logger = logging.getLogger(__name__) + + class ArrayCurvePlot(qt.QWidget): """ Widget for plotting a curve from a multi-dimensional signal array @@ -72,21 +78,16 @@ class ArrayCurvePlot(qt.QWidget): self._plot = Plot1D(self) - self.selectorDock = qt.QDockWidget("Data selector", self._plot) - # not closable - self.selectorDock.setFeatures(qt.QDockWidget.DockWidgetMovable | - qt.QDockWidget.DockWidgetFloatable) - self._selector = NumpyAxesSelector(self.selectorDock) + self._selector = NumpyAxesSelector(self) self._selector.setNamedAxesSelectorVisibility(False) self.__selector_is_connected = False - self.selectorDock.setWidget(self._selector) - self._plot.addTabbedDockWidget(self.selectorDock) self._plot.sigActiveCurveChanged.connect(self._setYLabelFromActiveLegend) - layout = qt.QGridLayout() + layout = qt.QVBoxLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self._plot, 0, 0) + layout.addWidget(self._plot) + layout.addWidget(self._selector) self.setLayout(layout) @@ -130,9 +131,9 @@ class ArrayCurvePlot(qt.QWidget): self._selector.setAxisNames(["Y"]) if len(ys[0].shape) < 2: - self.selectorDock.hide() + self._selector.hide() else: - self.selectorDock.show() + self._selector.show() self._plot.setGraphTitle(title or "") self._updateCurve() @@ -182,6 +183,9 @@ class ArrayCurvePlot(qt.QWidget): break def clear(self): + old = self._selector.blockSignals(True) + self._selector.clear() + self._selector.blockSignals(old) self._plot.clear() @@ -339,11 +343,8 @@ class ArrayImagePlot(qt.QWidget): normalization=Colormap.LINEAR)) self._plot.getIntensityHistogramAction().setVisible(True) - self.selectorDock = qt.QDockWidget("Data selector", self._plot) # not closable - self.selectorDock.setFeatures(qt.QDockWidget.DockWidgetMovable | - qt.QDockWidget.DockWidgetFloatable) - self._selector = NumpyAxesSelector(self.selectorDock) + self._selector = NumpyAxesSelector(self) self._selector.setNamedAxesSelectorVisibility(False) self._selector.selectionChanged.connect(self._updateImage) @@ -355,9 +356,8 @@ class ArrayImagePlot(qt.QWidget): layout = qt.QVBoxLayout() layout.addWidget(self._plot) + layout.addWidget(self._selector) layout.addWidget(self._auxSigSlider) - self.selectorDock.setWidget(self._selector) - self._plot.addTabbedDockWidget(self.selectorDock) self.setLayout(layout) @@ -413,9 +413,9 @@ class ArrayImagePlot(qt.QWidget): self._selector.setData(signals[0]) if len(signals[0].shape) <= img_ndim: - self.selectorDock.hide() + self._selector.hide() else: - self.selectorDock.show() + self._selector.show() self._auxSigSlider.setMaximum(len(signals) - 1) if len(signals) > 1: @@ -425,6 +425,7 @@ class ArrayImagePlot(qt.QWidget): self._auxSigSlider.setValue(0) self._updateImage() + self._plot.resetZoom() self._selector.selectionChanged.connect(self._updateImage) self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged) @@ -492,12 +493,202 @@ class ArrayImagePlot(qt.QWidget): self._plot.setGraphTitle(title) self._plot.getXAxis().setLabel(self.__x_axis_name) self._plot.getYAxis().setLabel(self.__y_axis_name) - self._plot.resetZoom() def clear(self): + old = self._selector.blockSignals(True) + self._selector.clear() + self._selector.blockSignals(old) self._plot.clear() +class ArrayComplexImagePlot(qt.QWidget): + """ + Widget for plotting an image of complex from a multi-dimensional signal array + and two 1D axes array. + + The signal array can have an arbitrary number of dimensions, the only + limitation being that the last two dimensions must have the same length as + the axes arrays. + + Sliders are provided to select indices on the first (n - 2) dimensions of + the signal array, and the plot is updated to show the image corresponding + to the selection. + + If one or both of the axes does not have regularly spaced values, the + the image is plotted as a coloured scatter plot. + """ + def __init__(self, parent=None, colormap=None): + """ + + :param parent: Parent QWidget + """ + super(ArrayComplexImagePlot, self).__init__(parent) + + self.__signals = None + self.__signals_names = None + self.__x_axis = None + self.__x_axis_name = None + self.__y_axis = None + self.__y_axis_name = None + + self._plot = ComplexImageView(self) + if colormap is not None: + for mode in (ComplexImageView.Mode.ABSOLUTE, + ComplexImageView.Mode.SQUARE_AMPLITUDE, + ComplexImageView.Mode.REAL, + ComplexImageView.Mode.IMAGINARY): + self._plot.setColormap(colormap, mode) + + self._plot.getPlot().getIntensityHistogramAction().setVisible(True) + self._plot.setKeepDataAspectRatio(True) + + # not closable + self._selector = NumpyAxesSelector(self) + self._selector.setNamedAxesSelectorVisibility(False) + self._selector.selectionChanged.connect(self._updateImage) + + self._auxSigSlider = HorizontalSliderWithBrowser(parent=self) + self._auxSigSlider.setMinimum(0) + self._auxSigSlider.setValue(0) + self._auxSigSlider.valueChanged[int].connect(self._sliderIdxChanged) + self._auxSigSlider.setToolTip("Select auxiliary signals") + + layout = qt.QVBoxLayout() + layout.addWidget(self._plot) + layout.addWidget(self._selector) + layout.addWidget(self._auxSigSlider) + + self.setLayout(layout) + + def _sliderIdxChanged(self, value): + self._updateImage() + + def getPlot(self): + """Returns the plot used for the display + + :rtype: PlotWidget + """ + return self._plot.getPlot() + + def setImageData(self, signals, + x_axis=None, y_axis=None, + signals_names=None, + xlabel=None, ylabel=None, + title=None): + """ + + :param signals: list of n-D datasets, whose last 2 dimensions are used as the + image's values, or list of 3D datasets interpreted as RGBA image. + :param x_axis: 1-D dataset used as the image's x coordinates. If + provided, its lengths must be equal to the length of the last + dimension of ``signal``. + :param y_axis: 1-D dataset used as the image's y. If provided, + its lengths must be equal to the length of the 2nd to last + dimension of ``signal``. + :param signals_names: Names for each image, used as subtitle and legend. + :param xlabel: Label for X axis + :param ylabel: Label for Y axis + :param title: Graph title + """ + self._selector.selectionChanged.disconnect(self._updateImage) + self._auxSigSlider.valueChanged.disconnect(self._sliderIdxChanged) + + self.__signals = signals + self.__signals_names = signals_names + self.__x_axis = x_axis + self.__x_axis_name = xlabel + self.__y_axis = y_axis + self.__y_axis_name = ylabel + self.__title = title + + self._selector.clear() + self._selector.setAxisNames(["Y", "X"]) + self._selector.setData(signals[0]) + + if len(signals[0].shape) <= 2: + self._selector.hide() + else: + self._selector.show() + + self._auxSigSlider.setMaximum(len(signals) - 1) + if len(signals) > 1: + self._auxSigSlider.show() + else: + self._auxSigSlider.hide() + self._auxSigSlider.setValue(0) + + self._updateImage() + self._plot.getPlot().resetZoom() + + self._selector.selectionChanged.connect(self._updateImage) + self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged) + + def _updateImage(self): + selection = self._selector.selection() + auxSigIdx = self._auxSigSlider.value() + + images = [img[selection] for img in self.__signals] + image = images[auxSigIdx] + + x_axis = self.__x_axis + y_axis = self.__y_axis + + if x_axis is None and y_axis is None: + xcalib = NoCalibration() + ycalib = NoCalibration() + else: + if x_axis is None: + # no calibration + x_axis = numpy.arange(image.shape[1]) + elif numpy.isscalar(x_axis) or len(x_axis) == 1: + # constant axis + x_axis = x_axis * numpy.ones((image.shape[1], )) + elif len(x_axis) == 2: + # linear calibration + x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1] + + if y_axis is None: + y_axis = numpy.arange(image.shape[0]) + elif numpy.isscalar(y_axis) or len(y_axis) == 1: + y_axis = y_axis * numpy.ones((image.shape[0], )) + elif len(y_axis) == 2: + y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1] + + xcalib = ArrayCalibration(x_axis) + ycalib = ArrayCalibration(y_axis) + + self._plot.setData(image) + if xcalib.is_affine(): + xorigin, xscale = xcalib(0), xcalib.get_slope() + else: + _logger.warning("Unsupported complex image X axis calibration") + xorigin, xscale = 0., 1. + + if ycalib.is_affine(): + yorigin, yscale = ycalib(0), ycalib.get_slope() + else: + _logger.warning("Unsupported complex image Y axis calibration") + yorigin, yscale = 0., 1. + + self._plot.setOrigin((xorigin, yorigin)) + self._plot.setScale((xscale, yscale)) + + title = "" + if self.__title: + title += self.__title + if not title.strip().endswith(self.__signals_names[auxSigIdx]): + title += "\n" + self.__signals_names[auxSigIdx] + self._plot.setGraphTitle(title) + self._plot.getXAxis().setLabel(self.__x_axis_name) + self._plot.getYAxis().setLabel(self.__y_axis_name) + + def clear(self): + old = self._selector.blockSignals(True) + self._selector.clear() + self._selector.blockSignals(old) + self._plot.setData(None) + + class ArrayStackPlot(qt.QWidget): """ Widget for plotting a n-D array (n >= 3) as a stack of images. @@ -665,4 +856,208 @@ class ArrayStackPlot(qt.QWidget): self.__x_axis_name]) def clear(self): + old = self._selector.blockSignals(True) + self._selector.clear() + self._selector.blockSignals(old) self._stack_view.clear() + + +class ArrayVolumePlot(qt.QWidget): + """ + Widget for plotting a n-D array (n >= 3) as a 3D scalar field. + Three axis arrays can be provided to calibrate the axes. + + The signal array can have an arbitrary number of dimensions, the only + limitation being that the last 3 dimensions must have the same length as + the axes arrays. + + Sliders are provided to select indices on the first (n - 3) dimensions of + the signal array, and the plot is updated to load the stack corresponding + to the selection. + """ + def __init__(self, parent=None): + """ + + :param parent: Parent QWidget + """ + super(ArrayVolumePlot, self).__init__(parent) + + self.__signal = None + self.__signal_name = None + # the Z, Y, X axes apply to the last three dimensions of the signal + # (in that order) + self.__z_axis = None + self.__z_axis_name = None + self.__y_axis = None + self.__y_axis_name = None + self.__x_axis = None + self.__x_axis_name = None + + from silx.gui.plot3d.ScalarFieldView import ScalarFieldView + from silx.gui.plot3d import SFViewParamTree + + self._view = ScalarFieldView(self) + + def computeIsolevel(data): + data = data[numpy.isfinite(data)] + if len(data) == 0: + return 0 + else: + return numpy.mean(data) + numpy.std(data) + + self._view.addIsosurface(computeIsolevel, '#FF0000FF') + + # Create a parameter tree for the scalar field view + options = SFViewParamTree.TreeView(self._view) + options.setSfView(self._view) + + # Add the parameter tree to the main window in a dock widget + dock = qt.QDockWidget() + dock.setWidget(options) + self._view.addDockWidget(qt.Qt.RightDockWidgetArea, dock) + + self._hline = qt.QFrame(self) + self._hline.setFrameStyle(qt.QFrame.HLine) + self._hline.setFrameShadow(qt.QFrame.Sunken) + self._legend = qt.QLabel(self) + self._selector = NumpyAxesSelector(self) + self._selector.setNamedAxesSelectorVisibility(False) + self.__selector_is_connected = False + + layout = qt.QVBoxLayout() + layout.addWidget(self._view) + layout.addWidget(self._hline) + layout.addWidget(self._legend) + layout.addWidget(self._selector) + + self.setLayout(layout) + + def getVolumeView(self): + """Returns the plot used for the display + + :rtype: ScalarFieldView + """ + return self._view + + def normalizeComplexData(self, data): + """ + Converts a complex data array to its amplitude, if necessary. + :param data: the data to normalize + :return: + """ + if hasattr(data, "dtype"): + isComplex = numpy.issubdtype(data.dtype, numpy.complexfloating) + else: + isComplex = isinstance(data, numbers.Complex) + if isComplex: + data = numpy.absolute(data) + return data + + def setData(self, signal, + x_axis=None, y_axis=None, z_axis=None, + signal_name=None, + xlabel=None, ylabel=None, zlabel=None, + title=None): + """ + + :param signal: n-D dataset, whose last 3 dimensions are used as the + 3D stack values. + :param x_axis: 1-D dataset used as the image's x coordinates. If + provided, its lengths must be equal to the length of the last + dimension of ``signal``. + :param y_axis: 1-D dataset used as the image's y. If provided, + its lengths must be equal to the length of the 2nd to last + dimension of ``signal``. + :param z_axis: 1-D dataset used as the image's z. If provided, + its lengths must be equal to the length of the 3rd to last + dimension of ``signal``. + :param signal_name: Label used in the legend + :param xlabel: Label for X axis + :param ylabel: Label for Y axis + :param zlabel: Label for Z axis + :param title: Graph title + """ + signal = self.normalizeComplexData(signal) + if self.__selector_is_connected: + self._selector.selectionChanged.disconnect(self._updateVolume) + self.__selector_is_connected = False + + self.__signal = signal + self.__signal_name = signal_name or "" + self.__x_axis = x_axis + self.__x_axis_name = xlabel + self.__y_axis = y_axis + self.__y_axis_name = ylabel + self.__z_axis = z_axis + self.__z_axis_name = zlabel + + self._selector.setData(signal) + self._selector.setAxisNames(["Y", "X", "Z"]) + + self._view.setAxesLabels(self.__x_axis_name or 'X', + self.__y_axis_name or 'Y', + self.__z_axis_name or 'Z') + self._updateVolume() + + # the legend label shows the selection slice producing the volume + # (only interesting for ndim > 3) + if signal.ndim > 3: + self._selector.setVisible(True) + self._legend.setVisible(True) + self._hline.setVisible(True) + else: + self._selector.setVisible(False) + self._legend.setVisible(False) + self._hline.setVisible(False) + + if not self.__selector_is_connected: + self._selector.selectionChanged.connect(self._updateVolume) + self.__selector_is_connected = True + + def _updateVolume(self): + """Update displayed stack according to the current axes selector + data.""" + data = self._selector.selectedData() + x_axis = self.__x_axis + y_axis = self.__y_axis + z_axis = self.__z_axis + + offset = [] + scale = [] + for axis in [x_axis, y_axis, z_axis]: + if axis is None: + calibration = NoCalibration() + elif len(axis) == 2: + calibration = LinearCalibration( + y_intercept=axis[0], slope=axis[1]) + else: + calibration = ArrayCalibration(axis) + if not calibration.is_affine(): + _logger.warning("Axis has not linear values, ignored") + offset.append(0.) + scale.append(1.) + else: + offset.append(calibration(0)) + scale.append(calibration.get_slope()) + + legend = self.__signal_name + "[" + for sl in self._selector.selection(): + if sl == slice(None): + legend += ":, " + else: + legend += str(sl) + ", " + legend = legend[:-2] + "]" + self._legend.setText("Displayed data: " + legend) + + self._view.setData(data, copy=False) + self._view.setScale(*scale) + self._view.setTranslation(*offset) + self._view.setAxesLabels(self.__x_axis_name, + self.__y_axis_name, + self.__z_axis_name) + + def clear(self): + old = self._selector.blockSignals(True) + self._selector.clear() + self._selector.blockSignals(old) + self._view.setData(None) diff --git a/silx/gui/data/TextFormatter.py b/silx/gui/data/TextFormatter.py index 1401634..98c37d7 100644 --- a/silx/gui/data/TextFormatter.py +++ b/silx/gui/data/TextFormatter.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,16 +29,15 @@ __authors__ = ["V. Valls"] __license__ = "MIT" __date__ = "24/07/2018" -import numpy +import logging import numbers -from silx.third_party import six + +import numpy +import six + from silx.gui import qt -import logging -try: - import h5py -except ImportError: - h5py = None +import h5py _logger = logging.getLogger(__name__) @@ -322,10 +321,9 @@ class TextFormatter(qt.QObject): if dtype.kind == 'S': return self.__formatCharString(data) elif dtype.kind == 'O': - if h5py is not None: - text = self.__formatH5pyObject(data, dtype) - if text is not None: - return text + text = self.__formatH5pyObject(data, dtype) + if text is not None: + return text try: # Try ascii/utf-8 text = "%s" % data.decode("utf-8") @@ -339,15 +337,14 @@ class TextFormatter(qt.QObject): elif isinstance(data, (numpy.integer)): if dtype is None: dtype = data.dtype - if h5py is not None: - enumType = h5py.check_dtype(enum=dtype) - if enumType is not None: - for key, value in enumType.items(): - if value == data: - result = {} - result["name"] = key - result["value"] = data - return self.__enumFormat % result + enumType = h5py.check_dtype(enum=dtype) + if enumType is not None: + for key, value in enumType.items(): + if value == data: + result = {} + result["name"] = key + result["value"] = data + return self.__enumFormat % result return self.__integerFormat % data elif isinstance(data, (numbers.Integral)): return self.__integerFormat % data @@ -373,21 +370,20 @@ class TextFormatter(qt.QObject): template = self.__floatFormat params = (data.real) return template % params - elif h5py is not None and isinstance(data, h5py.h5r.Reference): + elif isinstance(data, h5py.h5r.Reference): dtype = h5py.special_dtype(ref=h5py.Reference) text = self.__formatH5pyObject(data, dtype) return text - elif h5py is not None and isinstance(data, h5py.h5r.RegionReference): + elif isinstance(data, h5py.h5r.RegionReference): dtype = h5py.special_dtype(ref=h5py.RegionReference) text = self.__formatH5pyObject(data, dtype) return text elif isinstance(data, numpy.object_) or dtype is not None: if dtype is None: dtype = data.dtype - if h5py is not None: - text = self.__formatH5pyObject(data, dtype) - if text is not None: - return text + text = self.__formatH5pyObject(data, dtype) + if text is not None: + return text # That's a numpy object return str(data) return str(data) diff --git a/silx/gui/data/test/test_arraywidget.py b/silx/gui/data/test/test_arraywidget.py index 50ffc84..6bcbbd3 100644 --- a/silx/gui/data/test/test_arraywidget.py +++ b/silx/gui/data/test/test_arraywidget.py @@ -36,10 +36,7 @@ from silx.gui import qt from silx.gui.data import ArrayTableWidget from silx.gui.utils.testutils import TestCaseQt -try: - import h5py -except ImportError: - h5py = None +import h5py class TestArrayWidget(TestCaseQt): @@ -190,7 +187,6 @@ class TestArrayWidget(TestCaseQt): self.assertIs(b0, b1) -@unittest.skipIf(h5py is None, "Could not import h5py") class TestH5pyArrayWidget(TestCaseQt): """Basic test for ArrayTableWidget with a dataset. diff --git a/silx/gui/data/test/test_dataviewer.py b/silx/gui/data/test/test_dataviewer.py index a681f33..dc6fee8 100644 --- a/silx/gui/data/test/test_dataviewer.py +++ b/silx/gui/data/test/test_dataviewer.py @@ -24,7 +24,7 @@ # ###########################################################################*/ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "23/04/2018" +__date__ = "19/02/2019" import os import tempfile @@ -42,10 +42,7 @@ from silx.gui.data.DataViewerFrame import DataViewerFrame from silx.gui.utils.testutils import SignalListener from silx.gui.utils.testutils import TestCaseQt -try: - import h5py -except ImportError: - h5py = None +import h5py class _DataViewMock(DataView): @@ -170,8 +167,6 @@ class AbstractDataViewerTests(TestCaseQt): self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId()) def test_3d_h5_dataset(self): - if h5py is None: - self.skipTest("h5py library is not available") with self.h5_temporary_file() as h5file: dataset = h5file["data"] widget = self.create_widget() @@ -242,8 +237,9 @@ class AbstractDataViewerTests(TestCaseQt): # replace a view that is a child of a composite view widget = self.create_widget() view = _DataViewMock(widget) - widget.replaceView(DataViews.NXDATA_INVALID_MODE, - view) + replaced = widget.replaceView(DataViews.NXDATA_INVALID_MODE, + view) + self.assertTrue(replaced) nxdata_view = widget.getViewFromModeId(DataViews.NXDATA_MODE) self.assertNotIn(DataViews.NXDATA_INVALID_MODE, [v.modeId() for v in nxdata_view.availableViews()]) diff --git a/silx/gui/data/test/test_numpyaxesselector.py b/silx/gui/data/test/test_numpyaxesselector.py index 6b7b58c..df11c1a 100644 --- a/silx/gui/data/test/test_numpyaxesselector.py +++ b/silx/gui/data/test/test_numpyaxesselector.py @@ -37,10 +37,7 @@ from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector from silx.gui.utils.testutils import SignalListener from silx.gui.utils.testutils import TestCaseQt -try: - import h5py -except ImportError: - h5py = None +import h5py class TestNumpyAxesSelector(TestCaseQt): @@ -121,8 +118,6 @@ class TestNumpyAxesSelector(TestCaseQt): os.unlink(tmp_name) def test_h5py_dataset(self): - if h5py is None: - self.skipTest("h5py library is not available") with self.h5_temporary_file() as h5file: dataset = h5file["data"] expectedResult = dataset[0] diff --git a/silx/gui/data/test/test_textformatter.py b/silx/gui/data/test/test_textformatter.py index 850aa00..935344a 100644 --- a/silx/gui/data/test/test_textformatter.py +++ b/silx/gui/data/test/test_textformatter.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,17 +29,15 @@ __date__ = "12/12/2017" import unittest import shutil import tempfile + import numpy +import six from silx.gui.utils.testutils import TestCaseQt from silx.gui.utils.testutils import SignalListener from ..TextFormatter import TextFormatter -from silx.third_party import six -try: - import h5py -except ImportError: - h5py = None +import h5py class TestTextFormatter(TestCaseQt): @@ -108,8 +106,6 @@ class TestTextFormatterWithH5py(TestCaseQt): @classmethod def setUpClass(cls): super(TestTextFormatterWithH5py, cls).setUpClass() - if h5py is None: - raise unittest.SkipTest("h5py is not available") cls.tmpDirectory = tempfile.mkdtemp() cls.h5File = h5py.File("%s/formatter.h5" % cls.tmpDirectory, mode="w") diff --git a/silx/gui/dialog/AbstractDataFileDialog.py b/silx/gui/dialog/AbstractDataFileDialog.py index 40045fe..c660cd7 100644 --- a/silx/gui/dialog/AbstractDataFileDialog.py +++ b/silx/gui/dialog/AbstractDataFileDialog.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,29 +28,36 @@ This module contains an :class:`AbstractDataFileDialog`. __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "05/03/2018" +__date__ = "03/12/2018" import sys import os import logging -import numpy import functools +from distutils.version import LooseVersion + +import numpy +import six + import silx.io.url from silx.gui import qt from silx.gui.hdf5.Hdf5TreeModel import Hdf5TreeModel from . import utils -from silx.third_party import six from .FileTypeComboBox import FileTypeComboBox -try: - import fabio -except ImportError: - fabio = None + +import fabio _logger = logging.getLogger(__name__) +DEFAULT_SIDEBAR_URL = True +"""Set it to false to disable initilializing of the sidebar urls with the +default Qt list. This could allow to disable a behaviour known to segfault on +some version of PyQt.""" + + class _IconProvider(object): FileDialogToParentDir = qt.QStyle.SP_CustomBase + 1 @@ -143,14 +150,22 @@ class _SideBar(qt.QListView): :rtype: List[str] """ urls = [] - if qt.qVersion().startswith("5.") and sys.platform in ["linux", "linux2"]: + version = LooseVersion(qt.qVersion()) + feed_sidebar = True + + if not DEFAULT_SIDEBAR_URL: + _logger.debug("Skip default sidebar URLs (from setted variable)") + feed_sidebar = False + elif version.version[0] == 4 and sys.platform in ["win32"]: + # Avoid locking the GUI 5min in case of use of network driver + _logger.debug("Skip default sidebar URLs (avoid lock when using network drivers)") + feed_sidebar = False + elif version < LooseVersion("5.11.2") and qt.BINDING == "PyQt5" and sys.platform in ["linux", "linux2"]: # Avoid segfault on PyQt5 + gtk _logger.debug("Skip default sidebar URLs (avoid PyQt5 segfault)") - pass - elif qt.qVersion().startswith("4.") and sys.platform in ["win32"]: - # Avoid 5min of locked GUI relative to network driver - _logger.debug("Skip default sidebar URLs (avoid lock when using network drivers)") - else: + feed_sidebar = False + + if feed_sidebar: # Get default shortcut # There is no other way d = qt.QFileDialog(self) @@ -1061,8 +1076,6 @@ class AbstractDataFileDialog(qt.QDialog): def __openFabioFile(self, filename): self.__closeFile() try: - if fabio is None: - raise ImportError("Fabio module is not available") self.__fabio = fabio.open(filename) self.__openedFiles.append(self.__fabio) self.__selectedFile = filename @@ -1108,10 +1121,10 @@ class AbstractDataFileDialog(qt.QDialog): if codec.is_autodetect(): if self.__isSilxHavePriority(filename): openners.append(self.__openSilxFile) - if fabio is not None and self._isFabioFilesSupported(): + if self._isFabioFilesSupported(): openners.append(self.__openFabioFile) else: - if fabio is not None and self._isFabioFilesSupported(): + if self._isFabioFilesSupported(): openners.append(self.__openFabioFile) openners.append(self.__openSilxFile) elif codec.is_silx_codec(): @@ -1159,10 +1172,9 @@ class AbstractDataFileDialog(qt.QDialog): is_fabio_have_priority = not codec.is_silx_codec() and not self.__isSilxHavePriority(path) if is_fabio_decoder or is_fabio_have_priority: # Then it's flat frame container - if fabio is not None: - self.__openFabioFile(path) - if self.__fabio is not None: - selectedData = _FabioData(self.__fabio) + self.__openFabioFile(path) + if self.__fabio is not None: + selectedData = _FabioData(self.__fabio) else: assert(False) diff --git a/silx/gui/dialog/ColormapDialog.py b/silx/gui/dialog/ColormapDialog.py index cbbfa5a..9950ad4 100644 --- a/silx/gui/dialog/ColormapDialog.py +++ b/silx/gui/dialog/ColormapDialog.py @@ -63,9 +63,10 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"] __license__ = "MIT" -__date__ = "23/05/2018" +__date__ = "27/11/2018" +import enum import logging import numpy @@ -73,10 +74,10 @@ import numpy from .. import qt from ..colors import Colormap, preferredColormaps from ..plot import PlotWidget +from ..plot.items.axis import Axis from silx.gui.widgets.FloatEdit import FloatEdit import weakref from silx.math.combo import min_max -from silx.third_party import enum from silx.gui import icons from silx.math.histogram import Histogramnd @@ -154,39 +155,59 @@ class _ColormapNameCombox(qt.QComboBox): qt.QComboBox.__init__(self, parent) self.__initItems() - ORIGINAL_NAME = qt.Qt.UserRole + 1 + LUT_NAME = qt.Qt.UserRole + 1 + LUT_COLORS = qt.Qt.UserRole + 2 def __initItems(self): for colormapName in preferredColormaps(): index = self.count() self.addItem(str.title(colormapName)) - self.setItemIcon(index, self.getIconPreview(colormapName)) - self.setItemData(index, colormapName, role=self.ORIGINAL_NAME) + self.setItemIcon(index, self.getIconPreview(name=colormapName)) + self.setItemData(index, colormapName, role=self.LUT_NAME) - def getIconPreview(self, colormapName): + def getIconPreview(self, name=None, colors=None): """Return an icon preview from a LUT name. This icons are cached into a global structure. - :param str colormapName: str + :param str name: Name of the LUT + :param numpy.ndarray colors: Colors identify the LUT :rtype: qt.QIcon """ - if colormapName not in _colormapIconPreview: - icon = self.createIconPreview(colormapName) - _colormapIconPreview[colormapName] = icon - return _colormapIconPreview[colormapName] - - def createIconPreview(self, colormapName): + if name is not None: + iconKey = name + else: + iconKey = tuple(colors) + icon = _colormapIconPreview.get(iconKey, None) + if icon is None: + icon = self.createIconPreview(name, colors) + _colormapIconPreview[iconKey] = icon + return icon + + def createIconPreview(self, name=None, colors=None): """Create and return an icon preview from a LUT name. This icons are cached into a global structure. - :param str colormapName: Name of the LUT + :param str name: Name of the LUT + :param numpy.ndarray colors: Colors identify the LUT :rtype: qt.QIcon """ - colormap = Colormap(colormapName) + colormap = Colormap(name) size = 32 - lut = colormap.getNColors(size) + if name is not None: + lut = colormap.getNColors(size) + else: + lut = colors + if len(lut) > size: + # Down sample + step = int(len(lut) / size) + lut = lut[::step] + elif len(lut) < size: + # Over sample + indexes = numpy.arange(size) / float(size) * (len(lut) - 1) + indexes = indexes.astype("int") + lut = lut[indexes] if lut is None or len(lut) == 0: return qt.QIcon() @@ -204,18 +225,50 @@ class _ColormapNameCombox(qt.QComboBox): return qt.QIcon(pixmap) def getCurrentName(self): - return self.itemData(self.currentIndex(), self.ORIGINAL_NAME) + return self.itemData(self.currentIndex(), self.LUT_NAME) + + def getCurrentColors(self): + return self.itemData(self.currentIndex(), self.LUT_COLORS) + + def findLutName(self, name): + return self.findData(name, role=self.LUT_NAME) + + def findLutColors(self, lut): + for index in range(self.count()): + if self.itemData(index, role=self.LUT_NAME) is not None: + continue + colors = self.itemData(index, role=self.LUT_COLORS) + if colors is None: + continue + if numpy.array_equal(colors, lut): + return index + return -1 + + def setCurrentLut(self, colormap): + name = colormap.getName() + if name is not None: + self._setCurrentName(name) + else: + lut = colormap.getColormapLUT() + self._setCurrentLut(lut) - def findColormap(self, name): - return self.findData(name, role=self.ORIGINAL_NAME) + def _setCurrentLut(self, lut): + index = self.findLutColors(lut) + if index == -1: + index = self.count() + self.addItem("Custom") + self.setItemIcon(index, self.getIconPreview(colors=lut)) + self.setItemData(index, None, role=self.LUT_NAME) + self.setItemData(index, lut, role=self.LUT_COLORS) + self.setCurrentIndex(index) - def setCurrentName(self, name): - index = self.findColormap(name) + def _setCurrentName(self, name): + index = self.findLutName(name) if index < 0: index = self.count() self.addItem(str.title(name)) - self.setItemIcon(index, self.getIconPreview(name)) - self.setItemData(index, name, role=self.ORIGINAL_NAME) + self.setItemIcon(index, self.getIconPreview(name=name)) + self.setItemData(index, name, role=self.LUT_NAME) self.setCurrentIndex(index) @@ -255,6 +308,7 @@ class ColormapDialog(qt.QDialog): the self.setcolormap is a callback) """ + self.__displayInvalidated = False self._histogramData = None self._minMaxWasEdited = False self._initialRange = None @@ -276,20 +330,19 @@ class ColormapDialog(qt.QDialog): # Colormap row self._comboBoxColormap = _ColormapNameCombox(parent=formWidget) - self._comboBoxColormap.currentIndexChanged[int].connect(self._updateName) + self._comboBoxColormap.currentIndexChanged[int].connect(self._updateLut) formLayout.addRow('Colormap:', self._comboBoxColormap) # Normalization row self._normButtonLinear = qt.QRadioButton('Linear') self._normButtonLinear.setChecked(True) self._normButtonLog = qt.QRadioButton('Log') - self._normButtonLog.toggled.connect(self._activeLogNorm) normButtonGroup = qt.QButtonGroup(self) normButtonGroup.setExclusive(True) normButtonGroup.addButton(self._normButtonLinear) normButtonGroup.addButton(self._normButtonLog) - self._normButtonLinear.toggled[bool].connect(self._updateLinearNorm) + normButtonGroup.buttonClicked[qt.QAbstractButton].connect(self._updateNormalization) normLayout = qt.QHBoxLayout() normLayout.setContentsMargins(0, 0, 0, 0) @@ -388,9 +441,17 @@ class ColormapDialog(qt.QDialog): self.setFixedSize(self.sizeHint()) self._applyColormap() + def _displayLater(self): + self.__displayInvalidated = True + def showEvent(self, event): self.visibleChanged.emit(True) super(ColormapDialog, self).showEvent(event) + if self.isVisible(): + if self.__displayInvalidated: + self._applyColormap() + self._updateDataInPlot() + self.__displayInvalidated = False def closeEvent(self, event): if not self.isModal(): @@ -434,6 +495,54 @@ class ColormapDialog(qt.QDialog): def sizeHint(self): return self.layout().minimumSize() + def _computeView(self, dataMin, dataMax): + """Compute the location of the view according to the bound of the data + + :rtype: Tuple(float, float) + """ + marginRatio = 1.0 / 6.0 + scale = self._plot.getXAxis().getScale() + + if self._dataRange is not None: + if scale == Axis.LOGARITHMIC: + minRange = self._dataRange[1] + else: + minRange = self._dataRange[0] + maxRange = self._dataRange[2] + if minRange is not None: + dataMin = min(dataMin, minRange) + dataMax = max(dataMax, maxRange) + + if self._histogramData is not None: + info = min_max(self._histogramData[1]) + if scale == Axis.LOGARITHMIC: + minHisto = info.min_positive + else: + minHisto = info.minimum + maxHisto = info.maximum + if minHisto is not None: + dataMin = min(dataMin, minHisto) + dataMax = max(dataMax, maxHisto) + + if scale == Axis.LOGARITHMIC: + epsilon = numpy.finfo(numpy.float32).eps + if dataMin == 0: + dataMin = epsilon + if dataMax < dataMin: + dataMax = dataMin + epsilon + marge = marginRatio * abs(numpy.log10(dataMax) - numpy.log10(dataMin)) + viewMin = 10**(numpy.log10(dataMin) - marge) + viewMax = 10**(numpy.log10(dataMax) + marge) + else: # scale == Axis.LINEAR: + marge = marginRatio * abs(dataMax - dataMin) + if marge < 0.0001: + # Smaller that the QLineEdit precision + marge = 0.0001 + viewMin = dataMin - marge + viewMax = dataMax + marge + + return viewMin, viewMax + def _plotUpdate(self, updateMarkers=True): """Update the plot content @@ -454,27 +563,8 @@ class ColormapDialog(qt.QDialog): if minData > maxData: # avoid a full collapse minData, maxData = maxData, minData - minimum = minData - maximum = maxData - - if self._dataRange is not None: - minRange = self._dataRange[0] - maxRange = self._dataRange[2] - minimum = min(minimum, minRange) - maximum = max(maximum, maxRange) - if self._histogramData is not None: - minHisto = self._histogramData[1][0] - maxHisto = self._histogramData[1][-1] - minimum = min(minimum, minHisto) - maximum = max(maximum, maxHisto) - - marge = abs(maximum - minimum) / 6.0 - if marge < 0.0001: - # Smaller that the QLineEdit precision - marge = 0.0001 - - minView, maxView = minimum - marge, maximum + marge + minView, maxView = self._computeView(minData, maxData) if updateMarkers: # Save the state in we are not moving the markers @@ -483,6 +573,9 @@ class ColormapDialog(qt.QDialog): minView = min(minView, self._initialRange[0]) maxView = max(maxView, self._initialRange[1]) + if minView > minData: + # Hide the min range + minData = minView x = [minView, minData, maxData, maxView] y = [0, 0, 1, 1] @@ -493,26 +586,37 @@ class ColormapDialog(qt.QDialog): linestyle='-', resetzoom=False) + scale = self._plot.getXAxis().getScale() + if updateMarkers: - minDraggable = (self._colormap().isEditable() and - not self._minValue.isAutoChecked()) - self._plot.addXMarker( - self._minValue.getFiniteValue(), - legend='Min', - text='Min', - draggable=minDraggable, - color='blue', - constraint=self._plotMinMarkerConstraint) - - maxDraggable = (self._colormap().isEditable() and - not self._maxValue.isAutoChecked()) - self._plot.addXMarker( - self._maxValue.getFiniteValue(), - legend='Max', - text='Max', - draggable=maxDraggable, - color='blue', - constraint=self._plotMaxMarkerConstraint) + posMin = self._minValue.getFiniteValue() + posMax = self._maxValue.getFiniteValue() + + def isDisplayable(pos): + if scale == Axis.LOGARITHMIC: + return pos > 0.0 + return True + + if isDisplayable(posMin): + minDraggable = (self._colormap().isEditable() and + not self._minValue.isAutoChecked()) + self._plot.addXMarker( + posMin, + legend='Min', + text='Min', + draggable=minDraggable, + color='blue', + constraint=self._plotMinMarkerConstraint) + if isDisplayable(posMax): + maxDraggable = (self._colormap().isEditable() and + not self._maxValue.isAutoChecked()) + self._plot.addXMarker( + posMax, + legend='Max', + text='Max', + draggable=maxDraggable, + color='blue', + constraint=self._plotMaxMarkerConstraint) self._plot.resetZoom() @@ -546,7 +650,7 @@ class ColormapDialog(qt.QDialog): """Compute the data range as used by :meth:`setDataRange`. :param data: The data to process - :rtype: Tuple(float, float, float) + :rtype: List[Union[None,float]] """ if data is None or len(data) == 0: return None, None, None @@ -558,8 +662,6 @@ class ColormapDialog(qt.QDialog): if dataRange is not None: min_positive = dataRange.min_positive - if min_positive is None: - min_positive = float('nan') dataRange = dataRange.minimum, min_positive, dataRange.maximum if dataRange is None or len(dataRange) != 3: @@ -571,7 +673,7 @@ class ColormapDialog(qt.QDialog): return dataRange @staticmethod - def computeHistogram(data): + def computeHistogram(data, scale=Axis.LINEAR): """Compute the data histogram as used by :meth:`setHistogram`. :param data: The data to process @@ -588,7 +690,12 @@ class ColormapDialog(qt.QDialog): if len(_data) == 0: return None, None + if scale == Axis.LOGARITHMIC: + _data = numpy.log10(_data) xmin, xmax = min_max(_data, min_positive=False, finite=True) + if xmin is None: + return None, None + nbins = min(256, int(numpy.sqrt(_data.size))) data_range = xmin, xmax @@ -601,7 +708,10 @@ class ColormapDialog(qt.QDialog): _data = _data.ravel().astype(numpy.float32) histogram = Histogramnd(_data, n_bins=nbins, histo_range=data_range) - return histogram.histo, histogram.edges[0] + bins = histogram.edges[0] + if scale == Axis.LOGARITHMIC: + bins = 10**bins + return histogram.histo, bins def _getData(self): if self._data is None: @@ -624,7 +734,10 @@ class ColormapDialog(qt.QDialog): else: self._data = weakref.ref(data, self._dataAboutToFinalize) - self._updateDataInPlot() + if self.isVisible(): + self._updateDataInPlot() + else: + self._displayLater() def _setDataInPlotMode(self, mode): if self._dataInPlotMode == mode: @@ -660,10 +773,15 @@ class ColormapDialog(qt.QDialog): self.setDataRange(*result) elif mode == _DataInPlotMode.HISTOGRAM: # The histogram should be done in a worker thread - result = self.computeHistogram(data) + result = self.computeHistogram(data, scale=self._plot.getXAxis().getScale()) self.setHistogram(*result) self.setDataRange() + def _invalidateHistogram(self): + """Recompute the histogram if it is displayed""" + if self._dataInPlotMode == _DataInPlotMode.HISTOGRAM: + self._updateDataInPlot() + def _colormapAboutToFinalize(self, weakrefColormap): """Callback when the data weakref is about to be finalized.""" if self._colormap is weakrefColormap: @@ -727,9 +845,9 @@ class ColormapDialog(qt.QDialog): """ colormap = self.getColormap() if colormap is not None and self._colormapStoredState is not None: - if self._colormap()._toDict() != self._colormapStoredState: + if colormap != self._colormapStoredState: self._ignoreColormapChange = True - colormap._setFromDict(self._colormapStoredState) + colormap.setFromColormap(self._colormapStoredState) self._ignoreColormapChange = False self._applyColormap() @@ -740,12 +858,18 @@ class ColormapDialog(qt.QDialog): :param float positiveMin: The positive minimum of the data :param float maximum: The maximum of the data """ - if minimum is None or positiveMin is None or maximum is None: + scale = self._plot.getXAxis().getScale() + if scale == Axis.LOGARITHMIC: + dataMin, dataMax = positiveMin, maximum + else: + dataMin, dataMax = minimum, maximum + + if dataMin is None or dataMax is None: self._dataRange = None self._plot.remove(legend='Range', kind='histogram') else: hist = numpy.array([1]) - bin_edges = numpy.array([minimum, maximum]) + bin_edges = numpy.array([dataMin, dataMax]) self._plot.addHistogram(hist, bin_edges, legend="Range", @@ -801,7 +925,7 @@ class ColormapDialog(qt.QDialog): """ colormap = self.getColormap() if colormap is not None: - self._colormapStoredState = colormap._toDict() + self._colormapStoredState = colormap.copy() else: self._colormapStoredState = None @@ -830,8 +954,11 @@ class ColormapDialog(qt.QDialog): self._colormap = colormap self.storeCurrentState() - self._updateResetButton() - self._applyColormap() + if self.isVisible(): + self._applyColormap() + else: + self._updateResetButton() + self._displayLater() def _updateResetButton(self): resetButton = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset) @@ -839,7 +966,7 @@ class ColormapDialog(qt.QDialog): colormap = self.getColormap() if colormap is not None and colormap.isEditable(): # can reset only in the case the colormap changed - rStateEnabled = colormap._toDict() != self._colormapStoredState + rStateEnabled = colormap != self._colormapStoredState resetButton.setEnabled(rStateEnabled) def _applyColormap(self): @@ -856,12 +983,8 @@ class ColormapDialog(qt.QDialog): self._maxValue.setEnabled(False) else: self._ignoreColormapChange = True - - if colormap.getName() is not None: - name = colormap.getName() - self._comboBoxColormap.setCurrentName(name) - self._comboBoxColormap.setEnabled(self._colormap().isEditable()) - + self._comboBoxColormap.setCurrentLut(colormap) + self._comboBoxColormap.setEnabled(colormap.isEditable()) assert colormap.getNormalization() in Colormap.NORMALIZATIONS self._normButtonLinear.setChecked( colormap.getNormalization() == Colormap.LINEAR) @@ -870,12 +993,17 @@ class ColormapDialog(qt.QDialog): vmin = colormap.getVMin() vmax = colormap.getVMax() dataRange = colormap.getColormapRange() - self._normButtonLinear.setEnabled(self._colormap().isEditable()) - self._normButtonLog.setEnabled(self._colormap().isEditable()) + self._normButtonLinear.setEnabled(colormap.isEditable()) + self._normButtonLog.setEnabled(colormap.isEditable()) self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None) self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None) - self._minValue.setEnabled(self._colormap().isEditable()) - self._maxValue.setEnabled(self._colormap().isEditable()) + self._minValue.setEnabled(colormap.isEditable()) + self._maxValue.setEnabled(colormap.isEditable()) + + axis = self._plot.getXAxis() + scale = axis.LINEAR if colormap.getNormalization() == Colormap.LINEAR else axis.LOGARITHMIC + axis.setScale(scale) + self._ignoreColormapChange = False self._plotUpdate() @@ -908,26 +1036,47 @@ class ColormapDialog(qt.QDialog): self._plotUpdate() self._updateResetButton() - def _updateName(self): + def _updateLut(self): if self._ignoreColormapChange is True: return - if self._colormap(): + colormap = self._colormap() + if colormap is not None: self._ignoreColormapChange = True - self._colormap().setName( - self._comboBoxColormap.getCurrentName()) + name = self._comboBoxColormap.getCurrentName() + if name is not None: + colormap.setName(name) + else: + lut = self._comboBoxColormap.getCurrentColors() + colormap.setColormapLUT(lut) self._ignoreColormapChange = False - def _updateLinearNorm(self, isNormLinear): + def _updateNormalization(self, button): if self._ignoreColormapChange is True: return + if not button.isChecked(): + return + + if button is self._normButtonLinear: + norm = Colormap.LINEAR + scale = Axis.LINEAR + elif button is self._normButtonLog: + norm = Colormap.LOGARITHM + scale = Axis.LOGARITHMIC + else: + assert(False) - if self._colormap(): + colormap = self.getColormap() + if colormap is not None: self._ignoreColormapChange = True - norm = Colormap.LINEAR if isNormLinear else Colormap.LOGARITHM - self._colormap().setNormalization(norm) + colormap.setNormalization(norm) + axis = self._plot.getXAxis() + axis.setScale(scale) self._ignoreColormapChange = False + self._invalidateHistogram() + self._updateMinMaxData() + def _minMaxTextEdited(self, text): """Handle _minValue and _maxValue textEdited signal""" self._minMaxWasEdited = True @@ -975,13 +1124,3 @@ class ColormapDialog(qt.QDialog): else: # Use QDialog keyPressEvent super(ColormapDialog, self).keyPressEvent(event) - - def _activeLogNorm(self, isLog): - if self._ignoreColormapChange is True: - return - if self._colormap(): - self._ignoreColormapChange = True - norm = Colormap.LOGARITHM if isLog is True else Colormap.LINEAR - self._colormap().setNormalization(norm) - self._ignoreColormapChange = False - self._updateMinMaxData() diff --git a/silx/gui/dialog/DataFileDialog.py b/silx/gui/dialog/DataFileDialog.py index 7ff1258..d2d76a3 100644 --- a/silx/gui/dialog/DataFileDialog.py +++ b/silx/gui/dialog/DataFileDialog.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -30,16 +30,14 @@ __authors__ = ["V. Valls"] __license__ = "MIT" __date__ = "14/02/2018" +import enum import logging from silx.gui import qt from silx.gui.hdf5.Hdf5Formatter import Hdf5Formatter import silx.io from .AbstractDataFileDialog import AbstractDataFileDialog -from silx.third_party import enum -try: - import fabio -except ImportError: - fabio = None + +import fabio _logger = logging.getLogger(__name__) diff --git a/silx/gui/dialog/FileTypeComboBox.py b/silx/gui/dialog/FileTypeComboBox.py index 07b11cf..92529bc 100644 --- a/silx/gui/dialog/FileTypeComboBox.py +++ b/silx/gui/dialog/FileTypeComboBox.py @@ -28,12 +28,9 @@ This module contains utilitaries used by other dialog modules. __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "06/02/2018" +__date__ = "17/01/2019" -try: - import fabio -except ImportError: - fabio = None +import fabio import silx.io from silx.gui import qt @@ -82,7 +79,7 @@ class FileTypeComboBox(qt.QComboBox): def __initItems(self): self.clear() - if fabio is not None and self.__fabioUrlSupported: + if self.__fabioUrlSupported: self.__insertFabioFormats() self.__insertSilxFormats() self.__insertAllSupported() @@ -138,21 +135,36 @@ class FileTypeComboBox(qt.QComboBox): def __insertFabioFormats(self): formats = fabio.fabioformats.get_classes(reader=True) + from fabio import fabioutils + if hasattr(fabioutils, "COMPRESSED_EXTENSIONS"): + compressedExtensions = fabioutils.COMPRESSED_EXTENSIONS + else: + # Support for fabio < 0.9 + compressedExtensions = set(["gz", "bz2"]) + extensions = [] allExtensions = set([]) + def extensionsIterator(reader): + for extension in reader.DEFAULT_EXTENSIONS: + yield "*.%s" % extension + for compressedExtension in compressedExtensions: + for extension in reader.DEFAULT_EXTENSIONS: + yield "*.%s.%s" % (extension, compressedExtension) + for reader in formats: if not hasattr(reader, "DESCRIPTION"): continue if not hasattr(reader, "DEFAULT_EXTENSIONS"): continue - ext = reader.DEFAULT_EXTENSIONS - ext = ["*.%s" % e for e in ext] + displayext = reader.DEFAULT_EXTENSIONS + displayext = ["*.%s" % e for e in displayext] + ext = list(extensionsIterator(reader)) allExtensions.update(ext) if ext == []: ext = ["*"] - extensions.append((reader.DESCRIPTION, ext, reader.codec_name())) + extensions.append((reader.DESCRIPTION, displayext, ext, reader.codec_name())) extensions = list(sorted(extensions)) allExtensions = list(sorted(list(allExtensions))) @@ -162,13 +174,14 @@ class FileTypeComboBox(qt.QComboBox): self.setItemData(index, Codec(any_fabio=True), role=self.CODEC_ROLE) for e in extensions: + description, displayExt, allExt, _codecName = e index = self.count() if len(e[1]) < 10: - self.addItem("%s%s (%s)" % (self.INDENTATION, e[0], " ".join(e[1]))) + self.addItem("%s%s (%s)" % (self.INDENTATION, description, " ".join(displayExt))) else: - self.addItem(e[0]) - codec = Codec(fabio_codec=e[2]) - self.setItemData(index, e[1], role=self.EXTENSIONS_ROLE) + self.addItem("%s%s" % (self.INDENTATION, description)) + codec = Codec(fabio_codec=_codecName) + self.setItemData(index, allExt, role=self.EXTENSIONS_ROLE) self.setItemData(index, codec, role=self.CODEC_ROLE) def itemExtensions(self, index): diff --git a/silx/gui/dialog/ImageFileDialog.py b/silx/gui/dialog/ImageFileDialog.py index c324071..ef6b472 100644 --- a/silx/gui/dialog/ImageFileDialog.py +++ b/silx/gui/dialog/ImageFileDialog.py @@ -36,10 +36,7 @@ from silx.gui import qt from silx.gui.plot.PlotWidget import PlotWidget from .AbstractDataFileDialog import AbstractDataFileDialog import silx.io -try: - import fabio -except ImportError: - fabio = None +import fabio _logger = logging.getLogger(__name__) diff --git a/silx/gui/dialog/SafeFileSystemModel.py b/silx/gui/dialog/SafeFileSystemModel.py index 198e089..26954e3 100644 --- a/silx/gui/dialog/SafeFileSystemModel.py +++ b/silx/gui/dialog/SafeFileSystemModel.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,8 +34,10 @@ import sys import os.path import logging import weakref + +import six + from silx.gui import qt -from silx.third_party import six from .SafeFileIconProvider import SafeFileIconProvider _logger = logging.getLogger(__name__) diff --git a/silx/gui/dialog/test/test_colormapdialog.py b/silx/gui/dialog/test/test_colormapdialog.py index 6e50193..cbc9de1 100644 --- a/silx/gui/dialog/test/test_colormapdialog.py +++ b/silx/gui/dialog/test/test_colormapdialog.py @@ -26,13 +26,11 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "23/05/2018" +__date__ = "09/11/2018" -import doctest import unittest -from silx.gui.utils.testutils import qWaitForWindowExposedAndActivate from silx.gui import qt from silx.gui.dialog import ColormapDialog from silx.gui.utils.testutils import TestCaseQt @@ -47,23 +45,6 @@ import numpy.random _qapp = qt.QApplication.instance() or qt.QApplication([]) -def _tearDownQt(docTest): - """Tear down to use for test from docstring. - - Checks that dialog widget is displayed - """ - dialogWidget = docTest.globs['dialog'] - qWaitForWindowExposedAndActivate(dialogWidget) - dialogWidget.setAttribute(qt.Qt.WA_DeleteOnClose) - dialogWidget.close() - del dialogWidget - _qapp.processEvents() - - -cmapDocTestSuite = doctest.DocTestSuite(ColormapDialog, tearDown=_tearDownQt) -"""Test suite of tests from the module's docstrings.""" - - class TestColormapDialog(TestCaseQt, ParametricTestCase): """Test the ColormapDialog.""" def setUp(self): @@ -86,10 +67,12 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase): editing the same colormap""" colormapDiag2 = ColormapDialog.ColormapDialog() colormapDiag2.setColormap(self.colormap) + colormapDiag2.show() self.colormapDiag.setColormap(self.colormap) + self.colormapDiag.show() - self.colormapDiag._comboBoxColormap.setCurrentName('red') - self.colormapDiag._normButtonLog.setChecked(True) + self.colormapDiag._comboBoxColormap._setCurrentName('red') + self.colormapDiag._normButtonLog.click() self.assertTrue(self.colormap.getName() == 'red') self.assertTrue(self.colormapDiag.getColormap().getName() == 'red') self.assertTrue(self.colormap.getNormalization() == 'log') @@ -178,6 +161,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase): def testSetColormapIsCorrect(self): """Make sure the interface fir the colormap when set a new colormap""" self.colormap.setName('red') + self.colormapDiag.show() for norm in (Colormap.NORMALIZATIONS): for autoscale in (True, False): if autoscale is True: @@ -211,7 +195,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase): self.colormapDiag.show() del self.colormap self.assertTrue(self.colormapDiag.getColormap() is None) - self.colormapDiag._comboBoxColormap.setCurrentName('blue') + self.colormapDiag._comboBoxColormap._setCurrentName('blue') def testColormapEditedOutside(self): """Make sure the GUI is still up to date if the colormap is modified @@ -274,7 +258,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase): cb = self.colormapDiag._comboBoxColormap self.assertTrue(cb.getCurrentName() == colormapName) cb.setCurrentIndex(0) - index = cb.findColormap(colormapName) + index = cb.findLutName(colormapName) assert index is not 0 # if 0 then the rest of the test has no sense cb.setCurrentIndex(index) self.assertTrue(cb.getCurrentName() == colormapName) @@ -283,6 +267,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase): """Test that the colormapDialog is correctly updated when changing the colormap editable status""" colormap = Colormap(normalization='linear', vmin=1.0, vmax=10.0) + self.colormapDiag.show() self.colormapDiag.setColormap(colormap) for editable in (True, False): with self.subTest(editable=editable): @@ -302,7 +287,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase): # False self.colormapDiag.setModal(False) colormap.setEditable(True) - self.colormapDiag._normButtonLog.setChecked(True) + self.colormapDiag._normButtonLog.click() resetButton = self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset) self.assertTrue(resetButton.isEnabled()) colormap.setEditable(False) @@ -387,7 +372,6 @@ class TestColormapAction(TestCaseQt): def suite(): test_suite = unittest.TestSuite() - test_suite.addTest(cmapDocTestSuite) for testClass in (TestColormapDialog, TestColormapAction): test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( testClass)) diff --git a/silx/gui/dialog/test/test_datafiledialog.py b/silx/gui/dialog/test/test_datafiledialog.py index aff6bc4..06f8961 100644 --- a/silx/gui/dialog/test/test_datafiledialog.py +++ b/silx/gui/dialog/test/test_datafiledialog.py @@ -36,16 +36,8 @@ import shutil import os import io import weakref - -try: - import fabio -except ImportError: - fabio = None -try: - import h5py -except ImportError: - h5py = None - +import fabio +import h5py import silx.io.url from silx.gui import qt from silx.gui.utils import testutils @@ -62,36 +54,33 @@ def setUpModule(): data = numpy.arange(100 * 100) data.shape = 100, 100 - if fabio is not None: - filename = _tmpDirectory + "/singleimage.edf" - image = fabio.edfimage.EdfImage(data=data) - image.write(filename) - - if h5py is not None: - filename = _tmpDirectory + "/data.h5" - f = h5py.File(filename, "w") - f["scalar"] = 10 - f["image"] = data - f["cube"] = [data, data + 1, data + 2] - f["complex_image"] = data * 1j - f["group/image"] = data - f["nxdata/foo"] = 10 - f["nxdata"].attrs["NX_class"] = u"NXdata" - f.close() - - if h5py is not None: - directory = os.path.join(_tmpDirectory, "data") - os.mkdir(directory) - filename = os.path.join(directory, "data.h5") - f = h5py.File(filename, "w") - f["scalar"] = 10 - f["image"] = data - f["cube"] = [data, data + 1, data + 2] - f["complex_image"] = data * 1j - f["group/image"] = data - f["nxdata/foo"] = 10 - f["nxdata"].attrs["NX_class"] = u"NXdata" - f.close() + filename = _tmpDirectory + "/singleimage.edf" + image = fabio.edfimage.EdfImage(data=data) + image.write(filename) + + filename = _tmpDirectory + "/data.h5" + f = h5py.File(filename, "w") + f["scalar"] = 10 + f["image"] = data + f["cube"] = [data, data + 1, data + 2] + f["complex_image"] = data * 1j + f["group/image"] = data + f["nxdata/foo"] = 10 + f["nxdata"].attrs["NX_class"] = u"NXdata" + f.close() + + directory = os.path.join(_tmpDirectory, "data") + os.mkdir(directory) + filename = os.path.join(directory, "data.h5") + f = h5py.File(filename, "w") + f["scalar"] = 10 + f["image"] = data + f["cube"] = [data, data + 1, data + 2] + f["complex_image"] = data * 1j + f["group/image"] = data + f["nxdata/foo"] = 10 + f["nxdata"].attrs["NX_class"] = u"NXdata" + f.close() filename = _tmpDirectory + "/badformat.h5" with io.open(filename, "wb") as f: @@ -185,8 +174,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertEqual(dialog.result(), qt.QDialog.Rejected) def testSelectRoot_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -211,8 +198,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertEqual(dialog.result(), qt.QDialog.Accepted) def testSelectGroup_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -243,8 +228,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertEqual(dialog.result(), qt.QDialog.Accepted) def testSelectDataset_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -275,8 +258,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertEqual(dialog.result(), qt.QDialog.Accepted) def testClickOnBackToParentTool(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -307,8 +288,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(url.text(), _tmpDirectory) def testClickOnBackToRootTool(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -332,8 +311,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): # self.assertFalse(button.isEnabled()) def testClickOnBackToDirectoryTool(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -361,8 +338,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.allowedLeakingWidgets = 1 def testClickOnHistoryTools(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -402,8 +377,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(url.text(), path3) def testSelectImageFromEdf(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -417,8 +390,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.selectedUrl(), url.path()) def testSelectImage(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -433,8 +404,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.selectedUrl(), path) def testSelectScalar(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -449,8 +418,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.selectedUrl(), path) def testSelectGroup(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -467,8 +434,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(uri.data_path(), "/group") def testSelectRoot(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -485,8 +450,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(uri.data_path(), "/") def testSelectH5_Activate(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -533,10 +496,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): return selectable def testFilterExtensions(self): - if h5py is None: - self.skipTest("h5py is missing") - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -558,8 +517,6 @@ class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin): return dialog def testSelectGroup_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -585,8 +542,6 @@ class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin): self.assertFalse(button.isEnabled()) def testSelectDataset_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -632,8 +587,6 @@ class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin): return dialog def testSelectGroup_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -666,8 +619,6 @@ class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin): self.assertRaises(Exception, dialog.selectedData) def testSelectDataset_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -711,8 +662,6 @@ class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin): return dialog def testSelectGroupRefused_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -740,8 +689,6 @@ class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin): self.assertRaises(Exception, dialog.selectedData) def testSelectNXdataAccepted_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] dialog.show() @@ -944,8 +891,6 @@ class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin): self.assertIsNone(dialog._selectedData()) def testBadSubpath(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() self.qWaitForPendingActions(dialog) @@ -965,8 +910,6 @@ class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin): self.assertEqual(url.data_path(), "/group") def testUnsupportedSlicingPath(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() self.qWaitForPendingActions(dialog) dialog.selectUrl(_tmpDirectory + "/data.h5?path=/cube&slice=0") diff --git a/silx/gui/dialog/test/test_imagefiledialog.py b/silx/gui/dialog/test/test_imagefiledialog.py index 66469f3..068dcb9 100644 --- a/silx/gui/dialog/test/test_imagefiledialog.py +++ b/silx/gui/dialog/test/test_imagefiledialog.py @@ -36,16 +36,8 @@ import shutil import os import io import weakref - -try: - import fabio -except ImportError: - fabio = None -try: - import h5py -except ImportError: - h5py = None - +import fabio +import h5py import silx.io.url from silx.gui import qt from silx.gui.utils import testutils @@ -63,42 +55,39 @@ def setUpModule(): data = numpy.arange(100 * 100) data.shape = 100, 100 - if fabio is not None: - filename = _tmpDirectory + "/singleimage.edf" - image = fabio.edfimage.EdfImage(data=data) - image.write(filename) - - filename = _tmpDirectory + "/multiframe.edf" - image = fabio.edfimage.EdfImage(data=data) - image.appendFrame(data=data + 1) - image.appendFrame(data=data + 2) - image.write(filename) - - filename = _tmpDirectory + "/singleimage.msk" - image = fabio.fit2dmaskimage.Fit2dMaskImage(data=data % 2 == 1) - image.write(filename) - - if h5py is not None: - filename = _tmpDirectory + "/data.h5" - f = h5py.File(filename, "w") - f["scalar"] = 10 - f["image"] = data - f["cube"] = [data, data + 1, data + 2] - f["complex_image"] = data * 1j - f["group/image"] = data - f.close() - - if h5py is not None: - directory = os.path.join(_tmpDirectory, "data") - os.mkdir(directory) - filename = os.path.join(directory, "data.h5") - f = h5py.File(filename, "w") - f["scalar"] = 10 - f["image"] = data - f["cube"] = [data, data + 1, data + 2] - f["complex_image"] = data * 1j - f["group/image"] = data - f.close() + filename = _tmpDirectory + "/singleimage.edf" + image = fabio.edfimage.EdfImage(data=data) + image.write(filename) + + filename = _tmpDirectory + "/multiframe.edf" + image = fabio.edfimage.EdfImage(data=data) + image.appendFrame(data=data + 1) + image.appendFrame(data=data + 2) + image.write(filename) + + filename = _tmpDirectory + "/singleimage.msk" + image = fabio.fit2dmaskimage.Fit2dMaskImage(data=data % 2 == 1) + image.write(filename) + + filename = _tmpDirectory + "/data.h5" + f = h5py.File(filename, "w") + f["scalar"] = 10 + f["image"] = data + f["cube"] = [data, data + 1, data + 2] + f["complex_image"] = data * 1j + f["group/image"] = data + f.close() + + directory = os.path.join(_tmpDirectory, "data") + os.mkdir(directory) + filename = os.path.join(directory, "data.h5") + f = h5py.File(filename, "w") + f["scalar"] = 10 + f["image"] = data + f["cube"] = [data, data + 1, data + 2] + f["complex_image"] = data * 1j + f["group/image"] = data + f.close() filename = _tmpDirectory + "/badformat.edf" with io.open(filename, "wb") as f: @@ -192,8 +181,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertEqual(dialog.result(), qt.QDialog.Rejected) def testDisplayAndClickOpen(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -259,8 +246,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertEqual(dialog.viewMode(), qt.QFileDialog.List) def testClickOnBackToParentTool(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -291,8 +276,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(url.text(), _tmpDirectory) def testClickOnBackToRootTool(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -316,8 +299,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): # self.assertFalse(button.isEnabled()) def testClickOnBackToDirectoryTool(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -345,8 +326,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.allowedLeakingWidgets = 1 def testClickOnHistoryTools(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -386,8 +365,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(url.text(), path3) def testSelectImageFromEdf(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -402,8 +379,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.selectedUrl(), path) def testSelectImageFromEdf_Activate(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -426,8 +401,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.selectedUrl(), path) def testSelectFrameFromEdf(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -444,8 +417,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.selectedUrl(), path) def testSelectImageFromMsk(self): - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -460,8 +431,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.selectedUrl(), path) def testSelectImageFromH5(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -476,8 +445,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.selectedUrl(), path) def testSelectH5_Activate(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -498,8 +465,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.selectedUrl(), path) def testSelectFrameFromH5(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.show() self.qWaitForWindowExposed(dialog) @@ -541,10 +506,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): return selectable def testFilterExtensions(self): - if h5py is None: - self.skipTest("h5py is missing") - if fabio is None: - self.skipTest("fabio is missing") dialog = self.createDialog() browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] filters = testutils.findChildren(dialog, qt.QWidget, name="fileTypeCombo")[0] @@ -745,16 +706,12 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin): self.assertSamePath(dialog.directory(), _tmpDirectory) def testBadDataType(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.selectUrl(_tmpDirectory + "/data.h5::/complex_image") self.qWaitForPendingActions(dialog) self.assertIsNone(dialog._selectedData()) def testBadDataShape(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() dialog.selectUrl(_tmpDirectory + "/data.h5::/unknown") self.qWaitForPendingActions(dialog) @@ -773,8 +730,6 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin): self.assertIsNone(dialog._selectedData()) def testBadSubpath(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() self.qWaitForPendingActions(dialog) @@ -794,8 +749,6 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin): self.assertEqual(url.data_path(), "/group") def testBadSlicingPath(self): - if h5py is None: - self.skipTest("h5py is missing") dialog = self.createDialog() self.qWaitForPendingActions(dialog) dialog.selectUrl(_tmpDirectory + "/data.h5::/cube[a;45,-90]") diff --git a/silx/gui/dialog/utils.py b/silx/gui/dialog/utils.py index 1c16b44..e2334f9 100644 --- a/silx/gui/dialog/utils.py +++ b/silx/gui/dialog/utils.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,8 +33,10 @@ __date__ = "25/10/2017" import os import sys import types + +import six + from silx.gui import qt -from silx.third_party import six def samefile(path1, path2): diff --git a/silx/gui/hdf5/Hdf5Formatter.py b/silx/gui/hdf5/Hdf5Formatter.py index 6802142..5754fe8 100644 --- a/silx/gui/hdf5/Hdf5Formatter.py +++ b/silx/gui/hdf5/Hdf5Formatter.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -30,14 +30,12 @@ __license__ = "MIT" __date__ = "06/06/2018" import numpy -from silx.third_party import six +import six + from silx.gui import qt from silx.gui.data.TextFormatter import TextFormatter -try: - import h5py -except ImportError: - h5py = None +import h5py class Hdf5Formatter(qt.QObject): @@ -162,10 +160,9 @@ class Hdf5Formatter(qt.QObject): compound = [self.humanReadableDType(d) for d in compound] return "compound(%s)" % ", ".join(compound) elif numpy.issubdtype(dtype, numpy.integer): - if h5py is not None: - enumType = h5py.check_dtype(enum=dtype) - if enumType is not None: - return "enum" + enumType = h5py.check_dtype(enum=dtype) + if enumType is not None: + return "enum" text = str(dtype.newbyteorder('N')) if numpy.issubdtype(dtype, numpy.floating): diff --git a/silx/gui/hdf5/Hdf5Item.py b/silx/gui/hdf5/Hdf5Item.py index b3c313e..6ea870f 100644 --- a/silx/gui/hdf5/Hdf5Item.py +++ b/silx/gui/hdf5/Hdf5Item.py @@ -25,11 +25,12 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "03/09/2018" +__date__ = "17/01/2019" import logging import collections + from .. import qt from .. import icons from . import _utils @@ -37,7 +38,6 @@ from .Hdf5Node import Hdf5Node import silx.io.utils from silx.gui.data.TextFormatter import TextFormatter from ..hdf5.Hdf5Formatter import Hdf5Formatter -from ...third_party import six _logger = logging.getLogger(__name__) _formatter = TextFormatter() _hdf5Formatter = Hdf5Formatter(textFormatter=_formatter) @@ -217,14 +217,32 @@ class Hdf5Item(Hdf5Node): def _populateChild(self, populateAll=False): if self.isGroupObj(): - for name in self.obj: + keys = [] + try: + for name in self.obj: + keys.append(name) + except Exception: + lib_name = self.obj.__class__.__module__.split(".")[0] + _logger.error("Internal %s error. The file is corrupted.", lib_name) + _logger.debug("Backtrace", exc_info=True) + if keys == []: + # If the file was open in READ_ONLY we still can reach something + # https://github.com/silx-kit/silx/issues/2262 + try: + for name in self.obj: + keys.append(name) + except Exception: + lib_name = self.obj.__class__.__module__.split(".")[0] + _logger.error("Internal %s error (second time). The file is corrupted.", lib_name) + _logger.debug("Backtrace", exc_info=True) + for name in keys: try: class_ = self.obj.get(name, getclass=True) link = self.obj.get(name, getclass=True, getlink=True) link = silx.io.utils.get_h5_class(class_=link) except Exception: lib_name = self.obj.__class__.__module__.split(".")[0] - _logger.warning("Internal %s error", lib_name, exc_info=True) + _logger.error("Internal %s error", lib_name) _logger.debug("Backtrace", exc_info=True) class_ = None try: @@ -344,14 +362,12 @@ class Hdf5Item(Hdf5Node): def nexusClassName(self): """Returns the Nexus class name""" if self.__nx_class is None: - self.__nx_class = self.obj.attrs.get("NX_class", None) - if self.__nx_class is None: - self.__nx_class = "" + obj = self.obj.attrs.get("NX_class", None) + if obj is None: + text = "" else: - if six.PY2: - self.__nx_class = self.__nx_class.decode() - elif not isinstance(self.__nx_class, str): - self.__nx_class = str(self.__nx_class, "UTF-8") + text = self._getFormatter().textFormatter().toString(obj) + self.__nx_class = text.strip('"') return self.__nx_class def dataName(self, role): diff --git a/silx/gui/hdf5/Hdf5TreeModel.py b/silx/gui/hdf5/Hdf5TreeModel.py index 438200b..152f3e5 100644 --- a/silx/gui/hdf5/Hdf5TreeModel.py +++ b/silx/gui/hdf5/Hdf5TreeModel.py @@ -25,7 +25,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "08/10/2018" +__date__ = "12/03/2019" import os @@ -360,9 +360,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel): def mimeTypes(self): types = [] - if self.__fileMoveEnabled: - types.append(_utils.Hdf5NodeMimeData.MIME_TYPE) - if self.__datasetDragEnabled: + if self.__fileMoveEnabled or self.__datasetDragEnabled: types.append(_utils.Hdf5DatasetMimeData.MIME_TYPE) return types @@ -386,7 +384,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel): node = self.nodeFromIndex(indexes[0]) if self.__fileMoveEnabled and node.parent is self.__root: - mimeData = _utils.Hdf5NodeMimeData(node=node) + mimeData = _utils.Hdf5DatasetMimeData(node=node, isRoot=True) elif self.__datasetDragEnabled: mimeData = _utils.Hdf5DatasetMimeData(node=node) else: @@ -413,23 +411,24 @@ class Hdf5TreeModel(qt.QAbstractItemModel): if action == qt.Qt.IgnoreAction: return True - if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5NodeMimeData.MIME_TYPE): - dragNode = mimedata.node() - parentNode = self.nodeFromIndex(parentIndex) - if parentNode is not dragNode.parent: - return False + if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5DatasetMimeData.MIME_TYPE): + if mimedata.isRoot(): + dragNode = mimedata.node() + parentNode = self.nodeFromIndex(parentIndex) + if parentNode is not dragNode.parent: + return False - if row == -1: - # append to the parent - row = parentNode.childCount() - else: - # insert at row - pass + if row == -1: + # append to the parent + row = parentNode.childCount() + else: + # insert at row + pass - dragNodeParent = dragNode.parent - sourceRow = dragNodeParent.indexOfChild(dragNode) - self.moveRow(parentIndex, sourceRow, parentIndex, row) - return True + dragNodeParent = dragNode.parent + sourceRow = dragNodeParent.indexOfChild(dragNode) + self.moveRow(parentIndex, sourceRow, parentIndex, row) + return True if self.__fileDropEnabled and mimedata.hasFormat("text/uri-list"): @@ -571,7 +570,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel): drag-and-drop""" obj = node.obj for f in self.__openedFiles: - if f in obj: + if f is obj: _logger.debug("Close file %s", obj.filename) obj.close() self.__openedFiles.remove(obj) diff --git a/silx/gui/hdf5/NexusSortFilterProxyModel.py b/silx/gui/hdf5/NexusSortFilterProxyModel.py index 216e992..9c3533f 100644 --- a/silx/gui/hdf5/NexusSortFilterProxyModel.py +++ b/silx/gui/hdf5/NexusSortFilterProxyModel.py @@ -25,7 +25,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "24/07/2018" +__date__ = "29/11/2018" import logging @@ -108,6 +108,8 @@ class NexusSortFilterProxyModel(qt.QSortFilterProxyModel): def __isNXnode(self, node): """Returns true if the node is an NX concept""" + if not hasattr(node, "h5Class"): + return False class_ = node.h5Class if class_ is None or class_ != silx.io.utils.H5Type.GROUP: return False diff --git a/silx/gui/hdf5/_utils.py b/silx/gui/hdf5/_utils.py index 6a34933..aaab228 100644 --- a/silx/gui/hdf5/_utils.py +++ b/silx/gui/hdf5/_utils.py @@ -28,12 +28,15 @@ package `silx.gui.hdf5` package. __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "04/05/2018" +__date__ = "17/01/2019" import logging -from .. import qt +import os.path + import silx.io.utils +import silx.io.url +from .. import qt from silx.utils.html import escape _logger = logging.getLogger(__name__) @@ -107,11 +110,22 @@ class Hdf5DatasetMimeData(qt.QMimeData): MIME_TYPE = "application/x-internal-h5py-dataset" - def __init__(self, node=None, dataset=None): + SILX_URI_TYPE = "application/x-silx-uri" + + def __init__(self, node=None, dataset=None, isRoot=False): qt.QMimeData.__init__(self) self.__dataset = dataset self.__node = node + self.__isRoot = isRoot self.setData(self.MIME_TYPE, "".encode(encoding='utf-8')) + if node is not None: + h5Node = H5Node(node) + silxUrl = h5Node.url + self.setText(silxUrl) + self.setData(self.SILX_URI_TYPE, silxUrl.encode(encoding='utf-8')) + + def isRoot(self): + return self.__isRoot def node(self): return self.__node @@ -122,20 +136,6 @@ class Hdf5DatasetMimeData(qt.QMimeData): return self.__dataset -class Hdf5NodeMimeData(qt.QMimeData): - """Mimedata class to identify an internal drag and drop of a Hdf5Node.""" - - MIME_TYPE = "application/x-internal-h5py-node" - - def __init__(self, node=None): - qt.QMimeData.__init__(self) - self.__node = node - self.setData(self.MIME_TYPE, "".encode(encoding='utf-8')) - - def node(self): - return self.__node - - class H5Node(object): """Adapter over an h5py object to provide missing informations from h5py nodes, like internal node path and filename (which are not provided by @@ -419,3 +419,43 @@ class H5Node(object): :rtype: str """ return self.physical_name.split("/")[-1] + + @property + def data_url(self): + """Returns a :class:`silx.io.url.DataUrl` object identify this node in the file + system. + + :rtype: ~silx.io.url.DataUrl + """ + absolute_filename = os.path.abspath(self.local_filename) + return silx.io.url.DataUrl(scheme="silx", + file_path=absolute_filename, + data_path=self.local_name) + + @property + def url(self): + """Returns an URL object identifying this node in the file + system. + + This URL can be used in different ways. + + .. code-block:: python + + # Parsing the URL + import silx.io.url + dataurl = silx.io.url.DataUrl(item.url) + # dataurl provides access to URL fields + + # Open a numpy array + import silx.io + dataset = silx.io.get_data(item.url) + + # Open an hdf5 object (URL targetting a file or a group) + import silx.io + with silx.io.open(item.url) as h5: + ...your stuff... + + :rtype: str + """ + data_url = self.data_url + return data_url.path() diff --git a/silx/gui/hdf5/test/test_hdf5.py b/silx/gui/hdf5/test/test_hdf5.py index 1751a21..f22d4ae 100644 --- a/silx/gui/hdf5/test/test_hdf5.py +++ b/silx/gui/hdf5/test/test_hdf5.py @@ -26,7 +26,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "03/05/2018" +__date__ = "12/03/2019" import time @@ -43,10 +43,7 @@ from silx.gui.utils.testutils import SignalListener from silx.io import commonh5 import weakref -try: - import h5py -except ImportError: - h5py = None +import h5py _tmpDirectory = None @@ -56,14 +53,13 @@ def setUpModule(): global _tmpDirectory _tmpDirectory = tempfile.mkdtemp(prefix=__name__) - if h5py is not None: - filename = _tmpDirectory + "/data.h5" + filename = _tmpDirectory + "/data.h5" - # create h5 data - f = h5py.File(filename, "w") - g = f.create_group("arrays") - g.create_dataset("scalar", data=10) - f.close() + # create h5 data + f = h5py.File(filename, "w") + g = f.create_group("arrays") + g.create_dataset("scalar", data=10) + f.close() def tearDownModule(): @@ -91,8 +87,6 @@ class TestHdf5TreeModel(TestCaseQt): def setUp(self): super(TestHdf5TreeModel, self).setUp() - if h5py is None: - self.skipTest("h5py is not available") def waitForPendingOperations(self, model): for _ in range(10): @@ -127,8 +121,6 @@ class TestHdf5TreeModel(TestCaseQt): model.appendFile(filename) self.assertEqual(model.rowCount(qt.QModelIndex()), 1) # clean up - index = model.index(0, 0, qt.QModelIndex()) - h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) ref = weakref.ref(model) model = None self.qWaitForDestroy(ref) @@ -246,6 +238,37 @@ class TestHdf5TreeModel(TestCaseQt): model.setFileDropEnabled(False) self.assertNotEquals(model.supportedDropActions(), 0) + def testCloseFile(self): + """A file inserted as a filename is open and closed internally.""" + filename = _tmpDirectory + "/data.h5" + model = hdf5.Hdf5TreeModel() + self.assertEqual(model.rowCount(qt.QModelIndex()), 0) + model.insertFile(filename) + self.assertEqual(model.rowCount(qt.QModelIndex()), 1) + index = model.index(0, 0) + h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + model.removeIndex(index) + self.assertEqual(model.rowCount(qt.QModelIndex()), 0) + self.assertFalse(bool(h5File.id.valid), "The HDF5 file was not closed") + + def testNotCloseFile(self): + """A file inserted as an h5py object is not open (then not closed) + internally.""" + filename = _tmpDirectory + "/data.h5" + try: + h5File = h5py.File(filename) + model = hdf5.Hdf5TreeModel() + self.assertEqual(model.rowCount(qt.QModelIndex()), 0) + model.insertH5pyObject(h5File) + self.assertEqual(model.rowCount(qt.QModelIndex()), 1) + index = model.index(0, 0) + h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + model.removeIndex(index) + self.assertEqual(model.rowCount(qt.QModelIndex()), 0) + self.assertTrue(bool(h5File.id.valid), "The HDF5 file was unexpetedly closed") + finally: + h5File.close() + def testDropExternalFile(self): filename = _tmpDirectory + "/data.h5" model = hdf5.Hdf5TreeModel() @@ -571,8 +594,6 @@ class TestH5Node(TestCaseQt): @classmethod def setUpClass(cls): super(TestH5Node, cls).setUpClass() - if h5py is None: - raise unittest.SkipTest("h5py is not available") cls.tmpDirectory = tempfile.mkdtemp() cls.h5Filename = cls.createResource(cls.tmpDirectory) @@ -809,8 +830,6 @@ class TestHdf5TreeView(TestCaseQt): def setUp(self): super(TestHdf5TreeView, self).setUp() - if h5py is None: - self.skipTest("h5py is not available") def testCreate(self): view = hdf5.Hdf5TreeView() diff --git a/silx/gui/icons.py b/silx/gui/icons.py index ef99591..1493b92 100644 --- a/silx/gui/icons.py +++ b/silx/gui/icons.py @@ -29,7 +29,7 @@ Use :func:`getQIcon` to create Qt QIcon from the name identifying an icon. __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "05/10/2018" +__date__ = "07/01/2019" import os @@ -213,11 +213,12 @@ class MultiImageAnimatedIcon(AbstractAnimatedIcon): self.__frames = [] for i in range(100): try: - filename = getQFile("%s/%02d" % (filename, i)) + frame_filename = os.sep.join((filename, ("%02d" %i))) + frame_file = getQFile(frame_filename) except ValueError: break try: - icon = qt.QIcon(filename.fileName()) + icon = qt.QIcon(frame_file.fileName()) except ValueError: break self.__frames.append(icon) @@ -420,4 +421,5 @@ def getQFile(name): qfile = qt.QFile(filename) if qfile.exists(): return qfile + _logger.debug("File '%s' not found.", filename) raise ValueError('Not an icon name: %s' % name) diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py index fd4d34e..9798123 100644 --- a/silx/gui/plot/ColorBar.py +++ b/silx/gui/plot/ColorBar.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -251,10 +251,13 @@ class ColorBarWidget(qt.QWidget): def _defaultColormapChanged(self, event): """Handle plot default colormap changed""" - if (event['event'] == 'defaultColormapChanged' and - self.getPlot().getActiveImage() is None): - # No active image, take default colormap update into account - self._syncWithDefaultColormap() + if event['event'] == 'defaultColormapChanged': + plot = self.getPlot() + if (plot is not None and + plot.getActiveImage() is None and + plot._getActiveItem(kind='scatter') is None): + # No active item, take default colormap update into account + self._syncWithDefaultColormap() def _syncWithDefaultColormap(self, data=None): """Update colorbar according to plot default colormap""" @@ -801,7 +804,7 @@ class _TickBar(qt.QWidget): if self._norm == colors.Colormap.LINEAR: return 1 - (val - self._vmin) / (self._vmax - self._vmin) elif self._norm == colors.Colormap.LOGARITHM: - return 1 - (numpy.log10(val) - numpy.log10(self._vmin)) / (numpy.log10(self._vmax) - numpy.log(self._vmin)) + return 1 - (numpy.log10(val) - numpy.log10(self._vmin)) / (numpy.log10(self._vmax) - numpy.log10(self._vmin)) else: raise ValueError('Norm is not recognized') @@ -864,7 +867,7 @@ class _TickBar(qt.QWidget): def _guessType(self, font): """Try fo find the better format to display the tick's labels - :param QFont font: the font we want want to use durint the painting + :param QFont font: the font we want to use during the painting """ form = self._getStandardFormat() @@ -873,7 +876,7 @@ class _TickBar(qt.QWidget): for tick in self.ticks: width = max(fm.width(form.format(tick)), width) - # if the length of the string are too long we are mooving to scientific + # if the length of the string are too long we are moving to scientific # display if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH: return self._getScientificForm() diff --git a/silx/gui/plot/CompareImages.py b/silx/gui/plot/CompareImages.py index 88b257d..f7c4899 100644 --- a/silx/gui/plot/CompareImages.py +++ b/silx/gui/plot/CompareImages.py @@ -30,6 +30,7 @@ __license__ = "MIT" __date__ = "23/07/2018" +import enum import logging import numpy import weakref @@ -42,7 +43,6 @@ from silx.gui import plot from silx.gui import icons from silx.gui.colors import Colormap from silx.gui.plot import tools -from silx.third_party import enum _logger = logging.getLogger(__name__) diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py index bbcb0a5..2523cde 100644 --- a/silx/gui/plot/ComplexImageView.py +++ b/silx/gui/plot/ComplexImageView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -365,7 +365,7 @@ class ComplexImageView(qt.QWidget): - log10_amplitude_phase: Color-coded phase with log10(amplitude) as alpha. - :rtype: tuple of str + :rtype: List[Mode] """ return tuple(ImageComplexData.Mode) @@ -375,7 +375,12 @@ class ComplexImageView(qt.QWidget): See :meth:`getSupportedVisualizationModes` for the list of supported modes. - :param str mode: The mode to use. + How-to change visualization mode:: + + widget = ComplexImageView() + widget.setVisualizationMode(ComplexImageView.Mode.PHASE) + + :param Mode mode: The mode to use. """ self._plotImage.setVisualizationMode(mode) diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py index 81e684e..b426a23 100644 --- a/silx/gui/plot/CurvesROIWidget.py +++ b/silx/gui/plot/CurvesROIWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,50 +22,43 @@ # THE SOFTWARE. # # ###########################################################################*/ -"""Widget to handle regions of interest (ROI) on curves displayed in a PlotWindow. +""" +Widget to handle regions of interest (:class:`ROI`) on curves displayed in a +:class:`PlotWindow`. This widget is meant to work with :class:`PlotWindow`. - -ROI are defined by : - -- A name (`ROI` column) -- A type. The type is the label of the x axis. - This can be used to apply or not some ROI to a curve and do some post processing. -- The x coordinate of the left limit (`from` column) -- The x coordinate of the right limit (`to` column) -- Raw counts: Sum of the curve's values in the defined Region Of Intereset. - - .. image:: img/rawCounts.png - -- Net counts: Raw counts minus background - - .. image:: img/netCounts.png """ -__authors__ = ["V.A. Sole", "T. Vincent"] +__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"] __license__ = "MIT" -__date__ = "13/11/2017" +__date__ = "13/03/2018" from collections import OrderedDict - import logging import os import sys -import weakref - +import functools import numpy - from silx.io import dictdump from silx.utils import deprecation - +from silx.utils.weakref import WeakMethodProxy from .. import icons, qt +from silx.gui.plot.items.curve import Curve +from silx.math.combo import min_max +import weakref +from silx.gui.widgets.TableWidget import TableWidget _logger = logging.getLogger(__name__) class CurvesROIWidget(qt.QWidget): - """Widget displaying a table of ROI information. + """ + Widget displaying a table of ROI information. + + Implements also the following behavior: + + * if the roiTable has no ROI when showing create the default ICR one :param parent: See :class:`QWidget` :param str name: The title of this widget @@ -73,19 +66,18 @@ class CurvesROIWidget(qt.QWidget): sigROIWidgetSignal = qt.Signal(object) """Signal of ROIs modifications. - - Modification information if given as a dict with an 'event' key - providing the type of events. - - Type of events: - - - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict' - - - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader', - 'rowheader' + Modification information if given as a dict with an 'event' key + providing the type of events. + Type of events: + - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict' + - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader', + 'rowheader' """ sigROISignal = qt.Signal(object) + """Deprecated signal for backward compatibility with silx < 0.7. + Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal` + """ def __init__(self, parent=None, name=None, plot=None): super(CurvesROIWidget, self).__init__(parent) @@ -93,6 +85,8 @@ class CurvesROIWidget(qt.QWidget): self.setWindowTitle(name) assert plot is not None self._plotRef = weakref.ref(plot) + self._showAllMarkers = False + self.currentROI = None layout = qt.QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) @@ -103,13 +97,22 @@ class CurvesROIWidget(qt.QWidget): self.setHeader() layout.addWidget(self.headerLabel) ############## - self.roiTable = ROITable(self) + widgetAllCheckbox = qt.QWidget(parent=self) + self._showAllCheckBox = qt.QCheckBox("show all ROI", + parent=widgetAllCheckbox) + widgetAllCheckbox.setLayout(qt.QHBoxLayout()) + spacer = qt.QWidget(parent=widgetAllCheckbox) + spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed) + widgetAllCheckbox.layout().addWidget(spacer) + widgetAllCheckbox.layout().addWidget(self._showAllCheckBox) + layout.addWidget(widgetAllCheckbox) + ############## + self.roiTable = ROITable(self, plot=plot) rheight = self.roiTable.horizontalHeader().sizeHint().height() self.roiTable.setMinimumHeight(4 * rheight) - self.fillFromROIDict = self.roiTable.fillFromROIDict - self.getROIListAndDict = self.roiTable.getROIListAndDict layout.addWidget(self.roiTable) self._roiFileDir = qt.QDir.home().absolutePath() + self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers) ################# hbox = qt.QWidget(self) @@ -127,7 +130,8 @@ class CurvesROIWidget(qt.QWidget): self.addButton.setToolTip('Remove the selected ROI') self.resetButton = qt.QPushButton(hbox) self.resetButton.setText("Reset") - self.addButton.setToolTip('Clear all created ROIs. We only let the default ROI') + self.addButton.setToolTip('Clear all created ROIs. We only let the ' + 'default ROI') hboxlayout.addWidget(self.addButton) hboxlayout.addWidget(self.delButton) @@ -149,19 +153,22 @@ class CurvesROIWidget(qt.QWidget): layout.addWidget(hbox) + # Signal / Slot connections self.addButton.clicked.connect(self._add) self.delButton.clicked.connect(self._del) self.resetButton.clicked.connect(self._reset) self.loadButton.clicked.connect(self._load) self.saveButton.clicked.connect(self._save) - self.roiTable.sigROITableSignal.connect(self._forward) - self.currentROI = None - self._middleROIMarkerFlag = False + self.roiTable.activeROIChanged.connect(self._emitCurrentROISignal) + self._isConnected = False # True if connected to plot signals self._isInit = False + # expose API + self.getROIListAndDict = self.roiTable.getROIListAndDict + def getPlotWidget(self): """Returns the associated PlotWidget or None @@ -173,10 +180,6 @@ class CurvesROIWidget(qt.QWidget): self._visibilityChangedHandler(visible=True) qt.QWidget.showEvent(self, event) - def hideEvent(self, event): - self._visibilityChangedHandler(visible=False) - qt.QWidget.hideEvent(self, event) - @property def roiFileDir(self): """The directory from which to load/save ROI from/to files.""" @@ -188,135 +191,81 @@ class CurvesROIWidget(qt.QWidget): def roiFileDir(self, roiFileDir): self._roiFileDir = str(roiFileDir) - def setRois(self, roidict, order=None): - """Set the ROIs by providing a dictionary of ROI information. - - The dictionary keys are the ROI names. - Each value is a sub-dictionary of ROI info with the following fields: - - - ``"from"``: x coordinate of the left limit, as a float - - ``"to"``: x coordinate of the right limit, as a float - - ``"type"``: type of ROI, as a string (e.g "channels", "energy") - - - :param roidict: Dictionary of ROIs - :param str order: Field used for ordering the ROIs. - One of "from", "to", "type". - None (default) for no ordering, or same order as specified - in parameter ``roidict`` if provided as an OrderedDict. - """ - if order is None or order.lower() == "none": - roilist = list(roidict.keys()) - else: - assert order in ["from", "to", "type"] - roilist = sorted(roidict.keys(), - key=lambda roi_name: roidict[roi_name].get(order)) - - return self.roiTable.fillFromROIDict(roilist, roidict) + def setRois(self, rois, order=None): + return self.roiTable.setRois(rois, order) def getRois(self, order=None): - """Return the currently defined ROIs, as an ordered dict. + return self.roiTable.getRois(order) - The dictionary keys are the ROI names. - Each value is a sub-dictionary of ROI info with the following fields: + def setMiddleROIMarkerFlag(self, flag=True): + return self.roiTable.setMiddleROIMarkerFlag(flag) - - ``"from"``: x coordinate of the left limit, as a float - - ``"to"``: x coordinate of the right limit, as a float - - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + def _add(self): + """Add button clicked handler""" + def getNextRoiName(): + rois = self.roiTable.getRois(order=None) + roisNames = [] + [roisNames.append(roiName) for roiName in rois] + nrois = len(rois) + if nrois == 0: + return "ICR" + else: + i = 1 + newroi = "newroi %d" % i + while newroi in roisNames: + i += 1 + newroi = "newroi %d" % i + return newroi + roi = ROI(name=getNextRoiName()) - :param order: Field used for ordering the ROIs. - One of "from", "to", "type", "netcounts", "rawcounts". - None (default) to get the same order as displayed in the widget. - :return: Ordered dictionary of ROI information - """ - roilist, roidict = self.roiTable.getROIListAndDict() - if order is None or order.lower() == "none": - ordered_roilist = roilist + if roi.getName() == "ICR": + roi.setType("Default") else: - assert order in ["from", "to", "type", "netcounts", "rawcounts"] - ordered_roilist = sorted(roidict.keys(), - key=lambda roi_name: roidict[roi_name].get(order)) - - return OrderedDict([(name, roidict[name]) for name in ordered_roilist]) + roi.setType(self.getPlotWidget().getXAxis().getLabel()) - def setMiddleROIMarkerFlag(self, flag=True): - """Activate or deactivate middle marker. + xmin, xmax = self.getPlotWidget().getXAxis().getLimits() + fromdata = xmin + 0.25 * (xmax - xmin) + todata = xmin + 0.75 * (xmax - xmin) + if roi.isICR(): + fromdata, dummy0, todata, dummy1 = self._getAllLimits() + roi.setFrom(fromdata) + roi.setTo(todata) - This allows shifting both min and max limits at once, by dragging - a marker located in the middle. - - :param bool flag: True to activate middle ROI marker - """ - if flag: - self._middleROIMarkerFlag = True - else: - self._middleROIMarkerFlag = False + self.roiTable.addRoi(roi) - def _add(self): - """Add button clicked handler""" + # back compatibility pymca roi signals ddict = {} ddict['event'] = "AddROI" - roilist, roidict = self.roiTable.getROIListAndDict() - ddict['roilist'] = roilist - ddict['roidict'] = roidict + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals def _del(self): """Delete button clicked handler""" - row = self.roiTable.currentRow() - if row >= 0: - index = self.roiTable.labels.index('Type') - text = str(self.roiTable.item(row, index).text()) - if text.upper() != 'DEFAULT': - index = self.roiTable.labels.index('ROI') - key = str(self.roiTable.item(row, index).text()) - else: - # This is to prevent deleting ICR ROI, that is - # usually initialized as "Default" type. - return - roilist, roidict = self.roiTable.getROIListAndDict() - row = roilist.index(key) - del roilist[row] - del roidict[key] - if len(roilist) > 0: - currentroi = roilist[0] - else: - currentroi = None - - self.roiTable.fillFromROIDict(roilist=roilist, - roidict=roidict, - currentroi=currentroi) - ddict = {} - ddict['event'] = "DelROI" - ddict['roilist'] = roilist - ddict['roidict'] = roidict - self.sigROIWidgetSignal.emit(ddict) - - def _forward(self, ddict): - """Broadcast events from ROITable signal""" + self.roiTable.deleteActiveRoi() + + # back compatibility pymca roi signals + ddict = {} + ddict['event'] = "DelROI" + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals def _reset(self): """Reset button clicked handler""" + self.roiTable.clear() + self._add() + + # back compatibility pymca roi signals ddict = {} ddict['event'] = "ResetROI" - roilist0, roidict0 = self.roiTable.getROIListAndDict() - index = 0 - for key in roilist0: - if roidict0[key]['type'].upper() == 'DEFAULT': - index = roilist0.index(key) - break - roilist = [] - roidict = {} - if len(roilist0): - roilist.append(roilist0[index]) - roidict[roilist[0]] = {} - roidict[roilist[0]].update(roidict0[roilist[0]]) - self.roiTable.fillFromROIDict(roilist=roilist, roidict=roidict) - ddict['roilist'] = roilist - ddict['roidict'] = roidict + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals def _load(self): """Load button clicked handler""" @@ -334,32 +283,22 @@ class CurvesROIWidget(qt.QWidget): dialog.close() self.roiFileDir = os.path.dirname(outputFile) - self.load(outputFile) + self.roiTable.load(outputFile) + + # back compatibility pymca roi signals + ddict = {} + ddict['event'] = "LoadROI" + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict + self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals def load(self, filename): """Load ROI widget information from a file storing a dict of ROI. :param str filename: The file from which to load ROI """ - rois = dictdump.load(filename) - currentROI = None - if self.roiTable.rowCount(): - item = self.roiTable.item(self.roiTable.currentRow(), 0) - if item is not None: - currentROI = str(item.text()) - - # Remove rawcounts and netcounts from ROIs - for roi in rois['ROI']['roidict'].values(): - roi.pop('rawcounts', None) - roi.pop('netcounts', None) - - self.roiTable.fillFromROIDict(roilist=rois['ROI']['roilist'], - roidict=rois['ROI']['roidict'], - currentroi=currentROI) - - roilist, roidict = self.roiTable.getROIListAndDict() - event = {'event': 'LoadROI', 'roilist': roilist, 'roidict': roidict} - self.sigROIWidgetSignal.emit(event) + self.roiTable.load(filename) def _save(self): """Save button clicked handler""" @@ -396,142 +335,24 @@ class CurvesROIWidget(qt.QWidget): :param str filename: The file to which to save the ROIs """ - roilist, roidict = self.roiTable.getROIListAndDict() - datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}} - dictdump.dump(datadict, filename) + self.roiTable.save(filename) def setHeader(self, text='ROIs'): """Set the header text of this widget""" self.headerLabel.setText("%s<\b>" % text) - def _roiSignal(self, ddict): - """Handle ROI widget signal""" - _logger.debug("CurvesROIWidget._roiSignal %s", str(ddict)) - plot = self.getPlotWidget() - if plot is None: - return - - if ddict['event'] == "AddROI": - xmin, xmax = plot.getXAxis().getLimits() - fromdata = xmin + 0.25 * (xmax - xmin) - todata = xmin + 0.75 * (xmax - xmin) - plot.remove('ROI min', kind='marker') - plot.remove('ROI max', kind='marker') - if self._middleROIMarkerFlag: - plot.remove('ROI middle', kind='marker') - roiList, roiDict = self.roiTable.getROIListAndDict() - nrois = len(roiList) - if nrois == 0: - newroi = "ICR" - fromdata, dummy0, todata, dummy1 = self._getAllLimits() - draggable = False - color = 'black' - else: - # find the next index free for newroi. - for i in range(nrois): - i += 1 - newroi = "newroi %d" % i - if newroi not in roiList: - break - color = 'blue' - draggable = True - plot.addXMarker(fromdata, - legend='ROI min', - text='ROI min', - color=color, - draggable=draggable) - plot.addXMarker(todata, - legend='ROI max', - text='ROI max', - color=color, - draggable=draggable) - if draggable and self._middleROIMarkerFlag: - pos = 0.5 * (fromdata + todata) - plot.addXMarker(pos, - legend='ROI middle', - text="", - color='yellow', - draggable=draggable) - roiList.append(newroi) - roiDict[newroi] = {} - if newroi == "ICR": - roiDict[newroi]['type'] = "Default" - else: - roiDict[newroi]['type'] = plot.getXAxis().getLabel() - roiDict[newroi]['from'] = fromdata - roiDict[newroi]['to'] = todata - self.roiTable.fillFromROIDict(roilist=roiList, - roidict=roiDict, - currentroi=newroi) - self.currentROI = newroi - self.calculateRois() - elif ddict['event'] in ['DelROI', "ResetROI"]: - plot.remove('ROI min', kind='marker') - plot.remove('ROI max', kind='marker') - if self._middleROIMarkerFlag: - plot.remove('ROI middle', kind='marker') - roiList, roiDict = self.roiTable.getROIListAndDict() - roiDictKeys = list(roiDict.keys()) - if len(roiDictKeys): - currentroi = roiDictKeys[0] - else: - # create again the ICR - ddict = {"event": "AddROI"} - return self._roiSignal(ddict) - - self.roiTable.fillFromROIDict(roilist=roiList, - roidict=roiDict, - currentroi=currentroi) - self.currentROI = currentroi - - elif ddict['event'] == 'LoadROI': - self.calculateRois() + @deprecation.deprecated(replacement="calculateRois", + reason="CamelCase convention", + since_version="0.7") + def calculateROIs(self, *args, **kw): + self.calculateRois(*args, **kw) - elif ddict['event'] == 'selectionChanged': - _logger.debug("Selection changed") - self.roilist, self.roidict = self.roiTable.getROIListAndDict() - fromdata = ddict['roi']['from'] - todata = ddict['roi']['to'] - plot.remove('ROI min', kind='marker') - plot.remove('ROI max', kind='marker') - if ddict['key'] == 'ICR': - draggable = False - color = 'black' - else: - draggable = True - color = 'blue' - plot.addXMarker(fromdata, - legend='ROI min', - text='ROI min', - color=color, - draggable=draggable) - plot.addXMarker(todata, - legend='ROI max', - text='ROI max', - color=color, - draggable=draggable) - if draggable and self._middleROIMarkerFlag: - pos = 0.5 * (fromdata + todata) - plot.addXMarker(pos, - legend='ROI middle', - text="", - color='yellow', - draggable=True) - self.currentROI = ddict['key'] - if ddict['colheader'] in ['From', 'To']: - dict0 = {} - dict0['event'] = "SetActiveCurveEvent" - dict0['legend'] = plot.getActiveCurve(just_legend=1) - plot.setActiveCurve(dict0['legend']) - elif ddict['colheader'] == 'Raw Counts': - pass - elif ddict['colheader'] == 'Net Counts': - pass - else: - self._emitCurrentROISignal() + def calculateRois(self, roiList=None, roiDict=None): + """Compute ROI information""" + return self.roiTable.calculateRois() - else: - _logger.debug("Unknown or ignored event %s", ddict['event']) + def showAllMarkers(self, _show=True): + self.roiTable.showAllMarkers(_show) def _getAllLimits(self): """Retrieve the limits based on the curves.""" @@ -565,429 +386,1121 @@ class CurvesROIWidget(qt.QWidget): return xmin, ymin, xmax, ymax - @deprecation.deprecated(replacement="calculateRois", - reason="CamelCase convention") - def calculateROIs(self, *args, **kw): - self.calculateRois(*args, **kw) + def showEvent(self, event): + self._visibilityChangedHandler(visible=True) + qt.QWidget.showEvent(self, event) - def calculateRois(self, roiList=None, roiDict=None): - """Compute ROI information""" - if roiList is None or roiDict is None: - roiList, roiDict = self.roiTable.getROIListAndDict() + def hideEvent(self, event): + self._visibilityChangedHandler(visible=False) + qt.QWidget.hideEvent(self, event) - plot = self.getPlotWidget() - if plot is None: - activeCurve = None - else: - activeCurve = plot.getActiveCurve(just_legend=False) + def _visibilityChangedHandler(self, visible): + """Handle widget's visibility updates. - if activeCurve is None: - xproc = None - yproc = None - self.setHeader() - else: - x = activeCurve.getXData(copy=False) - y = activeCurve.getYData(copy=False) - legend = activeCurve.getLegend() - idx = numpy.argsort(x, kind='mergesort') - xproc = numpy.take(x, idx) - yproc = numpy.take(y, idx) - self.setHeader('ROIs of %s' % legend) - - for key in roiList: - if key == 'ICR': - if xproc is not None: - roiDict[key]['from'] = xproc.min() - roiDict[key]['to'] = xproc.max() - else: - roiDict[key]['from'] = 0 - roiDict[key]['to'] = -1 - fromData = roiDict[key]['from'] - toData = roiDict[key]['to'] - if xproc is not None: - idx = numpy.nonzero((fromData <= xproc) & - (xproc <= toData))[0] - if len(idx): - xw = xproc[idx] - yw = yproc[idx] - rawCounts = yw.sum(dtype=numpy.float) - deltaX = xw[-1] - xw[0] - deltaY = yw[-1] - yw[0] - if deltaX > 0.0: - slope = (deltaY / deltaX) - background = yw[0] + slope * (xw - xw[0]) - netCounts = (rawCounts - - background.sum(dtype=numpy.float)) - else: - netCounts = 0.0 - else: - rawCounts = 0.0 - netCounts = 0.0 - roiDict[key]['rawcounts'] = rawCounts - roiDict[key]['netcounts'] = netCounts - else: - roiDict[key].pop('rawcounts', None) - roiDict[key].pop('netcounts', None) + It is connected to plot signals only when visible. + """ + if visible: + # if no ROI existing yet, add the default one + if self.roiTable.rowCount() is 0: + self._add() + self.calculateRois() - self.roiTable.fillFromROIDict( - roilist=roiList, - roidict=roiDict, - currentroi=self.currentROI if self.currentROI in roiList else None) + def fillFromROIDict(self, *args, **kwargs): + self.roiTable.fillFromROIDict(*args, **kwargs) def _emitCurrentROISignal(self): ddict = {} ddict['event'] = "currentROISignal" - _roiList, roiDict = self.roiTable.getROIListAndDict() - if self.currentROI in roiDict: - ddict['ROI'] = roiDict[self.currentROI] + if self.roiTable.activeRoi is not None: + ddict['ROI'] = self.roiTable.activeRoi.toDict() + ddict['current'] = self.roiTable.activeRoi.getName() else: - self.currentROI = None - ddict['current'] = self.currentROI + ddict['current'] = None self.sigROISignal.emit(ddict) - def _handleROIMarkerEvent(self, ddict): - """Handle plot signals related to marker events.""" - if ddict['event'] == 'markerMoved': + @property + def currentRoi(self): + return self.roiTable.activeRoi - label = ddict['label'] - if label not in ['ROI min', 'ROI max', 'ROI middle']: - return - roiList, roiDict = self.roiTable.getROIListAndDict() - if self.currentROI is None: - return - if self.currentROI not in roiDict: - return +class _FloatItem(qt.QTableWidgetItem): + """ + Simple QTableWidgetItem overloading the < operator to deal with ordering + """ + def __init__(self): + qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type) - plot = self.getPlotWidget() - if plot is None: - return + def __lt__(self, other): + if self.text() in ('', ROITable.INFO_NOT_FOUND): + return False + if other.text() in ('', ROITable.INFO_NOT_FOUND): + return True + return float(self.text()) < float(other.text()) + + +class ROITable(TableWidget): + """Table widget displaying ROI information. + + See :class:`QTableWidget` for constructor arguments. - x = ddict['x'] - - if label == 'ROI min': - roiDict[self.currentROI]['from'] = x - if self._middleROIMarkerFlag: - pos = 0.5 * (roiDict[self.currentROI]['to'] + - roiDict[self.currentROI]['from']) - plot.addXMarker(pos, - legend='ROI middle', - text='', - color='yellow', - draggable=True) - elif label == 'ROI max': - roiDict[self.currentROI]['to'] = x - if self._middleROIMarkerFlag: - pos = 0.5 * (roiDict[self.currentROI]['to'] + - roiDict[self.currentROI]['from']) - plot.addXMarker(pos, - legend='ROI middle', - text='', - color='yellow', - draggable=True) - elif label == 'ROI middle': - delta = x - 0.5 * (roiDict[self.currentROI]['from'] + - roiDict[self.currentROI]['to']) - roiDict[self.currentROI]['from'] += delta - roiDict[self.currentROI]['to'] += delta - plot.addXMarker(roiDict[self.currentROI]['from'], - legend='ROI min', - text='ROI min', - color='blue', - draggable=True) - plot.addXMarker(roiDict[self.currentROI]['to'], - legend='ROI max', - text='ROI max', - color='blue', - draggable=True) + Behavior: listen at the active curve changed only when the widget is + visible. Otherwise won't compute the row and net counts... + """ + + activeROIChanged = qt.Signal() + """Signal emitted when the active roi changed or when the value of the + active roi are changing""" + + COLUMNS_INDEX = OrderedDict([ + ('ID', 0), + ('ROI', 1), + ('Type', 2), + ('From', 3), + ('To', 4), + ('Raw Counts', 5), + ('Net Counts', 6), + ('Raw Area', 7), + ('Net Area', 8), + ]) + + COLUMNS = list(COLUMNS_INDEX.keys()) + + INFO_NOT_FOUND = '????????' + + def __init__(self, parent=None, plot=None, rois=None): + super(ROITable, self).__init__(parent) + self._showAllMarkers = False + self._userIsEditingRoi = False + """bool used to avoid conflict when editing the ROI object""" + self._isConnected = False + self._roiToItems = {} + self._roiDict = {} + """dict of ROI object. Key is ROi id, value is the ROI object""" + self._markersHandler = _RoiMarkerManager() + + """ + Associate for each marker legend used when the `_showAllMarkers` option + is active a roi. + """ + self.setColumnCount(len(self.COLUMNS)) + self.setPlot(plot) + self.__setTooltip() + self.setSortingEnabled(True) + self.itemChanged.connect(self._itemChanged) + + @property + def roidict(self): + return self._getRoiDict() + + @property + def activeRoi(self): + return self._markersHandler._activeRoi + + def _getRoiDict(self): + ddict = {} + for id in self._roiDict: + ddict[self._roiDict[id].getName()] = self._roiDict[id] + return ddict + + def clear(self): + """ + .. note:: clear the interface only. keep the roidict... + """ + self._markersHandler.clear() + self._roiToItems = {} + self._roiDict = {} + + qt.QTableWidget.clear(self) + self.setRowCount(0) + self.setHorizontalHeaderLabels(self.COLUMNS) + header = self.horizontalHeader() + if hasattr(header, 'setSectionResizeMode'): # Qt5 + header.setSectionResizeMode(qt.QHeaderView.ResizeToContents) + else: # Qt4 + header.setResizeMode(qt.QHeaderView.ResizeToContents) + self.sortByColumn(0, qt.Qt.AscendingOrder) + self.hideColumn(self.COLUMNS_INDEX['ID']) + + def setPlot(self, plot): + self.clear() + self.plot = plot + + def __setTooltip(self): + self.horizontalHeaderItem(self.COLUMNS_INDEX['ROI']).setToolTip( + 'Region of interest identifier') + self.horizontalHeaderItem(self.COLUMNS_INDEX['Type']).setToolTip( + 'Type of the ROI') + self.horizontalHeaderItem(self.COLUMNS_INDEX['From']).setToolTip( + 'X-value of the min point') + self.horizontalHeaderItem(self.COLUMNS_INDEX['To']).setToolTip( + 'X-value of the max point') + self.horizontalHeaderItem(self.COLUMNS_INDEX['Raw Counts']).setToolTip( + 'Estimation of the integral between y=0 and the selected curve') + self.horizontalHeaderItem(self.COLUMNS_INDEX['Net Counts']).setToolTip( + 'Estimation of the integral between the segment [maxPt, minPt] ' + 'and the selected curve') + + def setRois(self, rois, order=None): + """Set the ROIs by providing a dictionary of ROI information. + + The dictionary keys are the ROI names. + Each value is a sub-dictionary of ROI info with the following fields: + + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + + + :param roidict: Dictionary of ROIs + :param str order: Field used for ordering the ROIs. + One of "from", "to", "type". + None (default) for no ordering, or same order as specified + in parameter ``roidict`` if provided as an OrderedDict. + """ + assert order in [None, "from", "to", "type"] + self.clear() + + # backward compatibility since 0.10.0 + if isinstance(rois, dict): + for roiName, roi in rois.items(): + roi['name'] = roiName + _roi = ROI._fromDict(roi) + self.addRoi(_roi) + else: + for roi in rois: + assert isinstance(roi, ROI) + self.addRoi(roi) + self._updateMarkers() + + def addRoi(self, roi): + """ + + :param :class:`ROI` roi: roi to add to the table + """ + assert isinstance(roi, ROI) + self._getItem(name='ID', row=None, roi=roi) + self._roiDict[roi.getID()] = roi + self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot)) + self._updateRoiInfo(roi.getID()) + callback = functools.partial(WeakMethodProxy(self._updateRoiInfo), + roi.getID()) + roi.sigChanged.connect(callback) + # set it as the active one + self.setActiveRoi(roi) + + def _getItem(self, name, row, roi): + if row: + item = self.item(row, self.COLUMNS_INDEX[name]) + else: + item = None + if item: + return item + else: + if name == 'ID': + assert roi + if roi.getID() in self._roiToItems: + return self._roiToItems[roi.getID()] + else: + # create a new row + row = self.rowCount() + self.setRowCount(self.rowCount() + 1) + item = qt.QTableWidgetItem(str(roi.getID()), + type=qt.QTableWidgetItem.Type) + self._roiToItems[roi.getID()] = item + elif name == 'ROI': + item = qt.QTableWidgetItem(roi.getName() if roi else '', + type=qt.QTableWidgetItem.Type) + if roi.getName().upper() in ('ICR', 'DEFAULT'): + item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled) + else: + item.setFlags(qt.Qt.ItemIsSelectable | + qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsEditable) + elif name == 'Type': + item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type) + item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)) + elif name in ('To', 'From'): + item = _FloatItem() + if roi.getName().upper() in ('ICR', 'DEFAULT'): + item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled) + else: + item.setFlags(qt.Qt.ItemIsSelectable | + qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsEditable) + elif name in ('Raw Counts', 'Net Counts', 'Raw Area', 'Net Area'): + item = _FloatItem() + item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)) else: - return - self.calculateRois(roiList, roiDict) - self._emitCurrentROISignal() + raise ValueError('item type not recognized') + + self.setItem(row, self.COLUMNS_INDEX[name], item) + return item + + def _itemChanged(self, item): + def getRoi(): + IDItem = self.item(item.row(), self.COLUMNS_INDEX['ID']) + assert IDItem + id = int(IDItem.text()) + assert id in self._roiDict + roi = self._roiDict[id] + return roi + + def signalChanged(roi): + if self.activeRoi and roi.getID() == self.activeRoi.getID(): + self.activeROIChanged.emit() + + self._userIsEditingRoi = True + if item.column() in (self.COLUMNS_INDEX['To'], self.COLUMNS_INDEX['From']): + roi = getRoi() + + if item.text() not in ('', self.INFO_NOT_FOUND): + try: + value = float(item.text()) + except ValueError: + value = 0 + changed = False + if item.column() == self.COLUMNS_INDEX['To']: + if value != roi.getTo(): + roi.setTo(value) + changed = True + else: + assert(item.column() == self.COLUMNS_INDEX['From']) + if value != roi.getFrom(): + roi.setFrom(value) + changed = True + if changed: + self._updateMarker(roi.getName()) + signalChanged(roi) + + if item.column() is self.COLUMNS_INDEX['ROI']: + roi = getRoi() + if roi.getName() != item.text(): + roi.setName(item.text()) + self._markersHandler.getMarkerHandler(roi.getID()).updateTexts() + signalChanged(roi) + + self._userIsEditingRoi = False + + def deleteActiveRoi(self): + """ + remove the current active roi + """ + activeItems = self.selectedItems() + if len(activeItems) is 0: + return + roiToRm = set() + for item in activeItems: + row = item.row() + itemID = self.item(row, self.COLUMNS_INDEX['ID']) + roiToRm.add(self._roiDict[int(itemID.text())]) + [self.removeROI(roi) for roi in roiToRm] + self.setActiveRoi(None) + + def removeROI(self, roi): + """ + remove the requested roi - def _visibilityChangedHandler(self, visible): - """Handle widget's visibility updates. + :param str name: the name of the roi to remove from the table + """ + if roi and roi.getID() in self._roiToItems: + item = self._roiToItems[roi.getID()] + self.removeRow(item.row()) + del self._roiToItems[roi.getID()] - It is connected to plot signals only when visible. + assert roi.getID() in self._roiDict + del self._roiDict[roi.getID()] + self._markersHandler.remove(roi) + + callback = functools.partial(WeakMethodProxy(self._updateRoiInfo), + roi.getID()) + roi.sigChanged.connect(callback) + + def setActiveRoi(self, roi): """ - plot = self.getPlotWidget() + Define the given roi as the active one. - if visible: - if not self._isInit: - # Deferred ROI widget init finalization - self._finalizeInit() - - if not self._isConnected and plot is not None: - plot.sigPlotSignal.connect(self._handleROIMarkerEvent) - plot.sigActiveCurveChanged.connect( - self._activeCurveChanged) - self._isConnected = True + .. warning:: this roi should already be registred / added to the table - self.calculateRois() + :param :class:`ROI` roi: the roi to defined as active + """ + if roi is None: + self.clearSelection() + self._markersHandler.setActiveRoi(None) + self.activeROIChanged.emit() else: - if self._isConnected: - if plot is not None: - plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent) - plot.sigActiveCurveChanged.disconnect( - self._activeCurveChanged) - self._isConnected = False + assert isinstance(roi, ROI) + if roi and roi.getID() in self._roiToItems.keys(): + self.selectRow(self._roiToItems[roi.getID()].row()) + self._markersHandler.setActiveRoi(roi) + self.activeROIChanged.emit() + + def _updateRoiInfo(self, roiID): + if self._userIsEditingRoi is True: + return + if roiID not in self._roiDict: + return + roi = self._roiDict[roiID] + if roi.isICR(): + activeCurve = self.plot.getActiveCurve() + if activeCurve: + xData = activeCurve.getXData() + if len(xData) > 0: + min, max = min_max(xData) + roi.blockSignals(True) + roi.setFrom(min) + roi.setTo(max) + roi.blockSignals(False) + + itemID = self._getItem(name='ID', roi=roi, row=None) + itemName = self._getItem(name='ROI', row=itemID.row(), roi=roi) + itemName.setText(roi.getName()) + + itemType = self._getItem(name='Type', row=itemID.row(), roi=roi) + itemType.setText(roi.getType() or self.INFO_NOT_FOUND) + + itemFrom = self._getItem(name='From', row=itemID.row(), roi=roi) + fromdata = str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND + itemFrom.setText(fromdata) + + itemTo = self._getItem(name='To', row=itemID.row(), roi=roi) + todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND + itemTo.setText(todata) + + rawCounts, netCounts = roi.computeRawAndNetCounts( + curve=self.plot.getActiveCurve(just_legend=False)) + itemRawCounts = self._getItem(name='Raw Counts', row=itemID.row(), + roi=roi) + rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND + itemRawCounts.setText(rawCounts) + + itemNetCounts = self._getItem(name='Net Counts', row=itemID.row(), + roi=roi) + netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND + itemNetCounts.setText(netCounts) + + rawArea, netArea = roi.computeRawAndNetArea( + curve=self.plot.getActiveCurve(just_legend=False)) + itemRawArea = self._getItem(name='Raw Area', row=itemID.row(), + roi=roi) + rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND + itemRawArea.setText(rawArea) + + itemNetArea = self._getItem(name='Net Area', row=itemID.row(), + roi=roi) + netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND + itemNetArea.setText(netArea) + + if self.activeRoi and roi.getID() == self.activeRoi.getID(): + self.activeROIChanged.emit() + + def currentChanged(self, current, previous): + if previous and current.row() != previous.row() and current.row() >= 0: + roiItem = self.item(current.row(), + self.COLUMNS_INDEX['ID']) + + assert roiItem + self.setActiveRoi(self._roiDict[int(roiItem.text())]) + self._markersHandler.updateAllMarkers() + qt.QTableWidget.currentChanged(self, current, previous) + + @deprecation.deprecated(reason="Removed", + replacement="roidict and roidict.values()", + since_version="0.10.0") + def getROIListAndDict(self): + """ - def _activeCurveChanged(self, *args): - """Recompute ROIs when active curve changed.""" - self.calculateRois() + :return: the list of roi objects and the dictionary of roi name to roi + object. + """ + roidict = self._roiDict + return list(roidict.values()), roidict - def _finalizeInit(self): - self._isInit = True - self.sigROIWidgetSignal.connect(self._roiSignal) - # initialize with the ICR if no ROi existing yet - if len(self.getRois()) is 0: - self._roiSignal({'event': "AddROI"}) + def calculateRois(self, roiList=None, roiDict=None): + """ + Update values of all registred rois (raw and net counts in particular) + :param roiList: deprecated parameter + :param roiDict: deprecated parameter + """ + if roiDict: + deprecation.deprecated_warning(name='roiDict', type_='Parameter', + reason='Unused parameter', + since_version="0.10.0") + if roiList: + deprecation.deprecated_warning(name='roiList', type_='Parameter', + reason='Unused parameter', + since_version="0.10.0") + + for roiID in self._roiDict: + self._updateRoiInfo(roiID) + + def _updateMarker(self, roiID): + """Make sure the marker of the given roi name is updated""" + if self._showAllMarkers or (self.activeRoi + and self.activeRoi.getName() == roiID): + self._updateMarkers() + + def _updateMarkers(self): + if self._showAllMarkers is True: + self._markersHandler.updateMarkers() + else: + if not self.activeRoi or not self.plot: + return + assert isinstance(self.activeRoi, ROI) + markerHandler = self._markersHandler.getMarkerHandler(self.activeRoi.getID()) + if markerHandler is not None: + markerHandler.updateMarkers() -class ROITable(qt.QTableWidget): - """Table widget displaying ROI information. + def getRois(self, order): + """ + Return the currently defined ROIs, as an ordered dict. - See :class:`QTableWidget` for constructor arguments. - """ + The dictionary keys are the ROI names. + Each value is a :class:`ROI` object.. - sigROITableSignal = qt.Signal(object) - """Signal of ROI table modifications. - """ + :param order: Field used for ordering the ROIs. + One of "from", "to", "type", "netcounts", "rawcounts". + None (default) to get the same order as displayed in the widget. + :return: Ordered dictionary of ROI information + """ - def __init__(self, *args, **kwargs): - super(ROITable, self).__init__(*args, **kwargs) - self.setRowCount(1) - self.labels = 'ROI', 'Type', 'From', 'To', 'Raw Counts', 'Net Counts' - self.setColumnCount(len(self.labels)) - self.setSortingEnabled(False) + if order is None or order.lower() == "none": + ordered_roilist = list(self._roiDict.values()) + res = OrderedDict([(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist]) + else: + assert order in ["from", "to", "type", "netcounts", "rawcounts"] + ordered_roilist = sorted(self._roiDict.keys(), + key=lambda roi_id: self._roiDict[roi_id].get(order)) + res = OrderedDict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist]) + + return res + + def save(self, filename): + """ + Save current ROIs of the widget as a dict of ROI to a file. + + :param str filename: The file to which to save the ROIs + """ + roilist = [] + roidict = {} + for roiID, roi in self._roiDict.items(): + roilist.append(roi.toDict()) + roidict[roi.getName()] = roi.toDict() + datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}} + dictdump.dump(datadict, filename) - for index, label in enumerate(self.labels): - item = self.horizontalHeaderItem(index) - if item is None: - item = qt.QTableWidgetItem(label, - qt.QTableWidgetItem.Type) - item.setText(label) - self.setHorizontalHeaderItem(index, item) + def load(self, filename): + """ + Load ROI widget information from a file storing a dict of ROI. - self.roidict = {} - self.roilist = [] + :param str filename: The file from which to load ROI + """ + roisDict = dictdump.load(filename) + rois = [] - self.building = False - self.fillFromROIDict(roilist=self.roilist, roidict=self.roidict) + # Remove rawcounts and netcounts from ROIs + for roiDict in roisDict['ROI']['roidict'].values(): + roiDict.pop('rawcounts', None) + roiDict.pop('netcounts', None) + rois.append(ROI._fromDict(roiDict)) - self.cellClicked[(int, int)].connect(self._cellClickedSlot) - self.cellChanged[(int, int)].connect(self._cellChangedSlot) - verticalHeader = self.verticalHeader() - verticalHeader.sectionClicked[int].connect(self._rowChangedSlot) + self.setRois(rois) - self.__setTooltip() + def showAllMarkers(self, _show=True): + """ - def __setTooltip(self): - assert(self.labels[0] == 'ROI') - self.horizontalHeaderItem(0).setToolTip('Region of interest identifier') - assert(self.labels[1] == 'Type') - self.horizontalHeaderItem(1).setToolTip('Type of the ROI') - assert(self.labels[2] == 'From') - self.horizontalHeaderItem(2).setToolTip('X-value of the min point') - assert(self.labels[3] == 'To') - self.horizontalHeaderItem(3).setToolTip('X-value of the max point') - assert(self.labels[4] == 'Raw Counts') - self.horizontalHeaderItem(4).setToolTip('Estimation of the integral \ - between y=0 and the selected curve') - assert(self.labels[5] == 'Net Counts') - self.horizontalHeaderItem(5).setToolTip('Estimation of the integral \ - between the segment [maxPt, minPt] and the selected curve') + :param bool _show: if true show all the markers of all the ROIs + boundaries otherwise will only show the one of + the active ROI. + """ + self._markersHandler.setShowAllMarkers(_show) + + def setMiddleROIMarkerFlag(self, flag=True): + """ + Activate or deactivate middle marker. + + This allows shifting both min and max limits at once, by dragging + a marker located in the middle. + + :param bool flag: True to activate middle ROI marker + """ + self._markersHandler._middleROIMarkerFlag = flag + + def _handleROIMarkerEvent(self, ddict): + """Handle plot signals related to marker events.""" + if ddict['event'] == 'markerMoved': + label = ddict['label'] + roiID = self._markersHandler.getRoiID(markerID=label) + if roiID: + self._markersHandler.changePosition(markerID=label, + x=ddict['x']) + self._updateRoiInfo(roiID) + + def showEvent(self, event): + self._visibilityChangedHandler(visible=True) + qt.QWidget.showEvent(self, event) + + def hideEvent(self, event): + self._visibilityChangedHandler(visible=False) + qt.QWidget.hideEvent(self, event) + + def _visibilityChangedHandler(self, visible): + """Handle widget's visibility updates. + + It is connected to plot signals only when visible. + """ + if visible: + assert self.plot + if self._isConnected is False: + self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent) + self.plot.sigActiveCurveChanged.connect(self._activeCurveChanged) + self._isConnected = True + self.calculateRois() + else: + if self._isConnected: + self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent) + self.plot.sigActiveCurveChanged.disconnect(self._activeCurveChanged) + self._isConnected = False + + def _activeCurveChanged(self, curve): + self.calculateRois() + + def setCountsVisible(self, visible): + """ + Display the columns relative to areas or not + + :param bool visible: True if the columns 'Raw Area' and 'Net Area' + should be visible. + """ + if visible is True: + self.showColumn(self.COLUMNS_INDEX['Raw Counts']) + self.showColumn(self.COLUMNS_INDEX['Net Counts']) + else: + self.hideColumn(self.COLUMNS_INDEX['Raw Counts']) + self.hideColumn(self.COLUMNS_INDEX['Net Counts']) + + def setAreaVisible(self, visible): + """ + Display the columns relative to areas or not + + :param bool visible: True if the columns 'Raw Area' and 'Net Area' + should be visible. + """ + if visible is True: + self.showColumn(self.COLUMNS_INDEX['Raw Area']) + self.showColumn(self.COLUMNS_INDEX['Net Area']) + else: + self.hideColumn(self.COLUMNS_INDEX['Raw Area']) + self.hideColumn(self.COLUMNS_INDEX['Net Area']) def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None): - """Set the ROIs by providing a list of ROI names and a dictionary - of ROI information for each ROI. + """ + This function API is kept for compatibility. + But `setRois` should be preferred. + Set the ROIs by providing a list of ROI names and a dictionary + of ROI information for each ROI. The ROI names must match an existing dictionary key. The name list is used to provide an order for the ROIs. - The dictionary's values are sub-dictionaries containing 3 mandatory fields: - - ``"from"``: x coordinate of the left limit, as a float - - ``"to"``: x coordinate of the right limit, as a float - - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") :param roilist: List of ROI names (keys of roidict) :type roilist: List :param dict roidict: Dict of ROI information :param currentroi: Name of the selected ROI or None (no selection) """ - if roidict is None: - roidict = {} - - self.building = True - line0 = 0 - self.roilist = [] - self.roidict = {} - for key in roilist: - if key in roidict.keys(): - roi = roidict[key] - self.roilist.append(key) - self.roidict[key] = {} - self.roidict[key].update(roi) - line0 = line0 + 1 - nlines = self.rowCount() - if (line0 > nlines): - self.setRowCount(line0) - line = line0 - 1 - self.roidict[key]['line'] = line - ROI = key - roitype = "%s" % roi['type'] - fromdata = "%6g" % (roi['from']) - todata = "%6g" % (roi['to']) - if 'rawcounts' in roi: - rawcounts = "%6g" % (roi['rawcounts']) - else: - rawcounts = " ?????? " - if 'netcounts' in roi: - netcounts = "%6g" % (roi['netcounts']) - else: - netcounts = " ?????? " - fields = [ROI, roitype, fromdata, todata, rawcounts, netcounts] - col = 0 - for field in fields: - key2 = self.item(line, col) - if key2 is None: - key2 = qt.QTableWidgetItem(field, - qt.QTableWidgetItem.Type) - self.setItem(line, col, key2) - else: - key2.setText(field) - if (ROI.upper() == 'ICR') or (ROI.upper() == 'DEFAULT'): - key2.setFlags(qt.Qt.ItemIsSelectable | - qt.Qt.ItemIsEnabled) - else: - if col in [0, 2, 3]: - key2.setFlags(qt.Qt.ItemIsSelectable | - qt.Qt.ItemIsEnabled | - qt.Qt.ItemIsEditable) - else: - key2.setFlags(qt.Qt.ItemIsSelectable | - qt.Qt.ItemIsEnabled) - col = col + 1 - self.setRowCount(line0) - i = 0 - for _label in self.labels: - self.resizeColumnToContents(i) - i = i + 1 - self.sortByColumn(2, qt.Qt.AscendingOrder) - for i in range(len(self.roilist)): - key = str(self.item(i, 0).text()) - self.roilist[i] = key - self.roidict[key]['line'] = i - if len(self.roilist) == 1: - self.selectRow(0) + if roidict is not None: + self.setRois(roidict) else: - if currentroi in self.roidict.keys(): - self.selectRow(self.roidict[currentroi]['line']) - _logger.debug("Qt4 ensureCellVisible to be implemented") - self.building = False + self.setRois(roilist) + if currentroi: + self.setActiveRoi(currentroi) - def getROIListAndDict(self): - """Return the currently defined ROIs, as a 2-tuple - ``(roiList, roiDict)`` - ``roiList`` is a list of ROI names. - ``roiDict`` is a dictionary of ROI info. +_indexNextROI = 0 - The ROI names must match an existing dictionary key. - The name list is used to provide an order for the ROIs. - The dictionary's values are sub-dictionaries containing 3 - fields: +class ROI(qt.QObject): + """The Region Of Interest is defined by: - - ``"from"``: x coordinate of the left limit, as a float - - ``"to"``: x coordinate of the right limit, as a float - - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + - A name + - A type. The type is the label of the x axis. This can be used to apply or + not some ROI to a curve and do some post processing. + - The x coordinate of the left limit (fromdata) + - The x coordinate of the right limit (todata) + :param str: name of the ROI + :param fromdata: left limit of the roi + :param todata: right limit of the roi + :param type: type of the ROI + """ + + sigChanged = qt.Signal() + """Signal emitted when the ROI is edited""" + + def __init__(self, name, fromdata=None, todata=None, type_=None): + qt.QObject.__init__(self) + assert type(name) is str + global _indexNextROI + self._id = _indexNextROI + _indexNextROI += 1 + + self._name = name + self._fromdata = fromdata + self._todata = todata + self._type = type_ or 'Default' - :return: ordered dict as a tuple of (list of ROI names, dict of info) + def getID(self): """ - return self.roilist, self.roidict - def _cellClickedSlot(self, *var, **kw): - # selection changed event, get the current selection - row = self.currentRow() - col = self.currentColumn() - if row >= 0 and row < len(self.roilist): - item = self.item(row, 0) - text = '' if item is None else str(item.text()) - self.roilist[row] = text - self._emitSelectionChangedSignal(row, col) + :return int: the unique ID of the ROI + """ + return self._id - def _rowChangedSlot(self, row): - self._emitSelectionChangedSignal(row, 0) + def setType(self, type_): + """ - def _cellChangedSlot(self, row, col): - _logger.debug("_cellChangedSlot(%d, %d)", row, col) - if self.building: - return - if col == 0: - self.nameSlot(row, col) + :param str type_: + """ + if self._type != type_: + self._type = type_ + self.sigChanged.emit() + + def getType(self): + """ + + :return str: the type of the ROI. + """ + return self._type + + def setName(self, name): + """ + Set the name of the :class:`ROI` + + :param str name: + """ + if self._name != name: + self._name = name + self.sigChanged.emit() + + def getName(self): + """ + + :return str: name of the :class:`ROI` + """ + return self._name + + def setFrom(self, frm): + """ + + :param frm: set x coordinate of the left limit + """ + if self._fromdata != frm: + self._fromdata = frm + self.sigChanged.emit() + + def getFrom(self): + """ + + :return: x coordinate of the left limit + """ + return self._fromdata + + def setTo(self, to): + """ + + :param to: x coordinate of the right limit + """ + if self._todata != to: + self._todata = to + self.sigChanged.emit() + + def getTo(self): + """ + + :return: x coordinate of the right limit + """ + return self._todata + + def getMiddle(self): + """ + + :return: middle position between 'from' and 'to' values + """ + return 0.5 * (self.getFrom() + self.getTo()) + + def toDict(self): + """ + + :return: dict containing the roi parameters + """ + ddict = { + 'type': self._type, + 'name': self._name, + 'from': self._fromdata, + 'to': self._todata, + } + if hasattr(self, '_extraInfo'): + ddict.update(self._extraInfo) + return ddict + + @staticmethod + def _fromDict(dic): + assert 'name' in dic + roi = ROI(name=dic['name']) + roi._extraInfo = {} + for key in dic: + if key == 'from': + roi.setFrom(dic['from']) + elif key == 'to': + roi.setTo(dic['to']) + elif key == 'type': + roi.setType(dic['type']) + else: + roi._extraInfo[key] = dic[key] + + return roi + + def isICR(self): + """ + + :return: True if the ROI is the `ICR` + """ + return self._name == 'ICR' + + def computeRawAndNetCounts(self, curve): + """Compute the Raw and net counts in the ROI for the given curve. + + - Raw count: Points values sum of the curve in the defined Region Of + Interest. + + .. image:: img/rawCounts.png + + - Net count: Raw counts minus background + + .. image:: img/netCounts.png + + :param CurveItem curve: + :return tuple: rawCount, netCount + """ + assert isinstance(curve, Curve) or curve is None + + if curve is None: + return None, None + + x = curve.getXData(copy=False) + y = curve.getYData(copy=False) + + idx = numpy.nonzero((self._fromdata <= x) & + (x <= self._todata))[0] + if len(idx): + xw = x[idx] + yw = y[idx] + rawCounts = yw.sum(dtype=numpy.float) + deltaX = xw[-1] - xw[0] + deltaY = yw[-1] - yw[0] + if deltaX > 0.0: + slope = (deltaY / deltaX) + background = yw[0] + slope * (xw - xw[0]) + netCounts = (rawCounts - + background.sum(dtype=numpy.float)) + else: + netCounts = 0.0 else: - self._valueChanged(row, col) + rawCounts = 0.0 + netCounts = 0.0 + return rawCounts, netCounts + + def computeRawAndNetArea(self, curve): + """Compute the Raw and net counts in the ROI for the given curve. + + - Raw area: integral of the curve between the min ROI point and the + max ROI point to the y = 0 line. - def _valueChanged(self, row, col): - if col not in [2, 3]: + .. image:: img/rawArea.png + + - Net area: Raw counts minus background + + .. image:: img/netArea.png + + :param CurveItem curve: + :return tuple: rawArea, netArea + """ + assert isinstance(curve, Curve) or curve is None + + if curve is None: + return None, None + + x = curve.getXData(copy=False) + y = curve.getYData(copy=False) + + y = y[(x >= self._fromdata) & (x <= self._todata)] + x = x[(x >= self._fromdata) & (x <= self._todata)] + + if x.size is 0: + return 0.0, 0.0 + + rawArea = numpy.trapz(y, x=x) + # to speed up and avoid an intersection calculation we are taking the + # closest index to the ROI + closestXLeftIndex = (numpy.abs(x - self.getFrom())).argmin() + closestXRightIndex = (numpy.abs(x - self.getTo())).argmin() + yBackground = y[closestXLeftIndex], y[closestXRightIndex] + background = numpy.trapz(yBackground, x=x) + netArea = rawArea - background + return rawArea, netArea + + +class _RoiMarkerManager(object): + """ + Deal with all the ROI markers + """ + def __init__(self): + self._roiMarkerHandlers = {} + self._middleROIMarkerFlag = False + self._showAllMarkers = False + self._activeRoi = None + + def setActiveRoi(self, roi): + self._activeRoi = roi + self.updateAllMarkers() + + def setShowAllMarkers(self, show): + if show != self._showAllMarkers: + self._showAllMarkers = show + self.updateAllMarkers() + + def add(self, roi, markersHandler): + assert isinstance(roi, ROI) + assert isinstance(markersHandler, _RoiMarkerHandler) + if roi.getID() in self._roiMarkerHandlers: + raise ValueError('roi with the same ID already existing') + else: + self._roiMarkerHandlers[roi.getID()] = markersHandler + + def getMarkerHandler(self, roiID): + if roiID in self._roiMarkerHandlers: + return self._roiMarkerHandlers[roiID] + else: + return None + + def clear(self): + roisHandler = list(self._roiMarkerHandlers.values()) + for roiHandler in roisHandler: + self.remove(roiHandler.roi) + + def remove(self, roi): + if roi is None: return - item = self.item(row, col) - if item is None: + assert isinstance(roi, ROI) + if roi.getID() in self._roiMarkerHandlers: + self._roiMarkerHandlers[roi.getID()].clear() + del self._roiMarkerHandlers[roi.getID()] + + def hasMarker(self, markerID): + assert type(markerID) is str + return self.getMarker(markerID) is not None + + def changePosition(self, markerID, x): + markerHandler = self.getMarker(markerID) + if markerHandler is None: + raise ValueError('Marker %s not register' % markerID) + markerHandler.changePosition(markerID=markerID, x=x) + + def updateMarker(self, markerID): + markerHandler = self.getMarker(markerID) + if markerHandler is None: + raise ValueError('Marker %s not register' % markerID) + roiID = self.getRoiID(markerID) + visible = (self._activeRoi and self._activeRoi.getID() == roiID) or self._showAllMarkers is True + markerHandler.setVisible(visible) + markerHandler.updateAllMarkers() + + def updateRoiMarkers(self, roiID): + if roiID in self._roiMarkerHandlers: + visible = ((self._activeRoi and self._activeRoi.getID() == roiID) + or self._showAllMarkers is True) + _roi = self._roiMarkerHandlers[roiID]._roi() + if _roi and not _roi.isICR(): + self._roiMarkerHandlers[roiID].showMiddleMarker(self._middleROIMarkerFlag) + self._roiMarkerHandlers[roiID].setVisible(visible) + self._roiMarkerHandlers[roiID].updateMarkers() + + def getMarker(self, markerID): + assert type(markerID) is str + for marker in list(self._roiMarkerHandlers.values()): + if marker.hasMarker(markerID): + return marker + + def updateMarkers(self): + for markerHandler in list(self._roiMarkerHandlers.values()): + markerHandler.updateMarkers() + + def getRoiID(self, markerID): + for roiID, markerHandler in self._roiMarkerHandlers.items(): + if markerHandler.hasMarker(markerID): + return roiID + return None + + def setShowMiddleMarkers(self, show): + self._middleROIMarkerFlag = show + self._roiMarkerHandlers.updateAllMarkers() + + def updateAllMarkers(self): + for roiID in self._roiMarkerHandlers: + self.updateRoiMarkers(roiID) + + def getVisibleRois(self): + res = {} + for roiID, roiHandler in self._roiMarkerHandlers.items(): + markers = (roiHandler.getMarker('min'), roiHandler.getMarker('max'), + roiHandler.getMarker('middle')) + for marker in markers: + if marker.isVisible(): + if roiID not in res: + res[roiID] = [] + res[roiID].append(marker) + return res + + +class _RoiMarkerHandler(object): + """Used to deal with ROI markers used in ROITable""" + def __init__(self, roi, plot): + assert roi and isinstance(roi, ROI) + assert plot + + self._roi = weakref.ref(roi) + self._plot = weakref.ref(plot) + self._draggable = False if roi.isICR() else True + self._color = 'black' if roi.isICR() else 'blue' + self._displayMidMarker = False + self._visible = True + + @property + def draggable(self): + return self._draggable + + @property + def plot(self): + return self._plot() + + def clear(self): + if self.plot and self.roi: + self.plot.removeMarker(self._markerID('min')) + self.plot.removeMarker(self._markerID('max')) + self.plot.removeMarker(self._markerID('middle')) + + @property + def roi(self): + return self._roi() + + def setVisible(self, visible): + if visible != self._visible: + self._visible = visible + self.updateMarkers() + + def showMiddleMarker(self, visible): + if self.draggable is False and visible is True: + _logger.warning("ROI is not draggable. Won't display middle marker") return - text = str(item.text()) - try: - value = float(text) - except: + self._displayMidMarker = visible + self.getMarker('middle').setVisible(self._displayMidMarker) + + def updateMarkers(self): + if self.roi is None: return - if row >= len(self.roilist): - _logger.debug("deleting???") + self._updateMinMarkerPos() + self._updateMaxMarkerPos() + self._updateMiddleMarkerPos() + + def _updateMinMarkerPos(self): + self.getMarker('min').setPosition(x=self.roi.getFrom(), y=None) + self.getMarker('min').setVisible(self._visible) + + def _updateMaxMarkerPos(self): + self.getMarker('max').setPosition(x=self.roi.getTo(), y=None) + self.getMarker('max').setVisible(self._visible) + + def _updateMiddleMarkerPos(self): + self.getMarker('middle').setPosition(x=self.roi.getMiddle(), y=None) + self.getMarker('middle').setVisible(self._displayMidMarker and self._visible) + + def getMarker(self, markerType): + if self.plot is None: + return None + assert markerType in ('min', 'max', 'middle') + if self.plot._getMarker(self._markerID(markerType)) is None: + assert self.roi + if markerType == 'min': + val = self.roi.getFrom() + elif markerType == 'max': + val = self.roi.getTo() + else: + val = self.roi.getMiddle() + + _color = self._color + if markerType == 'middle': + _color = 'yellow' + self.plot.addXMarker(val, + legend=self._markerID(markerType), + text=self.getMarkerName(markerType), + color=_color, + draggable=self.draggable) + return self.plot._getMarker(self._markerID(markerType)) + + def _markerID(self, markerType): + assert markerType in ('min', 'max', 'middle') + assert self.roi + return '_'.join((str(self.roi.getID()), markerType)) + + def getMarkerName(self, markerType): + assert markerType in ('min', 'max', 'middle') + assert self.roi + return ' '.join((self.roi.getName(), markerType)) + + def updateTexts(self): + self.getMarker('min').setText(self.getMarkerName('min')) + self.getMarker('max').setText(self.getMarkerName('max')) + self.getMarker('middle').setText(self.getMarkerName('middle')) + + def changePosition(self, markerID, x): + assert self.hasMarker(markerID) + markerType = self._getMarkerType(markerID) + assert markerType is not None + if self.roi is None: return - item = self.item(row, 0) - if item is None: - text = "" + if markerType == 'min': + self.roi.setFrom(x) + self._updateMiddleMarkerPos() + elif markerType == 'max': + self.roi.setTo(x) + self._updateMiddleMarkerPos() else: - text = str(item.text()) - if not len(text): - return - if col == 2: - self.roidict[text]['from'] = value - elif col == 3: - self.roidict[text]['to'] = value - self._emitSelectionChangedSignal(row, col) - - def nameSlot(self, row, col): - if col != 0: - return - if row >= len(self.roilist): - _logger.debug("deleting???") - return - item = self.item(row, col) - if item is None: - text = "" + delta = x - 0.5 * (self.roi.getFrom() + self.roi.getTo()) + self.roi.setFrom(self.roi.getFrom() + delta) + self.roi.setTo(self.roi.getTo() + delta) + self._updateMinMarkerPos() + self._updateMaxMarkerPos() + + def hasMarker(self, marker): + return marker in (self._markerID('min'), + self._markerID('max'), + self._markerID('middle')) + + def _getMarkerType(self, markerID): + if markerID.endswith('_min'): + return 'min' + elif markerID.endswith('_max'): + return 'max' + elif markerID.endswith('_middle'): + return 'middle' else: - text = str(item.text()) - if len(text) and (text not in self.roilist): - old = self.roilist[row] - self.roilist[row] = text - self.roidict[text] = {} - self.roidict[text].update(self.roidict[old]) - del self.roidict[old] - self._emitSelectionChangedSignal(row, col) - - def _emitSelectionChangedSignal(self, row, col): - ddict = {} - ddict['event'] = "selectionChanged" - ddict['row'] = row - ddict['col'] = col - ddict['roi'] = self.roidict[self.roilist[row]] - ddict['key'] = self.roilist[row] - ddict['colheader'] = self.labels[col] - ddict['rowheader'] = "%d" % row - self.sigROITableSignal.emit(ddict) + return None class CurvesROIDockWidget(qt.QDockWidget): @@ -1007,6 +1520,8 @@ class CurvesROIDockWidget(qt.QDockWidget): def __init__(self, parent=None, plot=None, name=None): super(CurvesROIDockWidget, self).__init__(name, parent) + assert plot is not None + self.plot = plot self.roiWidget = CurvesROIWidget(self, name, plot=plot) """Main widget of type :class:`CurvesROIWidget`""" @@ -1016,12 +1531,15 @@ class CurvesROIDockWidget(qt.QDockWidget): self.calculateROIs = self.calculateRois = self.roiWidget.calculateRois self.setRois = self.roiWidget.setRois self.getRois = self.roiWidget.getRois + self.roiWidget.sigROISignal.connect(self._forwardSigROISignal) - self.currentROI = self.roiWidget.currentROI self.layout().setContentsMargins(0, 0, 0, 0) self.setWidget(self.roiWidget) + self.setAreaVisible = self.roiWidget.roiTable.setAreaVisible + self.setCountsVisible = self.roiWidget.roiTable.setCountsVisible + def _forwardSigROISignal(self, ddict): # emit deprecated signal for backward compatibility (silx < 0.7) self.sigROISignal.emit(ddict) @@ -1042,3 +1560,7 @@ class CurvesROIDockWidget(qt.QDockWidget): """ self.raise_() qt.QDockWidget.showEvent(self, event) + + @property + def currentROI(self): + return self.roiWidget.currentRoi diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py index 990e479..9d727e7 100644 --- a/silx/gui/plot/MaskToolsWidget.py +++ b/silx/gui/plot/MaskToolsWidget.py @@ -35,7 +35,7 @@ from __future__ import division __authors__ = ["T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "29/08/2018" +__date__ = "15/02/2019" import os @@ -57,10 +57,7 @@ from .. import qt from silx.third_party.EdfFile import EdfFile from silx.third_party.TiffIO import TiffIO -try: - import fabio -except ImportError: - fabio = None +import fabio _logger = logging.getLogger(__name__) @@ -135,8 +132,6 @@ class ImageMask(BaseMask): self._saveToHdf5(filename, self.getMask(copy=False)) elif kind == 'msk': - if fabio is None: - raise ImportError("Fit2d mask files can't be written: Fabio module is not available") try: data = self.getMask(copy=False) image = fabio.fabioimage.FabioImage(data=data) @@ -250,6 +245,19 @@ class ImageMask(BaseMask): rows, cols = shapes.circle_fill(crow, ccol, radius) self.updatePoints(level, rows, cols, mask) + def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True): + """Mask/Unmask an ellipse of the given mask level. + + :param int level: Mask level to update. + :param int crow: Row of the center of the ellipse + :param int ccol: Column of the center of the ellipse + :param float radius_r: Radius of the ellipse in the row + :param float radius_c: Radius of the ellipse in the column + :param bool mask: True to mask (default), False to unmask. + """ + rows, cols = shapes.ellipse_fill(crow, ccol, radius_r, radius_c) + self.updatePoints(level, rows, cols, mask) + def updateLine(self, level, row0, col0, row1, col1, width, mask=True): """Mask/Unmask a line of the given mask level. @@ -300,6 +308,10 @@ class MaskToolsWidget(BaseMaskToolsWidget): _logger.error('Not an image, shape: %d', len(mask.shape)) return None + # Handle mask with single level + if self.multipleMasks() == 'single': + mask = numpy.array(mask != 0, dtype=numpy.uint8) + # if mask has not changed, do nothing if numpy.array_equal(mask, self.getSelectionMask()): return mask.shape @@ -501,8 +513,6 @@ class MaskToolsWidget(BaseMaskToolsWidget): _logger.debug("Backtrace", exc_info=True) raise e elif extension == "msk": - if fabio is None: - raise ImportError("Fit2d mask files can't be read: Fabio module is not available") try: mask = fabio.open(filename).data except Exception as e: @@ -682,41 +692,51 @@ class MaskToolsWidget(BaseMaskToolsWidget): level = self.levelSpinBox.value() - if (self._drawingMode == 'rectangle' and - event['event'] == 'drawingFinished'): - # Convert from plot to array coords - doMask = self._isMasking() - ox, oy = self._origin - sx, sy = self._scale - - height = int(abs(event['height'] / sy)) - width = int(abs(event['width'] / sx)) - - row = int((event['y'] - oy) / sy) - if sy < 0: - row -= height - - col = int((event['x'] - ox) / sx) - if sx < 0: - col -= width - - self._mask.updateRectangle( - level, - row=row, - col=col, - height=height, - width=width, - mask=doMask) - self._mask.commit() + if self._drawingMode == 'rectangle': + if event['event'] == 'drawingFinished': + # Convert from plot to array coords + doMask = self._isMasking() + ox, oy = self._origin + sx, sy = self._scale + + height = int(abs(event['height'] / sy)) + width = int(abs(event['width'] / sx)) + + row = int((event['y'] - oy) / sy) + if sy < 0: + row -= height + + col = int((event['x'] - ox) / sx) + if sx < 0: + col -= width + + self._mask.updateRectangle( + level, + row=row, + col=col, + height=height, + width=width, + mask=doMask) + self._mask.commit() - elif (self._drawingMode == 'polygon' and - event['event'] == 'drawingFinished'): - doMask = self._isMasking() - # Convert from plot to array coords - vertices = (event['points'] - self._origin) / self._scale - vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col) - self._mask.updatePolygon(level, vertices, doMask) - self._mask.commit() + elif self._drawingMode == 'ellipse': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + # Convert from plot to array coords + center = (event['points'][0] - self._origin) / self._scale + size = event['points'][1] / self._scale + center = center.astype(numpy.int) # (row, col) + self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask) + self._mask.commit() + + elif self._drawingMode == 'polygon': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + # Convert from plot to array coords + vertices = (event['points'] - self._origin) / self._scale + vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col) + self._mask.updatePolygon(level, vertices, doMask) + self._mask.commit() elif self._drawingMode == 'pencil': doMask = self._isMasking() @@ -743,6 +763,8 @@ class MaskToolsWidget(BaseMaskToolsWidget): self._lastPencilPos = None else: self._lastPencilPos = row, col + else: + _logger.error("Drawing mode %s unsupported", self._drawingMode) def _loadRangeFromColormapTriggered(self): """Set range from active image colormap range""" diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py index 356bda6..27abd10 100644 --- a/silx/gui/plot/PlotInteraction.py +++ b/silx/gui/plot/PlotInteraction.py @@ -26,7 +26,7 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "24/04/2018" +__date__ = "15/02/2019" import math @@ -96,10 +96,18 @@ class _PlotInteraction(object): fill = fill != 'none' # TODO not very nice either + greyed = colors.greyed(color)[0] + if greyed < 0.5: + color2 = "white" + else: + color2 = "black" + self.plot.addItem(points[:, 0], points[:, 1], legend=legend, replace=False, - shape=shape, color=color, fill=fill, + shape=shape, fill=fill, + color=color, linebgcolor=color2, linestyle="--", overlay=True) + self._selectionAreas.add(legend) def resetSelectionArea(self): @@ -274,6 +282,8 @@ class Zoom(_ZoomOnWheel): and zoom on mouse wheel. """ + SURFACE_THRESHOLD = 5 + def __init__(self, plot, color): self.color = color @@ -347,35 +357,44 @@ class Zoom(_ZoomOnWheel): self.setSelectionArea(corners, fill='none', color=self.color) - def endDrag(self, startPos, endPos): - x0, y0 = startPos - x1, y1 = endPos + def _zoom(self, x0, y0, x1, y1): + """Zoom to the rectangle view x0,y0 x1,y1. + """ + startPos = x0, y0 + endPos = x1, y1 + + # Store current zoom state in stack + self.plot.getLimitsHistory().push() - if x0 != x1 or y0 != y1: # Avoid empty zoom area - # Store current zoom state in stack - self.plot.getLimitsHistory().push() + if self.plot.isKeepDataAspectRatio(): + x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1) + + # Convert to data space and set limits + x0, y0 = self.plot.pixelToData(x0, y0, check=False) - if self.plot.isKeepDataAspectRatio(): - x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1) + dataPos = self.plot.pixelToData( + startPos[0], startPos[1], axis="right", check=False) + y2_0 = dataPos[1] - # Convert to data space and set limits - x0, y0 = self.plot.pixelToData(x0, y0, check=False) + x1, y1 = self.plot.pixelToData(x1, y1, check=False) - dataPos = self.plot.pixelToData( - startPos[0], startPos[1], axis="right", check=False) - y2_0 = dataPos[1] + dataPos = self.plot.pixelToData( + endPos[0], endPos[1], axis="right", check=False) + y2_1 = dataPos[1] - x1, y1 = self.plot.pixelToData(x1, y1, check=False) + xMin, xMax = min(x0, x1), max(x0, x1) + yMin, yMax = min(y0, y1), max(y0, y1) + y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1) - dataPos = self.plot.pixelToData( - endPos[0], endPos[1], axis="right", check=False) - y2_1 = dataPos[1] + self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) - xMin, xMax = min(x0, x1), max(x0, x1) - yMin, yMax = min(y0, y1), max(y0, y1) - y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1) + def endDrag(self, startPos, endPos): + x0, y0 = startPos + x1, y1 = endPos - self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + if abs(x0 - x1) * abs(y0 - y1) >= self.SURFACE_THRESHOLD: + # Avoid empty zoom area + self._zoom(x0, y0, x1, y1) self.resetSelectionArea() @@ -544,7 +563,6 @@ class SelectPolygon(Select): return self.DRAG_THRESHOLD_DIST * ratio - class Select2Points(Select): """Base class for drawing selection based on 2 input points.""" class Idle(State): @@ -603,6 +621,87 @@ class Select2Points(Select): self.cancelSelect() +class SelectEllipse(Select2Points): + """Drawing ellipse selection area state machine.""" + def beginSelect(self, x, y): + self.center = self.plot.pixelToData(x, y) + assert self.center is not None + + def _getEllipseSize(self, pointInEllipse): + """ + Returns the size from the center to the bounding box of the ellipse. + + :param Tuple[float,float] pointInEllipse: A point of the ellipse + :rtype: Tuple[float,float] + """ + x = abs(self.center[0] - pointInEllipse[0]) + y = abs(self.center[1] - pointInEllipse[1]) + if x == 0 or y == 0: + return x, y + # Ellipse definitions + # e: eccentricity + # a: length fron center to bounding box width + # b: length fron center to bounding box height + # Equations + # (1) b < a + # (2) For x,y a point in the ellipse: x^2/a^2 + y^2/b^2 = 1 + # (3) b = a * sqrt(1-e^2) + # (4) e = sqrt(a^2 - b^2) / a + + # The eccentricity of the ellipse defined by a,b=x,y is the same + # as the one we are searching for. + swap = x < y + if swap: + x, y = y, x + e = math.sqrt(x**2 - y**2) / x + # From (2) using (3) to replace b + # a^2 = x^2 + y^2 / (1-e^2) + a = math.sqrt(x**2 + y**2 / (1.0 - e**2)) + b = a * math.sqrt(1 - e**2) + if swap: + a, b = b, a + return a, b + + def select(self, x, y): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + width, height = self._getEllipseSize(dataPos) + + # Circle used for circle preview + nbpoints = 27. + angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints + circleShape = numpy.array((numpy.cos(angles) * width, + numpy.sin(angles) * height)).T + circleShape += numpy.array(self.center) + + self.setSelectionArea(circleShape, + shape="polygon", + fill='hatch', + color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'ellipse', + (self.center, (width, height)), + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + width, height = self._getEllipseSize(dataPos) + + eventDict = prepareDrawingSignal('drawingFinished', + 'ellipse', + (self.center, (width, height)), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + class SelectRectangle(Select2Points): """Drawing rectangle selection area state machine.""" def beginSelect(self, x, y): @@ -1488,6 +1587,7 @@ class PlotInteraction(object): _DRAW_MODES = { 'polygon': SelectPolygon, 'rectangle': SelectRectangle, + 'ellipse': SelectEllipse, 'line': SelectLine, 'vline': SelectVLine, 'hline': SelectHLine, diff --git a/silx/gui/plot/PlotToolButtons.py b/silx/gui/plot/PlotToolButtons.py index f6291b5..bf6b8ce 100644 --- a/silx/gui/plot/PlotToolButtons.py +++ b/silx/gui/plot/PlotToolButtons.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -45,6 +45,7 @@ import weakref from .. import icons from .. import qt +from ... import config from .items import SymbolMixIn @@ -250,24 +251,24 @@ class ProfileOptionToolButton(PlotToolButton): self.STATE = {} # is down self.STATE['sum', "icon"] = icons.getQIcon('math-sigma') - self.STATE['sum', "state"] = "compute profile sum" - self.STATE['sum', "action"] = "compute profile sum" + self.STATE['sum', "state"] = "Compute profile sum" + self.STATE['sum', "action"] = "Compute profile sum" # keep ration self.STATE['mean', "icon"] = icons.getQIcon('math-mean') - self.STATE['mean', "state"] = "compute profile mean" - self.STATE['mean', "action"] = "compute profile mean" + self.STATE['mean', "state"] = "Compute profile mean" + self.STATE['mean', "action"] = "Compute profile mean" - sumAction = self._createAction('sum') - sumAction.triggered.connect(self.setSum) - sumAction.setIconVisibleInMenu(True) + self.sumAction = self._createAction('sum') + self.sumAction.triggered.connect(self.setSum) + self.sumAction.setIconVisibleInMenu(True) - meanAction = self._createAction('mean') - meanAction.triggered.connect(self.setMean) - meanAction.setIconVisibleInMenu(True) + self.meanAction = self._createAction('mean') + self.meanAction.triggered.connect(self.setMean) + self.meanAction.setIconVisibleInMenu(True) menu = qt.QMenu(self) - menu.addAction(sumAction) - menu.addAction(meanAction) + menu.addAction(self.sumAction) + menu.addAction(self.meanAction) self.setMenu(menu) self.setPopupMode(qt.QToolButton.InstantPopup) self.setMean() @@ -370,7 +371,7 @@ class SymbolToolButton(PlotToolButton): slider = qt.QSlider(qt.Qt.Horizontal) slider.setRange(1, 20) - slider.setValue(SymbolMixIn._DEFAULT_SYMBOL_SIZE) + slider.setValue(config.DEFAULT_PLOT_SYMBOL_SIZE) slider.setTracking(False) slider.valueChanged.connect(self._sizeChanged) widgetAction = qt.QWidgetAction(menu) diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py index e023a21..cfe39fa 100644 --- a/silx/gui/plot/PlotWidget.py +++ b/silx/gui/plot/PlotWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,7 +31,7 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "12/10/2018" +__date__ = "21/12/2018" from collections import OrderedDict, namedtuple @@ -44,7 +44,6 @@ import numpy import silx from silx.utils.weakref import WeakMethodProxy -from silx.utils import deprecation from silx.utils.property import classproperty from silx.utils.deprecation import deprecated # Import matplotlib backend here to init matplotlib our way @@ -99,7 +98,7 @@ class PlotWidget(qt.QMainWindow): # TODO: Can be removed for silx 0.10 @classproperty - @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) + @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) def DEFAULT_BACKEND(self): """Class attribute setting the default backend for all instances.""" return silx.config.DEFAULT_PLOT_BACKEND @@ -193,21 +192,12 @@ class PlotWidget(qt.QMainWindow): It provides the visible state. """ - def __init__(self, parent=None, backend=None, - legends=False, callback=None, **kw): + def __init__(self, parent=None, backend=None): self._autoreplot = False self._dirty = False self._cursorInPlot = False self.__muteActiveItemChanged = False - if kw: - _logger.warning( - 'deprecated: __init__ extra arguments: %s', str(kw)) - if legends: - _logger.warning('deprecated: __init__ legend argument') - if callback: - _logger.warning('deprecated: __init__ callback argument') - self._panWithArrowKeys = True self._viewConstrains = None @@ -218,27 +208,8 @@ class PlotWidget(qt.QMainWindow): else: self.setWindowTitle('PlotWidget') - if backend is None: - backend = silx.config.DEFAULT_PLOT_BACKEND - - if hasattr(backend, "__call__"): - self._backend = backend(self, parent) - - elif hasattr(backend, "lower"): - lowerCaseString = backend.lower() - if lowerCaseString in ("matplotlib", "mpl"): - backendClass = BackendMatplotlibQt - elif lowerCaseString in ('gl', 'opengl'): - from .backends.BackendOpenGL import BackendOpenGL - backendClass = BackendOpenGL - elif lowerCaseString == 'none': - from .backends.BackendBase import BackendBase as backendClass - else: - raise ValueError("Backend not supported %s" % backend) - self._backend = backendClass(self, parent) - - else: - raise ValueError("Backend not supported %s" % str(backend)) + self._backend = None + self._setBackend(backend) self.setCallback() # set _callback @@ -258,6 +229,12 @@ class PlotWidget(qt.QMainWindow): self._activeLegend = {'curve': None, 'image': None, 'scatter': None} + # plot colors (updated later to sync backend) + self._foregroundColor = 0., 0., 0., 1. + self._gridColor = .7, .7, .7, 1. + self._backgroundColor = 1., 1., 1., 1. + self._dataBackgroundColor = None + # default properties self._cursorConfiguration = None @@ -275,7 +252,7 @@ class PlotWidget(qt.QMainWindow): self.setDefaultColormap() # Init default colormap - self.setDefaultPlotPoints(False) + self.setDefaultPlotPoints(silx.config.DEFAULT_PLOT_CURVE_SYMBOL_MODE) self.setDefaultPlotLines(True) self._limitsHistory = LimitsHistory(self) @@ -306,9 +283,41 @@ class PlotWidget(qt.QMainWindow): self.setGraphYLimits(0., 100., axis='right') self.setGraphYLimits(0., 100., axis='left') + # Sync backend colors with default ones + self._foregroundColorsUpdated() + self._backgroundColorsUpdated() + + def _setBackend(self, backend): + """Setup a new backend""" + assert(self._backend is None) + + if backend is None: + backend = silx.config.DEFAULT_PLOT_BACKEND + + if hasattr(backend, "__call__"): + backend = backend(self, self) + + elif hasattr(backend, "lower"): + lowerCaseString = backend.lower() + if lowerCaseString in ("matplotlib", "mpl"): + backendClass = BackendMatplotlibQt + elif lowerCaseString in ('gl', 'opengl'): + from .backends.BackendOpenGL import BackendOpenGL + backendClass = BackendOpenGL + elif lowerCaseString == 'none': + from .backends.BackendBase import BackendBase as backendClass + else: + raise ValueError("Backend not supported %s" % backend) + backend = backendClass(self, self) + + else: + raise ValueError("Backend not supported %s" % str(backend)) + + self._backend = backend + # TODO: Can be removed for silx 0.10 @staticmethod - @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) + @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) def setDefaultBackend(backend): """Set system wide default plot backend. @@ -349,6 +358,119 @@ class PlotWidget(qt.QMainWindow): if self._autoreplot and not wasDirty and self.isVisible(): self._backend.postRedisplay() + def _foregroundColorsUpdated(self): + """Handle change of foreground/grid color""" + if self._gridColor is None: + gridColor = self._foregroundColor + else: + gridColor = self._gridColor + self._backend.setForegroundColors( + self._foregroundColor, gridColor) + self._setDirtyPlot() + + def getForegroundColor(self): + """Returns the RGBA colors used to display the foreground of this widget + + :rtype: qt.QColor + """ + return qt.QColor.fromRgbF(*self._foregroundColor) + + def setForegroundColor(self, color): + """Set the foreground color of this widget. + + :param Union[List[int],List[float],QColor] color: + The new RGB(A) color. + """ + color = colors.rgba(color) + if self._foregroundColor != color: + self._foregroundColor = color + self._foregroundColorsUpdated() + + def getGridColor(self): + """Returns the RGBA colors used to display the grid lines + + An invalid QColor is returned if there is no grid color, + in which case the foreground color is used. + + :rtype: qt.QColor + """ + if self._gridColor is None: + return qt.QColor() # An invalid color + else: + return qt.QColor.fromRgbF(*self._gridColor) + + def setGridColor(self, color): + """Set the grid lines color + + :param Union[List[int],List[float],QColor,None] color: + The new RGB(A) color. + """ + if isinstance(color, qt.QColor) and not color.isValid(): + color = None + if color is not None: + color = colors.rgba(color) + if self._gridColor != color: + self._gridColor = color + self._foregroundColorsUpdated() + + def _backgroundColorsUpdated(self): + """Handle change of background/data background color""" + if self._dataBackgroundColor is None: + dataBGColor = self._backgroundColor + else: + dataBGColor = self._dataBackgroundColor + self._backend.setBackgroundColors( + self._backgroundColor, dataBGColor) + self._setDirtyPlot() + + def getBackgroundColor(self): + """Returns the RGBA colors used to display the background of this widget. + + :rtype: qt.QColor + """ + return qt.QColor.fromRgbF(*self._backgroundColor) + + def setBackgroundColor(self, color): + """Set the background color of this widget. + + :param Union[List[int],List[float],QColor] color: + The new RGB(A) color. + """ + color = colors.rgba(color) + if self._backgroundColor != color: + self._backgroundColor = color + self._backgroundColorsUpdated() + + def getDataBackgroundColor(self): + """Returns the RGBA colors used to display the background of the plot + view displaying the data. + + An invalid QColor is returned if there is no data background color. + + :rtype: qt.QColor + """ + if self._dataBackgroundColor is None: + # An invalid color + return qt.QColor() + else: + return qt.QColor.fromRgbF(*self._dataBackgroundColor) + + def setDataBackgroundColor(self, color): + """Set the background color of this widget. + + Set to None or an invalid QColor to use the background color. + + :param Union[List[int],List[float],QColor,None] color: + The new RGB(A) color. + """ + if isinstance(color, qt.QColor) and not color.isValid(): + color = None + if color is not None: + color = colors.rgba(color) + if self._dataBackgroundColor != color: + self._dataBackgroundColor = color + self._backgroundColorsUpdated() + def showEvent(self, event): if self._autoreplot and self._dirty: self._backend.postRedisplay() @@ -528,13 +650,13 @@ class PlotWidget(qt.QMainWindow): # This value is used when curve is updated either internally or by user. def addCurve(self, x, y, legend=None, info=None, - replace=False, replot=None, + replace=False, color=None, symbol=None, linewidth=None, linestyle=None, xlabel=None, ylabel=None, yaxis=None, xerror=None, yerror=None, z=None, selectable=None, fill=None, resetzoom=True, - histogram=None, copy=True, **kw): + histogram=None, copy=True): """Add a 1D curve given by x an y to the graph. Curves are uniquely identified by their legend. @@ -617,15 +739,6 @@ class PlotWidget(qt.QMainWindow): False to use provided arrays. :returns: The key string identify this curve """ - # Deprecation warnings - if replot is not None: - _logger.warning( - 'addCurve deprecated replot argument, use resetzoom instead') - resetzoom = replot and resetzoom - - if kw: - _logger.warning('addCurve: deprecated extra arguments') - # This is an histogram, use addHistogram if histogram is not None: histoLegend = self.addHistogram(histogram=y, @@ -825,13 +938,13 @@ class PlotWidget(qt.QMainWindow): return legend def addImage(self, data, legend=None, info=None, - replace=False, replot=None, - xScale=None, yScale=None, z=None, + replace=False, + z=None, selectable=None, draggable=None, colormap=None, pixmap=None, xlabel=None, ylabel=None, origin=None, scale=None, - resetzoom=True, copy=True, **kw): + resetzoom=True, copy=True): """Add a 2D dataset or an image to the plot. It displays either an array of data using a colormap or a RGB(A) image. @@ -883,28 +996,6 @@ class PlotWidget(qt.QMainWindow): False to use provided arrays. :returns: The key string identify this image """ - # Deprecation warnings - if xScale is not None or yScale is not None: - _logger.warning( - 'addImage deprecated xScale and yScale arguments,' - 'use origin, scale arguments instead.') - if origin is None and scale is None: - origin = xScale[0], yScale[0] - scale = xScale[1], yScale[1] - else: - _logger.warning( - 'addCurve: xScale, yScale and origin, scale arguments' - ' are conflicting. xScale and yScale are ignored.' - ' Use only origin, scale arguments.') - - if replot is not None: - _logger.warning( - 'addImage deprecated replot argument, use resetzoom instead') - resetzoom = replot and resetzoom - - if kw: - _logger.warning('addImage: deprecated extra arguments') - legend = "Unnamed Image 1.1" if legend is None else str(legend) # Check if image was previously active @@ -1090,7 +1181,8 @@ class PlotWidget(qt.QMainWindow): def addItem(self, xdata, ydata, legend=None, info=None, replace=False, shape="polygon", color='black', fill=True, - overlay=False, z=None, **kw): + overlay=False, z=None, linestyle="-", linewidth=1.0, + linebgcolor=None): """Add an item (i.e. a shape) to the plot. Items are uniquely identified by their legend. @@ -1114,13 +1206,23 @@ class PlotWidget(qt.QMainWindow): This allows for rendering optimization if this item is changed often. :param int z: Layer on which to draw the item (default: 2) + :param str linestyle: Style of the line. + Only relevant for line markers where X or Y is None. + Value in: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + :param float linewidth: Width of the line. + Only relevant for line markers where X or Y is None. + :param str linebgcolor: Background color of the line, e.g., 'blue', 'b', + '#FF0000'. It is used to draw dotted line using a second color. :returns: The key string identify this item """ # expected to receive the same parameters as the signal - if kw: - _logger.warning('addItem deprecated parameters: %s', str(kw)) - legend = "Unnamed Item 1.1" if legend is None else str(legend) z = int(z) if z is not None else 2 @@ -1138,6 +1240,9 @@ class PlotWidget(qt.QMainWindow): item.setOverlay(overlay) item.setZValue(z) item.setPoints(numpy.array((xdata, ydata)).T) + item.setLineStyle(linestyle) + item.setLineWidth(linewidth) + item.setLineBgColor(linebgcolor) self._add(item) @@ -1148,8 +1253,7 @@ class PlotWidget(qt.QMainWindow): color=None, selectable=False, draggable=False, - constraint=None, - **kw): + constraint=None): """Add a vertical line marker to the plot. Markers are uniquely identified by their legend. @@ -1177,10 +1281,6 @@ class PlotWidget(qt.QMainWindow): and that returns the filtered coordinates. :return: The key string identify this marker """ - if kw: - _logger.warning( - 'addXMarker deprecated extra parameters: %s', str(kw)) - return self._addMarker(x=x, y=None, legend=legend, text=text, color=color, selectable=selectable, draggable=draggable, @@ -1192,8 +1292,7 @@ class PlotWidget(qt.QMainWindow): color=None, selectable=False, draggable=False, - constraint=None, - **kw): + constraint=None): """Add a horizontal line marker to the plot. Markers are uniquely identified by their legend. @@ -1221,10 +1320,6 @@ class PlotWidget(qt.QMainWindow): and that returns the filtered coordinates. :return: The key string identify this marker """ - if kw: - _logger.warning( - 'addYMarker deprecated extra parameters: %s', str(kw)) - return self._addMarker(x=None, y=y, legend=legend, text=text, color=color, selectable=selectable, draggable=draggable, @@ -1236,8 +1331,7 @@ class PlotWidget(qt.QMainWindow): selectable=False, draggable=False, symbol='+', - constraint=None, - **kw): + constraint=None): """Add a point marker to the plot. Markers are uniquely identified by their legend. @@ -1277,10 +1371,6 @@ class PlotWidget(qt.QMainWindow): and that returns the filtered coordinates. :return: The key string identify this marker """ - if kw: - _logger.warning( - 'addMarker deprecated extra parameters: %s', str(kw)) - if x is None: xmin, xmax = self._xAxis.getLimits() x = 0.5 * (xmax + xmin) @@ -1368,7 +1458,7 @@ class PlotWidget(qt.QMainWindow): curve = self._getItem('curve', legend) return curve is not None and not curve.isVisible() - def hideCurve(self, legend, flag=True, replot=None): + def hideCurve(self, legend, flag=True): """Show/Hide the curve associated to legend. Even when hidden, the curve is kept in the list of curves. @@ -1376,9 +1466,6 @@ class PlotWidget(qt.QMainWindow): :param str legend: The legend associated to the curve to be hidden :param bool flag: True (default) to hide the curve, False to show it """ - if replot is not None: - _logger.warning('hideCurve deprecated replot parameter') - curve = self._getItem('curve', legend) if curve is None: _logger.warning('Curve not in plot: %s', legend) @@ -1660,16 +1747,13 @@ class PlotWidget(qt.QMainWindow): return self._getActiveItem(kind='curve', just_legend=just_legend) - def setActiveCurve(self, legend, replot=None): + def setActiveCurve(self, legend): """Make the curve associated to legend the active curve. :param legend: The legend associated to the curve or None to have no active curve. :type legend: str or None """ - if replot is not None: - _logger.warning('setActiveCurve deprecated replot parameter') - if not self.isActiveCurveHandling(): return if legend is None and self.getActiveCurveSelectionMode() == "legacy": @@ -1723,15 +1807,12 @@ class PlotWidget(qt.QMainWindow): """ return self._getActiveItem(kind='image', just_legend=just_legend) - def setActiveImage(self, legend, replot=None): + def setActiveImage(self, legend): """Make the image associated to legend the active image. :param str legend: The legend associated to the image or None to have no active image. """ - if replot is not None: - _logger.warning('setActiveImage deprecated replot parameter') - return self._setActiveItem(kind='image', legend=legend) def _getActiveItem(self, kind, just_legend=False): @@ -2028,14 +2109,12 @@ class PlotWidget(qt.QMainWindow): """ return self._backend.getGraphXLimits() - def setGraphXLimits(self, xmin, xmax, replot=None): + def setGraphXLimits(self, xmin, xmax): """Set the graph X (bottom) limits. :param float xmin: minimum bottom axis value :param float xmax: maximum bottom axis value """ - if replot is not None: - _logger.warning('setGraphXLimits deprecated replot parameter') self._xAxis.setLimits(xmin, xmax) def getGraphYLimits(self, axis='left'): @@ -2049,7 +2128,7 @@ class PlotWidget(qt.QMainWindow): yAxis = self._yAxis if axis == 'left' else self._yRightAxis return yAxis.getLimits() - def setGraphYLimits(self, ymin, ymax, axis='left', replot=None): + def setGraphYLimits(self, ymin, ymax, axis='left'): """Set the graph Y limits. :param float ymin: minimum bottom axis value @@ -2057,8 +2136,6 @@ class PlotWidget(qt.QMainWindow): :param str axis: The axis for which to get the limits: Either 'left' or 'right' """ - if replot is not None: - _logger.warning('setGraphYLimits deprecated replot parameter') assert axis in ('left', 'right') yAxis = self._yAxis if axis == 'left' else self._yRightAxis return yAxis.setLimits(ymin, ymax) @@ -2192,36 +2269,6 @@ class PlotWidget(qt.QMainWindow): def _isAxesDisplayed(self): return self._backend.isAxesDisplayed() - @property - @deprecated(since_version='0.6') - def sigSetYAxisInverted(self): - """Signal emitted when Y axis orientation has changed""" - return self._yAxis.sigInvertedChanged - - @property - @deprecated(since_version='0.6') - def sigSetXAxisLogarithmic(self): - """Signal emitted when X axis scale has changed""" - return self._xAxis._sigLogarithmicChanged - - @property - @deprecated(since_version='0.6') - def sigSetYAxisLogarithmic(self): - """Signal emitted when Y axis scale has changed""" - return self._yAxis._sigLogarithmicChanged - - @property - @deprecated(since_version='0.6') - def sigSetXAxisAutoScale(self): - """Signal emitted when X axis autoscale has changed""" - return self._xAxis.sigAutoScaleChanged - - @property - @deprecated(since_version='0.6') - def sigSetYAxisAutoScale(self): - """Signal emitted when Y axis autoscale has changed""" - return self._yAxis.sigAutoScaleChanged - def setYAxisInverted(self, flag=True): """Set the Y axis orientation. @@ -2290,6 +2337,8 @@ class PlotWidget(qt.QMainWindow): :param bool flag: True to respect data aspect ratio """ flag = bool(flag) + if flag == self.isKeepDataAspectRatio(): + return self._backend.setKeepDataAspectRatio(flag=flag) self._setDirtyPlot() self._forceResetZoom() @@ -2323,8 +2372,8 @@ class PlotWidget(qt.QMainWindow): # Defaults def isDefaultPlotPoints(self): - """Return True if default Curve symbol is 'o', False for no symbol.""" - return self._defaultPlotPoints == 'o' + """Return True if the default Curve symbol is set and False if not.""" + return self._defaultPlotPoints == silx.config.DEFAULT_PLOT_SYMBOL def setDefaultPlotPoints(self, flag): """Set the default symbol of all curves. @@ -2334,7 +2383,7 @@ class PlotWidget(qt.QMainWindow): :param bool flag: True to use 'o' as the default curve symbol, False to use no symbol. """ - self._defaultPlotPoints = 'o' if flag else '' + self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else '' # Reset symbol of all curves curves = self.getAllCurves(just_legend=False, withhidden=True) @@ -2510,7 +2559,7 @@ class PlotWidget(qt.QMainWindow): elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left': self.setActiveCurve(None) - def saveGraph(self, filename, fileFormat=None, dpi=None, **kw): + def saveGraph(self, filename, fileFormat=None, dpi=None): """Save a snapshot of the plot. Supported file formats depends on the backend in use. @@ -2523,9 +2572,6 @@ class PlotWidget(qt.QMainWindow): :param str fileFormat: String specifying the format :return: False if cannot save the plot, True otherwise """ - if kw: - _logger.warning('Extra parameters ignored: %s', str(kw)) - if fileFormat is None: if not hasattr(filename, 'lower'): _logger.warning( @@ -3080,149 +3126,3 @@ class PlotWidget(qt.QMainWindow): # Only call base class implementation when key is not handled. # See QWidget.keyPressEvent for details. super(PlotWidget, self).keyPressEvent(event) - - # Deprecated # - - def isDrawModeEnabled(self): - """Deprecated, use :meth:`getInteractiveMode` instead. - - Return True if the current interactive state is drawing.""" - _logger.warning( - 'isDrawModeEnabled deprecated, use getInteractiveMode instead') - return self.getInteractiveMode()['mode'] == 'draw' - - def setDrawModeEnabled(self, flag=True, shape='polygon', label=None, - color=None, **kwargs): - """Deprecated, use :meth:`setInteractiveMode` instead. - - Set the drawing mode if flag is True and its parameters. - - If flag is False, only item selection is enabled. - - Warning: Zoom and drawing are not compatible and cannot be enabled - simultaneously. - - :param bool flag: True to enable drawing and disable zoom and select. - :param str shape: Type of item to be drawn in: - hline, vline, rectangle, polygon (default) - :param str label: Associated text for identifying draw signals - :param color: The color to use to draw the selection area - :type color: string ("#RRGGBB") or 4 column unsigned byte array or - one of the predefined color names defined in colors.py - """ - _logger.warning( - 'setDrawModeEnabled deprecated, use setInteractiveMode instead') - - if kwargs: - _logger.warning('setDrawModeEnabled ignores additional parameters') - - if color is None: - color = 'black' - - if flag: - self.setInteractiveMode('draw', shape=shape, - label=label, color=color) - elif self.getInteractiveMode()['mode'] == 'draw': - self.setInteractiveMode('select') - - def getDrawMode(self): - """Deprecated, use :meth:`getInteractiveMode` instead. - - Return the draw mode parameters as a dict of None. - - It returns None if the interactive mode is not a drawing mode, - otherwise, it returns a dict containing the drawing mode parameters - as provided to :meth:`setDrawModeEnabled`. - """ - _logger.warning( - 'getDrawMode deprecated, use getInteractiveMode instead') - mode = self.getInteractiveMode() - return mode if mode['mode'] == 'draw' else None - - def isZoomModeEnabled(self): - """Deprecated, use :meth:`getInteractiveMode` instead. - - Return True if the current interactive state is zooming.""" - _logger.warning( - 'isZoomModeEnabled deprecated, use getInteractiveMode instead') - return self.getInteractiveMode()['mode'] == 'zoom' - - def setZoomModeEnabled(self, flag=True, color=None): - """Deprecated, use :meth:`setInteractiveMode` instead. - - Set the zoom mode if flag is True, else item selection is enabled. - - Warning: Zoom and drawing are not compatible and cannot be enabled - simultaneously - - :param bool flag: If True, enable zoom and select mode. - :param color: The color to use to draw the selection area. - (Default: 'black') - :param color: The color to use to draw the selection area - :type color: string ("#RRGGBB") or 4 column unsigned byte array or - one of the predefined color names defined in colors.py - """ - _logger.warning( - 'setZoomModeEnabled deprecated, use setInteractiveMode instead') - if color is None: - color = 'black' - - if flag: - self.setInteractiveMode('zoom', color=color) - elif self.getInteractiveMode()['mode'] == 'zoom': - self.setInteractiveMode('select') - - def insertMarker(self, *args, **kwargs): - """Deprecated, use :meth:`addMarker` instead.""" - _logger.warning( - 'insertMarker deprecated, use addMarker instead.') - return self.addMarker(*args, **kwargs) - - def insertXMarker(self, *args, **kwargs): - """Deprecated, use :meth:`addXMarker` instead.""" - _logger.warning( - 'insertXMarker deprecated, use addXMarker instead.') - return self.addXMarker(*args, **kwargs) - - def insertYMarker(self, *args, **kwargs): - """Deprecated, use :meth:`addYMarker` instead.""" - _logger.warning( - 'insertYMarker deprecated, use addYMarker instead.') - return self.addYMarker(*args, **kwargs) - - def isActiveCurveHandlingEnabled(self): - """Deprecated, use :meth:`isActiveCurveHandling` instead.""" - _logger.warning( - 'isActiveCurveHandlingEnabled deprecated, ' - 'use isActiveCurveHandling instead.') - return self.isActiveCurveHandling() - - def enableActiveCurveHandling(self, *args, **kwargs): - """Deprecated, use :meth:`setActiveCurveHandling` instead.""" - _logger.warning( - 'enableActiveCurveHandling deprecated, ' - 'use setActiveCurveHandling instead.') - return self.setActiveCurveHandling(*args, **kwargs) - - def invertYAxis(self, *args, **kwargs): - """Deprecated, use :meth:`Axis.setInverted` instead.""" - _logger.warning('invertYAxis deprecated, ' - 'use getYAxis().setInverted instead.') - return self.getYAxis().setInverted(*args, **kwargs) - - def showGrid(self, flag=True): - """Deprecated, use :meth:`setGraphGrid` instead.""" - _logger.warning("showGrid deprecated, use setGraphGrid instead") - if flag in (0, False): - flag = None - elif flag in (1, True): - flag = 'major' - else: - flag = 'both' - return self.setGraphGrid(flag) - - def keepDataAspectRatio(self, *args, **kwargs): - """Deprecated, use :meth:`setKeepDataAspectRatio`.""" - _logger.warning('keepDataAspectRatio deprecated,' - 'use setKeepDataAspectRatio instead') - return self.setKeepDataAspectRatio(*args, **kwargs) diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py index 23ea399..b44a512 100644 --- a/silx/gui/plot/PlotWindow.py +++ b/silx/gui/plot/PlotWindow.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,7 @@ The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`. __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "12/10/2018" +__date__ = "21/12/2018" import collections import logging @@ -217,10 +217,8 @@ class PlotWindow(PlotWidget): # Make colorbar background white self._colorbar.setAutoFillBackground(True) - palette = self._colorbar.palette() - palette.setColor(qt.QPalette.Background, qt.Qt.white) - palette.setColor(qt.QPalette.Window, qt.Qt.white) - self._colorbar.setPalette(palette) + self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground) + self._updateColorBarBackground() gridLayout = qt.QGridLayout() gridLayout.setSpacing(0) @@ -294,6 +292,43 @@ class PlotWindow(PlotWidget): for action in toolbar.actions(): self.addAction(action) + def setBackgroundColor(self, color): + super(PlotWindow, self).setBackgroundColor(color) + self._updateColorBarBackground() + + setBackgroundColor.__doc__ = PlotWidget.setBackgroundColor.__doc__ + + def setDataBackgroundColor(self, color): + super(PlotWindow, self).setDataBackgroundColor(color) + self._updateColorBarBackground() + + setDataBackgroundColor.__doc__ = PlotWidget.setDataBackgroundColor.__doc__ + + def setForegroundColor(self, color): + super(PlotWindow, self).setForegroundColor(color) + self._updateColorBarBackground() + + setForegroundColor.__doc__ = PlotWidget.setForegroundColor.__doc__ + + def _updateColorBarBackground(self): + """Update the colorbar background according to the state of the plot""" + if self._isAxesDisplayed(): + color = self.getBackgroundColor() + else: + color = self.getDataBackgroundColor() + if not color.isValid(): + # If no color defined, use the background one + color = self.getBackgroundColor() + + foreground = self.getForegroundColor() + + palette = self._colorbar.palette() + palette.setColor(qt.QPalette.Background, color) + palette.setColor(qt.QPalette.Window, color) + palette.setColor(qt.QPalette.WindowText, foreground) + palette.setColor(qt.QPalette.Text, foreground) + self._colorbar.setPalette(palette) + def getInteractiveModeToolBar(self): """Returns QToolBar controlling interactive mode. @@ -457,10 +492,6 @@ class PlotWindow(PlotWidget): return self._colorbar # getters for dock widgets - @property - @deprecated(replacement="getLegendsDockWidget()", since_version="0.4.0") - def legendsDockWidget(self): - return self.getLegendsDockWidget() def getLegendsDockWidget(self): """DockWidget with Legend panel""" @@ -470,11 +501,6 @@ class PlotWindow(PlotWidget): self.addTabbedDockWidget(self._legendsDockWidget) return self._legendsDockWidget - @property - @deprecated(replacement="getCurvesRoiWidget()", since_version="0.4.0") - def curvesROIDockWidget(self): - return self.getCurvesRoiDockWidget() - def getCurvesRoiDockWidget(self): # Undocumented for a "soft deprecation" in version 0.7.0 # (still used internally for lazy loading) @@ -496,11 +522,6 @@ class PlotWindow(PlotWidget): """ return self.getCurvesRoiDockWidget().roiWidget - @property - @deprecated(replacement="getMaskToolsDockWidget()", since_version="0.4.0") - def maskToolsDockWidget(self): - return self.getMaskToolsDockWidget() - def getMaskToolsDockWidget(self): """DockWidget with image mask panel (lazy-loaded).""" if self._maskToolsDockWidget is None: @@ -539,11 +560,6 @@ class PlotWindow(PlotWidget): def panModeAction(self): return self.getInteractiveModeToolBar().getPanModeAction() - @property - @deprecated(replacement="getConsoleAction()", since_version="0.4.0") - def consoleAction(self): - return self.getConsoleAction() - def getConsoleAction(self): """QAction handling the IPython console activation. @@ -563,11 +579,6 @@ class PlotWindow(PlotWidget): self._consoleAction.setEnabled(False) return self._consoleAction - @property - @deprecated(replacement="getCrosshairAction()", since_version="0.4.0") - def crosshairAction(self): - return self.getCrosshairAction() - def getCrosshairAction(self): """Action toggling crosshair cursor mode. @@ -577,11 +588,6 @@ class PlotWindow(PlotWidget): self._crosshairAction = actions.control.CrosshairAction(self, color='red') return self._crosshairAction - @property - @deprecated(replacement="getMaskAction()", since_version="0.4.0") - def maskAction(self): - return self.getMaskAction() - def getMaskAction(self): """QAction toggling image mask dock widget @@ -589,12 +595,6 @@ class PlotWindow(PlotWidget): """ return self.getMaskToolsDockWidget().toggleViewAction() - @property - @deprecated(replacement="getPanWithArrowKeysAction()", - since_version="0.4.0") - def panWithArrowKeysAction(self): - return self.getPanWithArrowKeysAction() - def getPanWithArrowKeysAction(self): """Action toggling pan with arrow keys. @@ -604,11 +604,6 @@ class PlotWindow(PlotWidget): self._panWithArrowKeysAction = actions.control.PanWithArrowKeysAction(self) return self._panWithArrowKeysAction - @property - @deprecated(replacement="getRoiAction()", since_version="0.4.0") - def roiAction(self): - return self.getRoiAction() - def getStatsAction(self): if self._statsAction is None: self._statsAction = qt.QAction('Curves stats', self) diff --git a/silx/gui/plot/PrintPreviewToolButton.py b/silx/gui/plot/PrintPreviewToolButton.py index b48505d..d857c18 100644 --- a/silx/gui/plot/PrintPreviewToolButton.py +++ b/silx/gui/plot/PrintPreviewToolButton.py @@ -111,10 +111,11 @@ from .. import icons from . import PlotWidget from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog from ..widgets.PrintGeometryDialog import PrintGeometryDialog +from silx.utils.deprecation import deprecated __authors__ = ["P. Knobel"] __license__ = "MIT" -__date__ = "18/07/2017" +__date__ = "20/12/2018" _logger = logging.getLogger(__name__) # _logger.setLevel(logging.DEBUG) @@ -132,19 +133,19 @@ class PrintPreviewToolButton(qt.QToolButton): if not isinstance(plot, PlotWidget): raise TypeError("plot parameter must be a PlotWidget") - self.plot = plot + self._plot = plot self.setIcon(icons.getQIcon('document-print')) printGeomAction = qt.QAction("Print geometry", self) printGeomAction.setToolTip("Define a print geometry prior to sending " "the plot to the print preview dialog") - printGeomAction.setIcon(icons.getQIcon('shape-rectangle')) # fixme: icon not displayed in menu + printGeomAction.setIcon(icons.getQIcon('shape-rectangle')) printGeomAction.triggered.connect(self._setPrintConfiguration) printPreviewAction = qt.QAction("Print preview", self) printPreviewAction.setToolTip("Send plot to the print preview dialog") - printPreviewAction.setIcon(icons.getQIcon('document-print')) # fixme: icon not displayed + printPreviewAction.setIcon(icons.getQIcon('document-print')) printPreviewAction.triggered.connect(self._plotToPrintPreview) menu = qt.QMenu(self) @@ -172,24 +173,64 @@ class PrintPreviewToolButton(qt.QToolButton): self._printPreviewDialog = PrintPreviewDialog(self.parent()) return self._printPreviewDialog + def getTitle(self): + """Implement this method to fetch the title in the plot. + + :return: Title to be printed above the plot, or None (no title added) + :rtype: str or None + """ + return None + + def getCommentAndPosition(self): + """Implement this method to fetch the legend to be printed below the + figure and its position. + + :return: Legend to be printed below the figure and its position: + "CENTER", "LEFT" or "RIGHT" + :rtype: (str, str) or (None, None) + """ + return None, None + + @property + @deprecated(since_version="0.10", + replacement="getPlot()") + def plot(self): + return self._plot + + def getPlot(self): + """Return the :class:`.PlotWidget` associated with this tool button. + + :rtype: :class:`.PlotWidget` + """ + return self._plot + def _plotToPrintPreview(self): """Grab the plot widget and send it to the print preview dialog. Make sure the print preview dialog is shown and raised.""" if not self.printPreviewDialog.ensurePrinterIsSet(): return + comment, commentPosition = self.getCommentAndPosition() + if qt.HAS_SVG: svgRenderer, viewBox = self._getSvgRendererAndViewbox() self.printPreviewDialog.addSvgItem(svgRenderer, - viewBox=viewBox) + title=self.getTitle(), + comment=comment, + commentPosition=commentPosition, + viewBox=viewBox, + keepRatio=self._printGeometry["keepAspectRatio"]) else: _logger.warning("Missing QtSvg library, using a raster image") if qt.BINDING in ["PyQt4", "PySide"]: - pixmap = qt.QPixmap.grabWidget(self.plot.centralWidget()) + pixmap = qt.QPixmap.grabWidget(self._plot.centralWidget()) else: # PyQt5 and hopefully PyQt6+ - pixmap = self.plot.centralWidget().grab() - self.printPreviewDialog.addPixmap(pixmap) + pixmap = self._plot.centralWidget().grab() + self.printPreviewDialog.addPixmap(pixmap, + title=self.getTitle(), + comment=comment, + commentPosition=commentPosition) self.printPreviewDialog.show() self.printPreviewDialog.raise_() @@ -201,7 +242,7 @@ class PrintPreviewToolButton(qt.QToolButton): and to the geometry configuration (width, height, ratio) specified by the user.""" imgData = StringIO() - assert self.plot.saveGraph(imgData, fileFormat="svg"), \ + assert self._plot.saveGraph(imgData, fileFormat="svg"), \ "Unable to save graph" imgData.flush() imgData.seek(0) @@ -310,7 +351,7 @@ class PrintPreviewToolButton(qt.QToolButton): self._printGeometry = self._printConfigurationDialog.getPrintGeometry() def _getPlotAspectRatio(self): - widget = self.plot.centralWidget() + widget = self._plot.centralWidget() graphWidth = float(widget.width()) graphHeight = float(widget.height()) return graphHeight / graphWidth diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py index 182cf60..46e4523 100644 --- a/silx/gui/plot/Profile.py +++ b/silx/gui/plot/Profile.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -180,7 +180,8 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): :type scale: 2-tuple of float :param int lineWidth: width of the profile line :param str method: method to compute the profile. Can be 'mean' or 'sum' - :return: `profile, area, profileName, xLabel`, where: + :return: `coords, profile, area, profileName, xLabel`, where: + - coords is the X coordinate to use to display the profile - profile is a 2D array of the profiles of the stack of images. For a single image, the profile is a curve, so this parameter has a shape *(1, len(curve))* @@ -188,10 +189,9 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): the effective ROI area corners in plot coords. - profileName is a string describing the ROI, meant to be used as title of the profile plot - - xLabel is a string describing the meaning of the X axis on the - profile plot ("rows", "columns", "distance") + - xLabel the label for X in the profile window - :rtype: tuple(ndarray, (ndarray, ndarray), str, str) + :rtype: tuple(ndarray,ndarray,(ndarray,ndarray),str) """ if currentData is None or roiInfo is None or lineWidth is None: raise ValueError("createProfile called with invalide arguments") @@ -212,12 +212,15 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): axis=0, method=method) + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[0] + origin[0] + yMin, yMax = min(area[1]), max(area[1]) - 1 if roiWidth <= 1: profileName = 'Y = %g' % yMin else: profileName = 'Y = [%g, %g]' % (yMin, yMax) - xLabel = 'Columns' + xLabel = 'X' elif lineProjectionMode == 'Y': # Vertical profile on the whole image profile, area = _alignedFullProfile(currentData3D, @@ -226,12 +229,15 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): axis=1, method=method) + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[1] + origin[1] + xMin, xMax = min(area[0]), max(area[0]) - 1 if roiWidth <= 1: profileName = 'X = %g' % xMin else: profileName = 'X = [%g, %g]' % (xMin, xMax) - xLabel = 'Rows' + xLabel = 'Y' else: # Free line profile @@ -306,35 +312,52 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): dCol = (endPt[1] - startPt[1]) / length # Extend ROI with half a pixel on each end - startPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol - endPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol + roiStartPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol + roiEndPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol # Rotate deltas by 90 degrees to apply line width dRow, dCol = dCol, -dRow area = ( - numpy.array((startPt[1] - 0.5 * roiWidth * dCol, - startPt[1] + 0.5 * roiWidth * dCol, - endPt[1] + 0.5 * roiWidth * dCol, - endPt[1] - 0.5 * roiWidth * dCol), + numpy.array((roiStartPt[1] - 0.5 * roiWidth * dCol, + roiStartPt[1] + 0.5 * roiWidth * dCol, + roiEndPt[1] + 0.5 * roiWidth * dCol, + roiEndPt[1] - 0.5 * roiWidth * dCol), dtype=numpy.float32) * scale[0] + origin[0], - numpy.array((startPt[0] - 0.5 * roiWidth * dRow, - startPt[0] + 0.5 * roiWidth * dRow, - endPt[0] + 0.5 * roiWidth * dRow, - endPt[0] - 0.5 * roiWidth * dRow), + numpy.array((roiStartPt[0] - 0.5 * roiWidth * dRow, + roiStartPt[0] + 0.5 * roiWidth * dRow, + roiEndPt[0] + 0.5 * roiWidth * dRow, + roiEndPt[0] - 0.5 * roiWidth * dRow), dtype=numpy.float32) * scale[1] + origin[1]) - y0, x0 = startPt - y1, x1 = endPt - if x1 == x0 or y1 == y0: - profileName = 'From (%g, %g) to (%g, %g)' % (x0, y0, x1, y1) + # Convert start and end points back to plot coords + y0 = startPt[0] * scale[1] + origin[1] + x0 = startPt[1] * scale[0] + origin[0] + y1 = endPt[0] * scale[1] + origin[1] + x1 = endPt[1] * scale[0] + origin[0] + + if startPt[1] == endPt[1]: + profileName = 'X = %g; Y = [%g, %g]' % (x0, y0, y1) + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[1] + y0 + xLabel = 'Y' + + elif startPt[0] == endPt[0]: + profileName = 'Y = %g; X = [%g, %g]' % (y0, x0, x1) + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[0] + x0 + xLabel = 'X' + else: m = (y1 - y0) / (x1 - x0) b = y0 - m * x0 profileName = 'y = %g * x %+g ; width=%d' % (m, b, roiWidth) - xLabel = 'Distance' + coords = numpy.linspace(x0, x1, len(profile[0]), + endpoint=True, + dtype=numpy.float32) + xLabel = 'X' - return profile, area, profileName, xLabel + return coords, profile, area, profileName, xLabel # ProfileToolBar ############################################################## @@ -458,7 +481,7 @@ class ProfileToolBar(qt.QToolBar): self.addWidget(self.lineWidthSpinBox) self.methodsButton = ProfileOptionToolButton(parent=self, plot=self) - self.addWidget(self.methodsButton) + self.__profileOptionToolAction = self.addWidget(self.methodsButton) # TODO: add connection with the signal self.methodsButton.sigMethodChanged.connect(self.setProfileMethod) @@ -650,7 +673,7 @@ class ProfileToolBar(qt.QToolBar): if self._roiInfo is None: return - profile, area, profileName, xLabel = createProfile( + coords, profile, area, profileName, xLabel = createProfile( roiInfo=self._roiInfo, currentData=currentData, origin=origin, @@ -658,28 +681,25 @@ class ProfileToolBar(qt.QToolBar): lineWidth=self.lineWidthSpinBox.value(), method=method) - self.getProfilePlot().setGraphTitle(profileName) + profilePlot = self.getProfilePlot() + + profilePlot.setGraphTitle(profileName) + profilePlot.getXAxis().setLabel(xLabel) dataIs3D = len(currentData.shape) > 2 if dataIs3D: - self.getProfilePlot().addImage(profile, - legend=profileName, - xlabel=xLabel, - ylabel="Frame index (depth)", - colormap=colormap) + profileScale = (coords[-1] - coords[0]) / profile.shape[1], 1 + profilePlot.addImage(profile, + legend=profileName, + colormap=colormap, + origin=(coords[0], 0), + scale=profileScale) + profilePlot.getYAxis().setLabel("Frame index (depth)") else: - coords = numpy.arange(len(profile[0]), dtype=numpy.float32) - # Scale horizontal and vertical profile coordinates - if self._roiInfo[2] == 'X': - coords = coords * scale[0] + origin[0] - elif self._roiInfo[2] == 'Y': - coords = coords * scale[1] + origin[1] - - self.getProfilePlot().addCurve(coords, - profile[0], - legend=profileName, - xlabel=xLabel, - color=self.overlayColor) + profilePlot.addCurve(coords, + profile[0], + legend=profileName, + color=self.overlayColor) self.plot.addItem(area[0], area[1], legend=self._POLYGON_LEGEND, @@ -732,6 +752,9 @@ class ProfileToolBar(qt.QToolBar): def getProfileMethod(self): return self._method + def getProfileOptionToolAction(self): + return self.__profileOptionToolAction + class Profile3DToolBar(ProfileToolBar): def __init__(self, parent=None, stackview=None, diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py index de645be..0c6797f 100644 --- a/silx/gui/plot/ScatterMaskToolsWidget.py +++ b/silx/gui/plot/ScatterMaskToolsWidget.py @@ -35,7 +35,7 @@ from __future__ import division __authors__ = ["P. Knobel"] __license__ = "MIT" -__date__ = "24/04/2018" +__date__ = "15/02/2019" import math @@ -152,6 +152,22 @@ class ScatterMask(BaseMask): stencil = (y - cy)**2 + (x - cx)**2 < radius**2 self.updateStencil(level, stencil, mask) + def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True): + """Mask/Unmask an ellipse of the given mask level. + + :param int level: Mask level to update. + :param int crow: Row of the center of the ellipse + :param int ccol: Column of the center of the ellipse + :param float radius_r: Radius of the ellipse in the row + :param float radius_c: Radius of the ellipse in the column + :param bool mask: True to mask (default), False to unmask. + """ + def is_inside(px, py): + return (px - ccol)**2 / radius_c**2 + (py - crow)**2 / radius_r**2 <= 1.0 + x, y = self._getXY() + indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])] + self.updatePoints(level, indices_inside, mask) + def updateLine(self, level, y0, x0, y1, x1, width, mask=True): """Mask/Unmask points inside a rectangle defined by a line (two end points) and a width. @@ -490,26 +506,35 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): level = self.levelSpinBox.value() - if (self._drawingMode == 'rectangle' and - event['event'] == 'drawingFinished'): - doMask = self._isMasking() + if self._drawingMode == 'rectangle': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + + self._mask.updateRectangle( + level, + y=event['y'], + x=event['x'], + height=abs(event['height']), + width=abs(event['width']), + mask=doMask) + self._mask.commit() - self._mask.updateRectangle( - level, - y=event['y'], - x=event['x'], - height=abs(event['height']), - width=abs(event['width']), - mask=doMask) - self._mask.commit() + elif self._drawingMode == 'ellipse': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + center = event['points'][0] + size = event['points'][1] + self._mask.updateEllipse(level, center[1], center[0], + size[1], size[0], doMask) + self._mask.commit() - elif (self._drawingMode == 'polygon' and - event['event'] == 'drawingFinished'): - doMask = self._isMasking() - vertices = event['points'] - vertices = vertices[:, (1, 0)] # (y, x) - self._mask.updatePolygon(level, vertices, doMask) - self._mask.commit() + elif self._drawingMode == 'polygon': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + vertices = event['points'] + vertices = vertices[:, (1, 0)] # (y, x) + self._mask.updatePolygon(level, vertices, doMask) + self._mask.commit() elif self._drawingMode == 'pencil': doMask = self._isMasking() @@ -536,6 +561,8 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): self._lastPencilPos = None else: self._lastPencilPos = y, x + else: + _logger.error("Drawing mode %s unsupported", self._drawingMode) def _loadRangeFromColormapTriggered(self): """Set range from active scatter colormap range""" diff --git a/silx/gui/plot/ScatterView.py b/silx/gui/plot/ScatterView.py index ae79cf9..5fc66ef 100644 --- a/silx/gui/plot/ScatterView.py +++ b/silx/gui/plot/ScatterView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -353,3 +353,13 @@ class ScatterView(qt.QMainWindow): return self.getPlotWidget().resetZoom(*args, **kwargs) resetZoom.__doc__ = PlotWidget.resetZoom.__doc__ + + def getSelectionMask(self, *args, **kwargs): + return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs) + + getSelectionMask.__doc__ = ScatterMaskToolsWidget.getSelectionMask.__doc__ + + def setSelectionMask(self, *args, **kwargs): + return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs) + + setSelectionMask.__doc__ = ScatterMaskToolsWidget.setSelectionMask.__doc__ diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py index 72b6cd4..2a3d7e8 100644 --- a/silx/gui/plot/StackView.py +++ b/silx/gui/plot/StackView.py @@ -89,14 +89,8 @@ from silx.utils.array_like import DatasetView, ListOfImages from silx.math import calibration from silx.utils.deprecation import deprecated_warning -try: - import h5py -except ImportError: - def is_dataset(obj): - return False - h5py = None -else: - from silx.io.utils import is_dataset +import h5py +from silx.io.utils import is_dataset _logger = logging.getLogger(__name__) diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py index bb66613..4ba4fab 100644 --- a/silx/gui/plot/StatsWidget.py +++ b/silx/gui/plot/StatsWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,552 +31,1266 @@ __license__ = "MIT" __date__ = "24/07/2018" -import functools +from collections import OrderedDict +from contextlib import contextmanager import logging +import weakref + import numpy -from collections import OrderedDict -import silx.utils.weakref from silx.gui import qt from silx.gui import icons -from silx.gui.plot.items.curve import Curve as CurveItem -from silx.gui.plot.items.histogram import Histogram as HistogramItem -from silx.gui.plot.items.image import ImageBase as ImageItem -from silx.gui.plot.items.scatter import Scatter as ScatterItem from silx.gui.plot import stats as statsmdl from silx.gui.widgets.TableWidget import TableWidget from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter +from silx.gui.plot.items.core import ItemChangedType +from silx.gui.widgets.FlowLayout import FlowLayout +from . import PlotWidget +from . import items as plotitems -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) -class StatsWidget(qt.QWidget): + +# Helper class to handle specific calls to PlotWidget and SceneWidget + +class _Wrapper(qt.QObject): + """Base class for connection with PlotWidget and SceneWidget. + + This class is used when no PlotWidget or SceneWidget is connected. + + :param plot: The plot to be used """ - Widget displaying a set of :class:`Stat` to be displayed on a - :class:`StatsTable` and to be apply on items contained in the :class:`Plot` - Also contains options to: - * compute statistics on all the data or on visible data only - * show statistics of all items or only the active one + sigItemAdded = qt.Signal(object) + """Signal emitted when a new item is added. - :param parent: Qt parent - :param plot: the plot containing items on which we want statistics. + It provides the added item. """ - sigVisibilityChanged = qt.Signal(bool) + sigItemRemoved = qt.Signal(object) + """Signal emitted when an item is (about to be) removed. - NUMBER_FORMAT = '{0:.3f}' + It provides the removed item. + """ - class OptionsWidget(qt.QToolBar): - - def __init__(self, parent=None): - qt.QToolBar.__init__(self, parent) - self.setIconSize(qt.QSize(16, 16)) - - action = qt.QAction(self) - action.setIcon(icons.getQIcon("stats-active-items")) - action.setText("Active items only") - action.setToolTip("Display stats for active items only.") - action.setCheckable(True) - action.setChecked(True) - self.__displayActiveItems = action - - action = qt.QAction(self) - action.setIcon(icons.getQIcon("stats-whole-items")) - action.setText("All items") - action.setToolTip("Display stats for all available items.") - action.setCheckable(True) - self.__displayWholeItems = action - - action = qt.QAction(self) - action.setIcon(icons.getQIcon("stats-visible-data")) - action.setText("Use the visible data range") - action.setToolTip("Use the visible data range.
" - "If activated the data is filtered to only use" - "visible data of the plot." - "The filtering is a data sub-sampling." - "No interpolation is made to fit data to" - "boundaries.") - action.setCheckable(True) - self.__useVisibleData = action - - action = qt.QAction(self) - action.setIcon(icons.getQIcon("stats-whole-data")) - action.setText("Use the full data range") - action.setToolTip("Use the full data range.") - action.setCheckable(True) - action.setChecked(True) - self.__useWholeData = action - - self.addAction(self.__displayWholeItems) - self.addAction(self.__displayActiveItems) - self.addSeparator() - self.addAction(self.__useVisibleData) - self.addAction(self.__useWholeData) - - self.itemSelection = qt.QActionGroup(self) - self.itemSelection.setExclusive(True) - self.itemSelection.addAction(self.__displayActiveItems) - self.itemSelection.addAction(self.__displayWholeItems) - - self.dataRangeSelection = qt.QActionGroup(self) - self.dataRangeSelection.setExclusive(True) - self.dataRangeSelection.addAction(self.__useWholeData) - self.dataRangeSelection.addAction(self.__useVisibleData) - - def isActiveItemMode(self): - return self.itemSelection.checkedAction() is self.__displayActiveItems - - def isVisibleDataRangeMode(self): - return self.dataRangeSelection.checkedAction() is self.__useVisibleData + sigCurrentChanged = qt.Signal(object) + """Signal emitted when the current item has changed. - def __init__(self, parent=None, plot=None, stats=None): - qt.QWidget.__init__(self, parent) - self.setLayout(qt.QVBoxLayout()) - self.layout().setContentsMargins(0, 0, 0, 0) - self._options = self.OptionsWidget(parent=self) - self.layout().addWidget(self._options) - self._statsTable = StatsTable(parent=self, plot=plot) - self.setStats = self._statsTable.setStats - self.setStats(stats) + It provides the current item. + """ - self.layout().addWidget(self._statsTable) - self.setPlot = self._statsTable.setPlot + sigVisibleDataChanged = qt.Signal() + """Signal emitted when the visible data area has changed""" - self._options.itemSelection.triggered.connect( - self._optSelectionChanged) - self._options.dataRangeSelection.triggered.connect( - self._optDataRangeChanged) - self._optSelectionChanged() - self._optDataRangeChanged() + def __init__(self, plot=None): + super(_Wrapper, self).__init__(parent=None) + self._plotRef = None if plot is None else weakref.ref(plot) - self.setDisplayOnlyActiveItem = self._statsTable.setDisplayOnlyActiveItem - self.setStatsOnVisibleData = self._statsTable.setStatsOnVisibleData + def getPlot(self): + """Returns the plot attached to this widget""" + return None if self._plotRef is None else self._plotRef() - def showEvent(self, event): - self.sigVisibilityChanged.emit(True) - qt.QWidget.showEvent(self, event) + def getItems(self): + """Returns the list of items in the plot - def hideEvent(self, event): - self.sigVisibilityChanged.emit(False) - qt.QWidget.hideEvent(self, event) + :rtype: List[object] + """ + return () - def _optSelectionChanged(self, action=None): - self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode()) + def getSelectedItems(self): + """Returns the list of selected items in the plot - def _optDataRangeChanged(self, action=None): - self._statsTable.setStatsOnVisibleData(self._options.isVisibleDataRangeMode()) + :rtype: List[object] + """ + return () + def setCurrentItem(self, item): + """Set the current/active item in the plot -class BasicStatsWidget(StatsWidget): + :param item: The plot item to set as active/current + """ + pass + + def getLabel(self, item): + """Returns the label of the given item. + + :param item: + :rtype: str + """ + return '' + + def getKind(self, item): + """Returns the kind of an item or None if not supported + + :param item: + :rtype: Union[str,None] + """ + return None + + +class _PlotWidgetWrapper(_Wrapper): + """Class handling PlotWidget specific calls and signal connections + + See :class:`._Wrapper` for documentation + + :param PlotWidget plot: """ - Widget defining a simple set of :class:`Stat` to be displayed on a - :class:`StatsWidget`. - :param parent: Qt parent - :param plot: the plot containing items on which we want statistics. + def __init__(self, plot): + assert isinstance(plot, PlotWidget) + super(_PlotWidgetWrapper, self).__init__(plot) + plot.sigItemAdded.connect(self.sigItemAdded.emit) + plot.sigItemAboutToBeRemoved.connect(self.sigItemRemoved.emit) + plot.sigActiveCurveChanged.connect(self._activeCurveChanged) + plot.sigActiveImageChanged.connect(self._activeImageChanged) + plot.sigActiveScatterChanged.connect(self._activeScatterChanged) + plot.sigPlotSignal.connect(self._limitsChanged) + + def _activeChanged(self, kind): + """Handle change of active curve/image/scatter""" + plot = self.getPlot() + if plot is not None: + item = plot._getActiveItem(kind=kind) + if item is None or self.getKind(item) is not None: + self.sigCurrentChanged.emit(item) + + def _activeCurveChanged(self, previous, current): + self._activeChanged(kind='curve') + + def _activeImageChanged(self, previous, current): + self._activeChanged(kind='image') + + def _activeScatterChanged(self, previous, current): + self._activeChanged(kind='scatter') + + def _limitsChanged(self, event): + """Handle change of plot area limits.""" + if event['event'] == 'limitsChanged': + self.sigVisibleDataChanged.emit() + + def getItems(self): + plot = self.getPlot() + return () if plot is None else plot._getItems() + + def getSelectedItems(self): + plot = self.getPlot() + items = [] + if plot is not None: + for kind in plot._ACTIVE_ITEM_KINDS: + item = plot._getActiveItem(kind=kind) + if item is not None: + items.append(item) + return tuple(items) + + def setCurrentItem(self, item): + plot = self.getPlot() + if plot is not None: + kind = self.getKind(item) + if kind in plot._ACTIVE_ITEM_KINDS: + if plot._getActiveItem(kind) != item: + plot._setActiveItem(kind, item.getLegend()) + + def getLabel(self, item): + return item.getLegend() + + def getKind(self, item): + if isinstance(item, plotitems.Curve): + return 'curve' + elif isinstance(item, plotitems.ImageData): + return 'image' + elif isinstance(item, plotitems.Scatter): + return 'scatter' + elif isinstance(item, plotitems.Histogram): + return 'histogram' + else: + return None + + +class _SceneWidgetWrapper(_Wrapper): + """Class handling SceneWidget specific calls and signal connections + + See :class:`._Wrapper` for documentation + + :param SceneWidget plot: """ - STATS = StatsHandler(( - (statsmdl.StatMin(), StatFormatter()), - statsmdl.StatCoordMin(), - (statsmdl.StatMax(), StatFormatter()), - statsmdl.StatCoordMax(), - (('std', numpy.std), StatFormatter()), - (('mean', numpy.mean), StatFormatter()), - statsmdl.StatCOM() - )) + def __init__(self, plot): + # Lazy-import to avoid circular imports + from ..plot3d.SceneWidget import SceneWidget - def __init__(self, parent=None, plot=None): - StatsWidget.__init__(self, parent=parent, plot=plot, stats=self.STATS) + assert isinstance(plot, SceneWidget) + super(_SceneWidgetWrapper, self).__init__(plot) + plot.getSceneGroup().sigItemAdded.connect(self.sigItemAdded) + plot.getSceneGroup().sigItemRemoved.connect(self.sigItemRemoved) + plot.selection().sigCurrentChanged.connect(self._currentChanged) + # sigVisibleDataChanged is never emitted + + def _currentChanged(self, current, previous): + self.sigCurrentChanged.emit(current) + + def getItems(self): + plot = self.getPlot() + return () if plot is None else tuple(plot.getSceneGroup().visit()) + + def getSelectedItems(self): + plot = self.getPlot() + return () if plot is None else (plot.selection().getCurrentItem(),) + def setCurrentItem(self, item): + plot = self.getPlot() + if plot is not None: + plot.selection().setCurrentItem(item) -class StatsTable(TableWidget): + def getLabel(self, item): + return item.getLabel() + + def getKind(self, item): + from ..plot3d import items as plot3ditems + + if isinstance(item, (plot3ditems.ImageData, + plot3ditems.ScalarField3D)): + return 'image' + elif isinstance(item, (plot3ditems.Scatter2D, + plot3ditems.Scatter3D)): + return 'scatter' + else: + return None + + +class _ScalarFieldViewWrapper(_Wrapper): + """Class handling ScalarFieldView specific calls and signal connections + + See :class:`._Wrapper` for documentation + + :param SceneWidget plot: """ - TableWidget displaying for each curves contained by the Plot some - information: - * legend - * minimal value - * maximal value - * standard deviation (std) + def __init__(self, plot): + # Lazy-import to avoid circular imports + from ..plot3d.ScalarFieldView import ScalarFieldView + from ..plot3d.items import ScalarField3D + + assert isinstance(plot, ScalarFieldView) + super(_ScalarFieldViewWrapper, self).__init__(plot) + self._item = ScalarField3D() + self._dataChanged() + plot.sigDataChanged.connect(self._dataChanged) + # sigItemAdded, sigItemRemoved, sigVisibleDataChanged are never emitted + + def _dataChanged(self): + plot = self.getPlot() + if plot is not None: + self._item.setData(plot.getData(copy=False), copy=False) + self.sigCurrentChanged.emit(self._item) - :param parent: The widget's parent. - :param plot: :class:`.PlotWidget` instance on which to operate + def getItems(self): + plot = self.getPlot() + return () if plot is None else (self._item,) + + def getSelectedItems(self): + return self.getItems() + + def setCurrentItem(self, item): + pass + + def getLabel(self, item): + return 'Data' + + def getKind(self, item): + return 'image' + + +class _Container(object): + """Class to contain a plot item. + + This is apparently needed for compatibility with PySide2, + + :param QObject obj: """ + def __init__(self, obj): + self._obj = obj - COMPATIBLE_KINDS = { - 'curve': CurveItem, - 'image': ImageItem, - 'scatter': ScatterItem, - 'histogram': HistogramItem - } + def __call__(self): + return self._obj - COMPATIBLE_ITEMS = tuple(COMPATIBLE_KINDS.values()) - def __init__(self, parent=None, plot=None): - TableWidget.__init__(self, parent) - """Next freeID for the curve""" - self.plot = None - self._displayOnlyActItem = False - self._statsOnVisibleData = False - self._lgdAndKindToItems = {} - """Associate to a tuple(legend, kind) the items legend""" - self.callbackImage = None - self.callbackScatter = None - self.callbackCurve = None - """Associate the curve legend to his first item""" +class _StatsWidgetBase(object): + """ + Base class for all widgets which want to display statistics + """ + def __init__(self, statsOnVisibleData, displayOnlyActItem): + self._displayOnlyActItem = displayOnlyActItem + self._statsOnVisibleData = statsOnVisibleData self._statsHandler = None - self._legendsSet = [] - """list of legends actually displayed""" - self._resetColumns() - self.setColumnCount(len(self._columns)) - self.setSelectionBehavior(qt.QAbstractItemView.SelectRows) - self.setPlot(plot) - self.setSortingEnabled(True) + self.__default_skipped_events = ( + ItemChangedType.ALPHA, + ItemChangedType.COLOR, + ItemChangedType.COLORMAP, + ItemChangedType.SYMBOL, + ItemChangedType.SYMBOL_SIZE, + ItemChangedType.LINE_WIDTH, + ItemChangedType.LINE_STYLE, + ItemChangedType.LINE_BG_COLOR, + ItemChangedType.FILL, + ItemChangedType.HIGHLIGHTED_COLOR, + ItemChangedType.HIGHLIGHTED_STYLE, + ItemChangedType.TEXT, + ItemChangedType.OVERLAY, + ItemChangedType.VISUALIZATION_MODE, + ) + + self._plotWrapper = _Wrapper() + self._dealWithPlotConnection(create=True) - def _resetColumns(self): - self._columns_index = OrderedDict([('legend', 0), ('kind', 1)]) - self._columns = self._columns_index.keys() - self.setColumnCount(len(self._columns)) + def setPlot(self, plot): + """Define the plot to interact with - def setStats(self, statsHandler): + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied """ + try: + import OpenGL + except ImportError: + has_opengl = False + else: + has_opengl = True + from ..plot3d.SceneWidget import SceneWidget # Lazy import + self._dealWithPlotConnection(create=False) + self.clear() + if plot is None: + self._plotWrapper = _Wrapper() + elif isinstance(plot, PlotWidget): + self._plotWrapper = _PlotWidgetWrapper(plot) + else: + if has_opengl is True: + if isinstance(plot, SceneWidget): + self._plotWrapper = _SceneWidgetWrapper(plot) + else: # Expect a ScalarFieldView + self._plotWrapper = _ScalarFieldViewWrapper(plot) + else: + _logger.warning('OpenGL not installed, %s not managed' % ('SceneWidget qnd ScalarFieldView')) + self._dealWithPlotConnection(create=True) + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. - :param statsHandler: Set the statistics to be displayed and how to - format them using - :rtype: :class:`StatsHandler` + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using """ - _statsHandler = statsHandler if statsHandler is None: - _statsHandler = StatsHandler(statFormatters=()) - if isinstance(_statsHandler, (list, tuple)): - _statsHandler = StatsHandler(_statsHandler) - assert isinstance(_statsHandler, StatsHandler) - self._resetColumns() - self.clear() - - for statName, stat in list(_statsHandler.stats.items()): - assert isinstance(stat, statsmdl.StatBase) - self._columns_index[statName] = len(self._columns_index) - self._statsHandler = _statsHandler - self._columns = self._columns_index.keys() - self.setColumnCount(len(self._columns)) + statsHandler = StatsHandler(statFormatters=()) + elif isinstance(statsHandler, (list, tuple)): + statsHandler = StatsHandler(statsHandler) + assert isinstance(statsHandler, StatsHandler) - self._updateItemObserve() - self._updateAllStats() + self._statsHandler = statsHandler def getStatsHandler(self): + """Returns the :class:`StatsHandler` in use. + + :rtype: StatsHandler + """ return self._statsHandler - def _updateAllStats(self): - for (legend, kind) in self._lgdAndKindToItems: - self._updateStats(legend, kind) + def getPlot(self): + """Returns the plot attached to this widget - @staticmethod - def _getKind(myItem): - if isinstance(myItem, CurveItem): - return 'curve' - elif isinstance(myItem, ImageItem): - return 'image' - elif isinstance(myItem, ScatterItem): - return 'scatter' - elif isinstance(myItem, HistogramItem): - return 'histogram' + :rtype: Union[PlotWidget,SceneWidget,None] + """ + return self._plotWrapper.getPlot() + + def _dealWithPlotConnection(self, create=True): + """Manage connection to plot signals + + Note: connection on Item are managed by _addItem and _removeItem methods + """ + connections = [] # List of (signal, slot) to connect/disconnect + if self._statsOnVisibleData: + connections.append( + (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats)) + + if self._displayOnlyActItem: + connections.append( + (self._plotWrapper.sigCurrentChanged, self._updateItemObserve)) else: - return None + connections += [ + (self._plotWrapper.sigItemAdded, self._addItem), + (self._plotWrapper.sigItemRemoved, self._removeItem), + (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged)] + + for signal, slot in connections: + if create: + signal.connect(slot) + else: + signal.disconnect(slot) - def setPlot(self, plot): + def _updateItemObserve(self, *args): + """Reload table depending on mode""" + raise NotImplementedError('Base class') + + def _updateStats(self, item): + """Update displayed information for given plot item + + :param item: The plot item + """ + raise NotImplementedError('Base class') + + def _updateAllStats(self): + """Update stats for all rows in the table""" + raise NotImplementedError('Base class') + + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + """Toggle display off all items or only the active/selected one + + :param bool displayOnlyActItem: + True if we want to only show active item """ - Define the plot to interact with + self._displayOnlyActItem = displayOnlyActItem + + def setStatsOnVisibleData(self, b): + """Toggle computation of statistics on whole data or only visible ones. + + .. warning:: When visible data is activated we will process to a simple + filtering of visible data by the user. The filtering is a + simple data sub-sampling. No interpolation is made to fit + data to boundaries. - :param plot: the plot containing the items on which statistics are - applied - :rtype: :class:`.PlotWidget` + :param bool b: True if we want to apply statistics only on visible data """ - if self.plot: + if self._statsOnVisibleData != b: self._dealWithPlotConnection(create=False) - self.plot = plot - self.clear() - if self.plot: + self._statsOnVisibleData = b self._dealWithPlotConnection(create=True) - self._updateItemObserve() + self._updateAllStats() - def _updateItemObserve(self): - if self.plot: - self.clear() - if self._displayOnlyActItem is True: - activeCurve = self.plot.getActiveCurve(just_legend=False) - activeScatter = self.plot._getActiveItem(kind='scatter', - just_legend=False) - activeImage = self.plot.getActiveImage(just_legend=False) - if activeCurve: - self._addItem(activeCurve) - if activeImage: - self._addItem(activeImage) - if activeScatter: - self._addItem(activeScatter) - else: - [self._addItem(curve) for curve in self.plot.getAllCurves()] - [self._addItem(image) for image in self.plot.getAllImages()] - scatters = self.plot._getItems(kind='scatter', - just_legend=False, - withhidden=True) - [self._addItem(scatter) for scatter in scatters] - histograms = self.plot._getItems(kind='histogram', - just_legend=False, - withhidden=True) - [self._addItem(histogram) for histogram in histograms] + def _addItem(self, item): + """Add a plot item to the table - def _dealWithPlotConnection(self, create=True): + If item is not supported, it is ignored. + + :param item: The plot item + :returns: True if the item is added to the widget. + :rtype: bool """ - Manage connection to plot signals + raise NotImplementedError('Base class') - Note: connection on Item are managed by the _removeItem function + def _removeItem(self, item): + """Remove table items corresponding to given plot item from the table. + + :param item: The plot item """ - if self.plot is None: - return - if self._displayOnlyActItem: - if create is True: - if self.callbackImage is None: - self.callbackImage = functools.partial(self._activeItemChanged, 'image') - self.callbackScatter = functools.partial(self._activeItemChanged, 'scatter') - self.callbackCurve = functools.partial(self._activeItemChanged, 'curve') - self.plot.sigActiveImageChanged.connect(self.callbackImage) - self.plot.sigActiveScatterChanged.connect(self.callbackScatter) - self.plot.sigActiveCurveChanged.connect(self.callbackCurve) - else: - if self.callbackImage is not None: - self.plot.sigActiveImageChanged.disconnect(self.callbackImage) - self.plot.sigActiveScatterChanged.disconnect(self.callbackScatter) - self.plot.sigActiveCurveChanged.disconnect(self.callbackCurve) - self.callbackImage = None - self.callbackScatter = None - self.callbackCurve = None - else: - if create is True: - self.plot.sigContentChanged.connect(self._plotContentChanged) - else: - self.plot.sigContentChanged.disconnect(self._plotContentChanged) - if create is True: - self.plot.sigPlotSignal.connect(self._zoomPlotChanged) - else: - self.plot.sigPlotSignal.disconnect(self._zoomPlotChanged) + raise NotImplementedError('Base class') + + def _plotCurrentChanged(self, current): + """Handle change of current item and update selection in table + + :param current: + """ + raise NotImplementedError('Base class') def clear(self): + """clear GUI""" + pass + + def _skipPlotItemChangedEvent(self, event): """ - Clear all existing items + + :param ItemChangedtype event: event to filter or not + :return: True if we want to ignore this ItemChangedtype + :rtype: bool """ - lgdsAndKinds = list(self._lgdAndKindToItems.keys()) - for lgdAndKind in lgdsAndKinds: - self._removeItem(legend=lgdAndKind[0], kind=lgdAndKind[1]) - self._lgdAndKindToItems = {} - qt.QTableWidget.clear(self) + return event in self.__default_skipped_events + + +class StatsTable(_StatsWidgetBase, TableWidget): + """ + TableWidget displaying for each curves contained by the Plot some + information: + + * legend + * minimal value + * maximal value + * standard deviation (std) + + :param QWidget parent: The widget's parent. + :param Union[PlotWidget,SceneWidget] plot: + :class:`PlotWidget` or :class:`SceneWidget` instance on which to operate + """ + + _LEGEND_HEADER_DATA = 'legend' + _KIND_HEADER_DATA = 'kind' + + def __init__(self, parent=None, plot=None): + TableWidget.__init__(self, parent) + _StatsWidgetBase.__init__(self, statsOnVisibleData=False, + displayOnlyActItem=False) + + # Init for _displayOnlyActItem == False + assert self._displayOnlyActItem is False + self.setSelectionBehavior(qt.QAbstractItemView.SelectRows) + self.setSelectionMode(qt.QAbstractItemView.SingleSelection) + self.currentItemChanged.connect(self._currentItemChanged) + self.setRowCount(0) + self.setColumnCount(2) - # It have to called befor3e accessing to the header items - self.setHorizontalHeaderLabels(list(self._columns)) - - if self._statsHandler is not None: - for columnId, name in enumerate(self._columns): - item = self.horizontalHeaderItem(columnId) - if name in self._statsHandler.stats: - stat = self._statsHandler.stats[name] - text = stat.name[0].upper() + stat.name[1:] - if stat.description is not None: - tooltip = stat.description - else: - tooltip = "" - else: - text = name[0].upper() + name[1:] - tooltip = "" - item.setToolTip(tooltip) - item.setText(text) + # Init headers + headerItem = qt.QTableWidgetItem('Legend') + headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA) + self.setHorizontalHeaderItem(0, headerItem) + headerItem = qt.QTableWidgetItem('Kind') + headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA) + self.setHorizontalHeaderItem(1, headerItem) + + self.setSortingEnabled(True) + self.setPlot(plot) - if hasattr(self.horizontalHeader(), 'setSectionResizeMode'): # Qt5 - self.horizontalHeader().setSectionResizeMode(qt.QHeaderView.ResizeToContents) + @contextmanager + def _disableSorting(self): + """Context manager that disables table sorting + + Previous state is restored when leaving + """ + sorting = self.isSortingEnabled() + if sorting: + self.setSortingEnabled(False) + yield + if sorting: + self.setSortingEnabled(sorting) + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. + + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using + """ + self._removeAllItems() + _StatsWidgetBase.setStats(self, statsHandler) + + self.setRowCount(0) + self.setColumnCount(len(self._statsHandler.stats) + 2) # + legend and kind + + for index, stat in enumerate(self._statsHandler.stats.values()): + headerItem = qt.QTableWidgetItem(stat.name.capitalize()) + headerItem.setData(qt.Qt.UserRole, stat.name) + if stat.description is not None: + headerItem.setToolTip(stat.description) + self.setHorizontalHeaderItem(2 + index, headerItem) + + horizontalHeader = self.horizontalHeader() + if hasattr(horizontalHeader, 'setSectionResizeMode'): # Qt5 + horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents) else: # Qt4 - self.horizontalHeader().setResizeMode(qt.QHeaderView.ResizeToContents) - self.setColumnHidden(self._columns_index['kind'], True) + horizontalHeader.setResizeMode(qt.QHeaderView.ResizeToContents) - def _addItem(self, item): - assert isinstance(item, self.COMPATIBLE_ITEMS) - if (item.getLegend(), self._getKind(item)) in self._lgdAndKindToItems: - self._updateStats(item.getLegend(), self._getKind(item)) - return + self._updateItemObserve() + + def setPlot(self, plot): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + _StatsWidgetBase.setPlot(self, plot) + self._updateItemObserve() + + def clear(self): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + self._removeAllItems() + + def _updateItemObserve(self, *args): + """Reload table depending on mode""" + self._removeAllItems() + + # Get selected or all items from the plot + if self._displayOnlyActItem: # Only selected + items = self._plotWrapper.getSelectedItems() + else: # All items + items = self._plotWrapper.getItems() + + # Add items to the plot + for item in items: + self._addItem(item) + + def _plotCurrentChanged(self, current): + """Handle change of current item and update selection in table - self.setRowCount(self.rowCount() + 1) - indexTable = self.rowCount() - 1 - kind = self._getKind(item) - - self._lgdAndKindToItems[(item.getLegend(), kind)] = {} - - # the get item will manage the item creation of not existing - _createItem = self._getItem - for itemName in self._columns: - _createItem(name=itemName, legend=item.getLegend(), kind=kind, - indexTable=indexTable) - - self._updateStats(legend=item.getLegend(), kind=kind) - - callback = functools.partial( - silx.utils.weakref.WeakMethodProxy(self._updateStats), - item.getLegend(), kind) - item.sigItemChanged.connect(callback) - self.setColumnHidden(self._columns_index['kind'], - item.getLegend() not in self._legendsSet) - self._legendsSet.append(item.getLegend()) - - def _getItem(self, name, legend, kind, indexTable): - if (legend, kind) not in self._lgdAndKindToItems: - self._lgdAndKindToItems[(legend, kind)] = {} - if not (name in self._lgdAndKindToItems[(legend, kind)] and - self._lgdAndKindToItems[(legend, kind)]): - if name in ('legend', 'kind'): - _item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type) - if name == 'legend': - _item.setText(legend) + :param current: + """ + row = self._itemToRow(current) + if row is None: + if self.currentRow() >= 0: + self.setCurrentCell(-1, -1) + elif row != self.currentRow(): + self.setCurrentCell(row, 0) + + def _tableItemToItem(self, tableItem): + """Find the plot item corresponding to a table item + + :param QTableWidgetItem tableItem: + :rtype: QObject + """ + container = tableItem.data(qt.Qt.UserRole) + return container() + + def _itemToRow(self, item): + """Find the row corresponding to a plot item + + :param item: The plot item + :return: The corresponding row index + :rtype: Union[int,None] + """ + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + if self._tableItemToItem(tableItem) == item: + return row + return None + + def _itemToTableItems(self, item): + """Find all table items corresponding to a plot item + + :param item: The plot item + :return: An ordered dict of column name to QTableWidgetItem mapping + for the given plot item. + :rtype: OrderedDict + """ + result = OrderedDict() + row = self._itemToRow(item) + if row is not None: + for column in range(self.columnCount()): + tableItem = self.item(row, column) + if self._tableItemToItem(tableItem) != item: + _logger.error("Table item/plot item mismatch") else: - assert name == 'kind' - _item.setText(kind) + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + result[name] = tableItem + return result + + def _plotItemChanged(self, event): + """Handle modifications of the items. + + :param event: + """ + if self._skipPlotItemChangedEvent(event) is True: + return + else: + item = self.sender() + self._updateStats(item) + + def _addItem(self, item): + """Add a plot item to the table + + If item is not supported, it is ignored. + + :param item: The plot item + :returns: True if the item is added to the widget. + :rtype: bool + """ + if self._itemToRow(item) is not None: + _logger.info("Item already present in the table") + self._updateStats(item) + return True + + kind = self._plotWrapper.getKind(item) + if kind not in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.info("Item has not a supported type: %s", item) + return False + + # Prepare table items + tableItems = [ + qt.QTableWidgetItem(), # Legend + qt.QTableWidgetItem()] # Kind + + for column in range(2, self.columnCount()): + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + + formatter = self._statsHandler.formatters[name] + if formatter: + tableItem = formatter.tabWidgetItemClass() else: - if self._statsHandler.formatters[name]: - _item = self._statsHandler.formatters[name].tabWidgetItemClass() - else: - _item = qt.QTableWidgetItem() - tooltip = self._statsHandler.stats[name].getToolTip(kind=kind) - if tooltip is not None: - _item.setToolTip(tooltip) + tableItem = qt.QTableWidgetItem() - _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable) - self.setItem(indexTable, self._columns_index[name], _item) - self._lgdAndKindToItems[(legend, kind)][name] = _item + tooltip = self._statsHandler.stats[name].getToolTip(kind=kind) + if tooltip is not None: + tableItem.setToolTip(tooltip) - return self._lgdAndKindToItems[(legend, kind)][name] + tableItems.append(tableItem) - def _removeItem(self, legend, kind): - if (legend, kind) not in self._lgdAndKindToItems or not self.plot: - return + # Disable sorting while adding table items + with self._disableSorting(): + # Add a row to the table + self.setRowCount(self.rowCount() + 1) - self.firstItem = self._lgdAndKindToItems[(legend, kind)]['legend'] - del self._lgdAndKindToItems[(legend, kind)] - self.removeRow(self.firstItem.row()) - self._legendsSet.remove(legend) - self.setColumnHidden(self._columns_index['kind'], - legend not in self._legendsSet) + # Add table items to the last row + row = self.rowCount() - 1 + for column, tableItem in enumerate(tableItems): + tableItem.setData(qt.Qt.UserRole, _Container(item)) + tableItem.setFlags( + qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable) + self.setItem(row, column, tableItem) - def _updateCurrentStats(self): - for lgdAndKind in self._lgdAndKindToItems: - self._updateStats(lgdAndKind[0], lgdAndKind[1]) + # Update table items content + self._updateStats(item) - def _updateStats(self, legend, kind, event=None): - if self._statsHandler is None: + # Listen for item changes + # Using queued connection to avoid issue with sender + # being that of the signal calling the signal + item.sigItemChanged.connect(self._plotItemChanged, + qt.Qt.QueuedConnection) + + return True + + def _removeItem(self, item): + """Remove table items corresponding to given plot item from the table. + + :param item: The plot item + """ + row = self._itemToRow(item) + if row is None: + kind = self._plotWrapper.getKind(item) + if kind in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.error("Removing item that is not in table: %s", str(item)) return + item.sigItemChanged.disconnect(self._plotItemChanged) + self.removeRow(row) + + def _removeAllItems(self): + """Remove content of the table""" + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + item = self._tableItemToItem(tableItem) + item.sigItemChanged.disconnect(self._plotItemChanged) + self.clearContents() + self.setRowCount(0) - assert kind in ('curve', 'image', 'scatter', 'histogram') - if kind == 'curve': - item = self.plot.getCurve(legend) - elif kind == 'image': - item = self.plot.getImage(legend) - elif kind == 'scatter': - item = self.plot.getScatter(legend) - elif kind == 'histogram': - item = self.plot.getHistogram(legend) - else: - raise ValueError('kind not managed') + def _updateStats(self, item): + """Update displayed information for given plot item - if not item or (item.getLegend(), kind) not in self._lgdAndKindToItems: + :param item: The plot item + """ + if item is None: + return + plot = self.getPlot() + if plot is None: + _logger.info("Plot not available") return - assert isinstance(item, self.COMPATIBLE_ITEMS) - - statsValDict = self._statsHandler.calculate(item, self.plot, - self._statsOnVisibleData) - - lgdItem = self._lgdAndKindToItems[(item.getLegend(), kind)]['legend'] - assert lgdItem - rowStat = lgdItem.row() - - for statName, statVal in list(statsValDict.items()): - assert statName in self._lgdAndKindToItems[(item.getLegend(), kind)] - tableItem = self._getItem(name=statName, legend=item.getLegend(), - kind=kind, indexTable=rowStat) - tableItem.setText(str(statVal)) - - def currentChanged(self, current, previous): - if current.row() >= 0: - legendItem = self.item(current.row(), self._columns_index['legend']) - assert legendItem - kindItem = self.item(current.row(), self._columns_index['kind']) - kind = kindItem.text() - if kind == 'curve': - self.plot.setActiveCurve(legendItem.text()) - elif kind == 'image': - self.plot.setActiveImage(legendItem.text()) - elif kind == 'scatter': - self.plot._setActiveItem('scatter', legendItem.text()) - elif kind == 'histogram': - # active histogram not managed by the plot actually - pass - else: - raise ValueError('kind not managed') - qt.QTableWidget.currentChanged(self, current, previous) + row = self._itemToRow(item) + if row is None: + _logger.error("This item is not in the table: %s", str(item)) + return - def setDisplayOnlyActiveItem(self, displayOnlyActItem): + statsHandler = self.getStatsHandler() + if statsHandler is not None: + stats = statsHandler.calculate( + item, plot, self._statsOnVisibleData) + else: + stats = {} + + with self._disableSorting(): + for name, tableItem in self._itemToTableItems(item).items(): + if name == self._LEGEND_HEADER_DATA: + text = self._plotWrapper.getLabel(item) + tableItem.setText(text) + elif name == self._KIND_HEADER_DATA: + tableItem.setText(self._plotWrapper.getKind(item)) + else: + value = stats.get(name) + if value is None: + _logger.error("Value not found for: %s", name) + tableItem.setText('-') + else: + tableItem.setText(str(value)) + + def _updateAllStats(self): + """Update stats for all rows in the table""" + with self._disableSorting(): + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + item = self._tableItemToItem(tableItem) + self._updateStats(item) + + def _currentItemChanged(self, current, previous): + """Handle change of selection in table and sync plot selection + + :param QTableWidgetItem current: + :param QTableWidgetItem previous: """ + if current and current.row() >= 0: + item = self._tableItemToItem(current) + self._plotWrapper.setCurrentItem(item) - :param bool displayOnlyActItem: True if we want to only show active - item + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + """Toggle display off all items or only the active/selected one + + :param bool displayOnlyActItem: + True if we want to only show active item """ if self._displayOnlyActItem == displayOnlyActItem: return - self._displayOnlyActItem = displayOnlyActItem self._dealWithPlotConnection(create=False) + if not self._displayOnlyActItem: + self.currentItemChanged.disconnect(self._currentItemChanged) + + _StatsWidgetBase.setDisplayOnlyActiveItem(self, displayOnlyActItem) + self._updateItemObserve() self._dealWithPlotConnection(create=True) + if not self._displayOnlyActItem: + self.currentItemChanged.connect(self._currentItemChanged) + self.setSelectionMode(qt.QAbstractItemView.SingleSelection) + else: + self.setSelectionMode(qt.QAbstractItemView.NoSelection) + + +class _OptionsWidget(qt.QToolBar): + + def __init__(self, parent=None): + qt.QToolBar.__init__(self, parent) + self.setIconSize(qt.QSize(16, 16)) + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-active-items")) + action.setText("Active items only") + action.setToolTip("Display stats for active items only.") + action.setCheckable(True) + action.setChecked(True) + self.__displayActiveItems = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-whole-items")) + action.setText("All items") + action.setToolTip("Display stats for all available items.") + action.setCheckable(True) + self.__displayWholeItems = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-visible-data")) + action.setText("Use the visible data range") + action.setToolTip("Use the visible data range.
" + "If activated the data is filtered to only use" + "visible data of the plot." + "The filtering is a data sub-sampling." + "No interpolation is made to fit data to" + "boundaries.") + action.setCheckable(True) + self.__useVisibleData = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-whole-data")) + action.setText("Use the full data range") + action.setToolTip("Use the full data range.") + action.setCheckable(True) + action.setChecked(True) + self.__useWholeData = action + + self.addAction(self.__displayWholeItems) + self.addAction(self.__displayActiveItems) + self.addSeparator() + self.addAction(self.__useVisibleData) + self.addAction(self.__useWholeData) + + self.itemSelection = qt.QActionGroup(self) + self.itemSelection.setExclusive(True) + self.itemSelection.addAction(self.__displayActiveItems) + self.itemSelection.addAction(self.__displayWholeItems) + + self.dataRangeSelection = qt.QActionGroup(self) + self.dataRangeSelection.setExclusive(True) + self.dataRangeSelection.addAction(self.__useWholeData) + self.dataRangeSelection.addAction(self.__useVisibleData) + + def isActiveItemMode(self): + return self.itemSelection.checkedAction() is self.__displayActiveItems + + def isVisibleDataRangeMode(self): + return self.dataRangeSelection.checkedAction() is self.__useVisibleData + + def setVisibleDataRangeModeEnabled(self, enabled): + """Enable/Disable the visible data range mode + + :param bool enabled: True to allow user to choose + stats on visible data + """ + self.__useVisibleData.setEnabled(enabled) + if not enabled: + self.__useWholeData.setChecked(True) + + +class StatsWidget(qt.QWidget): + """ + Widget displaying a set of :class:`Stat` to be displayed on a + :class:`StatsTable` and to be apply on items contained in the :class:`Plot` + Also contains options to: + + * compute statistics on all the data or on visible data only + * show statistics of all items or only the active one + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + """ + + sigVisibilityChanged = qt.Signal(bool) + """Signal emitted when the visibility of this widget changes. + + It Provides the visibility of the widget. + """ + + NUMBER_FORMAT = '{0:.3f}' + + def __init__(self, parent=None, plot=None, stats=None): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + self._options = _OptionsWidget(parent=self) + self.layout().addWidget(self._options) + self._statsTable = StatsTable(parent=self, plot=plot) + self.setStats(stats) + + self.layout().addWidget(self._statsTable) + + self._options.itemSelection.triggered.connect( + self._optSelectionChanged) + self._options.dataRangeSelection.triggered.connect( + self._optDataRangeChanged) + self._optSelectionChanged() + self._optDataRangeChanged() + + def _getStatsTable(self): + """Returns the :class:`StatsTable` used by this widget. + + :rtype: StatsTable + """ + return self._statsTable + + def showEvent(self, event): + self.sigVisibilityChanged.emit(True) + qt.QWidget.showEvent(self, event) + + def hideEvent(self, event): + self.sigVisibilityChanged.emit(False) + qt.QWidget.hideEvent(self, event) + + def _optSelectionChanged(self, action=None): + self._getStatsTable().setDisplayOnlyActiveItem( + self._options.isActiveItemMode()) + + def _optDataRangeChanged(self, action=None): + self._getStatsTable().setStatsOnVisibleData( + self._options.isVisibleDataRangeMode()) + + # Proxy methods + + def setStats(self, statsHandler): + return self._getStatsTable().setStats(statsHandler=statsHandler) + + setStats.__doc__ = StatsTable.setStats.__doc__ + + def setPlot(self, plot): + self._options.setVisibleDataRangeModeEnabled( + plot is None or isinstance(plot, PlotWidget)) + return self._getStatsTable().setPlot(plot=plot) + + setPlot.__doc__ = StatsTable.setPlot.__doc__ + + def getPlot(self): + return self._getStatsTable().getPlot() + + getPlot.__doc__ = StatsTable.getPlot.__doc__ + + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + return self._getStatsTable().setDisplayOnlyActiveItem( + displayOnlyActItem=displayOnlyActItem) + + setDisplayOnlyActiveItem.__doc__ = StatsTable.setDisplayOnlyActiveItem.__doc__ + def setStatsOnVisibleData(self, b): + return self._getStatsTable().setStatsOnVisibleData(b=b) + + setStatsOnVisibleData.__doc__ = StatsTable.setStatsOnVisibleData.__doc__ + + +DEFAULT_STATS = StatsHandler(( + (statsmdl.StatMin(), StatFormatter()), + statsmdl.StatCoordMin(), + (statsmdl.StatMax(), StatFormatter()), + statsmdl.StatCoordMax(), + statsmdl.StatCOM(), + (('mean', numpy.mean), StatFormatter()), + (('std', numpy.std), StatFormatter()), +)) + + +class BasicStatsWidget(StatsWidget): + """ + Widget defining a simple set of :class:`Stat` to be displayed on a + :class:`StatsWidget`. + + :param QWidget parent: Qt parent + :param PlotWidget plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + + .. snapshotqt:: img/BasicStatsWidget.png + :width: 300px + :align: center + + from silx.gui.plot import Plot1D + from silx.gui.plot.StatsWidget import BasicStatsWidget + + plot = Plot1D() + x = range(100) + y = x + plot.addCurve(x, y, legend='curve_0') + plot.setActiveCurve('curve_0') + + widget = BasicStatsWidget(plot=plot) + widget.show() + """ + def __init__(self, parent=None, plot=None): + StatsWidget.__init__(self, parent=parent, plot=plot, + stats=DEFAULT_STATS) + + +class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): + """ + Widget made to display stats into a QLayout with for all stat a couple + (QLabel, QLineEdit) created. + The the layout can be defined prior of adding any statistic. + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param str kind: the kind of plotitems we want to display + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + """ + + def __init__(self, parent=None, plot=None, kind='curve', stats=None, + statsOnVisibleData=False): + self._item_kind = kind + """The item displayed""" + self._statQlineEdit = {} + """list of legends actually displayed""" + self._n_statistics_per_line = 4 + """number of statistics displayed per line in the grid layout""" + qt.QWidget.__init__(self, parent) + _StatsWidgetBase.__init__(self, + statsOnVisibleData=statsOnVisibleData, + displayOnlyActItem=True) + self.setLayout(self._createLayout()) + self.setPlot(plot) + if stats is not None: + self.setStats(stats) + + def _addItemForStatistic(self, statistic): + assert isinstance(statistic, statsmdl.StatBase) + assert statistic.name in self._statsHandler.stats + + self.layout().setSpacing(2) + self.layout().setContentsMargins(2, 2, 2, 2) + + if isinstance(self.layout(), qt.QGridLayout): + parent = self + else: + widget = qt.QWidget(parent=self) + parent = widget + + qLabel = qt.QLabel(statistic.name + ':', parent=parent) + qLineEdit = qt.QLineEdit('', parent=parent) + qLineEdit.setReadOnly(True) + + self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit) + self._statQlineEdit[statistic.name] = qLineEdit + + def setPlot(self, plot): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied """ - .. warning:: When visible data is activated we will process to a simple - filtering of visible data by the user. The filtering is a - simple data sub-sampling. No interpolation is made to fit - data to boundaries. + _StatsWidgetBase.setPlot(self, plot) + self._updateAllStats() - :param bool b: True if we want to apply statistics only on visible data + def _addStatsWidgetsToLayout(self, qLabel, qLineEdit): + raise NotImplementedError('Base class') + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using """ - if self._statsOnVisibleData != b: - self._statsOnVisibleData = b - self._updateCurrentStats() + _StatsWidgetBase.setStats(self, statsHandler) + for statName, stat in list(self._statsHandler.stats.items()): + self._addItemForStatistic(stat) + self._updateAllStats() def _activeItemChanged(self, kind, previous, current): - """Callback used when plotting only the active item""" - assert kind in ('curve', 'image', 'scatter', 'histogram') - self._updateItemObserve() + if kind == self._item_kind: + self._updateAllStats() - def _plotContentChanged(self, action, kind, legend): - """Callback used when plotting all the plot items""" - if kind not in ('curve', 'image', 'scatter', 'histogram'): - return - if kind == 'curve': - item = self.plot.getCurve(legend) - elif kind == 'image': - item = self.plot.getImage(legend) - elif kind == 'scatter': - item = self.plot.getScatter(legend) - elif kind == 'histogram': - item = self.plot.getHistogram(legend) - else: - raise ValueError('kind not managed') + def _updateAllStats(self): + plot = self.getPlot() + if plot is not None: + _items = self._plotWrapper.getSelectedItems() + def kind_filter(_item): + return self._plotWrapper.getKind(_item) == self.getKind() + + items = list(filter(kind_filter, _items)) + assert len(items) in (0, 1) + if len(items) is 1: + self._setItem(items[0]) + + def setKind(self, kind): + """Change the kind of active item to display + :param str kind: kind of item to display information for ('curve' ...) + """ + if self._item_kind != kind: + self._item_kind = kind + self._updateItemObserve() - if action == 'add': - if item is None: - raise ValueError('Item from legend "%s" do not exists' % legend) - self._addItem(item) - elif action == 'remove': - self._removeItem(legend, kind) + def getKind(self): + """ + :return: kind of item we want to compute statistic for + :rtype: str + """ + return self._item_kind + + def _setItem(self, item): + if item is None: + for stat_name, stat_widget in self._statQlineEdit.items(): + stat_widget.setText('') + elif (self._statsHandler is not None and len( + self._statsHandler.stats) > 0): + plot = self.getPlot() + if plot is not None: + statsValDict = self._statsHandler.calculate(item, + plot, + self._statsOnVisibleData) + for statName, statVal in list(statsValDict.items()): + self._statQlineEdit[statName].setText(statVal) + + def _updateItemObserve(self, *argv): + assert self._displayOnlyActItem + _items = self._plotWrapper.getSelectedItems() + def kind_filter(_item): + return self._plotWrapper.getKind(_item) == self.getKind() + items = list(filter(kind_filter, _items)) + assert len(items) in (0, 1) + _item = items[0] if len(items) is 1 else None + self._setItem(_item) + + def _createLayout(self): + """create an instance of the main QLayout""" + raise NotImplementedError('Base class') + + def _addItem(self, item): + raise NotImplementedError('Display only the active item') + + def _removeItem(self, item): + raise NotImplementedError('Display only the active item') + + def _plotCurrentChanged(selfself, current): + raise NotImplementedError('Display only the active item') + + +class BasicLineStatsWidget(_BaseLineStatsWidget): + """ + Widget defining a simple set of :class:`Stat` to be displayed on a + :class:`LineStatsWidget`. + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param str kind: the kind of plotitems we want to display + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + """ + + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False): + _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind, + plot=plot, stats=stats, + statsOnVisibleData=statsOnVisibleData) + + def _createLayout(self): + return FlowLayout() + + def _addStatsWidgetsToLayout(self, qLabel, qLineEdit): + # create a mother widget to make sure both qLabel & qLineEdit will + # always be displayed side by side + widget = qt.QWidget(parent=self) + widget.setLayout(qt.QHBoxLayout()) + widget.layout().setSpacing(0) + widget.layout().setContentsMargins(0, 0, 0, 0) + + widget.layout().addWidget(qLabel) + widget.layout().addWidget(qLineEdit) + + self.layout().addWidget(widget) + + +class BasicGridStatsWidget(_BaseLineStatsWidget): + """ + pymca design like widget + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param str kind: the kind of plotitems we want to display + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + :param int statsPerLine: number of statistic to be displayed per line + + .. snapshotqt:: img/BasicGridStatsWidget.png + :width: 600px + :align: center + + from silx.gui.plot import Plot1D + from silx.gui.plot.StatsWidget import BasicGridStatsWidget + + plot = Plot1D() + x = range(100) + y = x + plot.addCurve(x, y, legend='curve_0') + plot.setActiveCurve('curve_0') + + widget = BasicGridStatsWidget(plot=plot, kind='curve') + widget.show() + """ - def _zoomPlotChanged(self, event): - if self._statsOnVisibleData is True: - if 'event' in event and event['event'] == 'limitsChanged': - self._updateCurrentStats() + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False, + statsPerLine=4): + _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind, + plot=plot, stats=stats, + statsOnVisibleData=statsOnVisibleData) + self._n_statistics_per_line = statsPerLine + + def _addStatsWidgetsToLayout(self, qLabel, qLineEdit): + column = len(self._statQlineEdit) % self._n_statistics_per_line + row = len(self._statQlineEdit) // self._n_statistics_per_line + self.layout().addWidget(qLabel, row, column * 2) + self.layout().addWidget(qLineEdit, row, column * 2 + 1) + + def _createLayout(self): + return qt.QGridLayout() diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py index e087354..0d11f17 100644 --- a/silx/gui/plot/_BaseMaskToolsWidget.py +++ b/silx/gui/plot/_BaseMaskToolsWidget.py @@ -29,7 +29,7 @@ from __future__ import division __authors__ = ["T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "29/08/2018" +__date__ = "15/02/2019" import os import weakref @@ -141,7 +141,7 @@ class BaseMask(qt.QObject): def commit(self): """Append the current mask to history if changed""" if (not self._history or self._redo or - not numpy.all(numpy.equal(self._mask, self._history[-1]))): + not numpy.array_equal(self._mask, self._history[-1])): if self._redo: self._redo = [] # Reset redo as a new action as been performed self.sigRedoable[bool].emit(False) @@ -325,7 +325,7 @@ class BaseMask(qt.QObject): raise NotImplementedError("To be implemented in subclass") def updateDisk(self, level, crow, ccol, radius, mask=True): - """Mask/Unmask data located inside a disk of the given mask level. + """Mask/Unmask data located inside a dick of the given mask level. :param int level: Mask level to update. :param crow: Disk center row/ordinate (y). @@ -335,6 +335,18 @@ class BaseMask(qt.QObject): """ raise NotImplementedError("To be implemented in subclass") + def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True): + """Mask/Unmask a disk of the given mask level. + + :param int level: Mask level to update. + :param int crow: Row of the center of the ellipse + :param int ccol: Column of the center of the ellipse + :param float radius_r: Radius of the ellipse in the row + :param float radius_c: Radius of the ellipse in the column + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + def updateLine(self, level, row0, col0, row1, col1, width, mask=True): """Mask/Unmask a line of the given mask level. @@ -376,13 +388,11 @@ class BaseMaskToolsWidget(qt.QWidget): self._plotRef = weakref.ref(plot) self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask - self._colormap = Colormap(name="", - normalization='linear', + self._colormap = Colormap(normalization='linear', vmin=0, - vmax=self._maxLevelNumber, - colors=None) + vmax=self._maxLevelNumber) self._defaultOverlayColor = rgba('gray') # Color of the mask - self._setMaskColors(1, 0.5) + self._setMaskColors(1, 0.5) # Set the colormap LUT if not isinstance(mask, BaseMask): raise TypeError("mask is not an instance of BaseMask") @@ -482,6 +492,7 @@ class BaseMaskToolsWidget(qt.QWidget): layout.addWidget(self._initMaskGroupBox()) layout.addWidget(self._initDrawGroupBox()) layout.addWidget(self._initThresholdGroupBox()) + layout.addWidget(self._initOtherToolsGroupBox()) layout.addStretch(1) self.setLayout(layout) @@ -617,6 +628,15 @@ class BaseMaskToolsWidget(qt.QWidget): self.rectAction.triggered.connect(self._activeRectMode) self.addAction(self.rectAction) + self.ellipseAction = qt.QAction( + icons.getQIcon('shape-ellipse'), 'Circle selection', None) + self.ellipseAction.setToolTip( + 'Rectangle selection tool: (Un)Mask a circle region R') + self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R)) + self.ellipseAction.setCheckable(True) + self.ellipseAction.triggered.connect(self._activeEllipseMode) + self.addAction(self.ellipseAction) + self.polygonAction = qt.QAction( icons.getQIcon('shape-polygon'), 'Polygon selection', None) self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S)) @@ -640,10 +660,11 @@ class BaseMaskToolsWidget(qt.QWidget): self.drawActionGroup = qt.QActionGroup(self) self.drawActionGroup.setExclusive(True) self.drawActionGroup.addAction(self.rectAction) + self.drawActionGroup.addAction(self.ellipseAction) self.drawActionGroup.addAction(self.polygonAction) self.drawActionGroup.addAction(self.pencilAction) - actions = (self.browseAction, self.rectAction, + actions = (self.browseAction, self.rectAction, self.ellipseAction, self.polygonAction, self.pencilAction) drawButtons = [] for action in actions: @@ -711,36 +732,28 @@ class BaseMaskToolsWidget(qt.QWidget): def _initThresholdGroupBox(self): """Init thresholding widgets""" - layout = qt.QVBoxLayout() - - # Thresholing self.belowThresholdAction = qt.QAction( icons.getQIcon('plot-roi-below'), 'Mask below threshold', None) self.belowThresholdAction.setToolTip( 'Mask image where values are below given threshold') self.belowThresholdAction.setCheckable(True) - self.belowThresholdAction.triggered[bool].connect( - self._belowThresholdActionTriggered) + self.belowThresholdAction.setChecked(True) self.betweenThresholdAction = qt.QAction( icons.getQIcon('plot-roi-between'), 'Mask within range', None) self.betweenThresholdAction.setToolTip( 'Mask image where values are within given range') self.betweenThresholdAction.setCheckable(True) - self.betweenThresholdAction.triggered[bool].connect( - self._betweenThresholdActionTriggered) self.aboveThresholdAction = qt.QAction( icons.getQIcon('plot-roi-above'), 'Mask above threshold', None) self.aboveThresholdAction.setToolTip( 'Mask image where values are above given threshold') self.aboveThresholdAction.setCheckable(True) - self.aboveThresholdAction.triggered[bool].connect( - self._aboveThresholdActionTriggered) self.thresholdActionGroup = qt.QActionGroup(self) - self.thresholdActionGroup.setExclusive(False) + self.thresholdActionGroup.setExclusive(True) self.thresholdActionGroup.addAction(self.belowThresholdAction) self.thresholdActionGroup.addAction(self.betweenThresholdAction) self.thresholdActionGroup.addAction(self.aboveThresholdAction) @@ -770,41 +783,50 @@ class BaseMaskToolsWidget(qt.QWidget): loadColormapRangeBtn.setDefaultAction(self.loadColormapRangeAction) widgets.append(loadColormapRangeBtn) - container = self._hboxWidget(*widgets, stretch=False) - layout.addWidget(container) + toolBar = self._hboxWidget(*widgets, stretch=False) - form = qt.QFormLayout() + config = qt.QGridLayout() + config.setContentsMargins(0, 0, 0, 0) + self.minLineLabel = qt.QLabel("Min:", self) self.minLineEdit = FloatEdit(self, value=0) - self.minLineEdit.setEnabled(False) - form.addRow('Min:', self.minLineEdit) + config.addWidget(self.minLineLabel, 0, 0) + config.addWidget(self.minLineEdit, 0, 1) + self.maxLineLabel = qt.QLabel("Max:", self) self.maxLineEdit = FloatEdit(self, value=0) - self.maxLineEdit.setEnabled(False) - form.addRow('Max:', self.maxLineEdit) + config.addWidget(self.maxLineLabel, 1, 0) + config.addWidget(self.maxLineEdit, 1, 1) self.applyMaskBtn = qt.QPushButton('Apply mask') self.applyMaskBtn.clicked.connect(self._maskBtnClicked) - self.applyMaskBtn.setEnabled(False) - form.addRow(self.applyMaskBtn) - - self.maskNanBtn = qt.QPushButton('Mask not finite values') - self.maskNanBtn.setToolTip('Mask Not a Number and infinite values') - self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked) - form.addRow(self.maskNanBtn) - thresholdWidget = qt.QWidget() - thresholdWidget.setLayout(form) - layout.addWidget(thresholdWidget) - - layout.addStretch(1) + layout = qt.QVBoxLayout() + layout.addWidget(toolBar) + layout.addLayout(config) + layout.addWidget(self.applyMaskBtn) self.thresholdGroup = qt.QGroupBox('Threshold') self.thresholdGroup.setLayout(layout) + + # Init widget state + self._thresholdActionGroupTriggered(self.belowThresholdAction) return self.thresholdGroup # track widget visibility and plot active image changes + def _initOtherToolsGroupBox(self): + layout = qt.QVBoxLayout() + + self.maskNanBtn = qt.QPushButton('Mask not finite values') + self.maskNanBtn.setToolTip('Mask Not a Number and infinite values') + self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked) + layout.addWidget(self.maskNanBtn) + + self.otherToolGroup = qt.QGroupBox('Other tools') + self.otherToolGroup.setLayout(layout) + return self.otherToolGroup + def changeEvent(self, event): """Reset drawing action when disabling widget""" if (event.type() == qt.QEvent.EnabledChange and @@ -883,6 +905,7 @@ class BaseMaskToolsWidget(qt.QWidget): The index of the mask for which we want to change the color. If none set this color for all the masks """ + rgb = rgba(rgb)[0:3] if level is None: self._overlayColors[:] = rgb self._defaultColors[:] = False @@ -925,6 +948,8 @@ class BaseMaskToolsWidget(qt.QWidget): """ if self._drawingMode == 'rectangle': self._activeRectMode() + elif self._drawingMode == 'ellipse': + self._activeEllipseMode() elif self._drawingMode == 'polygon': self._activePolygonMode() elif self._drawingMode == 'pencil': @@ -971,6 +996,16 @@ class BaseMaskToolsWidget(qt.QWidget): 'draw', shape='rectangle', source=self, color=color) self._updateDrawingModeWidgets() + def _activeEllipseMode(self): + """Handle circle action mode triggering""" + self._releaseDrawingMode() + self._drawingMode = 'ellipse' + self.plot.sigPlotSignal.connect(self._plotDrawEvent) + color = self.getCurrentMaskColor() + self.plot.setInteractiveMode( + 'draw', shape='ellipse', source=self, color=color) + self._updateDrawingModeWidgets() + def _activePolygonMode(self): """Handle polygon action mode triggering""" self._releaseDrawingMode() @@ -1016,36 +1051,28 @@ class BaseMaskToolsWidget(qt.QWidget): return doMask # Handle threshold UI events - def _belowThresholdActionTriggered(self, triggered): - if triggered: - self.minLineEdit.setEnabled(True) - self.maxLineEdit.setEnabled(False) - self.applyMaskBtn.setEnabled(True) - - def _betweenThresholdActionTriggered(self, triggered): - if triggered: - self.minLineEdit.setEnabled(True) - self.maxLineEdit.setEnabled(True) - self.applyMaskBtn.setEnabled(True) - - def _aboveThresholdActionTriggered(self, triggered): - if triggered: - self.minLineEdit.setEnabled(False) - self.maxLineEdit.setEnabled(True) - self.applyMaskBtn.setEnabled(True) def _thresholdActionGroupTriggered(self, triggeredAction): """Threshold action group listener.""" - if triggeredAction.isChecked(): - # Uncheck other actions - for action in self.thresholdActionGroup.actions(): - if action is not triggeredAction and action.isChecked(): - action.setChecked(False) - else: - # Disable min/max edit - self.minLineEdit.setEnabled(False) - self.maxLineEdit.setEnabled(False) - self.applyMaskBtn.setEnabled(False) + if triggeredAction is self.belowThresholdAction: + self.minLineLabel.setVisible(True) + self.maxLineLabel.setVisible(False) + self.minLineEdit.setVisible(True) + self.maxLineEdit.setVisible(False) + self.applyMaskBtn.setText("Mask bellow") + elif triggeredAction is self.betweenThresholdAction: + self.minLineLabel.setVisible(True) + self.maxLineLabel.setVisible(True) + self.minLineEdit.setVisible(True) + self.maxLineEdit.setVisible(True) + self.applyMaskBtn.setText("Mask between") + elif triggeredAction is self.aboveThresholdAction: + self.minLineLabel.setVisible(False) + self.maxLineLabel.setVisible(True) + self.minLineEdit.setVisible(False) + self.maxLineEdit.setVisible(True) + self.applyMaskBtn.setText("Mask above") + self.applyMaskBtn.setToolTip(triggeredAction.toolTip()) def _maskBtnClicked(self): if self.belowThresholdAction.isChecked(): diff --git a/silx/gui/plot/_utils/dtime_ticklayout.py b/silx/gui/plot/_utils/dtime_ticklayout.py index 95fc235..23c9dce 100644 --- a/silx/gui/plot/_utils/dtime_ticklayout.py +++ b/silx/gui/plot/_utils/dtime_ticklayout.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,6 +32,7 @@ __date__ = "04/04/2018" import datetime as dt +import enum import logging import math import time @@ -40,7 +41,6 @@ import dateutil.tz from dateutil.relativedelta import relativedelta -from silx.third_party import enum from .ticklayout import niceNumGeneric _logger = logging.getLogger(__name__) diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py index 10df130..2d01ef1 100644 --- a/silx/gui/plot/actions/control.py +++ b/silx/gui/plot/actions/control.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -303,9 +303,12 @@ class CurveStyleAction(PlotAction): currentState = (self.plot.isDefaultPlotLines(), self.plot.isDefaultPlotPoints()) - # line only, line and symbol, symbol only - states = (True, False), (True, True), (False, True) - newState = states[(states.index(currentState) + 1) % 3] + if currentState == (False, False): + newState = True, False + else: + # line only, line and symbol, symbol only + states = (True, False), (True, True), (False, True) + newState = states[(states.index(currentState) + 1) % 3] self.plot.setDefaultPlotLines(newState[0]) self.plot.setDefaultPlotPoints(newState[1]) diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py index 97de527..09e4a99 100644 --- a/silx/gui/plot/actions/io.py +++ b/silx/gui/plot/actions/io.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -502,7 +502,7 @@ class SaveAction(PlotAction): axes_errors=[xerror, yerror], title=plot.getGraphTitle()) - def setFileFilter(self, dataKind, nameFilter, func): + def setFileFilter(self, dataKind, nameFilter, func, index=None): """Set a name filter to add/replace a file format support :param str dataKind: @@ -513,10 +513,44 @@ class SaveAction(PlotAction): :param callable func: The function to call to perform saving. Expected signature is: bool func(PlotWidget plot, str filename, str nameFilter) + :param integer index: Index of the filter in the final list (or None) """ assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter') + # first append or replace the new filter to prevent colissions self._filters[dataKind][nameFilter] = func + if index is None: + # we are already done + return + + # get the current ordered list of keys + keyList = list(self._filters[dataKind].keys()) + + # deal with negative indices + if index < 0: + index = len(keyList) + index + if index < 0: + index = 0 + + if index >= len(keyList): + # nothing to be done, already at the end + txt = 'Requested index %d impossible, already at the end' % index + _logger.info(txt) + return + + # get the new ordered list + oldIndex = keyList.index(nameFilter) + del keyList[oldIndex] + keyList.insert(index, nameFilter) + + # build the new filters + newFilters = OrderedDict() + for key in keyList: + newFilters[key] = self._filters[dataKind][key] + + # and update the filters + self._filters[dataKind] = newFilters + return def getFileFilters(self, dataKind): """Returns the nameFilter and associated function for a kind of data. diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py index 7fb8be0..0514c85 100644 --- a/silx/gui/plot/backends/BackendBase.py +++ b/silx/gui/plot/backends/BackendBase.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,7 +31,7 @@ This API is a simplified version of PyMca PlotBackend API. __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "24/04/2018" +__date__ = "21/12/2018" import weakref from ... import qt @@ -170,7 +170,8 @@ class BackendBase(object): """ return legend - def addItem(self, x, y, legend, shape, color, fill, overlay, z): + def addItem(self, x, y, legend, shape, color, fill, overlay, z, + linestyle, linewidth, linebgcolor): """Add an item (i.e. a shape) to the plot. :param numpy.ndarray x: The X coords of the points of the shape @@ -182,6 +183,19 @@ class BackendBase(object): :param bool fill: True to fill the shape :param bool overlay: True if item is an overlay, False otherwise :param int z: Layer on which to draw the item + :param str linestyle: Style of the line. + Only relevant for line markers where X or Y is None. + Value in: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + :param float linewidth: Width of the line. + Only relevant for line markers where X or Y is None. + :param str linebgcolor: Background color of the line, e.g., 'blue', 'b', + '#FF0000'. It is used to draw dotted line using a second color. :returns: The handle used by the backend to univocally access the item """ return legend @@ -546,3 +560,20 @@ class BackendBase(object): This only check status set to axes from the public API """ return self._axesDisplayed + + def setForegroundColors(self, foregroundColor, gridColor): + """Set foreground and grid colors used to display this widget. + + :param List[float] foregroundColor: RGBA foreground color of the widget + :param List[float] gridColor: RGBA grid color of the data view + """ + pass + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + """Set background colors used to display this widget. + + :param List[float] backgroundColor: RGBA background color of the widget + :param Union[Tuple[float],None] dataBackgroundColor: + RGBA background color of the data view + """ + pass diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py index 3b1d6dd..726a839 100644 --- a/silx/gui/plot/backends/BackendMatplotlib.py +++ b/silx/gui/plot/backends/BackendMatplotlib.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,7 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent, H. Payno"] __license__ = "MIT" -__date__ = "01/08/2018" +__date__ = "21/12/2018" import logging @@ -56,12 +56,26 @@ from matplotlib.collections import PathCollection, LineCollection from matplotlib.ticker import Formatter, ScalarFormatter, Locator -from ....third_party.modest_image import ModestImage from . import BackendBase from .._utils import FLOAT32_MINPOS from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp +_PATCH_LINESTYLE = { + "-": 'solid', + "--": 'dashed', + '-.': 'dashdot', + ':': 'dotted', + '': "solid", + None: "solid", +} +"""Patches do not uses the same matplotlib syntax""" + + +def normalize_linestyle(linestyle): + """Normalize known old-style linestyle, else return the provided value.""" + return _PATCH_LINESTYLE.get(linestyle, linestyle) + class NiceDateLocator(Locator): """ @@ -115,7 +129,6 @@ class NiceDateLocator(Locator): return ticks - class NiceAutoDateFormatter(Formatter): """ Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the @@ -139,7 +152,6 @@ class NiceAutoDateFormatter(Formatter): else: return bestFormatString(self.locator.spacing, self.locator.unit) - def __call__(self, x, pos=None): """Return the format for tick val *x* at position *pos* Expects x to be a POSIX timestamp (seconds since 1 Jan 1970) @@ -149,8 +161,6 @@ class NiceAutoDateFormatter(Formatter): return tickStr - - class _MarkerContainer(Container): """Marker artists container supporting draw/remove and text position update @@ -204,6 +214,57 @@ class _MarkerContainer(Container): self.text.set_x(xmax) +class _DoubleColoredLinePatch(matplotlib.patches.Patch): + """Matplotlib patch to display any patch using double color.""" + + def __init__(self, patch): + super(_DoubleColoredLinePatch, self).__init__() + self.__patch = patch + self.linebgcolor = None + + def __getattr__(self, name): + return getattr(self.__patch, name) + + def draw(self, renderer): + oldLineStype = self.__patch.get_linestyle() + if self.linebgcolor is not None and oldLineStype != "solid": + oldLineColor = self.__patch.get_edgecolor() + oldHatch = self.__patch.get_hatch() + self.__patch.set_linestyle("solid") + self.__patch.set_edgecolor(self.linebgcolor) + self.__patch.set_hatch(None) + self.__patch.draw(renderer) + self.__patch.set_linestyle(oldLineStype) + self.__patch.set_edgecolor(oldLineColor) + self.__patch.set_hatch(oldHatch) + self.__patch.draw(renderer) + + def set_transform(self, transform): + self.__patch.set_transform(transform) + + def get_path(self): + return self.__patch.get_path() + + def contains(self, mouseevent, radius=None): + return self.__patch.contains(mouseevent, radius) + + def contains_point(self, point, radius=None): + return self.__patch.contains_point(point, radius) + + +class Image(AxesImage): + """An AxesImage with a fast path for uint8 RGBA images""" + + def set_data(self, A): + A = numpy.array(A, copy=False) + if A.ndim != 3 or A.shape[2] != 4 or A.dtype != numpy.uint8: + super(Image, self).set_data(A) + else: + # Call AxesImage.set_data with small data to set attributes + super(Image, self).set_data(numpy.zeros((2, 2, 4), dtype=A.dtype)) + self._A = A # Override stored data + + class BackendMatplotlib(BackendBase.BackendBase): """Base class for Matplotlib backend without a FigureCanvas. @@ -231,6 +292,8 @@ class BackendMatplotlib(BackendBase.BackendBase): self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left") self.ax2 = self.ax.twinx() self.ax2.set_label("right") + # Make sure background of Axes is displayed + self.ax2.patch.set_visible(True) # disable the use of offsets try: @@ -239,9 +302,9 @@ class BackendMatplotlib(BackendBase.BackendBase): self.ax2.get_yaxis().get_major_formatter().set_useOffset(False) self.ax2.get_xaxis().get_major_formatter().set_useOffset(False) except: - _logger.warning('Cannot disabled axes offsets in %s ' \ + _logger.warning('Cannot disabled axes offsets in %s ' % matplotlib.__version__) - + # critical for picking!!!! self.ax2.set_zorder(0) self.ax2.set_autoscaley_on(True) @@ -376,44 +439,13 @@ class BackendMatplotlib(BackendBase.BackendBase): picker = (selectable or draggable) - # Debian 7 specific support - # No transparent colormap with matplotlib < 1.2.0 - # Add support for transparent colormap for uint8 data with - # colormap with 256 colors, linear norm, [0, 255] range - if self._matplotlibVersion < _parse_version('1.2.0'): - if (len(data.shape) == 2 and colormap.getName() is None and - colormap.getColormapLUT() is not None): - colors = colormap.getColormapLUT() - if (colors.shape[-1] == 4 and - not numpy.all(numpy.equal(colors[3], 255))): - # This is a transparent colormap - if (colors.shape == (256, 4) and - colormap.getNormalization() == 'linear' and - not colormap.isAutoscale() and - colormap.getVMin() == 0 and - colormap.getVMax() == 255 and - data.dtype == numpy.uint8): - # Supported case, convert data to RGBA - data = colors[data.reshape(-1)].reshape( - data.shape + (4,)) - else: - _logger.warning( - 'matplotlib %s does not support transparent ' - 'colormap.', matplotlib.__version__) - - if ((height * width) > 5.0e5 and - origin == (0., 0.) and scale == (1., 1.)): - imageClass = ModestImage - else: - imageClass = AxesImage - # All image are shown as RGBA image - image = imageClass(self.ax, - label="__IMAGE__" + legend, - interpolation='nearest', - picker=picker, - zorder=z, - origin='lower') + image = Image(self.ax, + label="__IMAGE__" + legend, + interpolation='nearest', + picker=picker, + zorder=z, + origin='lower') if alpha < 1: image.set_alpha(alpha) @@ -438,40 +470,41 @@ class BackendMatplotlib(BackendBase.BackendBase): ystep = 1 if scale[1] >= 0. else -1 data = data[::ystep, ::xstep] - if self._matplotlibVersion < _parse_version('2.1'): - # matplotlib 1.4.2 do not support float128 - dtype = data.dtype - if dtype.kind == "f" and dtype.itemsize >= 16: - _logger.warning("Your matplotlib version do not support " - "float128. Data converted to float64.") - data = data.astype(numpy.float64) - if data.ndim == 2: # Data image, convert to RGBA image data = colormap.applyToData(data) image.set_data(data) - self.ax.add_artist(image) - return image - def addItem(self, x, y, legend, shape, color, fill, overlay, z): + def addItem(self, x, y, legend, shape, color, fill, overlay, z, + linestyle, linewidth, linebgcolor): + if (linebgcolor is not None and + shape not in ('rectangle', 'polygon', 'polylines')): + _logger.warning( + 'linebgcolor not implemented for %s with matplotlib backend', + shape) xView = numpy.array(x, copy=False) yView = numpy.array(y, copy=False) + linestyle = normalize_linestyle(linestyle) + if shape == "line": item = self.ax.plot(x, y, label=legend, color=color, - linestyle='-', marker=None)[0] + linestyle=linestyle, linewidth=linewidth, + marker=None)[0] elif shape == "hline": if hasattr(y, "__len__"): y = y[-1] - item = self.ax.axhline(y, label=legend, color=color) + item = self.ax.axhline(y, label=legend, color=color, + linestyle=linestyle, linewidth=linewidth) elif shape == "vline": if hasattr(x, "__len__"): x = x[-1] - item = self.ax.axvline(x, label=legend, color=color) + item = self.ax.axvline(x, label=legend, color=color, + linestyle=linestyle, linewidth=linewidth) elif shape == 'rectangle': xMin = numpy.nanmin(xView) @@ -484,10 +517,16 @@ class BackendMatplotlib(BackendBase.BackendBase): width=w, height=h, fill=False, - color=color) + color=color, + linestyle=linestyle, + linewidth=linewidth) if fill: item.set_hatch('.') + if linestyle != "solid" and linebgcolor is not None: + item = _DoubleColoredLinePatch(item) + item.linebgcolor = linebgcolor + self.ax.add_patch(item) elif shape in ('polygon', 'polylines'): @@ -500,10 +539,16 @@ class BackendMatplotlib(BackendBase.BackendBase): closed=closed, fill=False, label=legend, - color=color) + color=color, + linestyle=linestyle, + linewidth=linewidth) if fill and shape == 'polygon': item.set_hatch('/') + if linestyle != "solid" and linebgcolor is not None: + item = _DoubleColoredLinePatch(item) + item.linebgcolor = linebgcolor + self.ax.add_patch(item) else: @@ -908,8 +953,56 @@ class BackendMatplotlib(BackendBase.BackendBase): # remove external margins self.ax.set_position([0, 0, 1, 1]) self.ax2.set_position([0, 0, 1, 1]) + self._synchronizeBackgroundColors() + self._synchronizeForegroundColors() self._plot._setDirtyPlot() + def _synchronizeBackgroundColors(self): + backgroundColor = self._plot.getBackgroundColor().getRgbF() + + dataBackgroundColor = self._plot.getDataBackgroundColor() + if dataBackgroundColor.isValid(): + dataBackgroundColor = dataBackgroundColor.getRgbF() + else: + dataBackgroundColor = backgroundColor + + if self.ax2.axison: + self.fig.patch.set_facecolor(backgroundColor) + if self._matplotlibVersion < _parse_version('2'): + self.ax2.set_axis_bgcolor(dataBackgroundColor) + else: + self.ax2.set_facecolor(dataBackgroundColor) + else: + self.fig.patch.set_facecolor(dataBackgroundColor) + + def _synchronizeForegroundColors(self): + foregroundColor = self._plot.getForegroundColor().getRgbF() + + gridColor = self._plot.getGridColor() + if gridColor.isValid(): + gridColor = gridColor.getRgbF() + else: + gridColor = foregroundColor + + for axes in (self.ax, self.ax2): + if axes.axison: + axes.spines['bottom'].set_color(foregroundColor) + axes.spines['top'].set_color(foregroundColor) + axes.spines['right'].set_color(foregroundColor) + axes.spines['left'].set_color(foregroundColor) + axes.tick_params(axis='x', colors=foregroundColor) + axes.tick_params(axis='y', colors=foregroundColor) + axes.yaxis.label.set_color(foregroundColor) + axes.xaxis.label.set_color(foregroundColor) + axes.title.set_color(foregroundColor) + + for line in axes.get_xgridlines(): + line.set_color(gridColor) + + for line in axes.get_ygridlines(): + line.set_color(gridColor) + # axes.grid().set_markeredgecolor(gridColor) + class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): """QWidget matplotlib backend using a QtAgg canvas. @@ -1137,3 +1230,9 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): else: cursor = self._QT_CURSORS[cursor] FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor)) + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + self._synchronizeBackgroundColors() + + def setForegroundColors(self, foregroundColor, gridColor): + self._synchronizeForegroundColors() diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py index 9e2cb73..e33d03c 100644 --- a/silx/gui/plot/backends/BackendOpenGL.py +++ b/silx/gui/plot/backends/BackendOpenGL.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,7 @@ from __future__ import division __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "01/08/2018" +__date__ = "21/12/2018" from collections import OrderedDict, namedtuple from ctypes import c_void_p @@ -44,10 +44,11 @@ from ... import qt from ..._glutils import gl from ... import _glutils as glu from .glutils import ( + GLLines2D, GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D, mat4Ortho, mat4Identity, LEFT, RIGHT, BOTTOM, TOP, - Text2D, Shape2D) + Text2D, FilledShape2D) from .glutils.PlotImageFile import saveImageToFile _logger = logging.getLogger(__name__) @@ -338,6 +339,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): f=f) BackendBase.BackendBase.__init__(self, plot, parent) + self._backgroundColor = 1., 1., 1., 1. + self._dataBackgroundColor = 1., 1., 1., 1. + self.matScreenProj = mat4Identity() self._progBase = glu.Program( @@ -357,6 +361,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._glGarbageCollector = [] self._plotFrame = GLPlotFrame2D( + foregroundColor=(0., 0., 0., 1.), + gridColor=(.7, .7, .7, 1.), margins={'left': 100, 'right': 50, 'top': 50, 'bottom': 50}) # Make postRedisplay asynchronous using Qt signal @@ -432,7 +438,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): def initializeGL(self): gl.testGL() - gl.glClearColor(1., 1., 1., 1.) gl.glClearStencil(0) gl.glEnable(gl.GL_BLEND) @@ -482,6 +487,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._plotFBOs[context] = plotFBOTex with plotFBOTex: + gl.glClearColor(*self._backgroundColor) gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) self._renderPlotAreaGL() self._plotFrame.render() @@ -530,6 +536,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): item.discard() self._glGarbageCollector = [] + gl.glClearColor(*self._backgroundColor) gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) # Check if window is large enough @@ -543,100 +550,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): glu.setGLContextGetter() _current_context = None - def _nonOrthoAxesLineMarkerPrimitives(self, marker, pixelOffset): - """Generates the vertices and label for a line marker. - - :param dict marker: Description of a line marker - :param int pixelOffset: Offset of text from borders in pixels - :return: Line vertices and Text label or None - :rtype: 2-tuple (2x2 numpy.array of float, Text2D) - """ - label, vertices = None, None - - xCoord, yCoord = marker['x'], marker['y'] - assert xCoord is None or yCoord is None # Specific to line markers - - # Get plot corners in data coords - plotLeft, plotTop, plotWidth, plotHeight = self.getPlotBoundsInPixels() - - corners = [(plotLeft, plotTop), - (plotLeft, plotTop + plotHeight), - (plotLeft + plotWidth, plotTop + plotHeight), - (plotLeft + plotWidth, plotTop)] - corners = numpy.array([self.pixelToData(x, y, axis='left', check=False) - for (x, y) in corners]) - - borders = { - 'right': (corners[3], corners[2]), - 'top': (corners[0], corners[3]), - 'bottom': (corners[2], corners[1]), - 'left': (corners[1], corners[0]) - } - - textLayouts = { # align, valign, offsets - 'right': (RIGHT, BOTTOM, (-1., -1.)), - 'top': (LEFT, TOP, (1., 1.)), - 'bottom': (LEFT, BOTTOM, (1., -1.)), - 'left': (LEFT, BOTTOM, (1., -1.)) - } - - if xCoord is None: # Horizontal line in data space - if marker['text'] is not None: - # Find intersection of hline with borders in data - # Order is important as it stops at first intersection - for border_name in ('right', 'top', 'bottom', 'left'): - (x0, y0), (x1, y1) = borders[border_name] - - if min(y0, y1) <= yCoord < max(y0, y1): - xIntersect = (yCoord - y0) * (x1 - x0) / (y1 - y0) + x0 - - # Add text label - pixelPos = self.dataToPixel( - xIntersect, yCoord, axis='left', check=False) - - align, valign, offsets = textLayouts[border_name] - - x = pixelPos[0] + offsets[0] * pixelOffset - y = pixelPos[1] + offsets[1] * pixelOffset - label = Text2D(marker['text'], x, y, - color=marker['color'], - bgColor=(1., 1., 1., 0.5), - align=align, valign=valign) - break # Stop at first intersection - - xMin, xMax = corners[:, 0].min(), corners[:, 0].max() - vertices = numpy.array( - ((xMin, yCoord), (xMax, yCoord)), dtype=numpy.float32) - - else: # yCoord is None: vertical line in data space - if marker['text'] is not None: - # Find intersection of hline with borders in data - # Order is important as it stops at first intersection - for border_name in ('top', 'bottom', 'right', 'left'): - (x0, y0), (x1, y1) = borders[border_name] - if min(x0, x1) <= xCoord < max(x0, x1): - yIntersect = (xCoord - x0) * (y1 - y0) / (x1 - x0) + y0 - - # Add text label - pixelPos = self.dataToPixel( - xCoord, yIntersect, axis='left', check=False) - - align, valign, offsets = textLayouts[border_name] - - x = pixelPos[0] + offsets[0] * pixelOffset - y = pixelPos[1] + offsets[1] * pixelOffset - label = Text2D(marker['text'], x, y, - color=marker['color'], - bgColor=(1., 1., 1., 0.5), - align=align, valign=valign) - break # Stop at first intersection - - yMin, yMax = corners[:, 1].min(), corners[:, 1].max() - vertices = numpy.array( - ((xCoord, yMin), (xCoord, yMax)), dtype=numpy.float32) - - return vertices, label - def _renderMarkersGL(self): if len(self._markers) == 0: return @@ -651,16 +564,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) - # Prepare vertical and horizontal markers rendering - self._progBase.use() - gl.glUniformMatrix4fv( - self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, - self.matScreenProj.astype(numpy.float32)) - gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) - gl.glUniform1i(self._progBase.uniforms['hatchStep'], 0) - gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) - posAttrib = self._progBase.attributes['position'] - labels = [] pixelOffset = 3 @@ -677,59 +580,43 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): continue if xCoord is None or yCoord is None: - if not self.isDefaultBaseVectors(): # Non-orthogonal axes - vertices, label = self._nonOrthoAxesLineMarkerPrimitives( - marker, pixelOffset) - if label is not None: - labels.append(label) + pixelPos = self.dataToPixel( + xCoord, yCoord, axis='left', check=False) - else: # Orthogonal axes - pixelPos = self.dataToPixel( - xCoord, yCoord, axis='left', check=False) - - if xCoord is None: # Horizontal line in data space - if marker['text'] is not None: - x = self._plotFrame.size[0] - \ - self._plotFrame.margins.right - pixelOffset - y = pixelPos[1] - pixelOffset - label = Text2D(marker['text'], x, y, - color=marker['color'], - bgColor=(1., 1., 1., 0.5), - align=RIGHT, valign=BOTTOM) - labels.append(label) - - width = self._plotFrame.size[0] - vertices = numpy.array(((0, pixelPos[1]), - (width, pixelPos[1])), - dtype=numpy.float32) - - else: # yCoord is None: vertical line in data space - if marker['text'] is not None: - x = pixelPos[0] + pixelOffset - y = self._plotFrame.margins.top + pixelOffset - label = Text2D(marker['text'], x, y, - color=marker['color'], - bgColor=(1., 1., 1., 0.5), - align=LEFT, valign=TOP) - labels.append(label) - - height = self._plotFrame.size[1] - vertices = numpy.array(((pixelPos[0], 0), - (pixelPos[0], height)), - dtype=numpy.float32) + if xCoord is None: # Horizontal line in data space + if marker['text'] is not None: + x = self._plotFrame.size[0] - \ + self._plotFrame.margins.right - pixelOffset + y = pixelPos[1] - pixelOffset + label = Text2D(marker['text'], x, y, + color=marker['color'], + bgColor=(1., 1., 1., 0.5), + align=RIGHT, valign=BOTTOM) + labels.append(label) - self._progBase.use() - gl.glUniform4f(self._progBase.uniforms['color'], - *marker['color']) + width = self._plotFrame.size[0] + lines = GLLines2D((0, width), (pixelPos[1], pixelPos[1]), + style=marker['linestyle'], + color=marker['color'], + width=marker['linewidth']) + lines.render(self.matScreenProj) + + else: # yCoord is None: vertical line in data space + if marker['text'] is not None: + x = pixelPos[0] + pixelOffset + y = self._plotFrame.margins.top + pixelOffset + label = Text2D(marker['text'], x, y, + color=marker['color'], + bgColor=(1., 1., 1., 0.5), + align=LEFT, valign=TOP) + labels.append(label) - gl.glEnableVertexAttribArray(posAttrib) - gl.glVertexAttribPointer(posAttrib, - 2, - gl.GL_FLOAT, - gl.GL_FALSE, - 0, vertices) - gl.glLineWidth(1) - gl.glDrawArrays(gl.GL_LINES, 0, len(vertices)) + height = self._plotFrame.size[1] + lines = GLLines2D((pixelPos[0], pixelPos[0]), (0, height), + style=marker['linestyle'], + color=marker['color'], + width=marker['linewidth']) + lines.render(self.matScreenProj) else: pixelPos = self.dataToPixel( @@ -820,13 +707,17 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): def _renderPlotAreaGL(self): plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] - self._plotFrame.renderGrid() - gl.glScissor(self._plotFrame.margins.left, self._plotFrame.margins.bottom, plotWidth, plotHeight) gl.glEnable(gl.GL_SCISSOR_TEST) + if self._dataBackgroundColor != self._backgroundColor: + gl.glClearColor(*self._dataBackgroundColor) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + + self._plotFrame.renderGrid() + # Matrix trBounds = self._plotFrame.transformedDataRanges if trBounds.x[0] == trBounds.x[1] or \ @@ -853,32 +744,61 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # Render Items gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) - self._progBase.use() - gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, - self.matScreenProj.astype(numpy.float32)) - gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) - gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) - for item in self._items.values(): if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or (isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)): # Ignore items <= 0. on log axes continue - closed = item['shape'] != 'polylines' - points = [self.dataToPixel(x, y, axis='left', check=False) - for (x, y) in zip(item['x'], item['y'])] - shape2D = Shape2D(points, - fill=item['fill'], - fillColor=item['color'], - stroke=True, - strokeColor=item['color'], - strokeClosed=closed) + if item['shape'] == 'hline': + width = self._plotFrame.size[0] + _, yPixel = self.dataToPixel( + None, item['y'], axis='left', check=False) + points = numpy.array(((0., yPixel), (width, yPixel)), + dtype=numpy.float32) - posAttrib = self._progBase.attributes['position'] - colorUnif = self._progBase.uniforms['color'] - hatchStepUnif = self._progBase.uniforms['hatchStep'] - shape2D.render(posAttrib, colorUnif, hatchStepUnif) + elif item['shape'] == 'vline': + xPixel, _ = self.dataToPixel( + item['x'], None, axis='left', check=False) + height = self._plotFrame.size[1] + points = numpy.array(((xPixel, 0), (xPixel, height)), + dtype=numpy.float32) + + else: + points = numpy.array([ + self.dataToPixel(x, y, axis='left', check=False) + for (x, y) in zip(item['x'], item['y'])]) + + # Draw the fill + if (item['fill'] is not None and + item['shape'] not in ('hline', 'vline')): + self._progBase.use() + gl.glUniformMatrix4fv( + self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, + self.matScreenProj.astype(numpy.float32)) + gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) + gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) + + shape2D = FilledShape2D( + points, style=item['fill'], color=item['color']) + shape2D.render( + posAttrib=self._progBase.attributes['position'], + colorUnif=self._progBase.uniforms['color'], + hatchStepUnif=self._progBase.uniforms['hatchStep']) + + # Draw the stroke + if item['linestyle'] not in ('', ' ', None): + if item['shape'] != 'polylines': + # close the polyline + points = numpy.append(points, + numpy.atleast_2d(points[0]), axis=0) + + lines = GLLines2D(points[:, 0], points[:, 1], + style=item['linestyle'], + color=item['color'], + dash2ndColor=item['linebgcolor'], + width=item['linewidth']) + lines.render(self.matScreenProj) gl.glDisable(gl.GL_SCISSOR_TEST) @@ -1123,7 +1043,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): return legend, 'image' - def addItem(self, x, y, legend, shape, color, fill, overlay, z): + def addItem(self, x, y, legend, shape, color, fill, overlay, z, + linestyle, linewidth, linebgcolor): # TODO handle overlay if shape not in ('polygon', 'rectangle', 'line', 'vline', 'hline', 'polylines'): @@ -1154,7 +1075,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): 'color': colors.rgba(color), 'fill': 'hatch' if fill else None, 'x': x, - 'y': y + 'y': y, + 'linestyle': linestyle, + 'linewidth': linewidth, + 'linebgcolor': linebgcolor, } return legend, 'item' @@ -1166,10 +1090,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if symbol is None: symbol = '+' - if linestyle != '-' or linewidth != 1: - _logger.warning( - 'OpenGL backend does not support marker line style and width.') - behaviors = set() if selectable: behaviors.add('selectable') @@ -1191,6 +1111,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): 'behaviors': behaviors, 'constraint': constraint if isConstraint else None, 'symbol': symbol, + 'linestyle': linestyle, + 'linewidth': linewidth, } return legend, 'marker' @@ -1441,37 +1363,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if label: _logger.warning('Right axis label not implemented') - # Non orthogonal axes - - def setBaseVectors(self, x=(1., 0.), y=(0., 1.)): - """Set base vectors. - - Useful for non-orthogonal axes. - If an axis is in log scale, skew is applied to log transformed values. - - Base vector does not work well with log axes, to investi - """ - if x != (1., 0.) and y != (0., 1.): - if self._plotFrame.xAxis.isLog: - _logger.warning("setBaseVectors disables X axis logarithmic.") - self.setXAxisLogarithmic(False) - if self._plotFrame.yAxis.isLog: - _logger.warning("setBaseVectors disables Y axis logarithmic.") - self.setYAxisLogarithmic(False) - - if self.isKeepDataAspectRatio(): - _logger.warning("setBaseVectors disables keepDataAspectRatio.") - self.keepDataAspectRatio(False) - - self._plotFrame.baseVectors = x, y - - def getBaseVectors(self): - return self._plotFrame.baseVectors - - def isDefaultBaseVectors(self): - return self._plotFrame.baseVectors == \ - self._plotFrame.DEFAULT_BASE_VECTORS - # Graph limits def _setDataRanges(self, xlim=None, ylim=None, y2lim=None): @@ -1486,26 +1377,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # Update axes range with a clipped range if too wide self._plotFrame.setDataRanges(xlim, ylim, y2lim) - if not self.isDefaultBaseVectors(): - # Update axes range with axes bounds in data coords - plotLeft, plotTop, plotWidth, plotHeight = \ - self.getPlotBoundsInPixels() - - self._plotFrame.xAxis.dataRange = sorted([ - self.pixelToData(x, y, axis='left', check=False)[0] - for (x, y) in ((plotLeft, plotTop + plotHeight), - (plotLeft + plotWidth, plotTop + plotHeight))]) - - self._plotFrame.yAxis.dataRange = sorted([ - self.pixelToData(x, y, axis='left', check=False)[1] - for (x, y) in ((plotLeft, plotTop + plotHeight), - (plotLeft, plotTop))]) - - self._plotFrame.y2Axis.dataRange = sorted([ - self.pixelToData(x, y, axis='right', check=False)[1] - for (x, y) in ((plotLeft + plotWidth, plotTop + plotHeight), - (plotLeft + plotWidth, plotTop))]) - def _ensureAspectRatio(self, keepDim=None): """Update plot bounds in order to keep aspect ratio. @@ -1619,11 +1490,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): _logger.warning( "KeepDataAspectRatio is ignored with log axes") - if flag and not self.isDefaultBaseVectors(): - _logger.warning( - "setXAxisLogarithmic ignored because baseVectors are set") - return - self._plotFrame.xAxis.isLog = flag def setYAxisLogarithmic(self, flag): @@ -1633,11 +1499,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): _logger.warning( "KeepDataAspectRatio is ignored with log axes") - if flag and not self.isDefaultBaseVectors(): - _logger.warning( - "setYAxisLogarithmic ignored because baseVectors are set") - return - self._plotFrame.yAxis.isLog = flag self._plotFrame.y2Axis.isLog = flag @@ -1658,9 +1519,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if flag and (self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog): _logger.warning("KeepDataAspectRatio is ignored with log axes") - if flag and not self.isDefaultBaseVectors(): - _logger.warning( - "keepDataAspectRatio ignored because baseVectors are set") self._keepDataAspectRatio = flag @@ -1723,3 +1581,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): def setAxesDisplayed(self, displayed): BackendBase.BackendBase.setAxesDisplayed(self, displayed) self._plotFrame.displayed = displayed + + def setForegroundColors(self, foregroundColor, gridColor): + self._plotFrame.foregroundColor = foregroundColor + self._plotFrame.gridColor = gridColor + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + self._backgroundColor = backgroundColor + self._dataBackgroundColor = dataBackgroundColor diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py index 12b6bbe..5f8d652 100644 --- a/silx/gui/plot/backends/glutils/GLPlotCurve.py +++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -42,7 +42,7 @@ import numpy from silx.math.combo import min_max from ...._glutils import gl -from ...._glutils import Program, vertexBuffer +from ...._glutils import Program, vertexBuffer, VertexBufferAttrib from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate @@ -245,7 +245,7 @@ class _Fill2D(object): SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':' -class _Lines2D(object): +class GLLines2D(object): """Object rendering curve as a polyline :param xVboData: X coordinates VBO @@ -323,6 +323,7 @@ class _Lines2D(object): /* Dashes: [0, x], [y, z] Dash period: w */ uniform vec4 dash; + uniform vec4 dash2ndColor; varying float vDist; varying vec4 vColor; @@ -330,25 +331,52 @@ class _Lines2D(object): void main(void) { float dist = mod(vDist, dash.w); if ((dist > dash.x && dist < dash.y) || dist > dash.z) { - discard; + if (dash2ndColor.a == 0.) { + discard; // Discard full transparent bg color + } else { + gl_FragColor = dash2ndColor; + } + } else { + gl_FragColor = vColor; } - gl_FragColor = vColor; } """, attrib0='xPos') def __init__(self, xVboData=None, yVboData=None, colorVboData=None, distVboData=None, - style=SOLID, color=(0., 0., 0., 1.), - width=1, dashPeriod=20, drawMode=None, + style=SOLID, color=(0., 0., 0., 1.), dash2ndColor=None, + width=1, dashPeriod=10., drawMode=None, offset=(0., 0.)): + if (xVboData is not None and + not isinstance(xVboData, VertexBufferAttrib)): + xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32) self.xVboData = xVboData + + if (yVboData is not None and + not isinstance(yVboData, VertexBufferAttrib)): + yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32) self.yVboData = yVboData + + # Compute distances if not given while providing numpy array coordinates + if (isinstance(self.xVboData, numpy.ndarray) and + isinstance(self.yVboData, numpy.ndarray) and + distVboData is None): + distVboData = distancesFromArrays(self.xVboData, self.yVboData) + + if (distVboData is not None and + not isinstance(distVboData, VertexBufferAttrib)): + distVboData = numpy.array( + distVboData, copy=False, dtype=numpy.float32) self.distVboData = distVboData + + if colorVboData is not None: + assert isinstance(colorVboData, VertexBufferAttrib) self.colorVboData = colorVboData self.useColorVboData = colorVboData is not None self.color = color + self.dash2ndColor = dash2ndColor self.width = width self._style = None self.style = style @@ -396,29 +424,46 @@ class _Lines2D(object): gl.glUniform2f(program.uniforms['halfViewportSize'], 0.5 * viewWidth, 0.5 * viewHeight) + dashPeriod = self.dashPeriod * self.width if self.style == DOTTED: - dash = (0.1 * self.dashPeriod, - 0.6 * self.dashPeriod, - 0.7 * self.dashPeriod, - self.dashPeriod) + dash = (0.2 * dashPeriod, + 0.5 * dashPeriod, + 0.7 * dashPeriod, + dashPeriod) elif self.style == DASHDOT: - dash = (0.3 * self.dashPeriod, - 0.5 * self.dashPeriod, - 0.6 * self.dashPeriod, - self.dashPeriod) + dash = (0.3 * dashPeriod, + 0.5 * dashPeriod, + 0.6 * dashPeriod, + dashPeriod) else: - dash = (0.5 * self.dashPeriod, - self.dashPeriod, - self.dashPeriod, - self.dashPeriod) + dash = (0.5 * dashPeriod, + dashPeriod, + dashPeriod, + dashPeriod) gl.glUniform4f(program.uniforms['dash'], *dash) + if self.dash2ndColor is None: + # Use fully transparent color which gets discarded in shader + dash2ndColor = (0., 0., 0., 0.) + else: + dash2ndColor = self.dash2ndColor + gl.glUniform4f(program.uniforms['dash2ndColor'], *dash2ndColor) + distAttrib = program.attributes['distance'] gl.glEnableVertexAttribArray(distAttrib) - self.distVboData.setVertexAttrib(distAttrib) + if isinstance(self.distVboData, VertexBufferAttrib): + self.distVboData.setVertexAttrib(distAttrib) + else: + gl.glVertexAttribPointer(distAttrib, + 1, + gl.GL_FLOAT, + False, + 0, + self.distVboData) - gl.glEnable(gl.GL_LINE_SMOOTH) + if self.width != 1: + gl.glEnable(gl.GL_LINE_SMOOTH) matrix = numpy.dot(matrix, mat4Translate(*self.offset)).astype(numpy.float32) @@ -435,11 +480,27 @@ class _Lines2D(object): xPosAttrib = program.attributes['xPos'] gl.glEnableVertexAttribArray(xPosAttrib) - self.xVboData.setVertexAttrib(xPosAttrib) + if isinstance(self.xVboData, VertexBufferAttrib): + self.xVboData.setVertexAttrib(xPosAttrib) + else: + gl.glVertexAttribPointer(xPosAttrib, + 1, + gl.GL_FLOAT, + False, + 0, + self.xVboData) yPosAttrib = program.attributes['yPos'] gl.glEnableVertexAttribArray(yPosAttrib) - self.yVboData.setVertexAttrib(yPosAttrib) + if isinstance(self.yVboData, VertexBufferAttrib): + self.yVboData.setVertexAttrib(yPosAttrib) + else: + gl.glVertexAttribPointer(yPosAttrib, + 1, + gl.GL_FLOAT, + False, + 0, + self.yVboData) gl.glLineWidth(self.width) gl.glDrawArrays(self._drawMode, 0, self.xVboData.size) @@ -447,7 +508,7 @@ class _Lines2D(object): gl.glDisable(gl.GL_LINE_SMOOTH) -def _distancesFromArrays(xData, yData): +def distancesFromArrays(xData, yData): """Returns distances between each points :param numpy.ndarray xData: X coordinate of points @@ -711,7 +772,7 @@ class _ErrorBars(object): This is using its own VBO as opposed to fill/points/lines. There is no picking on error bars. - It uses 2 vertices per error bars and uses :class:`_Lines2D` to + It uses 2 vertices per error bars and uses :class:`GLLines2D` to render error bars and :class:`_Points2D` to render the ends. :param numpy.ndarray xData: X coordinates of the data. @@ -753,7 +814,7 @@ class _ErrorBars(object): self._xData, self._yData = None, None self._xError, self._yError = None, None - self._lines = _Lines2D( + self._lines = GLLines2D( None, None, color=color, drawMode=gl.GL_LINES, offset=offset) self._xErrPoints = _Points2D( None, None, color=color, marker=V_LINE, offset=offset) @@ -957,7 +1018,7 @@ class GLPlotCurve2D(object): self.xMin, self.yMin, offset=self.offset) - self.lines = _Lines2D() + self.lines = GLLines2D() self.lines.style = lineStyle self.lines.color = lineColor self.lines.width = lineWidth @@ -999,7 +1060,7 @@ class GLPlotCurve2D(object): @classmethod def init(cls): """OpenGL context initialization""" - _Lines2D.init() + GLLines2D.init() _Points2D.init() def prepare(self): @@ -1007,7 +1068,7 @@ class GLPlotCurve2D(object): if self.xVboData is None: xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None if self.lineStyle in (DASHED, DASHDOT, DOTTED): - dists = _distancesFromArrays(self.xData, self.yData) + dists = distancesFromArrays(self.xData, self.yData) if self.colorData is None: xAttrib, yAttrib, dAttrib = vertexBuffer( (self.xData, self.yData, dists)) diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py index 4ad1547..43f6e10 100644 --- a/silx/gui/plot/backends/glutils/GLPlotFrame.py +++ b/silx/gui/plot/backends/glutils/GLPlotFrame.py @@ -63,6 +63,7 @@ class PlotAxis(object): def __init__(self, plot, tickLength=(0., 0.), + foregroundColor=(0., 0., 0., 1.0), labelAlign=CENTER, labelVAlign=CENTER, titleAlign=CENTER, titleVAlign=CENTER, titleRotate=0, titleOffset=(0., 0.)): @@ -78,6 +79,7 @@ class PlotAxis(object): self._title = '' self._tickLength = tickLength + self._foregroundColor = foregroundColor self._labelAlign = labelAlign self._labelVAlign = labelVAlign self._titleAlign = titleAlign @@ -168,6 +170,20 @@ class PlotAxis(object): if plot is not None: plot._dirty() + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "foregroundColor must have length 4, got {}".format(len(self._foregroundColor)) + if self._foregroundColor != color: + self._foregroundColor = color + self._dirtyTicks() + @property def ticks(self): """Ticks as tuples: ((x, y) in display, dataPos, textLabel).""" @@ -192,6 +208,7 @@ class PlotAxis(object): tickScale = 1. label = Text2D(text=text, + color=self._foregroundColor, x=xPixel - xTickLength, y=yPixel - yTickLength, align=self._labelAlign, @@ -223,6 +240,7 @@ class PlotAxis(object): # yOffset -= 3 * yTickLength axisTitle = Text2D(text=self.title, + color=self._foregroundColor, x=xAxisCenter + xOffset, y=yAxisCenter + yOffset, align=self._titleAlign, @@ -373,15 +391,21 @@ class GLPlotFrame(object): # Margins used when plot frame is not displayed _NoDisplayMargins = _Margins(0, 0, 0, 0) - def __init__(self, margins): + def __init__(self, margins, foregroundColor, gridColor): """ :param margins: The margins around plot area for axis and labels. :type margins: dict with 'left', 'right', 'top', 'bottom' keys and values as ints. + :param foregroundColor: color used for the frame and labels. + :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0 + :param gridColor: color used for grid lines. + :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0 """ self._renderResources = None self._margins = self._Margins(**margins) + self._foregroundColor = foregroundColor + self._gridColor = gridColor self.axes = [] # List of PlotAxis to be updated by subclasses @@ -400,6 +424,36 @@ class GLPlotFrame(object): GRID_SUB_TICKS = 2 GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS) + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "foregroundColor must have length 4, got {}".format(len(self._foregroundColor)) + if self._foregroundColor != color: + self._foregroundColor = color + for axis in self.axes: + axis.foregroundColor = color + self._dirty() + + @property + def gridColor(self): + """Color used for frame and labels""" + return self._gridColor + + @gridColor.setter + def gridColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "gridColor must have length 4, got {}".format(len(self._gridColor)) + if self._gridColor != color: + self._gridColor = color + self._dirty() + @property def displayed(self): """Whether axes and their labels are displayed or not (bool)""" @@ -522,6 +576,7 @@ class GLPlotFrame(object): self.margins.right) // 2 yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS labels.append(Text2D(text=self.title, + color=self._foregroundColor, x=xTitle, y=yTitle, align=CENTER, @@ -556,7 +611,7 @@ class GLPlotFrame(object): gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj.astype(numpy.float32)) - gl.glUniform4f(prog.uniforms['color'], 0., 0., 0., 1.) + gl.glUniform4f(prog.uniforms['color'], *self._foregroundColor) gl.glUniform1f(prog.uniforms['tickFactor'], 0.) gl.glEnableVertexAttribArray(prog.attributes['position']) @@ -590,7 +645,7 @@ class GLPlotFrame(object): gl.glLineWidth(self._LINE_WIDTH) gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj.astype(numpy.float32)) - gl.glUniform4f(prog.uniforms['color'], 0.7, 0.7, 0.7, 1.) + gl.glUniform4f(prog.uniforms['color'], *self._gridColor) gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen gl.glEnableVertexAttribArray(prog.attributes['position']) @@ -606,15 +661,21 @@ class GLPlotFrame(object): # GLPlotFrame2D ############################################################### class GLPlotFrame2D(GLPlotFrame): - def __init__(self, margins): + def __init__(self, margins, foregroundColor, gridColor): """ :param margins: The margins around plot area for axis and labels. :type margins: dict with 'left', 'right', 'top', 'bottom' keys and values as ints. + :param foregroundColor: color used for the frame and labels. + :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0 + :param gridColor: color used for grid lines. + :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0 + """ - super(GLPlotFrame2D, self).__init__(margins) + super(GLPlotFrame2D, self).__init__(margins, foregroundColor, gridColor) self.axes.append(PlotAxis(self, tickLength=(0., -5.), + foregroundColor=self._foregroundColor, labelAlign=CENTER, labelVAlign=TOP, titleAlign=CENTER, titleVAlign=TOP, titleRotate=0, @@ -624,6 +685,7 @@ class GLPlotFrame2D(GLPlotFrame): self.axes.append(PlotAxis(self, tickLength=(5., 0.), + foregroundColor=self._foregroundColor, labelAlign=RIGHT, labelVAlign=CENTER, titleAlign=CENTER, titleVAlign=BOTTOM, titleRotate=ROTATE_270, @@ -632,6 +694,7 @@ class GLPlotFrame2D(GLPlotFrame): self._y2Axis = PlotAxis(self, tickLength=(-5., 0.), + foregroundColor=self._foregroundColor, labelAlign=LEFT, labelVAlign=CENTER, titleAlign=CENTER, titleVAlign=TOP, titleRotate=ROTATE_270, @@ -825,23 +888,6 @@ class GLPlotFrame2D(GLPlotFrame): _logger.info('yMax: warning log10(%f)', y2Max) y2Max = 0. - # Non-orthogonal axes - if self.baseVectors != self.DEFAULT_BASE_VECTORS: - (xx, xy), (yx, yy) = self.baseVectors - skew_mat = numpy.array(((xx, yx), (xy, yy))) - - corners = [(xMin, yMin), (xMin, yMax), - (xMax, yMin), (xMax, yMax), - (xMin, y2Min), (xMin, y2Max), - (xMax, y2Min), (xMax, y2Max)] - - corners = numpy.array( - [numpy.dot(skew_mat, corner) for corner in corners], - dtype=numpy.float32) - xMin, xMax = corners[:, 0].min(), corners[:, 0].max() - yMin, yMax = corners[0:4, 1].min(), corners[0:4, 1].max() - y2Min, y2Max = corners[4:, 1].min(), corners[4:, 1].max() - self._transformedDataRanges = self._DataRanges( (xMin, xMax), (yMin, yMax), (y2Min, y2Max)) @@ -861,16 +907,6 @@ class GLPlotFrame2D(GLPlotFrame): mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1) else: mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1) - - # Non-orthogonal axes - if self.baseVectors != self.DEFAULT_BASE_VECTORS: - (xx, xy), (yx, yy) = self.baseVectors - mat = numpy.dot(mat, numpy.array(( - (xx, yx, 0., 0.), - (xy, yy, 0., 0.), - (0., 0., 1., 0.), - (0., 0., 0., 1.)), dtype=numpy.float64)) - self._transformedDataProjMat = mat return self._transformedDataProjMat @@ -890,16 +926,6 @@ class GLPlotFrame2D(GLPlotFrame): mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1) else: mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1) - - # Non-orthogonal axes - if self.baseVectors != self.DEFAULT_BASE_VECTORS: - (xx, xy), (yx, yy) = self.baseVectors - mat = numpy.dot(mat, numpy.matrix(( - (xx, yx, 0., 0.), - (xy, yy, 0., 0.), - (0., 0., 1., 0.), - (0., 0., 0., 1.)), dtype=numpy.float64)) - self._transformedDataY2ProjMat = mat return self._transformedDataY2ProjMat @@ -1114,3 +1140,17 @@ class GLPlotFrame2D(GLPlotFrame): vertices = numpy.append(vertices, extraVertices, axis=0) self._renderResources = (vertices, gridVertices, labels) + + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "foregroundColor must have length 4, got {}".format(len(self._foregroundColor)) + if self._foregroundColor != color: + self._y2Axis.foregroundColor = color + GLPlotFrame.foregroundColor.fset(self, color) # call parent property diff --git a/silx/gui/plot/backends/glutils/GLSupport.py b/silx/gui/plot/backends/glutils/GLSupport.py index 18c5eb7..da6dffa 100644 --- a/silx/gui/plot/backends/glutils/GLSupport.py +++ b/silx/gui/plot/backends/glutils/GLSupport.py @@ -60,16 +60,12 @@ def buildFillMaskIndices(nIndices, dtype=None): return indices -class Shape2D(object): +class FilledShape2D(object): _NO_HATCH = 0 _HATCH_STEP = 20 - def __init__(self, points, fill='solid', stroke=True, - fillColor=(0., 0., 0., 1.), strokeColor=(0., 0., 0., 1.), - strokeClosed=True): + def __init__(self, points, style='solid', color=(0., 0., 0., 1.)): self.vertices = numpy.array(points, dtype=numpy.float32, copy=False) - self.strokeClosed = strokeClosed - self._indices = buildFillMaskIndices(len(self.vertices)) tVertex = numpy.transpose(self.vertices) @@ -81,28 +77,16 @@ class Shape2D(object): self._xMin, self._xMax = xMin, xMax self._yMin, self._yMax = yMin, yMax - self.fill = fill - self.fillColor = fillColor - self.stroke = stroke - self.strokeColor = strokeColor - - @property - def xMin(self): - return self._xMin - - @property - def xMax(self): - return self._xMax - - @property - def yMin(self): - return self._yMin + self.style = style + self.color = color - @property - def yMax(self): - return self._yMax + def render(self, posAttrib, colorUnif, hatchStepUnif): + assert self.style in ('hatch', 'solid') + gl.glUniform4f(colorUnif, *self.color) + step = self._HATCH_STEP if self.style == 'hatch' else self._NO_HATCH + gl.glUniform1i(hatchStepUnif, step) - def prepareFillMask(self, posAttrib): + # Prepare fill mask gl.glEnableVertexAttribArray(posAttrib) gl.glVertexAttribPointer(posAttrib, 2, @@ -126,9 +110,6 @@ class Shape2D(object): gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE) gl.glDepthMask(gl.GL_TRUE) - def renderFill(self, posAttrib): - self.prepareFillMask(posAttrib) - gl.glVertexAttribPointer(posAttrib, 2, gl.GL_FLOAT, @@ -138,30 +119,6 @@ class Shape2D(object): gl.glDisable(gl.GL_STENCIL_TEST) - def renderStroke(self, posAttrib): - gl.glEnableVertexAttribArray(posAttrib) - gl.glVertexAttribPointer(posAttrib, - 2, - gl.GL_FLOAT, - gl.GL_FALSE, - 0, self.vertices) - gl.glLineWidth(1) - drawMode = gl.GL_LINE_LOOP if self.strokeClosed else gl.GL_LINE_STRIP - gl.glDrawArrays(drawMode, 0, len(self.vertices)) - - def render(self, posAttrib, colorUnif, hatchStepUnif): - assert self.fill in ['hatch', 'solid', None] - if self.fill is not None: - gl.glUniform4f(colorUnif, *self.fillColor) - step = self._HATCH_STEP if self.fill == 'hatch' else self._NO_HATCH - gl.glUniform1i(hatchStepUnif, step) - self.renderFill(posAttrib) - - if self.stroke: - gl.glUniform4f(colorUnif, *self.strokeColor) - gl.glUniform1i(hatchStepUnif, self._NO_HATCH) - self.renderStroke(posAttrib) - # matrix ###################################################################### diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py index e7957ac..f829f78 100644 --- a/silx/gui/plot/items/__init__.py +++ b/silx/gui/plot/items/__init__.py @@ -36,7 +36,7 @@ from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa AlphaMixIn, LineMixIn, ItemChangedType) # noqa from .complex import ImageComplexData # noqa -from .curve import Curve # noqa +from .curve import Curve, CurveStyle # noqa from .histogram import Histogram # noqa from .image import ImageBase, ImageData, ImageRgba, MaskImageData # noqa from .shape import Shape # noqa diff --git a/silx/gui/plot/items/axis.py b/silx/gui/plot/items/axis.py index 3d9fe14..8ea5c7a 100644 --- a/silx/gui/plot/items/axis.py +++ b/silx/gui/plot/items/axis.py @@ -27,16 +27,16 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "06/12/2017" +__date__ = "22/11/2018" import datetime as dt +import enum import logging import dateutil.tz from ... import qt -from silx.third_party import enum _logger = logging.getLogger(__name__) @@ -448,6 +448,8 @@ class YAxis(Axis): False for Y axis going from bottom to top """ flag = bool(flag) + if self.isInverted() == flag: + return self._getBackend().setYAxisInverted(flag) self._getPlot()._setDirtyPlot() self.sigInvertedChanged.emit(flag) diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py index 535b0a9..7fffd77 100644 --- a/silx/gui/plot/items/complex.py +++ b/silx/gui/plot/items/complex.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,9 +33,9 @@ __date__ = "14/06/2018" import logging -import numpy +import enum -from silx.third_party import enum +import numpy from ...colors import Colormap from .core import ColormapMixIn, ItemChangedType @@ -137,7 +137,6 @@ class ImageComplexData(ImageBase, ColormapMixIn): name='hsv', vmin=-numpy.pi, vmax=numpy.pi) - phaseColormap.setEditable(False) self._colormaps = { # Default colormaps for all modes self.Mode.ABSOLUTE: colormap, @@ -180,7 +179,6 @@ class ImageComplexData(ImageBase, ColormapMixIn): colormap=colormap, alpha=self.getAlpha()) - def setVisualizationMode(self, mode): """Set the visualization mode to use. diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py index e000751..bf3b719 100644 --- a/silx/gui/plot/items/core.py +++ b/silx/gui/plot/items/core.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,20 +27,23 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "14/06/2018" +__date__ = "29/01/2019" import collections from copy import deepcopy import logging +import enum import warnings import weakref + import numpy -from silx.third_party import six, enum +import six from ... import qt from ... import colors from ...colors import Colormap +from silx import config _logger = logging.getLogger(__name__) @@ -82,6 +85,9 @@ class ItemChangedType(enum.Enum): COLOR = 'colorChanged' """Item's color changed flag.""" + LINE_BG_COLOR = 'lineBgColorChanged' + """Item's line background color changed flag.""" + YAXIS = 'yAxisChanged' """Item's Y axis binding changed flag.""" @@ -411,10 +417,12 @@ class ColormapMixIn(ItemMixInBase): return self._colormap def setColormap(self, colormap): - """Set the colormap of this image + """Set the colormap of this item :param silx.gui.colors.Colormap colormap: colormap description """ + if self._colormap is colormap: + return if isinstance(colormap, dict): colormap = Colormap._fromDict(colormap) @@ -433,10 +441,10 @@ class ColormapMixIn(ItemMixInBase): class SymbolMixIn(ItemMixInBase): """Mix-in class for items with symbol type""" - _DEFAULT_SYMBOL = '' + _DEFAULT_SYMBOL = None """Default marker of the item""" - _DEFAULT_SYMBOL_SIZE = 6.0 + _DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE """Default marker size of the item""" _SUPPORTED_SYMBOLS = collections.OrderedDict(( @@ -451,8 +459,15 @@ class SymbolMixIn(ItemMixInBase): """Dict of supported symbols""" def __init__(self): - self._symbol = self._DEFAULT_SYMBOL - self._symbol_size = self._DEFAULT_SYMBOL_SIZE + if self._DEFAULT_SYMBOL is None: # Use default from config + self._symbol = config.DEFAULT_PLOT_SYMBOL + else: + self._symbol = self._DEFAULT_SYMBOL + + if self._DEFAULT_SYMBOL_SIZE is None: # Use default from config + self._symbol_size = config.DEFAULT_PLOT_SYMBOL_SIZE + else: + self._symbol_size = self._DEFAULT_SYMBOL_SIZE @classmethod def getSupportedSymbols(cls): @@ -892,14 +907,14 @@ class Points(Item, SymbolMixIn, AlphaMixIn): # use the getData class method because instance method can be # overloaded to return additional arrays data = Points.getData(self, copy=False, - displayed=True) + displayed=True) if len(data) == 5: # hack to avoid duplicating caching mechanism in Scatter # (happens when cached data is used, caching done using # Scatter._logFilterData) - x, y, xerror, yerror = data[0], data[1], data[3], data[4] + x, y, _xerror, _yerror = data[0], data[1], data[3], data[4] else: - x, y, xerror, yerror = data + x, y, _xerror, _yerror = data self._boundsCache[(xPositive, yPositive)] = ( numpy.nanmin(x), diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py index 80d9dea..79def55 100644 --- a/silx/gui/plot/items/curve.py +++ b/silx/gui/plot/items/curve.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,9 +31,10 @@ __date__ = "24/04/2018" import logging + import numpy +import six -from silx.third_party import six from ....utils.deprecation import deprecated from ... import colors from .core import (Points, LabelsMixIn, ColorMixIn, YAxisMixIn, diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py index 389e8a6..a1d6586 100644 --- a/silx/gui/plot/items/histogram.py +++ b/silx/gui/plot/items/histogram.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -197,13 +197,13 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, values[clipped_values] = numpy.nan - if xPositive or yPositive: + if yPositive: return (numpy.nanmin(edges), numpy.nanmax(edges), numpy.nanmin(values), numpy.nanmax(values)) - else: # No log scale, include 0 in bounds + else: # No log scale on y axis, include 0 in bounds return (numpy.nanmin(edges), numpy.nanmax(edges), min(0, numpy.nanmin(values)), diff --git a/silx/gui/plot/items/roi.py b/silx/gui/plot/items/roi.py index f55ef91..0169439 100644 --- a/silx/gui/plot/items/roi.py +++ b/silx/gui/plot/items/roi.py @@ -65,7 +65,7 @@ class RegionOfInterest(qt.QObject): # Avoid circular dependancy from ..tools import roi as roi_tools assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager) - super(RegionOfInterest, self).__init__(parent) + qt.QObject.__init__(self, parent) self._color = rgba('red') self._items = WeakList() self._editAnchors = WeakList() @@ -108,7 +108,7 @@ class RegionOfInterest(qt.QObject): return qt.QColor.fromRgbF(*self._color) def _getAnchorColor(self, color): - """Returns the anchor color from the base ROI color + """Returns the anchor color from the base ROI color :param Union[numpy.array,Tuple,List]: color :rtype: Union[numpy.array,Tuple,List] @@ -209,7 +209,7 @@ class RegionOfInterest(qt.QObject): def setFirstShapePoints(self, points): """"Initialize the ROI using the points from the first interaction. - This interaction is constains by the plot API and only supports few + This interaction is constrained by the plot API and only supports few shapes. """ points = self._createControlPointsFromFirstShape(points) @@ -410,6 +410,13 @@ class RegionOfInterest(qt.QObject): plot._remove(item) self._labelItem = None + def _updated(self, event=None, checkVisibility=True): + """Implement Item mix-in update method by updating the plot items + + See :class:`~silx.gui.plot.items.Item._updated` + """ + self._createPlotItems() + def __str__(self): """Returns parameters of the ROI as a string.""" points = self._getControlPoints() @@ -417,7 +424,7 @@ class RegionOfInterest(qt.QObject): return "%s(%s)" % (self.__class__.__name__, params) -class PointROI(RegionOfInterest): +class PointROI(RegionOfInterest, items.SymbolMixIn): """A ROI identifying a point in a 2D plot.""" _kind = "Point" @@ -426,6 +433,10 @@ class PointROI(RegionOfInterest): _plotShape = "point" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.SymbolMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def getPosition(self): """Returns the position of this ROI @@ -458,6 +469,8 @@ class PointROI(RegionOfInterest): marker.setPosition(points[0][0], points[0][1]) marker.setText(self.getLabel()) marker.setColor(rgba(self.getColor())) + marker.setSymbol(self.getSymbol()) + marker.setSymbolSize(self.getSymbolSize()) marker._setDraggable(False) return [marker] @@ -466,6 +479,8 @@ class PointROI(RegionOfInterest): marker.setPosition(points[0][0], points[0][1]) marker.setText(self.getLabel()) marker._setDraggable(self.isEditable()) + marker.setSymbol(self.getSymbol()) + marker.setSymbolSize(self.getSymbolSize()) return [marker] def __str__(self): @@ -474,7 +489,7 @@ class PointROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class LineROI(RegionOfInterest): +class LineROI(RegionOfInterest, items.LineMixIn): """A ROI identifying a line in a 2D plot. This ROI provides 1 anchor for each boundary of the line, plus an center @@ -487,6 +502,10 @@ class LineROI(RegionOfInterest): _plotShape = "line" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def _createControlPointsFromFirstShape(self, points): center = numpy.mean(points, axis=0) controlPoints = numpy.array([points[0], points[1], center]) @@ -535,6 +554,8 @@ class LineROI(RegionOfInterest): item.setColor(rgba(self.getColor())) item.setFill(False) item.setOverlay(True) + item.setLineStyle(self.getLineStyle()) + item.setLineWidth(self.getLineWidth()) return [item] def _createAnchorItems(self, points): @@ -582,7 +603,7 @@ class LineROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class HorizontalLineROI(RegionOfInterest): +class HorizontalLineROI(RegionOfInterest, items.LineMixIn): """A ROI identifying an horizontal line in a 2D plot.""" _kind = "HLine" @@ -591,6 +612,10 @@ class HorizontalLineROI(RegionOfInterest): _plotShape = "hline" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def _createControlPointsFromFirstShape(self, points): points = numpy.array([(float('nan'), points[0, 1])], dtype=numpy.float64) @@ -636,6 +661,8 @@ class HorizontalLineROI(RegionOfInterest): marker.setText(self.getLabel()) marker.setColor(rgba(self.getColor())) marker._setDraggable(False) + marker.setLineWidth(self.getLineWidth()) + marker.setLineStyle(self.getLineStyle()) return [marker] def _createAnchorItems(self, points): @@ -643,6 +670,8 @@ class HorizontalLineROI(RegionOfInterest): marker.setPosition(points[0][0], points[0][1]) marker.setText(self.getLabel()) marker._setDraggable(self.isEditable()) + marker.setLineWidth(self.getLineWidth()) + marker.setLineStyle(self.getLineStyle()) return [marker] def __str__(self): @@ -651,7 +680,7 @@ class HorizontalLineROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class VerticalLineROI(RegionOfInterest): +class VerticalLineROI(RegionOfInterest, items.LineMixIn): """A ROI identifying a vertical line in a 2D plot.""" _kind = "VLine" @@ -660,6 +689,10 @@ class VerticalLineROI(RegionOfInterest): _plotShape = "vline" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def _createControlPointsFromFirstShape(self, points): points = numpy.array([(points[0, 0], float('nan'))], dtype=numpy.float64) @@ -705,6 +738,8 @@ class VerticalLineROI(RegionOfInterest): marker.setText(self.getLabel()) marker.setColor(rgba(self.getColor())) marker._setDraggable(False) + marker.setLineWidth(self.getLineWidth()) + marker.setLineStyle(self.getLineStyle()) return [marker] def _createAnchorItems(self, points): @@ -712,6 +747,8 @@ class VerticalLineROI(RegionOfInterest): marker.setPosition(points[0][0], points[0][1]) marker.setText(self.getLabel()) marker._setDraggable(self.isEditable()) + marker.setLineWidth(self.getLineWidth()) + marker.setLineStyle(self.getLineStyle()) return [marker] def __str__(self): @@ -720,7 +757,7 @@ class VerticalLineROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class RectangleROI(RegionOfInterest): +class RectangleROI(RegionOfInterest, items.LineMixIn): """A ROI identifying a rectangle in a 2D plot. This ROI provides 1 anchor for each corner, plus an anchor in the @@ -733,6 +770,10 @@ class RectangleROI(RegionOfInterest): _plotShape = "rectangle" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def _createControlPointsFromFirstShape(self, points): point0 = points[0] point1 = points[1] @@ -838,6 +879,8 @@ class RectangleROI(RegionOfInterest): item.setColor(rgba(self.getColor())) item.setFill(False) item.setOverlay(True) + item.setLineStyle(self.getLineStyle()) + item.setLineWidth(self.getLineWidth()) return [item] def _createAnchorItems(self, points): @@ -894,7 +937,7 @@ class RectangleROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class PolygonROI(RegionOfInterest): +class PolygonROI(RegionOfInterest, items.LineMixIn): """A ROI identifying a closed polygon in a 2D plot. This ROI provides 1 anchor for each point of the polygon. @@ -906,6 +949,10 @@ class PolygonROI(RegionOfInterest): _plotShape = "polygon" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def getPoints(self): """Returns the list of the points of this polygon. @@ -948,6 +995,8 @@ class PolygonROI(RegionOfInterest): item.setColor(rgba(self.getColor())) item.setFill(False) item.setOverlay(True) + item.setLineStyle(self.getLineStyle()) + item.setLineWidth(self.getLineWidth()) return [item] def _createAnchorItems(self, points): @@ -967,7 +1016,7 @@ class PolygonROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class ArcROI(RegionOfInterest): +class ArcROI(RegionOfInterest, items.LineMixIn): """A ROI identifying an arc of a circle with a width. This ROI provides 3 anchors to control the curvature, 1 anchor to control @@ -986,6 +1035,7 @@ class ArcROI(RegionOfInterest): 'startAngle', 'endAngle']) def __init__(self, parent=None): + items.LineMixIn.__init__(self) RegionOfInterest.__init__(self, parent=parent) self._geometry = None @@ -1357,6 +1407,8 @@ class ArcROI(RegionOfInterest): item.setColor(rgba(self.getColor())) item.setFill(False) item.setOverlay(True) + item.setLineStyle(self.getLineStyle()) + item.setLineWidth(self.getLineWidth()) return [item] def _createAnchorItems(self, points): diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py index acc74b4..707dd3d 100644 --- a/silx/gui/plot/items/scatter.py +++ b/silx/gui/plot/items/scatter.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -46,9 +46,6 @@ class Scatter(Points, ColormapMixIn): _DEFAULT_SELECTABLE = True """Default selectable state for scatter plots""" - _DEFAULT_SYMBOL = 'o' - """Default symbol of the scatter plots""" - def __init__(self): Points.__init__(self) ColormapMixIn.__init__(self) diff --git a/silx/gui/plot/items/shape.py b/silx/gui/plot/items/shape.py index 65b26a1..9fc1306 100644 --- a/silx/gui/plot/items/shape.py +++ b/silx/gui/plot/items/shape.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,14 +27,16 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "17/05/2017" +__date__ = "21/12/2018" import logging import numpy +import six -from .core import (Item, ColorMixIn, FillMixIn, ItemChangedType) +from ... import colors +from .core import Item, ColorMixIn, FillMixIn, ItemChangedType, LineMixIn _logger = logging.getLogger(__name__) @@ -42,7 +44,7 @@ _logger = logging.getLogger(__name__) # TODO probably make one class for each kind of shape # TODO check fill:polygon/polyline + fill = duplicated -class Shape(Item, ColorMixIn, FillMixIn): +class Shape(Item, ColorMixIn, FillMixIn, LineMixIn): """Description of a shape item :param str type_: The type of shape in: @@ -53,10 +55,12 @@ class Shape(Item, ColorMixIn, FillMixIn): Item.__init__(self) ColorMixIn.__init__(self) FillMixIn.__init__(self) + LineMixIn.__init__(self) self._overlay = False assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polylines') self._type = type_ self._points = () + self._lineBgColor = None self._handle = None @@ -71,7 +75,10 @@ class Shape(Item, ColorMixIn, FillMixIn): color=self.getColor(), fill=self.isFill(), overlay=self.isOverlay(), - z=self.getZValue()) + z=self.getZValue(), + linestyle=self.getLineStyle(), + linewidth=self.getLineWidth(), + linebgcolor=self.getLineBgColor()) def isOverlay(self): """Return true if shape is drawn as an overlay @@ -119,3 +126,31 @@ class Shape(Item, ColorMixIn, FillMixIn): """ self._points = numpy.array(points, copy=copy) self._updated(ItemChangedType.DATA) + + def getLineBgColor(self): + """Returns the RGBA color of the item + :rtype: 4-tuple of float in [0, 1] or array of colors + """ + return self._lineBgColor + + def setLineBgColor(self, color, copy=True): + """Set item color + :param color: color(s) to be used + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in colors.py + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + if color is not None: + if isinstance(color, six.string_types): + color = colors.rgba(color) + else: + color = numpy.array(color, copy=copy) + # TODO more checks + improve color array support + if color.ndim == 1: # Single RGBA color + color = colors.rgba(color) + else: # Array of colors + assert color.ndim == 2 + + self._lineBgColor = color + self._updated(ItemChangedType.LINE_BG_COLOR) diff --git a/silx/gui/plot/matplotlib/Colormap.py b/silx/gui/plot/matplotlib/Colormap.py index 772a473..38f3b55 100644 --- a/silx/gui/plot/matplotlib/Colormap.py +++ b/silx/gui/plot/matplotlib/Colormap.py @@ -29,7 +29,13 @@ from matplotlib.colors import ListedColormap import matplotlib.colors import matplotlib.cm import silx.resources -from silx.utils.deprecation import deprecated +from silx.utils.deprecation import deprecated, deprecated_warning + + +deprecated_warning(type_='module', + name=__file__, + replacement='silx.gui.colors.Colormap', + since_version='0.10.0') _logger = logging.getLogger(__name__) @@ -46,25 +52,30 @@ _CMAPS = {} @property +@deprecated(since_version='0.10.0') def magma(): return getColormap('magma') @property +@deprecated(since_version='0.10.0') def inferno(): return getColormap('inferno') @property +@deprecated(since_version='0.10.0') def plasma(): return getColormap('plasma') @property +@deprecated(since_version='0.10.0') def viridis(): return getColormap('viridis') +@deprecated(since_version='0.10.0') def getColormap(name): """Returns matplotlib colormap corresponding to given name @@ -143,6 +154,7 @@ def getColormap(name): return matplotlib.cm.get_cmap(name) +@deprecated(since_version='0.10.0') def getScalarMappable(colormap, data=None): """Returns matplotlib ScalarMappable corresponding to colormap @@ -223,6 +235,8 @@ def applyColormapToData(data, colormap): return rgbaImage +@deprecated(replacement='silx.colors.Colormap.getSupportedColormaps', + since_version='0.10.0') def getSupportedColormaps(): """Get the supported colormap names as a tuple of str. """ diff --git a/silx/gui/plot/stats/stats.py b/silx/gui/plot/stats/stats.py index a753989..ad61536 100644 --- a/silx/gui/plot/stats/stats.py +++ b/silx/gui/plot/stats/stats.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -30,15 +30,15 @@ __license__ = "MIT" __date__ = "06/06/2018" -import numpy -from silx.gui.plot.items.curve import Curve as CurveItem -from silx.gui.plot.items.image import ImageBase as ImageItem -from silx.gui.plot.items.scatter import Scatter as ScatterItem -from silx.gui.plot.items.histogram import Histogram as HistogramItem -from silx.math.combo import min_max from collections import OrderedDict import logging +import numpy + +from .. import items +from ....math.combo import min_max + + logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ class Stats(OrderedDict): def calculate(self, item, plot, onlimits): """ - Call all :class:`Stat` object registred and return the result of the + Call all :class:`Stat` object registered and return the result of the computation. :param item: the item for which we want statistics @@ -72,17 +72,29 @@ class Stats(OrderedDict): :return dict: dictionary with :class:`Stat` name as ket and result of the calculation as value """ - res = {} - if isinstance(item, CurveItem): + context = None + # Check for PlotWidget items + if isinstance(item, items.Curve): context = _CurveContext(item, plot, onlimits) - elif isinstance(item, ImageItem): + elif isinstance(item, items.ImageData): context = _ImageContext(item, plot, onlimits) - elif isinstance(item, ScatterItem): + elif isinstance(item, items.Scatter): context = _ScatterContext(item, plot, onlimits) - elif isinstance(item, HistogramItem): + elif isinstance(item, items.Histogram): context = _HistogramContext(item, plot, onlimits) else: - raise ValueError('Item type not managed') + # Check for SceneWidget items + from ...plot3d import items as items3d # Lazy import + + if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)): + context = _plot3DScatterContext(item, plot, onlimits) + elif isinstance(item, (items3d.ImageData, items3d.ScalarField3D)): + context = _plot3DArrayContext(item, plot, onlimits) + + if context is None: + raise ValueError('Item type not managed') + + res = {} for statName, stat in list(self.items()): if context.kind not in stat.compatibleKinds: logger.debug('kind %s not managed by statistic %s' @@ -124,12 +136,54 @@ class _StatsContext(object): self.min = None self.max = None self.data = None + self.values = None + """The array of data""" + + self.axes = None + """A list of array of position on each axis. + + If the signal is an array, + then each axis has the length of that dimension, + and the order is (z, y, x) (i.e., as the array shape). + If the signal is not an array, + then each axis has the same length as the signal, + and the order is (x, y, z). + """ + self.createContext(item, plot, onlimits) def createContext(self, item, plot, onlimits): raise NotImplementedError("Base class") + def isStructuredData(self): + """Returns True if data as an array-like structure. + + :rtype: bool + """ + if self.values is None or self.axes is None: + return False + + if numpy.prod([len(axis) for axis in self.axes]) == self.values.size: + return True + else: + # Make sure there is the right number of value in axes + for axis in self.axes: + assert len(axis) == self.values.size + return False + + def isScalarData(self): + """Returns True if data is a scalar. + + :rtype: bool + """ + if self.values is None or self.axes is None: + return False + if self.isStructuredData(): + return len(self.axes) == self.values.ndim + else: + return self.values.ndim == 1 + class _CurveContext(_StatsContext): """ @@ -149,8 +203,9 @@ class _CurveContext(_StatsContext): if onlimits: minX, maxX = plot.getXAxis().getLimits() - yData = yData[(minX <= xData) & (xData <= maxX)] - xData = xData[(minX <= xData) & (xData <= maxX)] + mask = (minX <= xData) & (xData <= maxX) + yData = yData[mask] + xData = xData[mask] self.xData = xData self.yData = yData @@ -160,11 +215,12 @@ class _CurveContext(_StatsContext): self.min, self.max = None, None self.data = (xData, yData) self.values = yData + self.axes = (xData,) class _HistogramContext(_StatsContext): """ - StatsContext for :class:`Curve` + StatsContext for :class:`Histogram` :param item: the item for which we want to compute the context :param plot: the plot containing the item @@ -176,12 +232,13 @@ class _HistogramContext(_StatsContext): plot=plot, onlimits=onlimits) def createContext(self, item, plot, onlimits): - xData, edges = item.getData(copy=True)[0:2] - yData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment()) + yData, edges = item.getData(copy=True)[0:2] + xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment()) if onlimits: minX, maxX = plot.getXAxis().getLimits() - yData = yData[(minX <= xData) & (xData <= maxX)] - xData = xData[(minX <= xData) & (xData <= maxX)] + mask = (minX <= xData) & (xData <= maxX) + yData = yData[mask] + xData = xData[mask] self.xData = xData self.yData = yData @@ -191,11 +248,13 @@ class _HistogramContext(_StatsContext): self.min, self.max = None, None self.data = (xData, yData) self.values = yData + self.axes = (xData,) class _ScatterContext(_StatsContext): - """ - StatsContext for :class:`Scatter` + """StatsContext scatter plots. + + It supports :class:`~silx.gui.plot.items.Scatter`. :param item: the item for which we want to compute the context :param plot: the plot containing the item @@ -207,11 +266,14 @@ class _ScatterContext(_StatsContext): onlimits=onlimits) def createContext(self, item, plot, onlimits): - xData, yData, valueData, xerror, yerror = item.getData(copy=True) - assert plot + valueData = item.getValueData(copy=True) + xData = item.getXData(copy=True) + yData = item.getYData(copy=True) + if onlimits: minX, maxX = plot.getXAxis().getLimits() minY, maxY = plot.getYAxis().getLimits() + # filter on X axis valueData = valueData[(minX <= xData) & (xData <= maxX)] yData = yData[(minX <= xData) & (xData <= maxX)] @@ -220,17 +282,20 @@ class _ScatterContext(_StatsContext): valueData = valueData[(minY <= yData) & (yData <= maxY)] xData = xData[(minY <= yData) & (yData <= maxY)] yData = yData[(minY <= yData) & (yData <= maxY)] + if len(valueData) > 0: self.min, self.max = min_max(valueData) else: self.min, self.max = None, None self.data = (xData, yData, valueData) self.values = valueData + self.axes = (xData, yData) class _ImageContext(_StatsContext): - """ - StatsContext for :class:`ImageBase` + """StatsContext for images. + + It supports :class:`~silx.gui.plot.items.ImageData`. :param item: the item for which we want to compute the context :param plot: the plot containing the item @@ -244,7 +309,8 @@ class _ImageContext(_StatsContext): def createContext(self, item, plot, onlimits): self.origin = item.getOrigin() self.scale = item.getScale() - self.data = item.getData() + + self.data = item.getData(copy=True) if onlimits: minX, maxX = plot.getXAxis().getLimits() @@ -259,25 +325,88 @@ class _ImageContext(_StatsContext): YMinBound = max(YMinBound, 0) if XMaxBound <= XMinBound or YMaxBound <= YMinBound: - return self.noDataSelected() - data = item.getData() - self.data = data[YMinBound:YMaxBound + 1, XMinBound:XMaxBound + 1] - else: - self.data = item.getData() - + self.data = None + else: + self.data = self.data[YMinBound:YMaxBound + 1, + XMinBound:XMaxBound + 1] if self.data.size > 0: self.min, self.max = min_max(self.data) else: self.min, self.max = None, None self.values = self.data + if self.values is not None: + self.axes = (self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]), + self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1])) + + +class _plot3DScatterContext(_StatsContext): + """StatsContext for 3D scatter plots. + + It supports :class:`~silx.gui.plot3d.items.Scatter2D` and + :class:`~silx.gui.plot3d.items.Scatter3D`. + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + """ + def __init__(self, item, plot, onlimits): + _StatsContext.__init__(self, kind='scatter', item=item, plot=plot, + onlimits=onlimits) + + def createContext(self, item, plot, onlimits): + if onlimits: + raise RuntimeError("Unsupported plot %s" % str(plot)) + + values = item.getValueData(copy=False) + + if values is not None and len(values) > 0: + self.values = values + axes = [item.getXData(copy=False), item.getYData(copy=False)] + if self.values.ndim == 3: + axes.append(item.getZData(copy=False)) + self.axes = tuple(axes) + + self.min, self.max = min_max(self.values) + else: + self.values = None + self.axes = None + self.min, self.max = None, None + + +class _plot3DArrayContext(_StatsContext): + """StatsContext for 3D scalar field and data image. + + It supports :class:`~silx.gui.plot3d.items.ScalarField3D` and + :class:`~silx.gui.plot3d.items.ImageData`. + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + """ + def __init__(self, item, plot, onlimits): + _StatsContext.__init__(self, kind='image', item=item, plot=plot, + onlimits=onlimits) + + def createContext(self, item, plot, onlimits): + if onlimits: + raise RuntimeError("Unsupported plot %s" % str(plot)) + + values = item.getData(copy=False) + + if values is not None and len(values) > 0: + self.values = values + self.axes = tuple([numpy.arange(size) for size in self.values.shape]) + self.min, self.max = min_max(self.values) + else: + self.values = None + self.axes = None + self.min, self.max = None, None + -BASIC_COMPATIBLE_KINDS = { - 'curve': CurveItem, - 'image': ImageItem, - 'scatter': ScatterItem, - 'histogram': HistogramItem, -} +BASIC_COMPATIBLE_KINDS = 'curve', 'image', 'scatter', 'histogram' class StatBase(object): @@ -285,9 +414,8 @@ class StatBase(object): Base class for defining a statistic. :param str name: the name of the statistic. Must be unique. - :param compatibleKinds: the kind of items (curve, scatter...) for which - the statistic apply. - :rtype: List or tuple + :param List[str] compatibleKinds: + The kind of items (curve, scatter...) for which the statistic apply. """ def __init__(self, name, compatibleKinds=BASIC_COMPATIBLE_KINDS, description=None): self.name = name @@ -298,7 +426,7 @@ class StatBase(object): """ compute the statistic for the given :class:`StatsContext` - :param context: + :param _StatsContext context: :return dict: key is stat name, statistic computed is the dict value """ raise NotImplementedError('Base class') @@ -307,7 +435,7 @@ class StatBase(object): """ If necessary add a tooltip for a stat kind - :param str kinf: the kind of item the statistic is compute for. + :param str kind: the kind of item the statistic is compute for. :return: tooltip or None if no tooltip """ return None @@ -329,17 +457,18 @@ class Stat(StatBase): self._fct = fct def calculate(self, context): - if context.kind in self.compatibleKinds: - return self._fct(context.values) + if context.values is not None: + if context.kind in self.compatibleKinds: + return self._fct(context.values) + else: + raise ValueError('Kind %s not managed by %s' + '' % (context.kind, self.name)) else: - raise ValueError('Kind %s not managed by %s' - '' % (context.kind, self.name)) + return None class StatMin(StatBase): - """ - Compute the minimal value on data - """ + """Compute the minimal value on data""" def __init__(self): StatBase.__init__(self, name='min') @@ -348,9 +477,7 @@ class StatMin(StatBase): class StatMax(StatBase): - """ - Compute the maximal value on data - """ + """Compute the maximal value on data""" def __init__(self): StatBase.__init__(self, name='max') @@ -359,9 +486,7 @@ class StatMax(StatBase): class StatDelta(StatBase): - """ - Compute the delta between minimal and maximal on data - """ + """Compute the delta between minimal and maximal on data""" def __init__(self): StatBase.__init__(self, name='delta') @@ -369,123 +494,84 @@ class StatDelta(StatBase): return context.max - context.min -class StatCoordMin(StatBase): - """ - Compute the first coordinates of the data minimal value - """ +class _StatCoord(StatBase): + """Base class for argmin and argmax stats""" + + def _indexToCoordinates(self, context, index): + """Returns the coordinates of data point at given index + + If data is an array, coordinates are in reverse order from data shape. + + :param _StatsContext context: + :param int index: Index in the flattened data array + :rtype: List[int] + """ + if context.isStructuredData(): + coordinates = [] + for axis in reversed(context.axes): + coordinates.append(axis[index % len(axis)]) + index = index // len(axis) + return tuple(coordinates) + else: + return tuple(axis[index] for axis in context.axes) + + +class StatCoordMin(_StatCoord): + """Compute the coordinates of the first minimum value of the data""" def __init__(self): - StatBase.__init__(self, name='coords min') + _StatCoord.__init__(self, name='coords min') def calculate(self, context): - if context.kind in ('curve', 'histogram'): - return context.xData[numpy.argmin(context.yData)] - elif context.kind == 'scatter': - xData, yData, valueData = context.data - return (xData[numpy.argmin(valueData)], - yData[numpy.argmin(valueData)]) - elif context.kind == 'image': - scaleX, scaleY = context.scale - originX, originY = context.origin - index1D = numpy.argmin(context.data) - ySize = (context.data.shape[1]) - x = index1D % context.data.shape[1] - y = (index1D - x) / ySize - x = x * scaleX + originX - y = y * scaleY + originY - return (x, y) - else: - raise ValueError('kind not managed') + if context.values is None or not context.isScalarData(): + return None + + index = numpy.argmin(context.values) + return self._indexToCoordinates(context, index) def getToolTip(self, kind): - if kind in ('scatter', 'image'): - return '(x, y)' - else: - return None + return "Coordinates of the first minimum value of the data" -class StatCoordMax(StatBase): - """ - Compute the first coordinates of the data minimal value - """ + +class StatCoordMax(_StatCoord): + """Compute the coordinates of the first maximum value of the data""" def __init__(self): - StatBase.__init__(self, name='coords max') + _StatCoord.__init__(self, name='coords max') def calculate(self, context): - if context.kind in ('curve', 'histogram'): - return context.xData[numpy.argmax(context.yData)] - elif context.kind == 'scatter': - xData, yData, valueData = context.data - return (xData[numpy.argmax(valueData)], - yData[numpy.argmax(valueData)]) - elif context.kind == 'image': - scaleX, scaleY = context.scale - originX, originY = context.origin - index1D = numpy.argmax(context.data) - ySize = (context.data.shape[1]) - x = index1D % context.data.shape[1] - y = (index1D - x) / ySize - x = x * scaleX + originX - y = y * scaleY + originY - return (x, y) - else: - raise ValueError('kind not managed') + if context.values is None or not context.isScalarData(): + return None + + index = numpy.argmax(context.values) + return self._indexToCoordinates(context, index) def getToolTip(self, kind): - if kind in ('scatter', 'image'): - return '(x, y)' - else: - return None + return "Coordinates of the first maximum value of the data" + class StatCOM(StatBase): - """ - Compute data center of mass - """ + """Compute data center of mass""" def __init__(self): StatBase.__init__(self, name='COM', description='Center of mass') def calculate(self, context): - if context.kind in ('curve', 'histogram'): - xData, yData = context.data - deno = numpy.sum(yData).astype(numpy.float32) - if deno == 0.: - return numpy.nan - else: - return numpy.sum(xData * yData).astype(numpy.float32) / deno - elif context.kind == 'scatter': - xData, yData, values = context.data - deno = numpy.sum(values).astype(numpy.float32) - if deno == 0.: - return numpy.nan, numpy.nan - else: - xcom = numpy.sum(xData * values).astype(numpy.float32) / deno - ycom = numpy.sum(yData * values).astype(numpy.float32) / deno - return (xcom, ycom) - elif context.kind == 'image': - yData = numpy.sum(context.data, axis=1) - xData = numpy.sum(context.data, axis=0) - dataXRange = range(context.data.shape[1]) - dataYRange = range(context.data.shape[0]) - xScale, yScale = context.scale - xOrigin, yOrigin = context.origin - - denoY = numpy.sum(yData) - if denoY == 0.: - ycom = numpy.nan - else: - ycom = numpy.sum(yData * dataYRange) / denoY - ycom = ycom * yScale + yOrigin + if context.values is None or not context.isScalarData(): + return None - denoX = numpy.sum(xData) - if denoX == 0.: - xcom = numpy.nan - else: - xcom = numpy.sum(xData * dataXRange) / denoX - xcom = xcom * xScale + xOrigin - return (xcom, ycom) + values = numpy.array(context.values, dtype=numpy.float64) + sum_ = numpy.sum(values) + if sum_ == 0.: + return (numpy.nan,) * len(context.axes) + + if context.isStructuredData(): + centerofmass = [] + for index, axis in enumerate(context.axes): + axes = tuple([i for i in range(len(context.axes)) if i != index]) + centerofmass.append( + numpy.sum(axis * numpy.sum(values, axis=axes)) / sum_) + return tuple(reversed(centerofmass)) else: - raise ValueError('kind not managed') + return tuple( + numpy.sum(axis * values) / sum_ for axis in context.axes) def getToolTip(self, kind): - if kind in ('scatter', 'image'): - return '(x, y)' - else: - return None + return "Compute the center of mass of the dataset" diff --git a/silx/gui/plot/stats/statshandler.py b/silx/gui/plot/stats/statshandler.py index 0a62b31..f69daff 100644 --- a/silx/gui/plot/stats/statshandler.py +++ b/silx/gui/plot/stats/statshandler.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -45,7 +45,14 @@ class _FloatItem(qt.QTableWidgetItem): qt.QTableWidgetItem.__init__(self, type=type) def __lt__(self, other): - return float(self.text()) < float(other.text()) + self_values = self.text().lstrip('(').rstrip(')').split(',') + other_values = other.text().lstrip('(').rstrip(')').split(',') + for self_value, other_value in zip(self_values, other_values): + f_self_value = float(self_value) + f_other_value = float(other_value) + if f_self_value != f_other_value: + return f_self_value < f_other_value + return False class StatFormatter(object): @@ -89,10 +96,60 @@ class StatsHandler(object): self.stats = statsmdl.Stats() self.formatters = {} for elmt in statFormatters: - helper = _StatHelper(elmt) - self.add(stat=helper.stat, formatter=helper.statFormatter) + stat, formatter = self._processStatArgument(elmt) + self.add(stat=stat, formatter=formatter) + + @staticmethod + def _processStatArgument(arg): + """Process an element of the init arguments + + :param arg: The argument to process + :return: Corresponding (StatBase, StatFormatter) + """ + stat, formatter = None, None + + if isinstance(arg, statsmdl.StatBase): + stat = arg + else: + assert len(arg) > 0 + if isinstance(arg[0], statsmdl.StatBase): + stat = arg[0] + if len(arg) > 2: + raise ValueError('To many argument with %s. At most one ' + 'argument can be associated with the ' + 'BaseStat (the `StatFormatter`') + if len(arg) == 2: + assert arg[1] is None or isinstance(arg[1], (StatFormatter, str)) + formatter = arg[1] + else: + if isinstance(arg[0], tuple): + if len(arg) > 1: + formatter = arg[1] + arg = arg[0] + + if type(arg[0]) is not str: + raise ValueError('first element of the tuple should be a string' + ' or a StatBase instance') + if len(arg) == 1: + raise ValueError('A function should be associated with the' + 'stat name') + if len(arg) > 3: + raise ValueError('Two much argument given for defining statistic.' + 'Take at most three arguments (name, function, ' + 'kinds)') + if len(arg) == 2: + stat = statsmdl.Stat(name=arg[0], fct=arg[1]) + else: + stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2]) + + return stat, formatter def add(self, stat, formatter=None): + """Add a stat to the list. + + :param StatBase stat: + :param Union[None,StatFormatter] formatter: + """ assert isinstance(stat, statsmdl.StatBase) self.stats.add(stat) _formatter = formatter @@ -101,9 +158,9 @@ class StatsHandler(object): self.formatters[stat.name] = _formatter def format(self, name, val): - """ - Apply the format for the `name` statistic and the given value - :param name: the name of the associated statistic + """Apply the format for the `name` statistic and the given value + + :param str name: the name of the associated statistic :param val: value before formatting :return: formatted value """ @@ -123,7 +180,7 @@ class StatsHandler(object): def calculate(self, item, plot, onlimits): """ - compute all statistic registred and return the list of formatted + compute all statistic registered and return the list of formatted statistics result. :param item: item for which we want to compute statistics @@ -137,54 +194,3 @@ class StatsHandler(object): for resName, resValue in list(res.items()): res[resName] = self.format(resName, res[resName]) return res - - -class _StatHelper(object): - """ - Helper class to generated the requested StatBase instance and the - associated StatFormatter - """ - def __init__(self, arg): - self.statFormatter = None - self.stat = None - - if isinstance(arg, statsmdl.StatBase): - self.stat = arg - else: - assert len(arg) > 0 - if isinstance(arg[0], statsmdl.StatBase): - self.dealWithStatAndFormatter(arg) - else: - _arg = arg - if isinstance(arg[0], tuple): - _arg = arg[0] - if len(arg) > 1: - self.statFormatter = arg[1] - self.createStatInstanceAndFormatter(_arg) - - def dealWithStatAndFormatter(self, arg): - assert isinstance(arg[0], statsmdl.StatBase) - self.stat = arg[0] - if len(arg) > 2: - raise ValueError('To many argument with %s. At most one ' - 'argument can be associated with the ' - 'BaseStat (the `StatFormatter`') - if len(arg) is 2: - assert isinstance(arg[1], (StatFormatter, type(None), str)) - self.statFormatter = arg[1] - - def createStatInstanceAndFormatter(self, arg): - if type(arg[0]) is not str: - raise ValueError('first element of the tuple should be a string' - ' or a StatBase instance') - if len(arg) is 1: - raise ValueError('A function should be associated with the' - 'stat name') - if len(arg) > 3: - raise ValueError('Two much argument given for defining statistic.' - 'Take at most three arguments (name, function, ' - 'kinds)') - if len(arg) is 2: - self.stat = statsmdl.Stat(name=arg[0], fct=arg[1]) - else: - self.stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2]) diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py index 0704779..5bcabd8 100644 --- a/silx/gui/plot/test/testCurvesROIWidget.py +++ b/silx/gui/plot/test/testCurvesROIWidget.py @@ -36,7 +36,7 @@ from collections import OrderedDict import numpy from silx.gui import qt from silx.test.utils import temp_dir -from silx.gui.utils.testutils import TestCaseQt +from silx.gui.utils.testutils import TestCaseQt, SignalListener from silx.gui.plot import PlotWindow, CurvesROIWidget @@ -52,7 +52,8 @@ class TestCurvesROIWidget(TestCaseQt): self.plot.show() self.qWaitForWindowExposed(self.plot) - self.widget = CurvesROIWidget.CurvesROIDockWidget(plot=self.plot, name='TEST') + self.widget = self.plot.getCurvesRoiDockWidget() + self.widget.show() self.qWaitForWindowExposed(self.widget) @@ -67,10 +68,6 @@ class TestCurvesROIWidget(TestCaseQt): super(TestCurvesROIWidget, self).tearDown() - def testEmptyPlot(self): - """Empty plot, display ROI widget""" - pass - def testWithCurves(self): """Plot with curves: test all ROI widget buttons""" for offset in range(2): @@ -80,13 +77,16 @@ class TestCurvesROIWidget(TestCaseQt): # Add two ROI self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) + self.qWait(200) self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) + self.qWait(200) # Change active curve self.plot.setActiveCurve(str(1)) # Delete a ROI self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton) + self.qWait(200) with temp_dir() as tmpDir: self.tmpFile = os.path.join(tmpDir, 'test.ini') @@ -94,30 +94,42 @@ class TestCurvesROIWidget(TestCaseQt): # Save ROIs self.widget.roiWidget.save(self.tmpFile) self.assertTrue(os.path.isfile(self.tmpFile)) + self.assertTrue(len(self.widget.getRois()) is 2) # Reset ROIs self.mouseClick(self.widget.roiWidget.resetButton, qt.Qt.LeftButton) + self.qWait(200) + rois = self.widget.getRois() + self.assertTrue(len(rois) is 1) + print(rois) + roiID = list(rois.keys())[0] + self.assertTrue(rois[roiID].getName() == 'ICR') # Load ROIs self.widget.roiWidget.load(self.tmpFile) + self.assertTrue(len(self.widget.getRois()) is 2) del self.tmpFile def testMiddleMarker(self): """Test with middle marker enabled""" - self.widget.roiWidget.setMiddleROIMarkerFlag(True) + self.widget.roiWidget.roiTable.setMiddleROIMarkerFlag(True) # Add a ROI self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) - xleftMarker = self.plot._getMarker(legend='ROI min').getXPosition() - xMiddleMarker = self.plot._getMarker(legend='ROI middle').getXPosition() - xRightMarker = self.plot._getMarker(legend='ROI max').getXPosition() - self.assertAlmostEqual(xMiddleMarker, - xleftMarker + (xRightMarker - xleftMarker) / 2.) - - def testCalculation(self): + for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers: + handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[roiID] + assert handler.getMarker('min') + xleftMarker = handler.getMarker('min').getXPosition() + xMiddleMarker = handler.getMarker('middle').getXPosition() + xRightMarker = handler.getMarker('max').getXPosition() + thValue = xleftMarker + (xRightMarker - xleftMarker) / 2. + self.assertAlmostEqual(xMiddleMarker, thValue) + + def testAreaCalculation(self): + """Test result of area calculation""" x = numpy.arange(100.) y = numpy.arange(100.) @@ -129,30 +141,60 @@ class TestCurvesROIWidget(TestCaseQt): self.plot.setActiveCurve("positive") # Add two ROIs - ddict = {} - ddict["positive"] = {"from": 10, "to": 20, "type":"X"} - ddict["negative"] = {"from": -20, "to": -10, "type":"X"} - self.widget.roiWidget.setRois(ddict) + roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20, + todata=-10, type_='X') + roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10, + todata=20, type_='X') + + self.widget.roiWidget.setRois((roi_pos, roi_neg)) + + posCurve = self.plot.getCurve('positive') + negCurve = self.plot.getCurve('negative') + + self.assertEqual(roi_pos.computeRawAndNetArea(posCurve), + (numpy.trapz(y=[10, 20], x=[10, 20]), + 0.0)) + self.assertEqual(roi_pos.computeRawAndNetArea(negCurve), + (0.0, 0.0)) + self.assertEqual(roi_neg.computeRawAndNetArea(posCurve), + ((0.0), 0.0)) + self.assertEqual(roi_neg.computeRawAndNetArea(negCurve), + ((-150.0), 0.0)) + + def testCountsCalculation(self): + """Test result of count calculation""" + x = numpy.arange(100.) + y = numpy.arange(100.) - # And calculate the expected output - self.widget.calculateROIs() + # Add two curves + self.plot.addCurve(x, y, legend="positive") + self.plot.addCurve(-x, y, legend="negative") + + # Make sure there is an active curve and it is the positive one + self.plot.setActiveCurve("positive") - output = self.widget.roiWidget.getRois() - self.assertEqual(output["positive"]["rawcounts"], - y[ddict["positive"]["from"]:ddict["positive"]["to"]+1].sum(), - "Calculation failed on positive X coordinates") + # Add two ROIs + roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20, + todata=-10, type_='X') + roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10, + todata=20, type_='X') + + self.widget.roiWidget.setRois((roi_pos, roi_neg)) - # Set the curve with negative X coordinates as active - self.plot.setActiveCurve("negative") + posCurve = self.plot.getCurve('positive') + negCurve = self.plot.getCurve('negative') - # the ROIs should have been automatically updated - output = self.widget.roiWidget.getRois() - selection = numpy.nonzero((-x >= output["negative"]["from"]) & \ - (-x <= output["negative"]["to"]))[0] - self.assertEqual(output["negative"]["rawcounts"], - y[selection].sum(), "Calculation failed on negative X coordinates") + self.assertEqual(roi_pos.computeRawAndNetCounts(posCurve), + (y[10:21].sum(), 0.0)) + self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve), + (0.0, 0.0)) + self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve), + ((0.0), 0.0)) + self.assertEqual(roi_neg.computeRawAndNetCounts(negCurve), + (y[10:21].sum(), 0.0)) def testDeferedInit(self): + """Test behavior of the deferedInit""" x = numpy.arange(100.) y = numpy.arange(100.) self.plot.addCurve(x=x, y=y, legend="name", replace="True") @@ -164,12 +206,123 @@ class TestCurvesROIWidget(TestCaseQt): ]) roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget - self.assertFalse(roiWidget._isInit) self.plot.getCurvesRoiDockWidget().setRois(roisDefs) self.assertTrue(len(roiWidget.getRois()) is len(roisDefs)) self.plot.getCurvesRoiDockWidget().setVisible(True) self.assertTrue(len(roiWidget.getRois()) is len(roisDefs)) + def testDictCompatibility(self): + """Test that ROI api is valid with dict and not information is lost""" + roiDict = {'from': 20, 'to': 200, 'type': 'energy', 'comment': 'no', + 'name': 'myROI', 'calibration': [1, 2, 3]} + roi = CurvesROIWidget.ROI._fromDict(roiDict) + self.assertTrue(roi.toDict() == roiDict) + + def testShowAllROI(self): + """Test the show allROI action""" + x = numpy.arange(100.) + y = numpy.arange(100.) + self.plot.addCurve(x=x, y=y, legend="name", replace="True") + + roisDefsDict = { + "range1": {"from": 20, "to": 200,"type": "energy"}, + "range2": {"from": 300, "to": 500, "type": "energy"} + } + + roisDefsObj = ( + CurvesROIWidget.ROI(name='range3', fromdata=20, todata=200, + type_='energy'), + CurvesROIWidget.ROI(name='range4', fromdata=300, todata=500, + type_='energy') + ) + self.widget.roiWidget.showAllMarkers(True) + roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget + roiWidget.setRois(roisDefsDict) + self.assertTrue(len(self.plot._getAllMarkers()) is 2*3) + + markersHandler = self.widget.roiWidget.roiTable._markersHandler + roiWidget.showAllMarkers(True) + ICRROI = markersHandler.getVisibleRois() + self.assertTrue(len(ICRROI) is 2) + + roiWidget.showAllMarkers(False) + ICRROI = markersHandler.getVisibleRois() + self.assertTrue(len(ICRROI) is 1) + + roiWidget.setRois(roisDefsObj) + self.qapp.processEvents() + self.assertTrue(len(self.plot._getAllMarkers()) is 2*3) + + markersHandler = self.widget.roiWidget.roiTable._markersHandler + roiWidget.showAllMarkers(True) + ICRROI = markersHandler.getVisibleRois() + self.assertTrue(len(ICRROI) is 2) + + roiWidget.showAllMarkers(False) + ICRROI = markersHandler.getVisibleRois() + self.assertTrue(len(ICRROI) is 1) + + def testRoiEdition(self): + """Make sure if the ROI object is edited the ROITable will be updated + """ + roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.widget.roiWidget.setRois((roi, )) + + x = (0, 1, 1, 2, 2, 3) + y = (1, 1, 2, 2, 1, 1) + self.plot.addCurve(x=x, y=y, legend='linearCurve') + self.plot.setActiveCurve(legend='linearCurve') + self.widget.calculateROIs() + + roiTable = self.widget.roiWidget.roiTable + indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX + itemRawCounts = roiTable.item(0, indexesColumns['Raw Counts']) + itemNetCounts = roiTable.item(0, indexesColumns['Net Counts']) + + self.assertTrue(itemRawCounts.text() == '8.0') + self.assertTrue(itemNetCounts.text() == '2.0') + + itemRawArea = roiTable.item(0, indexesColumns['Raw Area']) + itemNetArea = roiTable.item(0, indexesColumns['Net Area']) + + self.assertTrue(itemRawArea.text() == '4.0') + self.assertTrue(itemNetArea.text() == '1.0') + + roi.setTo(2) + itemRawArea = roiTable.item(0, indexesColumns['Raw Area']) + self.assertTrue(itemRawArea.text() == '3.0') + roi.setFrom(1) + itemRawArea = roiTable.item(0, indexesColumns['Raw Area']) + self.assertTrue(itemRawArea.text() == '2.0') + + def testRemoveActiveROI(self): + """Test widget behavior when removing the active ROI""" + roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.widget.roiWidget.setRois((roi,)) + + self.widget.roiWidget.roiTable.setActiveRoi(None) + self.assertTrue(len(self.widget.roiWidget.roiTable.selectedItems()) is 0) + self.widget.roiWidget.setRois((roi,)) + self.plot.setActiveCurve(legend='linearCurve') + self.widget.calculateROIs() + + def testEmitCurrentROI(self): + """Test behavior of the CurvesROIWidget.sigROISignal""" + roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.widget.roiWidget.setRois((roi,)) + signalListener = SignalListener() + self.widget.roiWidget.sigROISignal.connect(signalListener.partial()) + self.widget.show() + self.qapp.processEvents() + self.assertTrue(signalListener.callCount() is 0) + self.assertTrue(self.widget.roiWidget.roiTable.activeRoi is roi) + roi.setFrom(0.0) + self.qapp.processEvents() + self.assertTrue(signalListener.callCount() is 0) + roi.setFrom(0.3) + self.qapp.processEvents() + self.assertTrue(signalListener.callCount() is 1) + def suite(): test_suite = unittest.TestSuite() diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py index 6912ea3..a05c1be 100644 --- a/silx/gui/plot/test/testMaskToolsWidget.py +++ b/silx/gui/plot/test/testMaskToolsWidget.py @@ -42,10 +42,7 @@ from silx.gui.utils.testutils import getQToolButtonFromAction from silx.gui.plot import PlotWindow, MaskToolsWidget from .utils import PlotWidgetTestCase -try: - import fabio -except ImportError: - fabio = None +import fabio _logger = logging.getLogger(__name__) @@ -254,8 +251,6 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase): self.__loadSave("npy") def testLoadSaveFit2D(self): - if fabio is None: - self.skipTest("Fabio is missing") self.__loadSave("msk") def testSigMaskChangedEmitted(self): diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py index 857b9bc..9d7c093 100644 --- a/silx/gui/plot/test/testPlotWidget.py +++ b/silx/gui/plot/test/testPlotWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,7 +26,7 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "21/09/2018" +__date__ = "03/01/2019" import unittest @@ -36,8 +36,6 @@ import numpy from silx.utils.testutils import ParametricTestCase, parameterize from silx.gui.utils.testutils import SignalListener from silx.gui.utils.testutils import TestCaseQt -from silx.utils import testutils -from silx.utils import deprecation from silx.test.utils import test_options @@ -184,6 +182,39 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase): self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x))) self.assertEqual(items[5].getType(), 'rectangle') + def testBackGroundColors(self): + self.plot.setVisible(True) + self.qWaitForWindowExposed(self.plot) + self.qapp.processEvents() + + # Custom the full background + color = self.plot.getBackgroundColor() + self.assertTrue(color.isValid()) + self.assertEqual(color, qt.QColor(255, 255, 255)) + self.plot.setBackgroundColor("red") + color = self.plot.getBackgroundColor() + self.assertTrue(color.isValid()) + self.qapp.processEvents() + + # Custom the data background + color = self.plot.getDataBackgroundColor() + self.assertFalse(color.isValid()) + self.plot.setDataBackgroundColor("red") + color = self.plot.getDataBackgroundColor() + self.assertTrue(color.isValid()) + self.qapp.processEvents() + + # Back to default + self.plot.setBackgroundColor('white') + self.plot.setDataBackgroundColor(None) + color = self.plot.getBackgroundColor() + self.assertTrue(color.isValid()) + self.assertEqual(color, qt.QColor(255, 255, 255)) + color = self.plot.getDataBackgroundColor() + self.assertFalse(color.isValid()) + self.qapp.processEvents() + + class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): """Basic tests for addImage""" @@ -881,17 +912,12 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): if getter is not None: self.assertEqual(getter(), expected) - @testutils.test_logging(deprecation.depreclog.name) def testOldPlotAxis_Logarithmic(self): """Test silx API prior to silx 0.6""" x = self.plot.getXAxis() y = self.plot.getYAxis() yright = self.plot.getYAxis(axis="right") - listener = SignalListener() - self.plot.sigSetXAxisLogarithmic.connect(listener.partial("x")) - self.plot.sigSetYAxisLogarithmic.connect(listener.partial("y")) - self.assertEqual(x.getScale(), x.LINEAR) self.assertEqual(y.getScale(), x.LINEAR) self.assertEqual(yright.getScale(), x.LINEAR) @@ -902,7 +928,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.getScale(), x.LINEAR) self.assertEqual(self.plot.isXAxisLogarithmic(), True) self.assertEqual(self.plot.isYAxisLogarithmic(), False) - self.assertEqual(listener.arguments(callIndex=-1), ("x", True)) self.plot.setYAxisLogarithmic(True) self.assertEqual(x.getScale(), x.LOGARITHMIC) @@ -910,7 +935,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.getScale(), x.LOGARITHMIC) self.assertEqual(self.plot.isXAxisLogarithmic(), True) self.assertEqual(self.plot.isYAxisLogarithmic(), True) - self.assertEqual(listener.arguments(callIndex=-1), ("y", True)) yright.setScale(yright.LINEAR) self.assertEqual(x.getScale(), x.LOGARITHMIC) @@ -918,19 +942,13 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.getScale(), x.LINEAR) self.assertEqual(self.plot.isXAxisLogarithmic(), True) self.assertEqual(self.plot.isYAxisLogarithmic(), False) - self.assertEqual(listener.arguments(callIndex=-1), ("y", False)) - @testutils.test_logging(deprecation.depreclog.name) def testOldPlotAxis_AutoScale(self): """Test silx API prior to silx 0.6""" x = self.plot.getXAxis() y = self.plot.getYAxis() yright = self.plot.getYAxis(axis="right") - listener = SignalListener() - self.plot.sigSetXAxisAutoScale.connect(listener.partial("x")) - self.plot.sigSetYAxisAutoScale.connect(listener.partial("y")) - self.assertEqual(x.isAutoScale(), True) self.assertEqual(y.isAutoScale(), True) self.assertEqual(yright.isAutoScale(), True) @@ -941,7 +959,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.isAutoScale(), True) self.assertEqual(self.plot.isXAxisAutoScale(), False) self.assertEqual(self.plot.isYAxisAutoScale(), True) - self.assertEqual(listener.arguments(callIndex=-1), ("x", False)) self.plot.setYAxisAutoScale(False) self.assertEqual(x.isAutoScale(), False) @@ -949,7 +966,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.isAutoScale(), False) self.assertEqual(self.plot.isXAxisAutoScale(), False) self.assertEqual(self.plot.isYAxisAutoScale(), False) - self.assertEqual(listener.arguments(callIndex=-1), ("y", False)) yright.setAutoScale(True) self.assertEqual(x.isAutoScale(), False) @@ -957,18 +973,13 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.isAutoScale(), True) self.assertEqual(self.plot.isXAxisAutoScale(), False) self.assertEqual(self.plot.isYAxisAutoScale(), True) - self.assertEqual(listener.arguments(callIndex=-1), ("y", True)) - @testutils.test_logging(deprecation.depreclog.name) def testOldPlotAxis_Inverted(self): """Test silx API prior to silx 0.6""" x = self.plot.getXAxis() y = self.plot.getYAxis() yright = self.plot.getYAxis(axis="right") - listener = SignalListener() - self.plot.sigSetYAxisInverted.connect(listener.partial("y")) - self.assertEqual(x.isInverted(), False) self.assertEqual(y.isInverted(), False) self.assertEqual(yright.isInverted(), False) @@ -978,14 +989,12 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(y.isInverted(), True) self.assertEqual(yright.isInverted(), True) self.assertEqual(self.plot.isYAxisInverted(), True) - self.assertEqual(listener.arguments(callIndex=-1), ("y", True)) yright.setInverted(False) self.assertEqual(x.isInverted(), False) self.assertEqual(y.isInverted(), False) self.assertEqual(yright.isInverted(), False) self.assertEqual(self.plot.isYAxisInverted(), False) - self.assertEqual(listener.arguments(callIndex=-1), ("y", False)) def testLogXWithData(self): self.plot.setGraphTitle('Curve X: Log Y: Linear') diff --git a/silx/gui/plot/test/testSaveAction.py b/silx/gui/plot/test/testSaveAction.py index 85669bf..0eb129d 100644 --- a/silx/gui/plot/test/testSaveAction.py +++ b/silx/gui/plot/test/testSaveAction.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -106,12 +106,30 @@ class TestSaveActionExtension(PlotWidgetTestCase): self.assertEqual(saveAction.getFileFilters('all')[nameFilter], self._dummySaveFunction) + # Add a new file filter at a particular position + nameFilter = 'Dummy file2 (*.dummy)' + saveAction.setFileFilter('all', nameFilter, + self._dummySaveFunction, index=3) + self.assertTrue(nameFilter in saveAction.getFileFilters('all')) + filters = saveAction.getFileFilters('all') + self.assertEqual(filters[nameFilter], self._dummySaveFunction) + self.assertEqual(list(filters.keys()).index(nameFilter),3) + # Update an existing file filter nameFilter = SaveAction.IMAGE_FILTER_EDF saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction) self.assertEqual(saveAction.getFileFilters('image')[nameFilter], self._dummySaveFunction) + # Change the position of an existing file filter + nameFilter = 'Dummy file2 (*.dummy)' + oldIndex = list(saveAction.getFileFilters('all')).index(nameFilter) + newIndex = oldIndex - 1 + saveAction.setFileFilter('all', nameFilter, + self._dummySaveFunction, index=newIndex) + filters = saveAction.getFileFilters('all') + self.assertEqual(filters[nameFilter], self._dummySaveFunction) + self.assertEqual(list(filters.keys()).index(nameFilter), newIndex) def suite(): test_suite = unittest.TestSuite() diff --git a/silx/gui/plot/test/testScatterMaskToolsWidget.py b/silx/gui/plot/test/testScatterMaskToolsWidget.py index a446911..171ec42 100644 --- a/silx/gui/plot/test/testScatterMaskToolsWidget.py +++ b/silx/gui/plot/test/testScatterMaskToolsWidget.py @@ -42,10 +42,7 @@ from silx.gui.utils.testutils import getQToolButtonFromAction from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget from .utils import PlotWidgetTestCase -try: - import fabio -except ImportError: - fabio = None +import fabio _logger = logging.getLogger(__name__) diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py index faedcff..7fbc247 100644 --- a/silx/gui/plot/test/testStats.py +++ b/silx/gui/plot/test/testStats.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -112,34 +112,34 @@ class TestStats(TestCaseQt): """Test result for simple stats on a curve""" _stats = self.getBasicStats() xData = yData = numpy.array(range(20)) - self.assertTrue(_stats['min'].calculate(self.curveContext) == 0) - self.assertTrue(_stats['max'].calculate(self.curveContext) == 19) - self.assertTrue(_stats['minCoords'].calculate(self.curveContext) == [0]) - self.assertTrue(_stats['maxCoords'].calculate(self.curveContext) == [19]) - self.assertTrue(_stats['std'].calculate(self.curveContext) == numpy.std(yData)) - self.assertTrue(_stats['mean'].calculate(self.curveContext) == numpy.mean(yData)) + self.assertEqual(_stats['min'].calculate(self.curveContext), 0) + self.assertEqual(_stats['max'].calculate(self.curveContext), 19) + self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,)) + self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,)) + self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData)) + self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData)) com = numpy.sum(xData * yData) / numpy.sum(yData) - self.assertTrue(_stats['com'].calculate(self.curveContext) == com) + self.assertEqual(_stats['com'].calculate(self.curveContext), com) def testBasicStatsImage(self): """Test result for simple stats on an image""" _stats = self.getBasicStats() - self.assertTrue(_stats['min'].calculate(self.imageContext) == 0) - self.assertTrue(_stats['max'].calculate(self.imageContext) == 128 * 32 - 1) - self.assertTrue(_stats['minCoords'].calculate(self.imageContext) == (0, 0)) - self.assertTrue(_stats['maxCoords'].calculate(self.imageContext) == (127, 31)) - self.assertTrue(_stats['std'].calculate(self.imageContext) == numpy.std(self.imageData)) - self.assertTrue(_stats['mean'].calculate(self.imageContext) == numpy.mean(self.imageData)) - - yData = numpy.sum(self.imageData, axis=1) - xData = numpy.sum(self.imageData, axis=0) + self.assertEqual(_stats['min'].calculate(self.imageContext), 0) + self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1) + self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0)) + self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31)) + self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData)) + self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData)) + + yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1) + xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0) dataXRange = range(self.imageData.shape[1]) dataYRange = range(self.imageData.shape[0]) ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData) xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData) - self.assertTrue(_stats['com'].calculate(self.imageContext) == (xcom, ycom)) + self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom)) def testStatsImageAdv(self): """Test that scale and origin are taking into account for images""" @@ -153,52 +153,46 @@ class TestStats(TestCaseQt): onlimits=False ) _stats = self.getBasicStats() - self.assertTrue(_stats['min'].calculate(image2Context) == 0) - self.assertTrue( - _stats['max'].calculate(image2Context) == 128 * 32 - 1) - self.assertTrue( - _stats['minCoords'].calculate(image2Context) == (100, 10)) - self.assertTrue( - _stats['maxCoords'].calculate(image2Context) == (127*2. + 100, - 31 * 0.5 + 10) - ) - self.assertTrue( - _stats['std'].calculate(image2Context) == numpy.std( - self.imageData)) - self.assertTrue( - _stats['mean'].calculate(image2Context) == numpy.mean( - self.imageData)) + self.assertEqual(_stats['min'].calculate(image2Context), 0) + self.assertEqual( + _stats['max'].calculate(image2Context), 128 * 32 - 1) + self.assertEqual( + _stats['minCoords'].calculate(image2Context), (100, 10)) + self.assertEqual( + _stats['maxCoords'].calculate(image2Context), (127*2. + 100, + 31 * 0.5 + 10)) + self.assertEqual(_stats['std'].calculate(image2Context), + numpy.std(self.imageData)) + self.assertEqual(_stats['mean'].calculate(image2Context), + numpy.mean(self.imageData)) yData = numpy.sum(self.imageData, axis=1) xData = numpy.sum(self.imageData, axis=0) - dataXRange = range(self.imageData.shape[1]) - dataYRange = range(self.imageData.shape[0]) + dataXRange = numpy.arange(self.imageData.shape[1], dtype=numpy.float64) + dataYRange = numpy.arange(self.imageData.shape[0], dtype=numpy.float64) ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData) ycom = (ycom * 0.5) + 10 xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData) xcom = (xcom * 2.) + 100 - self.assertTrue( - _stats['com'].calculate(image2Context) == (xcom, ycom)) + self.assertTrue(numpy.allclose( + _stats['com'].calculate(image2Context), (xcom, ycom))) def testBasicStatsScatter(self): """Test result for simple stats on a scatter""" _stats = self.getBasicStats() - self.assertTrue(_stats['min'].calculate(self.scatterContext) == 5) - self.assertTrue(_stats['max'].calculate(self.scatterContext) == 90) - self.assertTrue(_stats['minCoords'].calculate(self.scatterContext) == (0, 2)) - self.assertTrue(_stats['maxCoords'].calculate(self.scatterContext) == (50, 69)) - self.assertTrue(_stats['std'].calculate(self.scatterContext) == numpy.std(self.valuesScatterData)) - self.assertTrue(_stats['mean'].calculate(self.scatterContext) == numpy.mean(self.valuesScatterData)) - - comx = numpy.sum(self.xScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum( - self.valuesScatterData).astype(numpy.float32) - comy = numpy.sum(self.yScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum( - self.valuesScatterData).astype(numpy.float32) - self.assertTrue(numpy.all( - numpy.equal(_stats['com'].calculate(self.scatterContext), - (comx, comy))) - ) + self.assertEqual(_stats['min'].calculate(self.scatterContext), 5) + self.assertEqual(_stats['max'].calculate(self.scatterContext), 90) + self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2)) + self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69)) + self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData)) + self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData)) + + data = self.valuesScatterData.astype(numpy.float64) + comx = numpy.sum(self.xScatterData * data) / numpy.sum(data) + comy = numpy.sum(self.yScatterData * data) / numpy.sum(data) + self.assertEqual(_stats['com'].calculate(self.scatterContext), + (comx, comy)) def testKindNotManagedByStat(self): """Make sure an exception is raised if we try to execute calculate @@ -227,21 +221,21 @@ class TestStats(TestCaseQt): item=self.plot1d.getCurve('curve0'), plot=self.plot1d, onlimits=True) - self.assertTrue(stat.calculate(curveContextOnLimits) == 2) + self.assertEqual(stat.calculate(curveContextOnLimits), 2) self.plot2d.getXAxis().setLimitsConstraints(minPos=32) imageContextOnLimits = stats._ImageContext( item=self.plot2d.getImage('test image'), plot=self.plot2d, onlimits=True) - self.assertTrue(stat.calculate(imageContextOnLimits) == 32) + self.assertEqual(stat.calculate(imageContextOnLimits), 32) self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40) scatterContextOnLimits = stats._ScatterContext( item=self.scatterPlot.getScatter('scatter plot'), plot=self.scatterPlot, onlimits=True) - self.assertTrue(stat.calculate(scatterContextOnLimits) == 20) + self.assertEqual(stat.calculate(scatterContextOnLimits), 20) class TestStatsFormatter(TestCaseQt): @@ -267,15 +261,15 @@ class TestStatsFormatter(TestCaseQt): """Make sure a formatter with no formatter definition will return a simple cast to str""" emptyFormatter = statshandler.StatFormatter() - self.assertTrue( - emptyFormatter.format(self.stat.calculate(self.curveContext)) == '0.000') + self.assertEqual( + emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000') def testSettedFormatter(self): """Make sure a formatter with no formatter definition will return a simple cast to str""" formatter= statshandler.StatFormatter(formatter='{0:.3f}') - self.assertTrue( - formatter.format(self.stat.calculate(self.curveContext)) == '0.000') + self.assertEqual( + formatter.format(self.stat.calculate(self.curveContext)), '0.000') class TestStatsHandler(unittest.TestCase): @@ -309,9 +303,9 @@ class TestStatsHandler(unittest.TestCase): res = handler0.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) self.assertTrue('min' in res) - self.assertTrue(res['min'] == '0') + self.assertEqual(res['min'], '0') self.assertTrue('max' in res) - self.assertTrue(res['max'] == '19') + self.assertEqual(res['max'], '19') handler1 = statshandler.StatsHandler( ( @@ -323,9 +317,9 @@ class TestStatsHandler(unittest.TestCase): res = handler1.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) self.assertTrue('min' in res) - self.assertTrue(res['min'] == '0') + self.assertEqual(res['min'], '0') self.assertTrue('max' in res) - self.assertTrue(res['max'] == '19.000') + self.assertEqual(res['max'], '19.000') handler2 = statshandler.StatsHandler( ( @@ -336,9 +330,9 @@ class TestStatsHandler(unittest.TestCase): res = handler2.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) self.assertTrue('min' in res) - self.assertTrue(res['min'] == '0') + self.assertEqual(res['min'], '0') self.assertTrue('max' in res) - self.assertTrue(res['max'] == '19.000') + self.assertEqual(res['max'], '19.000') handler3 = statshandler.StatsHandler(( (('amin', numpy.argmin), statshandler.StatFormatter()), @@ -348,9 +342,9 @@ class TestStatsHandler(unittest.TestCase): res = handler3.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) self.assertTrue('amin' in res) - self.assertTrue(res['amin'] == '0.000') + self.assertEqual(res['amin'], '0.000') self.assertTrue('amax' in res) - self.assertTrue(res['amax'] == '19') + self.assertEqual(res['amax'], '19') with self.assertRaises(ValueError): statshandler.StatsHandler(('name')) @@ -395,47 +389,49 @@ class TestStatsWidgetWithCurves(TestCaseQt): def testInit(self): """Make sure all the curves are registred on initialization""" - self.assertTrue(self.widget.rowCount() is 3) + self.assertEqual(self.widget.rowCount(), 3) def testRemoveCurve(self): """Make sure the Curves stats take into account the curve removal from plot""" self.plot.removeCurve('curve2') - self.assertTrue(self.widget.rowCount() is 2) + self.assertEqual(self.widget.rowCount(), 2) for iRow in range(2): self.assertTrue(self.widget.item(iRow, 0).text() in ('curve0', 'curve1')) self.plot.removeCurve('curve0') - self.assertTrue(self.widget.rowCount() is 1) + self.assertEqual(self.widget.rowCount(), 1) self.plot.removeCurve('curve1') - self.assertTrue(self.widget.rowCount() is 0) + self.assertEqual(self.widget.rowCount(), 0) def testAddCurve(self): """Make sure the Curves stats take into account the add curve action""" self.plot.addCurve(legend='curve3', x=range(10), y=range(10)) - self.assertTrue(self.widget.rowCount() is 4) + self.assertEqual(self.widget.rowCount(), 4) - def testUpdateCurveFrmAddCurve(self): + def testUpdateCurveFromAddCurve(self): """Make sure the stats of the cuve will be removed after updating a curve""" self.plot.addCurve(legend='curve0', x=range(10), y=range(10)) - self.assertTrue(self.widget.rowCount() is 3) - itemMax = self.widget._getItem(name='max', legend='curve0', - kind='curve', indexTable=None) - self.assertTrue(itemMax.text() == '9') + self.qapp.processEvents() + self.assertEqual(self.widget.rowCount(), 3) + curve = self.plot._getItem(kind='curve', legend='curve0') + tableItems = self.widget._itemToTableItems(curve) + self.assertEqual(tableItems['max'].text(), '9') - def testUpdateCurveFrmCurveObj(self): + def testUpdateCurveFromCurveObj(self): self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) - self.assertTrue(self.widget.rowCount() is 3) - itemMax = self.widget._getItem(name='max', legend='curve0', - kind='curve', indexTable=None) - self.assertTrue(itemMax.text() == '3') + self.qapp.processEvents() + self.assertEqual(self.widget.rowCount(), 3) + curve = self.plot._getItem(kind='curve', legend='curve0') + tableItems = self.widget._itemToTableItems(curve) + self.assertEqual(tableItems['max'].text(), '3') def testSetAnotherPlot(self): plot2 = Plot1D() plot2.addCurve(x=range(26), y=range(26), legend='new curve') self.widget.setPlot(plot2) - self.assertTrue(self.widget.rowCount() is 1) + self.assertEqual(self.widget.rowCount(), 1) self.qapp.processEvents() plot2.setAttribute(qt.Qt.WA_DeleteOnClose) plot2.close() @@ -444,12 +440,15 @@ class TestStatsWidgetWithCurves(TestCaseQt): class TestStatsWidgetWithImages(TestCaseQt): """Basic test for StatsWidget with images""" + + IMAGE_LEGEND = 'test image' + def setUp(self): TestCaseQt.setUp(self) self.plot = Plot2D() self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128), - legend='test image', replace=False) + legend=self.IMAGE_LEGEND, replace=False) self.widget = StatsWidget.StatsTable(plot=self.plot) @@ -476,31 +475,30 @@ class TestStatsWidgetWithImages(TestCaseQt): TestCaseQt.tearDown(self) def test(self): - columnsIndex = self.widget._columns_index - itemLegend = self.widget._lgdAndKindToItems[('test image', 'image')]['legend'] - itemMin = self.widget.item(itemLegend.row(), columnsIndex['min']) - itemMax = self.widget.item(itemLegend.row(), columnsIndex['max']) - itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta']) - itemCoordsMin = self.widget.item(itemLegend.row(), - columnsIndex['coords min']) - itemCoordsMax = self.widget.item(itemLegend.row(), - columnsIndex['coords max']) - max = (128 * 128) - 1 - self.assertTrue(itemMin.text() == '0.000') - self.assertTrue(itemMax.text() == '{0:.3f}'.format(max)) - self.assertTrue(itemDelta.text() == '{0:.3f}'.format(max)) - self.assertTrue(itemCoordsMin.text() == '0.0, 0.0') - self.assertTrue(itemCoordsMax.text() == '127.0, 127.0') + image = self.plot._getItem( + kind='image', legend=self.IMAGE_LEGEND) + tableItems = self.widget._itemToTableItems(image) + + maxText = '{0:.3f}'.format((128 * 128) - 1) + self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND) + self.assertEqual(tableItems['min'].text(), '0.000') + self.assertEqual(tableItems['max'].text(), maxText) + self.assertEqual(tableItems['delta'].text(), maxText) + self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0') + self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0') class TestStatsWidgetWithScatters(TestCaseQt): + + SCATTER_LEGEND = 'scatter plot' + def setUp(self): TestCaseQt.setUp(self) self.scatterPlot = Plot2D() self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60], [2, 3, 4, 26, 69, 6], [5, 6, 7, 10, 90, 20], - legend='scatter plot') + legend=self.SCATTER_LEGEND) self.widget = StatsWidget.StatsTable(plot=self.scatterPlot) mystats = statshandler.StatsHandler(( @@ -526,33 +524,89 @@ class TestStatsWidgetWithScatters(TestCaseQt): TestCaseQt.tearDown(self) def testStats(self): - columnsIndex = self.widget._columns_index - itemLegend = self.widget._lgdAndKindToItems[('scatter plot', 'scatter')]['legend'] - itemMin = self.widget.item(itemLegend.row(), columnsIndex['min']) - itemMax = self.widget.item(itemLegend.row(), columnsIndex['max']) - itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta']) - itemCoordsMin = self.widget.item(itemLegend.row(), - columnsIndex['coords min']) - itemCoordsMax = self.widget.item(itemLegend.row(), - columnsIndex['coords max']) - self.assertTrue(itemMin.text() == '5') - self.assertTrue(itemMax.text() == '90') - self.assertTrue(itemDelta.text() == '85') - self.assertTrue(itemCoordsMin.text() == '0, 2') - self.assertTrue(itemCoordsMax.text() == '50, 69') + scatter = self.scatterPlot._getItem( + kind='scatter', legend=self.SCATTER_LEGEND) + tableItems = self.widget._itemToTableItems(scatter) + self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND) + self.assertEqual(tableItems['min'].text(), '5') + self.assertEqual(tableItems['coords min'].text(), '0, 2') + self.assertEqual(tableItems['max'].text(), '90') + self.assertEqual(tableItems['coords max'].text(), '50, 69') + self.assertEqual(tableItems['delta'].text(), '85') class TestEmptyStatsWidget(TestCaseQt): def test(self): widget = StatsWidget.StatsWidget() widget.show() + self.qWaitForWindowExposed(widget) + + +# skip unit test for pyqt4 because there is some unrealised widget without +# apparent reason +@unittest.skipIf(qt.qVersion().split('.')[0] == '4', reason='PyQt4 not tested') +class TestLineWidget(TestCaseQt): + """Some test for the StatsLineWidget.""" + def setUp(self): + TestCaseQt.setUp(self) + + mystats = statshandler.StatsHandler(( + (stats.StatMin(), statshandler.StatFormatter()), + )) + + self.plot = Plot1D() + self.plot.show() + x = range(20) + y = range(20) + self.plot.addCurve(x, y, legend='curve0') + y = range(12, 32) + self.plot.addCurve(x, y, legend='curve1') + y = range(-2, 18) + self.plot.addCurve(x, y, legend='curve2') + self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot, + kind='curve', + stats=mystats) + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.widget.setPlot(None) + self.widget._statQlineEdit.clear() + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + self.plot = None + TestCaseQt.tearDown(self) + + def test(self): + self.widget.setStatsOnVisibleData(False) + self.qapp.processEvents() + self.plot.setActiveCurve(legend='curve0') + self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.000') + self.plot.setActiveCurve(legend='curve1') + self.assertTrue(self.widget._statQlineEdit['min'].text() == '12.000') + self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5) + self.widget.setStatsOnVisibleData(True) + self.qapp.processEvents() + self.assertTrue(self.widget._statQlineEdit['min'].text() == '14.000') + self.plot.setActiveCurve(None) + self.assertTrue(self.plot.getActiveCurve() is None) + self.widget.setStatsOnVisibleData(False) + self.qapp.processEvents() + self.assertFalse(self.widget._statQlineEdit['min'].text() == '14.000') + self.widget.setKind('image') + self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312) + self.qapp.processEvents() + self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.312') def suite(): test_suite = unittest.TestSuite() for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters, TestStatsWidgetWithImages, TestStatsWidgetWithCurves, - TestStatsFormatter, TestEmptyStatsWidget): + TestStatsFormatter, TestEmptyStatsWidget, + TestLineWidget): test_suite.addTest( unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) return test_suite diff --git a/silx/gui/plot/test/testUtilsAxis.py b/silx/gui/plot/test/testUtilsAxis.py index 016fafe..64373b8 100644 --- a/silx/gui/plot/test/testUtilsAxis.py +++ b/silx/gui/plot/test/testUtilsAxis.py @@ -26,7 +26,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "14/02/2018" +__date__ = "20/11/2018" import unittest @@ -155,6 +155,53 @@ class TestAxisSync(TestCaseQt): self.assertEqual(self.plot2.getYAxis().isInverted(), True) self.assertEqual(self.plot3.getYAxis().isInverted(), True) + def testSyncCenter(self): + """Test direction change""" + # Not the same scale + self.plot1.getXAxis().setLimits(0, 200) + self.plot2.getXAxis().setLimits(0, 20) + self.plot3.getXAxis().setLimits(0, 2) + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()], + syncLimits=False, syncCenter=True) + + self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (100 - 1, 100 + 1)) + + def testSyncCenterAndZoom(self): + """Test direction change""" + # Not the same scale + self.plot1.getXAxis().setLimits(0, 200) + self.plot2.getXAxis().setLimits(0, 20) + self.plot3.getXAxis().setLimits(0, 2) + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()], + syncLimits=False, syncCenter=True, syncZoom=True) + + # Supposing all the plots use the same size + self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (0, 200)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (0, 200)) + + def testAddAxis(self): + """Test synchronization after construction""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis()]) + sync.addAxis(self.plot3.getXAxis()) + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testRemoveAxis(self): + """Test synchronization after construction""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + sync.removeAxis(self.plot3.getXAxis()) + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + def suite(): test_suite = unittest.TestSuite() diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py index d58c041..98295ba 100644 --- a/silx/gui/plot/tools/roi.py +++ b/silx/gui/plot/tools/roi.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,6 +31,7 @@ __date__ = "28/06/2018" import collections +import enum import functools import logging import time @@ -38,7 +39,6 @@ import weakref import numpy -from ....third_party import enum from ....utils.weakref import WeakMethodProxy from ... import qt, icons from .. import PlotWidget @@ -806,11 +806,17 @@ class RegionOfInterestTableWidget(qt.QTableWidget): self.itemChanged.connect(self.__itemChanged) - @staticmethod - def __itemChanged(item): + def __itemChanged(self, item): """Handle item updates""" column = item.column() - roi = item.data(qt.Qt.UserRole) + index = item.data(qt.Qt.UserRole) + + if index is not None: + manager = self.getRegionOfInterestManager() + roi = manager.getRois()[index] + else: + roi = None + if column == 0: roi.setLabel(item.text()) elif column == 1: @@ -882,13 +888,13 @@ class RegionOfInterestTableWidget(qt.QTableWidget): label = roi.getLabel() item = qt.QTableWidgetItem(label) item.setFlags(baseFlags | qt.Qt.ItemIsEditable) - item.setData(qt.Qt.UserRole, roi) + item.setData(qt.Qt.UserRole, index) self.setItem(index, 0, item) # Editable item = qt.QTableWidgetItem() item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable) - item.setData(qt.Qt.UserRole, roi) + item.setData(qt.Qt.UserRole, index) item.setCheckState( qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked) self.setItem(index, 1, item) diff --git a/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/silx/gui/plot/tools/test/testScatterProfileToolBar.py index b99cac7..0f4b668 100644 --- a/silx/gui/plot/tools/test/testScatterProfileToolBar.py +++ b/silx/gui/plot/tools/test/testScatterProfileToolBar.py @@ -97,7 +97,7 @@ class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase): self.profile._getRoiManager().addRoi(roi) # Wait for async interpolator init - for _ in range(10): + for _ in range(20): self.qWait(200) if not self.profile.hasPendingOperations(): break diff --git a/silx/gui/plot/utils/axis.py b/silx/gui/plot/utils/axis.py index bd19996..693e8eb 100644 --- a/silx/gui/plot/utils/axis.py +++ b/silx/gui/plot/utils/axis.py @@ -27,13 +27,14 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "23/02/2018" +__date__ = "20/11/2018" import functools import logging from contextlib import contextmanager import weakref import silx.utils.weakref as silxWeakref +from silx.gui.plot.items.axis import Axis, XAxis, YAxis try: from ...qt.inspect import isValid as _isQObjectValid @@ -61,7 +62,14 @@ class SyncAxes(object): .. versionadded:: 0.6 """ - def __init__(self, axes, syncLimits=True, syncScale=True, syncDirection=True): + def __init__(self, axes, + syncLimits=True, + syncScale=True, + syncDirection=True, + syncCenter=False, + syncZoom=False, + filterHiddenPlots=False + ): """ Constructor @@ -69,17 +77,34 @@ class SyncAxes(object): :param bool syncLimits: Synchronize axes limits :param bool syncScale: Synchronize axes scale :param bool syncDirection: Synchronize axes direction + :param bool syncCenter: Synchronize the center of the axes in the center + of the plots + :param bool syncZoom: Synchronize the zoom of the plot + :param bool filterHiddenPlots: True to avoid updating hidden plots. + Default: False. """ object.__init__(self) + + def implies(x, y): return bool(y ** x) + + assert(implies(syncZoom, not syncLimits)) + assert(implies(syncCenter, not syncLimits)) + assert(implies(syncLimits, not syncCenter)) + assert(implies(syncLimits, not syncZoom)) + + self.__filterHiddenPlots = filterHiddenPlots self.__locked = False self.__axisRefs = [] self.__syncLimits = syncLimits self.__syncScale = syncScale self.__syncDirection = syncDirection + self.__syncCenter = syncCenter + self.__syncZoom = syncZoom self.__callbacks = None + self.__lastMainAxis = None for axis in axes: - self.__axisRefs.append(weakref.ref(axis)) + self.addAxis(axis) self.start() @@ -90,47 +115,131 @@ class SyncAxes(object): After that, any changes to any axes will be used to synchronize other axes. """ - if self.__callbacks is not None: + if self.isSynchronizing(): raise RuntimeError("Axes already synchronized") self.__callbacks = {} axes = self.__getAxes() - if len(axes) == 0: - raise RuntimeError('No axis to synchronize') # register callback for further sync for axis in axes: - refAxis = weakref.ref(axis) - callbacks = [] - if self.__syncLimits: - # the weakref is needed to be able ignore self references - callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged) - callback = functools.partial(callback, refAxis) - sig = axis.sigLimitsChanged - sig.connect(callback) - callbacks.append(("sigLimitsChanged", callback)) - if self.__syncScale: - # the weakref is needed to be able ignore self references - callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged) - callback = functools.partial(callback, refAxis) - sig = axis.sigScaleChanged - sig.connect(callback) - callbacks.append(("sigScaleChanged", callback)) - if self.__syncDirection: - # the weakref is needed to be able ignore self references - callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged) - callback = functools.partial(callback, refAxis) - sig = axis.sigInvertedChanged - sig.connect(callback) - callbacks.append(("sigInvertedChanged", callback)) - - self.__callbacks[refAxis] = callbacks + self.__connectAxes(axis) + self.synchronize() + + def isSynchronizing(self): + """Returns true if events are connected to the axes to synchronize them + all together + + :rtype: bool + """ + return self.__callbacks is not None + + def __connectAxes(self, axis): + refAxis = weakref.ref(axis) + callbacks = [] + if self.__syncLimits: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigLimitsChanged + sig.connect(callback) + callbacks.append(("sigLimitsChanged", callback)) + elif self.__syncCenter and self.__syncZoom: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisCenterAndZoomChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigLimitsChanged + sig.connect(callback) + callbacks.append(("sigLimitsChanged", callback)) + elif self.__syncZoom: + raise NotImplementedError() + elif self.__syncCenter: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisCenterChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigLimitsChanged + sig.connect(callback) + callbacks.append(("sigLimitsChanged", callback)) + if self.__syncScale: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigScaleChanged + sig.connect(callback) + callbacks.append(("sigScaleChanged", callback)) + if self.__syncDirection: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigInvertedChanged + sig.connect(callback) + callbacks.append(("sigInvertedChanged", callback)) + + if self.__filterHiddenPlots: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisVisibilityChanged) + callback = functools.partial(callback, refAxis) + plot = axis._getPlot() + plot.sigVisibilityChanged.connect(callback) + callbacks.append(("sigVisibilityChanged", callback)) + + self.__callbacks[refAxis] = callbacks + def __disconnectAxes(self, axis): + if axis is not None and _isQObjectValid(axis): + ref = weakref.ref(axis) + callbacks = self.__callbacks.pop(ref) + for sigName, callback in callbacks: + if sigName == "sigVisibilityChanged": + obj = axis._getPlot() + else: + obj = axis + if obj is not None: + sig = getattr(obj, sigName) + sig.disconnect(callback) + + def addAxis(self, axis): + """Add a new axes to synchronize. + + :param ~silx.gui.plot.items.Axis axis: The axis to synchronize + """ + self.__axisRefs.append(weakref.ref(axis)) + if self.isSynchronizing(): + self.__connectAxes(axis) + # This could be done faster as only this axis have to be fixed + self.synchronize() + + def removeAxis(self, axis): + """Remove an axis from the synchronized axes. + + :param ~silx.gui.plot.items.Axis axis: The axis to remove + """ + ref = weakref.ref(axis) + self.__axisRefs.remove(ref) + if self.isSynchronizing(): + self.__disconnectAxes(axis) + + def synchronize(self, mainAxis=None): + """Synchronize programatically all the axes. + + :param ~silx.gui.plot.items.Axis mainAxis: + The axis to take as reference (Default: the first axis). + """ # sync the current state - mainAxis = axes[0] + axes = self.__getAxes() + if len(axes) == 0: + return + + if mainAxis is None: + mainAxis = axes[0] + refMainAxis = weakref.ref(mainAxis) if self.__syncLimits: self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits()) + elif self.__syncCenter and self.__syncZoom: + self.__axisCenterAndZoomChanged(refMainAxis, *mainAxis.getLimits()) + elif self.__syncCenter: + self.__axisCenterChanged(refMainAxis, *mainAxis.getLimits()) if self.__syncScale: self.__axisScaleChanged(refMainAxis, mainAxis.getScale()) if self.__syncDirection: @@ -138,14 +247,11 @@ class SyncAxes(object): def stop(self): """Stop the synchronization of the axes""" - if self.__callbacks is None: + if not self.isSynchronizing(): raise RuntimeError("Axes not synchronized") - for ref, callbacks in self.__callbacks.items(): + for ref in list(self.__callbacks.keys()): axis = ref() - if axis is not None and _isQObjectValid(axis): - for sigName, callback in callbacks: - sig = getattr(axis, sigName) - sig.disconnect(callback) + self.__disconnectAxes(axis) self.__callbacks = None def __del__(self): @@ -168,32 +274,130 @@ class SyncAxes(object): yield self.__locked = False - def __otherAxes(self, changedAxis): + def __axesToUpdate(self, changedAxis): for axis in self.__getAxes(): if axis is changedAxis: continue + if self.__filterHiddenPlots: + plot = axis._getPlot() + if not plot.isVisible(): + continue yield axis + def __axisVisibilityChanged(self, changedAxis, isVisible): + if not isVisible: + return + if self.__locked: + return + changedAxis = changedAxis() + if self.__lastMainAxis is None: + self.__lastMainAxis = self.__axisRefs[0] + mainAxis = self.__lastMainAxis + mainAxis = mainAxis() + self.synchronize(mainAxis=mainAxis) + # force back the main axis + self.__lastMainAxis = weakref.ref(mainAxis) + + def __getAxesCenter(self, axis, vmin, vmax): + """Returns the value displayed in the center of this axis range. + + :rtype: float + """ + scale = axis.getScale() + if scale == Axis.LINEAR: + center = (vmin + vmax) * 0.5 + else: + raise NotImplementedError("Log scale not implemented") + return center + + def __getRangeInPixel(self, axis): + """Returns the size of the axis in pixel""" + bounds = axis._getPlot().getPlotBoundsInPixels() + # bounds: left, top, width, height + if isinstance(axis, XAxis): + return bounds[2] + elif isinstance(axis, YAxis): + return bounds[3] + else: + assert(False) + + def __getLimitsFromCenter(self, axis, pos, pixelSize=None): + """Returns the limits to apply to this axis to move the `pos` into the + center of this axis. + + :param Axis axis: + :param float pos: Position in the center of the computed limits + :param Union[None,float] pixelSize: Pixel size to apply to compute the + limits. If `None` the current pixel size is applyed. + """ + scale = axis.getScale() + if scale == Axis.LINEAR: + if pixelSize is None: + # Use the current pixel size of the axis + limits = axis.getLimits() + valueRange = limits[0] - limits[1] + a = pos - valueRange * 0.5 + b = pos + valueRange * 0.5 + else: + pixelRange = self.__getRangeInPixel(axis) + a = pos - pixelRange * 0.5 * pixelSize + b = pos + pixelRange * 0.5 * pixelSize + + else: + raise NotImplementedError("Log scale not implemented") + if a > b: + return b, a + return a, b + def __axisLimitsChanged(self, changedAxis, vmin, vmax): if self.__locked: return + self.__lastMainAxis = changedAxis + changedAxis = changedAxis() + with self.__inhibitSignals(): + for axis in self.__axesToUpdate(changedAxis): + axis.setLimits(vmin, vmax) + + def __axisCenterAndZoomChanged(self, changedAxis, vmin, vmax): + if self.__locked: + return + self.__lastMainAxis = changedAxis + changedAxis = changedAxis() + with self.__inhibitSignals(): + center = self.__getAxesCenter(changedAxis, vmin, vmax) + pixelRange = self.__getRangeInPixel(changedAxis) + if pixelRange == 0: + return + pixelSize = (vmax - vmin) / pixelRange + for axis in self.__axesToUpdate(changedAxis): + vmin, vmax = self.__getLimitsFromCenter(axis, center, pixelSize) + axis.setLimits(vmin, vmax) + + def __axisCenterChanged(self, changedAxis, vmin, vmax): + if self.__locked: + return + self.__lastMainAxis = changedAxis changedAxis = changedAxis() with self.__inhibitSignals(): - for axis in self.__otherAxes(changedAxis): + center = self.__getAxesCenter(changedAxis, vmin, vmax) + for axis in self.__axesToUpdate(changedAxis): + vmin, vmax = self.__getLimitsFromCenter(axis, center) axis.setLimits(vmin, vmax) def __axisScaleChanged(self, changedAxis, scale): if self.__locked: return + self.__lastMainAxis = changedAxis changedAxis = changedAxis() with self.__inhibitSignals(): - for axis in self.__otherAxes(changedAxis): + for axis in self.__axesToUpdate(changedAxis): axis.setScale(scale) def __axisInvertedChanged(self, changedAxis, isInverted): if self.__locked: return + self.__lastMainAxis = changedAxis changedAxis = changedAxis() with self.__inhibitSignals(): - for axis in self.__otherAxes(changedAxis): + for axis in self.__axesToUpdate(changedAxis): axis.setInverted(isInverted) diff --git a/silx/gui/plot3d/ParamTreeView.py b/silx/gui/plot3d/ParamTreeView.py index ee0c876..8cf2b90 100644 --- a/silx/gui/plot3d/ParamTreeView.py +++ b/silx/gui/plot3d/ParamTreeView.py @@ -43,7 +43,7 @@ __date__ = "05/12/2017" import numbers import sys -from silx.third_party import six +import six from .. import qt from ..widgets.FloatEdit import FloatEdit as _FloatEdit diff --git a/silx/gui/plot3d/ScalarFieldView.py b/silx/gui/plot3d/ScalarFieldView.py index e5e680c..50cba05 100644 --- a/silx/gui/plot3d/ScalarFieldView.py +++ b/silx/gui/plot3d/ScalarFieldView.py @@ -886,6 +886,8 @@ class ScalarFieldView(Plot3DWindow): self._bbox = axes.LabelledAxes() self._bbox.children = [self._group] + self._outerScale = transform.Scale(1., 1., 1.) + self._bbox.transforms = [self._outerScale] self.getPlot3DWidget().viewport.scene.children.append(self._bbox) self._selectionBox = primitives.Box() @@ -1204,6 +1206,25 @@ class ScalarFieldView(Plot3DWindow): # Transformations + def setOuterScale(self, sx=1., sy=1., sz=1.): + """Set the scale to apply to the whole scene including the axes. + + This is useful when axis lengths in data space are really different. + + :param float sx: Scale factor along the X axis + :param float sy: Scale factor along the Y axis + :param float sz: Scale factor along the Z axis + """ + self._outerScale.setScale(sx, sy, sz) + self.centerScene() + + def getOuterScale(self): + """Returns the scales provided by :meth:`setOuterScale`. + + :rtype: numpy.ndarray + """ + return self._outerScale.scale + def setScale(self, sx=1., sy=1., sz=1.): """Set the scale of the 3D scalar field (i.e., size of a voxel). diff --git a/silx/gui/plot3d/SceneWidget.py b/silx/gui/plot3d/SceneWidget.py index 4a824d7..e60dcfc 100644 --- a/silx/gui/plot3d/SceneWidget.py +++ b/silx/gui/plot3d/SceneWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -30,10 +30,11 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" __date__ = "24/04/2018" -import numpy +import enum import weakref -from silx.third_party import enum +import numpy + from .. import qt from ..colors import rgba @@ -229,6 +230,9 @@ class SceneSelection(qt.QObject): :raise ValueError: If the item is not the widget's scene """ previous = self.getCurrentItem() + if item is previous: + return # Fast path, nothing to do + if previous is not None: previous.sigItemChanged.disconnect(self.__currentChanged) @@ -252,15 +256,18 @@ class SceneSelection(qt.QObject): 'Not an Item3D: %s' % str(item)) current = self.getCurrentItem() - if current is not previous: - self.sigCurrentChanged.emit(current, previous) - self.__updateSelectionModel() + self.sigCurrentChanged.emit(current, previous) + self.__updateSelectionModel() def __currentChanged(self, event): """Handle updates of the selected item""" if event == items.Item3DChangedType.ROOT_ITEM: item = self.sender() - if item.root() != self.getSceneGroup(): + + parent = self.parent() + assert isinstance(parent, SceneWidget) + + if item.root() != parent.getSceneGroup(): self.setSelectedItem(None) # Synchronization with QItemSelectionModel @@ -488,7 +495,8 @@ class SceneWidget(Plot3DWidget): :param int index: The index at which to place the item. By default it is appended to the end of the list. :return: The newly created scalar volume item - :rtype: items.ScalarField3D + :rtype: ~silx.gui.plot3d.items.volume.ScalarField3D + """ volume = items.ScalarField3D() volume.setData(data, copy=copy) @@ -508,7 +516,7 @@ class SceneWidget(Plot3DWidget): :param int index: The index at which to place the item. By default it is appended to the end of the list. :return: The newly created 3D scatter item - :rtype: items.Scatter3D + :rtype: ~silx.gui.plot3d.items.scatter.Scatter3D """ scatter3d = items.Scatter3D() scatter3d.setData(x=x, y=y, z=z, value=value, copy=copy) @@ -528,7 +536,7 @@ class SceneWidget(Plot3DWidget): :param int index: The index at which to place the item. By default it is appended to the end of the list. :return: The newly created 2D scatter item - :rtype: items.Scatter2D + :rtype: ~silx.gui.plot3d.items.scatter.Scatter2D """ scatter2d = items.Scatter2D() scatter2d.setData(x=x, y=y, value=value, copy=copy) @@ -548,7 +556,7 @@ class SceneWidget(Plot3DWidget): :param int index: The index at which to place the item. By default it is appended to the end of the list. :return: The newly created image item - :rtype: items.ImageData or items.ImageRgba + :rtype: ~silx.gui.plot3d.items.image.ImageData or ~silx.gui.plot3d.items.image.ImageRgba :raise ValueError: For arrays of unsupported dimensions """ data = numpy.array(data, copy=False) diff --git a/silx/gui/plot3d/_model/items.py b/silx/gui/plot3d/_model/items.py index b09f29a..7e58d14 100644 --- a/silx/gui/plot3d/_model/items.py +++ b/silx/gui/plot3d/_model/items.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -38,8 +38,7 @@ import logging import weakref import numpy - -from silx.third_party import six +import six from ...utils.image import convertArrayToQImage from ...colors import preferredColormaps @@ -202,7 +201,7 @@ class Settings(StaticRow): super(Settings, self).__init__(('Settings', None), children=children) -class Item3DRow(StaticRow): +class Item3DRow(BaseRow): """Represents an :class:`Item3D` with checkable visibility :param Item3D item: The scene item to represent. @@ -210,9 +209,8 @@ class Item3DRow(StaticRow): """ def __init__(self, item, name=None): - if name is None: - name = item.getLabel() - super(Item3DRow, self).__init__((name, None)) + self.__name = None if name is None else six.text_type(name) + super(Item3DRow, self).__init__() self.setFlags( self.flags(0) | qt.Qt.ItemIsUserCheckable | qt.Qt.ItemIsSelectable, @@ -224,7 +222,8 @@ class Item3DRow(StaticRow): def _itemChanged(self, event): """Handle visibility change""" - if event == items.ItemChangedType.VISIBLE: + if event in (items.ItemChangedType.VISIBLE, + items.Item3DChangedType.LABEL): model = self.model() if model is not None: index = self.index(column=1) @@ -235,16 +234,25 @@ class Item3DRow(StaticRow): return self._item() def data(self, column, role): - if column == 0 and role == qt.Qt.CheckStateRole: - item = self.item() - if item is not None and item.isVisible(): - return qt.Qt.Checked - else: - return qt.Qt.Unchecked - elif column == 0 and role == qt.Qt.DecorationRole: - return icons.getQIcon('item-3dim') - else: - return super(Item3DRow, self).data(column, role) + if column == 0: + if role == qt.Qt.CheckStateRole: + item = self.item() + if item is not None and item.isVisible(): + return qt.Qt.Checked + else: + return qt.Qt.Unchecked + + elif role == qt.Qt.DecorationRole: + return icons.getQIcon('item-3dim') + + elif role == qt.Qt.DisplayRole: + if self.__name is None: + item = self.item() + return '' if item is None else item.getLabel() + else: + return self.__name + + return super(Item3DRow, self).data(column, role) def setData(self, column, value, role): if column == 0 and role == qt.Qt.CheckStateRole: @@ -256,6 +264,9 @@ class Item3DRow(StaticRow): return False return super(Item3DRow, self).setData(column, value, role) + def columnCount(self): + return 2 + class DataItem3DBoundingBoxRow(ProxyRow): """Represents :class:`DataItem3D` bounding box visibility @@ -562,7 +573,6 @@ class _ColormapBaseProxyRow(ProxyRow): """Signal used internally to notify colormap (or data) update""" def __init__(self, item, *args, **kwargs): - self._dataRange = None self._item = weakref.ref(item) self._colormap = item.getColormap() @@ -581,19 +591,11 @@ class _ColormapBaseProxyRow(ProxyRow): :return: Colormap range (min, max) """ - if self._dataRange is None: - item = self.item() - if item is not None and self._colormap is not None: - if hasattr(item, 'getDataRange'): - data = item.getDataRange() - else: - data = item.getData(copy=False) - - self._dataRange = self._colormap.getColormapRange(data) - - else: # Fallback - self._dataRange = 1, 100 - return self._dataRange + item = self.item() + if item is not None and self._colormap is not None: + return self._colormap.getColormapRange(item._getDataRange()) + else: + return 1, 100 # Fallback def _modelUpdated(self, *args, **kwargs): """Emit dataChanged in the model""" @@ -624,7 +626,6 @@ class _ColormapBaseProxyRow(ProxyRow): self._colormap = None elif event == items.ItemChangedType.DATA: - self._dataRange = None self._sigColormapChanged.emit() diff --git a/silx/gui/plot3d/items/__init__.py b/silx/gui/plot3d/items/__init__.py index b2a9dab..58eee9c 100644 --- a/silx/gui/plot3d/items/__init__.py +++ b/silx/gui/plot3d/items/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -38,6 +38,6 @@ from .mixins import (ColormapMixIn, InterpolationMixIn, # noqa PlaneMixIn, SymbolMixIn) # noqa from .clipplane import ClipPlane # noqa from .image import ImageData, ImageRgba # noqa -from .mesh import Mesh, Box, Cylinder, Hexagon # noqa +from .mesh import Mesh, ColormapMesh, Box, Cylinder, Hexagon # noqa from .scatter import Scatter2D, Scatter3D # noqa from .volume import ScalarField3D # noqa diff --git a/silx/gui/plot3d/items/core.py b/silx/gui/plot3d/items/core.py index 0aefced..1745b2b 100644 --- a/silx/gui/plot3d/items/core.py +++ b/silx/gui/plot3d/items/core.py @@ -32,10 +32,10 @@ __license__ = "MIT" __date__ = "15/11/2017" from collections import defaultdict +import enum import numpy - -from silx.third_party import enum, six +import six from ... import qt from ...plot.items import ItemChangedType diff --git a/silx/gui/plot3d/items/mesh.py b/silx/gui/plot3d/items/mesh.py index 21936ea..d3f5e38 100644 --- a/silx/gui/plot3d/items/mesh.py +++ b/silx/gui/plot3d/items/mesh.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,17 +35,18 @@ __date__ = "17/07/2018" import logging import numpy -from ..scene import primitives, utils +from ..scene import primitives, utils, function from ..scene.transform import Rotate from .core import DataItem3D, ItemChangedType +from .mixins import ColormapMixIn from ._pick import PickingResult _logger = logging.getLogger(__name__) -class Mesh(DataItem3D): - """Description of mesh. +class _MeshBase(DataItem3D): + """Base class for :class:`Mesh' and :class:`ColormapMesh`. :param parent: The View widget this item belongs to. """ @@ -54,48 +55,22 @@ class Mesh(DataItem3D): DataItem3D.__init__(self, parent=parent) self._mesh = None - def setData(self, - position, - color, - normal=None, - mode='triangles', - copy=True): - """Set mesh geometry data. - - Supported drawing modes are: 'triangles', 'triangle_strip', 'fan' + def _setMesh(self, mesh): + """Set mesh primitive - :param numpy.ndarray position: - Position (x, y, z) of each vertex as a (N, 3) array - :param numpy.ndarray color: Colors for each point or a single color - :param numpy.ndarray normal: Normals for each point or None (default) - :param str mode: The drawing mode. - :param bool copy: True (default) to copy the data, - False to use as is (do not modify!). + :param Union[None,Geometry] mesh: The scene primitive """ self._getScenePrimitive().children = [] # Remove any previous mesh - if position is None or len(position) == 0: - self._mesh = None - else: - self._mesh = primitives.Mesh3D( - position, color, normal, mode=mode, copy=copy) + self._mesh = mesh + if self._mesh is not None: self._getScenePrimitive().children.append(self._mesh) - self.sigItemChanged.emit(ItemChangedType.DATA) + self._updated(ItemChangedType.DATA) - def getData(self, copy=True): - """Get the mesh geometry. - - :param bool copy: - True (default) to get a copy, - False to get internal representation (do not modify!). - :return: The positions, colors, normals and mode - :rtype: tuple of numpy.ndarray - """ - return (self.getPositionData(copy=copy), - self.getColorData(copy=copy), - self.getNormalData(copy=copy), - self.getDrawMode()) + def _getMesh(self): + """Returns the underlying Mesh scene primitive""" + return self._mesh def getPositionData(self, copy=True): """Get the mesh vertex positions. @@ -106,38 +81,38 @@ class Mesh(DataItem3D): :return: The (x, y, z) positions as a (N, 3) array :rtype: numpy.ndarray """ - if self._mesh is None: + if self._getMesh() is None: return numpy.empty((0, 3), dtype=numpy.float32) else: - return self._mesh.getAttribute('position', copy=copy) + return self._getMesh().getAttribute('position', copy=copy) - def getColorData(self, copy=True): - """Get the mesh vertex colors. + def getNormalData(self, copy=True): + """Get the mesh vertex normals. :param bool copy: True (default) to get a copy, False to get internal representation (do not modify!). - :return: The RGBA colors as a (N, 4) array or a single color - :rtype: numpy.ndarray + :return: The normals as a (N, 3) array, a single normal or None + :rtype: Union[numpy.ndarray,None] """ - if self._mesh is None: - return numpy.empty((0, 4), dtype=numpy.float32) + if self._getMesh() is None: + return None else: - return self._mesh.getAttribute('color', copy=copy) + return self._getMesh().getAttribute('normal', copy=copy) - def getNormalData(self, copy=True): - """Get the mesh vertex normals. + def getIndices(self, copy=True): + """Get the vertex indices. :param bool copy: True (default) to get a copy, False to get internal representation (do not modify!). - :return: The normals as a (N, 3) array, a single normal or None - :rtype: numpy.ndarray or None + :return: The vertex indices as an array or None. + :rtype: Union[numpy.ndarray,None] """ - if self._mesh is None: + if self._getMesh() is None: return None else: - return self._mesh.getAttribute('normal', copy=copy) + return self._getMesh().getIndices(copy=copy) def getDrawMode(self): """Get mesh rendering mode. @@ -145,7 +120,7 @@ class Mesh(DataItem3D): :return: The drawing mode of this primitive :rtype: str """ - return self._mesh.drawMode + return self._getMesh().drawMode def _pickFull(self, context): """Perform precise picking in this item at given widget position. @@ -164,28 +139,34 @@ class Mesh(DataItem3D): return None mode = self.getDrawMode() - if mode == 'triangles': - triangles = positions.reshape(-1, 3, 3) - - elif mode == 'triangle_strip': - # Expand strip - triangles = numpy.empty((len(positions) - 2, 3, 3), - dtype=positions.dtype) - triangles[:, 0] = positions[:-2] - triangles[:, 1] = positions[1:-1] - triangles[:, 2] = positions[2:] - - elif mode == 'fan': - # Expand fan - triangles = numpy.empty((len(positions) - 2, 3, 3), - dtype=positions.dtype) - triangles[:, 0] = positions[0] - triangles[:, 1] = positions[1:-1] - triangles[:, 2] = positions[2:] + vertexIndices = self.getIndices(copy=False) + if vertexIndices is not None: # Expand indices + positions = utils.unindexArrays(mode, vertexIndices, positions)[0] + triangles = positions.reshape(-1, 3, 3) else: - _logger.warning("Unsupported draw mode: %s" % mode) - return None + if mode == 'triangles': + triangles = positions.reshape(-1, 3, 3) + + elif mode == 'triangle_strip': + # Expand strip + triangles = numpy.empty((len(positions) - 2, 3, 3), + dtype=positions.dtype) + triangles[:, 0] = positions[:-2] + triangles[:, 1] = positions[1:-1] + triangles[:, 2] = positions[2:] + + elif mode == 'fan': + # Expand fan + triangles = numpy.empty((len(positions) - 2, 3, 3), + dtype=positions.dtype) + triangles[:, 0] = positions[0] + triangles[:, 1] = positions[1:-1] + triangles[:, 2] = positions[2:] + + else: + _logger.warning("Unsupported draw mode: %s" % mode) + return None trianglesIndices, t, barycentric = utils.segmentTrianglesIntersection( rayObject, triangles) @@ -208,12 +189,160 @@ class Mesh(DataItem3D): indices = trianglesIndices + closest # For corners 1 and 2 indices[closest == 0] = 0 # For first corner (common) + if vertexIndices is not None: + # Convert from indices in expanded triangles to input vertices + indices = vertexIndices[indices] + return PickingResult(self, positions=points, indices=indices, fetchdata=self.getPositionData) +class Mesh(_MeshBase): + """Description of mesh. + + :param parent: The View widget this item belongs to. + """ + + def __init__(self, parent=None): + _MeshBase.__init__(self, parent=parent) + + def setData(self, + position, + color, + normal=None, + mode='triangles', + indices=None, + copy=True): + """Set mesh geometry data. + + Supported drawing modes are: 'triangles', 'triangle_strip', 'fan' + + :param numpy.ndarray position: + Position (x, y, z) of each vertex as a (N, 3) array + :param numpy.ndarray color: Colors for each point or a single color + :param Union[numpy.ndarray,None] normal: Normals for each point or None (default) + :param str mode: The drawing mode. + :param Union[List[int],None] indices: + Array of vertex indices or None to use arrays directly. + :param bool copy: True (default) to copy the data, + False to use as is (do not modify!). + """ + assert mode in ('triangles', 'triangle_strip', 'fan') + if position is None or len(position) == 0: + mesh = None + else: + mesh = primitives.Mesh3D( + position, color, normal, mode=mode, indices=indices, copy=copy) + self._setMesh(mesh) + + def getData(self, copy=True): + """Get the mesh geometry. + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: The positions, colors, normals and mode + :rtype: tuple of numpy.ndarray + """ + return (self.getPositionData(copy=copy), + self.getColorData(copy=copy), + self.getNormalData(copy=copy), + self.getDrawMode()) + + def getColorData(self, copy=True): + """Get the mesh vertex colors. + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: The RGBA colors as a (N, 4) array or a single color + :rtype: numpy.ndarray + """ + if self._getMesh() is None: + return numpy.empty((0, 4), dtype=numpy.float32) + else: + return self._getMesh().getAttribute('color', copy=copy) + + +class ColormapMesh(_MeshBase, ColormapMixIn): + """Description of mesh which color is defined by scalar and a colormap. + + :param parent: The View widget this item belongs to. + """ + + def __init__(self, parent=None): + _MeshBase.__init__(self, parent=parent) + ColormapMixIn.__init__(self, function.Colormap()) + + def setData(self, + position, + value, + normal=None, + mode='triangles', + indices=None, + copy=True): + """Set mesh geometry data. + + Supported drawing modes are: 'triangles', 'triangle_strip', 'fan' + + :param numpy.ndarray position: + Position (x, y, z) of each vertex as a (N, 3) array + :param numpy.ndarray value: Data value for each vertex. + :param Union[numpy.ndarray,None] normal: Normals for each point or None (default) + :param str mode: The drawing mode. + :param Union[List[int],None] indices: + Array of vertex indices or None to use arrays directly. + :param bool copy: True (default) to copy the data, + False to use as is (do not modify!). + """ + assert mode in ('triangles', 'triangle_strip', 'fan') + if position is None or len(position) == 0: + mesh = None + else: + mesh = primitives.ColormapMesh3D( + position=position, + value=numpy.array(value, copy=False).reshape(-1, 1), # Make it a 2D array + colormap=self._getSceneColormap(), + normal=normal, + mode=mode, + indices=indices, + copy=copy) + self._setMesh(mesh) + + # Store data range info + ColormapMixIn._setRangeFromData(self, self.getValueData(copy=False)) + + def getData(self, copy=True): + """Get the mesh geometry. + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: The positions, values, normals and mode + :rtype: tuple of numpy.ndarray + """ + return (self.getPositionData(copy=copy), + self.getValueData(copy=copy), + self.getNormalData(copy=copy), + self.getDrawMode()) + + def getValueData(self, copy=True): + """Get the mesh vertex values. + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: Array of data values + :rtype: numpy.ndarray + """ + if self._getMesh() is None: + return numpy.empty((0,), dtype=numpy.float32) + else: + return self._getMesh().getAttribute('value', copy=copy) + + class _CylindricalVolume(DataItem3D): """Class that represents a volume with a rotational symmetry along z @@ -345,7 +474,7 @@ class _CylindricalVolume(DataItem3D): vertices, color, normals, mode='triangles', copy=False) self._getScenePrimitive().children.append(self._mesh) - self.sigItemChanged.emit(ItemChangedType.DATA) + self._updated(ItemChangedType.DATA) def _pickFull(self, context): """Perform precise picking in this item at given widget position. diff --git a/silx/gui/plot3d/items/mixins.py b/silx/gui/plot3d/items/mixins.py index 8e96441..40b8438 100644 --- a/silx/gui/plot3d/items/mixins.py +++ b/silx/gui/plot3d/items/mixins.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -114,19 +114,17 @@ class ColormapMixIn(_ColormapMixIn): self.__sceneColormap = sceneColormap self._syncSceneColormap() - self.sigItemChanged.connect(self.__colormapUpdated) - - def __colormapUpdated(self, event): + def _colormapChanged(self): """Handle colormap updates""" - if event == ItemChangedType.COLORMAP: - self._syncSceneColormap() + self._syncSceneColormap() + super(ColormapMixIn, self)._colormapChanged() def _setRangeFromData(self, data=None): """Compute the data range the colormap should use from provided data. :param data: Data set from which to compute the range or None """ - if data is None or len(data) == 0: + if data is None or data.size == 0: dataRange = None else: dataRange = min_max(data, min_positive=True, finite=True) @@ -144,6 +142,13 @@ class ColormapMixIn(_ColormapMixIn): if self.getColormap().isAutoscale(): self._syncSceneColormap() + def _getDataRange(self): + """Returns the data range as used in the scene for colormap + + :rtype: Union[List[float],None] + """ + return self._dataRange + def _setSceneColormap(self, sceneColormap): """Set the scene colormap to sync with Colormap object. @@ -171,8 +176,6 @@ class ColormapMixIn(_ColormapMixIn): class SymbolMixIn(_SymbolMixIn): """Mix-in class for symbol and symbolSize properties for Item3D""" - _DEFAULT_SYMBOL = 'o' - _DEFAULT_SYMBOL_SIZE = 7.0 _SUPPORTED_SYMBOLS = collections.OrderedDict(( ('o', 'Circle'), ('d', 'Diamond'), diff --git a/silx/gui/plot3d/items/scatter.py b/silx/gui/plot3d/items/scatter.py index a13c3db..b7bcd09 100644 --- a/silx/gui/plot3d/items/scatter.py +++ b/silx/gui/plot3d/items/scatter.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -36,6 +36,7 @@ import logging import sys import numpy +from ....utils.deprecation import deprecated from ..scene import function, primitives, utils from .core import DataItem3D, Item3DChangedType, ItemChangedType @@ -43,7 +44,7 @@ from .mixins import ColormapMixIn, SymbolMixIn from ._pick import PickingResult -_logger = logging.getLevelName(__name__) +_logger = logging.getLogger(__name__) class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn): @@ -94,7 +95,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn): self._scatter.setAttribute('z', z, copy=copy) self._scatter.setAttribute('value', value, copy=copy) - ColormapMixIn._setRangeFromData(self, self.getValues(copy=False)) + ColormapMixIn._setRangeFromData(self, self.getValueData(copy=False)) self._updated(ItemChangedType.DATA) def getData(self, copy=True): @@ -107,7 +108,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn): return (self.getXData(copy), self.getYData(copy), self.getZData(copy), - self.getValues(copy)) + self.getValueData(copy)) def getXData(self, copy=True): """Returns X data coordinates. @@ -139,7 +140,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn): """ return self._scatter.getAttribute('z', copy=copy).reshape(-1) - def getValues(self, copy=True): + def getValueData(self, copy=True): """Returns data values. :param bool copy: True to get a copy, @@ -149,6 +150,11 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn): """ return self._scatter.getAttribute('value', copy=copy).reshape(-1) + @deprecated(reason="Consistency with PlotWidget items", + replacement="getValueData", since_version="0.10.0") + def getValues(self, copy=True): + return self.getValueData(copy) + def _pickFull(self, context, threshold=0., sort='depth'): """Perform picking in this item at given widget position. @@ -202,7 +208,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn): return PickingResult(self, positions=dataPoints[picked, :3], indices=picked, - fetchdata=self.getValues) + fetchdata=self.getValueData) else: return None @@ -269,8 +275,8 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): Supported visualization modes are: - 'points': For scatter plot representation - - 'lines': For Delaunay tesselation-based wireframe representation - - 'solid': For Delaunay tesselation-based solid surface representation + - 'lines': For Delaunay tessellation-based wireframe representation + - 'solid': For Delaunay tessellation-based solid surface representation :param str mode: Mode of representation to use """ @@ -384,7 +390,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): self._cachedTrianglesIndices = None # Store data range info - ColormapMixIn._setRangeFromData(self, self.getValues(copy=False)) + ColormapMixIn._setRangeFromData(self, self.getValueData(copy=False)) self._updateScene() @@ -399,7 +405,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): """ return (self.getXData(copy=copy), self.getYData(copy=copy), - self.getValues(copy=copy)) + self.getValueData(copy=copy)) def getXData(self, copy=True): """Returns X data coordinates. @@ -421,7 +427,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): """ return numpy.array(self._y, copy=copy) - def getValues(self, copy=True): + def getValueData(self, copy=True): """Returns data values. :param bool copy: True to get a copy, @@ -431,6 +437,11 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): """ return numpy.array(self._value, copy=copy) + @deprecated(reason="Consistency with PlotWidget items", + replacement="getValueData", since_version="0.10.0") + def getValues(self, copy=True): + return self.getValueData(copy) + def _pickPoints(self, context, points, threshold=1., sort='depth'): """Perform picking while in 'points' visualization mode @@ -472,7 +483,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): return PickingResult(self, positions=points[picked, :3], indices=picked, - fetchdata=self.getValues) + fetchdata=self.getValueData) else: return None @@ -507,7 +518,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): return PickingResult(self, positions=positions, indices=indices, - fetchdata=self.getValues) + fetchdata=self.getValueData) def _pickFull(self, context): """Perform picking in this item at given widget position. @@ -521,7 +532,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): return None if self.isHeightMap(): - zData = self.getValues(copy=False) + zData = self.getValueData(copy=False) else: zData = numpy.zeros_like(xData) diff --git a/silx/gui/plot3d/items/volume.py b/silx/gui/plot3d/items/volume.py index ca22f1f..08ad02a 100644 --- a/silx/gui/plot3d/items/volume.py +++ b/silx/gui/plot3d/items/volume.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -78,13 +78,15 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn): def _parentChanged(self, event): """Handle data change in the parent this plane belongs to""" if event == ItemChangedType.DATA: - self._getPlane().setData(self.sender().getData(copy=False), - copy=False) + data = self.sender().getData(copy=False) + self._getPlane().setData(data, copy=False) # Store data range info as 3-tuple of values self._dataRange = self.sender().getDataRange() + self._setRangeFromData( + None if self._dataRange is None else numpy.array(self._dataRange)) - self.sigItemChanged.emit(ItemChangedType.DATA) + self._updated(ItemChangedType.DATA) # Colormap @@ -104,7 +106,7 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn): display = bool(display) if display != self.getDisplayValuesBelowMin(): self._getPlane().colormap.displayValuesBelowMin = display - self.sigItemChanged.emit(ItemChangedType.ALPHA) + self._updated(ItemChangedType.ALPHA) def getDataRange(self): """Return the range of the data as a 3-tuple of values. diff --git a/silx/gui/plot3d/scene/primitives.py b/silx/gui/plot3d/scene/primitives.py index 474581a..ca06e30 100644 --- a/silx/gui/plot3d/scene/primitives.py +++ b/silx/gui/plot3d/scene/primitives.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -1878,11 +1878,13 @@ class ColormapMesh3D(Geometry): colormap=None, normal=None, mode='triangles', - indices=None): + indices=None, + copy=True): super(ColormapMesh3D, self).__init__(mode, indices, position=position, normal=normal, - value=value) + value=value, + copy=copy) self._lineWidth = 1.0 self._lineSmooth = True diff --git a/silx/gui/plot3d/test/__init__.py b/silx/gui/plot3d/test/__init__.py index c58f307..8825cf4 100644 --- a/silx/gui/plot3d/test/__init__.py +++ b/silx/gui/plot3d/test/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -59,6 +59,7 @@ def suite(): from .testGL import suite as testGLSuite from .testScalarFieldView import suite as testScalarFieldViewSuite from .testSceneWidgetPicking import suite as testSceneWidgetPickingSuite + from .testStatsWidget import suite as testStatsWidgetSuite testsuite = unittest.TestSuite() testsuite.addTest(testGLSuite()) @@ -66,4 +67,5 @@ def suite(): testsuite.addTest(testScalarFieldViewSuite()) testsuite.addTest(testSceneWidgetPickingSuite()) testsuite.addTest(toolsTestSuite()) + testsuite.addTest(testStatsWidgetSuite()) return testsuite diff --git a/silx/gui/plot3d/test/testSceneWidgetPicking.py b/silx/gui/plot3d/test/testSceneWidgetPicking.py index d0c6467..649fb47 100644 --- a/silx/gui/plot3d/test/testSceneWidgetPicking.py +++ b/silx/gui/plot3d/test/testSceneWidgetPicking.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -122,7 +122,7 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase): self.assertEqual(nbPos, len(data)) self.assertTrue(numpy.array_equal( data, - item.getValues()[picking[0].getIndices()])) + item.getValueData()[picking[0].getIndices()])) # Picking outside data picking = list(self.widget.pickItems(1, 1)) @@ -217,6 +217,55 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase): picking = list(self.widget.pickItems(1, 1)) self.assertEqual(len(picking), 0) + def testPickMeshWithIndices(self): + """Test picking of Mesh items defined by indices""" + + triangles = items.Mesh() + triangles.setData( + position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)), + color=(1, 0, 0, 1), + indices=numpy.array( # dummy triangles and square + (0, 0, 1, 0, 1, 2, 1, 2, 3), dtype=numpy.uint8), + mode='triangles') + triangleStrip = items.Mesh() + triangleStrip.setData( + position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)), + color=(0, 1, 0, 1), + indices=numpy.array( # dummy triangles and square + (1, 0, 0, 1, 2, 3), dtype=numpy.uint8), + mode='triangle_strip') + triangleFan = items.Mesh() + triangleFan.setData( + position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)), + color=(0, 0, 1, 1), + indices=numpy.array( # dummy triangle, square, dummy + (1, 1, 0, 2, 3, 3), dtype=numpy.uint8), + mode='fan') + + for item in (triangles, triangleStrip, triangleFan): + with self.subTest(mode=item.getDrawMode()): + # Add item + self.widget.clearItems() + self.widget.addItem(item) + self.widget.resetZoom('front') + self.qapp.processEvents() + + # Picking on data (at widget center) + picking = list(self.widget.pickItems(*self._widgetCenter())) + + self.assertEqual(len(picking), 1) + self.assertIs(picking[0].getItem(), item) + nbPos = len(picking[0].getPositions()) + data = picking[0].getData() + self.assertEqual(nbPos, len(data)) + self.assertTrue(numpy.array_equal( + data, + item.getPositionData()[picking[0].getIndices()])) + + # Picking outside data + picking = list(self.widget.pickItems(1, 1)) + self.assertEqual(len(picking), 0) + def testPickCylindricalMesh(self): """Test picking of Box, Cylinder and Hexagon items""" diff --git a/silx/gui/plot3d/test/testStatsWidget.py b/silx/gui/plot3d/test/testStatsWidget.py new file mode 100644 index 0000000..1157aec --- /dev/null +++ b/silx/gui/plot3d/test/testStatsWidget.py @@ -0,0 +1,213 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# ###########################################################################*/ +"""Test silx.gui.plot.StatsWidget with SceneWidget and ScalarFieldView""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/01/2019" + + +import unittest + +import numpy + +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import TestCaseQt +from silx.gui import qt + +from silx.gui.plot.StatsWidget import BasicStatsWidget + +from silx.gui.plot3d.ScalarFieldView import ScalarFieldView +from silx.gui.plot3d.SceneWidget import SceneWidget, items + + +class TestSceneWidget(TestCaseQt, ParametricTestCase): + """Tests StatsWidget combined with SceneWidget""" + + def setUp(self): + super(TestSceneWidget, self).setUp() + self.sceneWidget = SceneWidget() + self.sceneWidget.resize(300, 300) + self.sceneWidget.show() + self.statsWidget = BasicStatsWidget() + self.statsWidget.setPlot(self.sceneWidget) + # self.qWaitForWindowExposed(self.sceneWidget) + + def tearDown(self): + self.qapp.processEvents() + self.sceneWidget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.sceneWidget.close() + del self.sceneWidget + self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.statsWidget.close() + del self.statsWidget + super(TestSceneWidget, self).tearDown() + + def test(self): + """Test StatsWidget with SceneWidget""" + # Prepare scene + + # Data image + image = self.sceneWidget.addImage(numpy.arange(100).reshape(10, 10)) + image.setLabel('Image') + # RGB image + imageRGB = self.sceneWidget.addImage( + numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3)) + imageRGB.setLabel('RGB Image') + # 2D scatter + data = numpy.arange(100) + scatter2D = self.sceneWidget.add2DScatter(x=data, y=data, value=data) + scatter2D.setLabel('2D Scatter') + # 3D scatter + scatter3D = self.sceneWidget.add3DScatter(x=data, y=data, z=data, value=data) + scatter3D.setLabel('3D Scatter') + # Add a group + group = items.GroupItem() + self.sceneWidget.addItem(group) + # 3D scalar field + data = numpy.arange(64**3).reshape(64, 64, 64) + scalarField = items.ScalarField3D() + scalarField.setData(data, copy=False) + scalarField.setLabel('3D Scalar field') + group.addItem(scalarField) + + statsTable = self.statsWidget._getStatsTable() + + # Test selection only + self.statsWidget.setDisplayOnlyActiveItem(True) + self.assertEqual(statsTable.rowCount(), 0) + + self.sceneWidget.selection().setCurrentItem(group) + self.assertEqual(statsTable.rowCount(), 0) + + for item in (image, scatter2D, scatter3D, scalarField): + with self.subTest('selection only', item=item.getLabel()): + self.sceneWidget.selection().setCurrentItem(item) + self.assertEqual(statsTable.rowCount(), 1) + self._checkItem(item) + + # Test all data + self.statsWidget.setDisplayOnlyActiveItem(False) + self.assertEqual(statsTable.rowCount(), 4) + + for item in (image, scatter2D, scatter3D, scalarField): + with self.subTest('all items', item=item.getLabel()): + self._checkItem(item) + + def _checkItem(self, item): + """Check that item is in StatsTable and that stats are OK + + :param silx.gui.plot3d.items.Item3D item: + """ + if isinstance(item, (items.Scatter2D, items.Scatter3D)): + data = item.getValueData(copy=False) + else: + data = item.getData(copy=False) + + statsTable = self.statsWidget._getStatsTable() + tableItems = statsTable._itemToTableItems(item) + self.assertTrue(len(tableItems) > 0) + self.assertEqual(tableItems['legend'].text(), item.getLabel()) + self.assertEqual(float(tableItems['min'].text()), numpy.min(data)) + self.assertEqual(float(tableItems['max'].text()), numpy.max(data)) + # TODO + + +class TestScalarFieldView(TestCaseQt): + """Tests StatsWidget combined with ScalarFieldView""" + + def setUp(self): + super(TestScalarFieldView, self).setUp() + self.scalarFieldView = ScalarFieldView() + self.scalarFieldView.resize(300, 300) + self.scalarFieldView.show() + self.statsWidget = BasicStatsWidget() + self.statsWidget.setPlot(self.scalarFieldView) + # self.qWaitForWindowExposed(self.sceneWidget) + + def tearDown(self): + self.qapp.processEvents() + self.scalarFieldView.setAttribute(qt.Qt.WA_DeleteOnClose) + self.scalarFieldView.close() + del self.scalarFieldView + self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.statsWidget.close() + del self.statsWidget + super(TestScalarFieldView, self).tearDown() + + def _getTextFor(self, row, name): + """Returns text in table at given row for column name + + :param int row: Row number in the table + :param str name: Column id + :rtype: Union[str,None] + """ + statsTable = self.statsWidget._getStatsTable() + + for column in range(statsTable.columnCount()): + headerItem = statsTable.horizontalHeaderItem(column) + if headerItem.data(qt.Qt.UserRole) == name: + tableItem = statsTable.item(row, column) + return tableItem.text() + + return None + + def test(self): + """Test StatsWidget with ScalarFieldView""" + data = numpy.arange(64**3, dtype=numpy.float64).reshape(64, 64, 64) + self.scalarFieldView.setData(data) + + statsTable = self.statsWidget._getStatsTable() + + # Test selection only + self.statsWidget.setDisplayOnlyActiveItem(True) + self.assertEqual(statsTable.rowCount(), 1) + + # Test all data + self.statsWidget.setDisplayOnlyActiveItem(False) + self.assertEqual(statsTable.rowCount(), 1) + + for column in range(statsTable.columnCount()): + self.assertEqual(float(self._getTextFor(0, 'min')), numpy.min(data)) + self.assertEqual(float(self._getTextFor(0, 'max')), numpy.max(data)) + sum_ = numpy.sum(data) + comz = numpy.sum(numpy.arange(data.shape[0]) * numpy.sum(data, axis=(1, 2))) / sum_ + comy = numpy.sum(numpy.arange(data.shape[1]) * numpy.sum(data, axis=(0, 2))) / sum_ + comx = numpy.sum(numpy.arange(data.shape[2]) * numpy.sum(data, axis=(0, 1))) / sum_ + self.assertEqual(self._getTextFor(0, 'COM'), str((comx, comy, comz))) + + +def suite(): + testsuite = unittest.TestSuite() + testsuite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase( + TestSceneWidget)) + testsuite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase( + TestScalarFieldView)) + return testsuite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/qt/_pyside_dynamic.py b/silx/gui/qt/_pyside_dynamic.py index 13d1a9d..6013416 100644 --- a/silx/gui/qt/_pyside_dynamic.py +++ b/silx/gui/qt/_pyside_dynamic.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Taken from: https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8 +# Plus: https://github.com/spyder-ide/qtpy/commit/001a862c401d757feb63025f88dbb4601d353c84 # Copyright (c) 2011 Sebastian Wiesner # Modifications by Charl Botha @@ -83,7 +84,9 @@ class UiLoader(QUiLoader): QUiLoader.__init__(self, baseinstance) self.baseinstance = baseinstance - self.customWidgets = customWidgets + self.customWidgets = {} + self.uifile = None + self.customWidgets.update(customWidgets) def createWidget(self, class_name, parent=None, name=''): """ @@ -107,13 +110,15 @@ class UiLoader(QUiLoader): # this will raise KeyError if the user has not supplied the # relevant class_name in the dictionary, or TypeError, if # customWidgets is None - try: - widget = self.customWidgets[class_name](parent) - - except (TypeError, KeyError): + if class_name not in self.customWidgets: raise Exception('No custom widget ' + class_name + ' found in customWidgets param of' + - 'UiLoader __init__.') + 'UiFile %s.' % self.uifile) + try: + widget = self.customWidgets[class_name](parent) + except Exception: + _logger.error("Fail to instanciate widget %s from file %s", class_name, self.uifile) + raise if self.baseinstance: # set an attribute for the new child widget on the base @@ -126,6 +131,42 @@ class UiLoader(QUiLoader): return widget + def _parse_custom_widgets(self, ui_file): + """ + This function is used to parse a ui file and look for the + section, then automatically load all the custom widget classes. + """ + import importlib + from xml.etree.ElementTree import ElementTree + + # Parse the UI file + etree = ElementTree() + ui = etree.parse(ui_file) + + # Get the customwidgets section + custom_widgets = ui.find('customwidgets') + + if custom_widgets is None: + return + + custom_widget_classes = {} + + for custom_widget in custom_widgets.getchildren(): + + cw_class = custom_widget.find('class').text + cw_header = custom_widget.find('header').text + + module = importlib.import_module(cw_header) + + custom_widget_classes[cw_class] = getattr(module, cw_class) + + self.customWidgets.update(custom_widget_classes) + + def load(self, uifile): + self._parse_custom_widgets(uifile) + self.uifile = uifile + return QUiLoader.load(self, uifile) + if "PySide2.QtCore" in sys.modules: @@ -155,7 +196,6 @@ if "PySide2.QtCore" in sys.modules: orientation = Property("Qt::Orientation", getOrientation, setOrientation) - CUSTOM_WIDGETS = {"Line": _Line} """Default custom widgets for `loadUi`""" diff --git a/silx/gui/test/test_colors.py b/silx/gui/test/test_colors.py index e980068..2f883bc 100644 --- a/silx/gui/test/test_colors.py +++ b/silx/gui/test/test_colors.py @@ -29,14 +29,13 @@ from __future__ import absolute_import __authors__ = ["H.Payno"] __license__ = "MIT" -__date__ = "05/10/2018" +__date__ = "09/11/2018" import unittest import numpy from silx.utils.testutils import ParametricTestCase from silx.gui import colors from silx.gui.colors import Colormap -from silx.gui.colors import preferredColormaps, setPreferredColormaps from silx.utils.exceptions import NotEditableError @@ -158,12 +157,12 @@ class TestDictAPI(unittest.TestCase): self.assertFalse(colormapObject.isAutoscale() == clm_dict['autoscale']) def testMissingKeysFromDict(self): - """Make sure we can create a Colormap object from a dictionnary even if - there is missing keys excepts if those keys are 'colors' or 'name' + """Make sure we can create a Colormap object from a dictionary even if + there is missing keys except if those keys are 'colors' or 'name' """ - colormap = Colormap._fromDict({'name': 'toto'}) + colormap = Colormap._fromDict({'name': 'blue'}) self.assertTrue(colormap.getVMin() is None) - colormap = Colormap._fromDict({'colors': numpy.zeros(10)}) + colormap = Colormap._fromDict({'colors': numpy.zeros((5, 3))}) self.assertTrue(colormap.getName() is None) with self.assertRaises(ValueError): @@ -227,15 +226,17 @@ class TestObjectAPI(ParametricTestCase): def testCopy(self): """Make sure the copy function is correctly processing """ - colormapObject = Colormap(name='toto', - colors=numpy.array([12, 13, 14]), + colormapObject = Colormap(name='red', + colors=numpy.array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]), vmin=None, vmax=None, normalization=Colormap.LOGARITHM) colormapObject2 = colormapObject.copy() self.assertTrue(colormapObject == colormapObject2) - colormapObject.setColormapLUT(numpy.array([0, 1])) + colormapObject.setColormapLUT([[0, 0, 0], [255, 255, 255]]) self.assertFalse(colormapObject == colormapObject2) colormapObject2 = colormapObject.copy() @@ -361,7 +362,7 @@ class TestObjectAPI(ParametricTestCase): with self.assertRaises(NotEditableError): colormap.setName('magma') with self.assertRaises(NotEditableError): - colormap.setColormapLUT(numpy.array([0, 1])) + colormap.setColormapLUT([[0., 0., 0.], [1., 1., 1.]]) with self.assertRaises(NotEditableError): colormap._setFromDict(colormap._toDict()) state = colormap.saveState() @@ -371,7 +372,28 @@ class TestObjectAPI(ParametricTestCase): def testBadColorsType(self): """Make sure colors can't be something else than an array""" with self.assertRaises(TypeError): - Colormap(name='temperature', colors=256) + Colormap(colors=256) + + def testEqual(self): + colormap1 = Colormap() + colormap2 = Colormap() + self.assertEqual(colormap1, colormap2) + + def testCompareString(self): + colormap = Colormap() + self.assertNotEqual(colormap, "a") + + def testCompareNone(self): + colormap = Colormap() + self.assertNotEqual(colormap, None) + + def testSet(self): + colormap = Colormap() + other = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM) + self.assertNotEqual(colormap, other) + colormap.setFromColormap(other) + self.assertIsNot(colormap, other) + self.assertEqual(colormap, other) class TestPreferredColormaps(unittest.TestCase): @@ -379,27 +401,76 @@ class TestPreferredColormaps(unittest.TestCase): def setUp(self): # Save preferred colormaps - self._colormaps = preferredColormaps() + self._colormaps = colors.preferredColormaps() def tearDown(self): # Restore saved preferred colormaps - setPreferredColormaps(self._colormaps) + colors.setPreferredColormaps(self._colormaps) def test(self): colormaps = 'viridis', 'magma' - setPreferredColormaps(colormaps) - self.assertEqual(preferredColormaps(), colormaps) + colors.setPreferredColormaps(colormaps) + self.assertEqual(colors.preferredColormaps(), colormaps) with self.assertRaises(ValueError): - setPreferredColormaps(()) + colors.setPreferredColormaps(()) with self.assertRaises(ValueError): - setPreferredColormaps(('This is not a colormap',)) + colors.setPreferredColormaps(('This is not a colormap',)) colormaps = 'red', 'green' - setPreferredColormaps(('This is not a colormap',) + colormaps) - self.assertEqual(preferredColormaps(), colormaps) + colors.setPreferredColormaps(('This is not a colormap',) + colormaps) + self.assertEqual(colors.preferredColormaps(), colormaps) + + +class TestRegisteredLut(unittest.TestCase): + """Test get|setPreferredColormaps functions""" + + def setUp(self): + # Save preferred colormaps + lut = numpy.arange(8 * 3) + lut.shape = -1, 3 + lut = lut / (8.0 * 3) + colors.registerLUT("test_8", colors=lut, cursor_color='blue') + + def testColormap(self): + colormap = Colormap("test_8") + self.assertIsNotNone(colormap) + + def testCursor(self): + color = colors.cursorColorForColormap("test_8") + self.assertEqual(color, 'blue') + + def testLut(self): + colormap = Colormap("test_8") + colors = colormap.getNColors(8) + self.assertEquals(len(colors), 8) + + def testUint8(self): + lut = numpy.array([[255, 0, 0], [200, 0, 0], [150, 0, 0]], dtype="uint") + colors.registerLUT("test_type", lut) + colormap = colors.Colormap(name="test_type") + lut = colormap.getNColors(3) + self.assertEqual(lut.shape, (3, 4)) + self.assertEqual(lut[0, 0], 255) + + def testFloatRGB(self): + lut = numpy.array([[1.0, 0, 0], [0.5, 0, 0], [0, 0, 0]], dtype="float") + colors.registerLUT("test_type", lut) + colormap = colors.Colormap(name="test_type") + lut = colormap.getNColors(3) + self.assertEqual(lut.shape, (3, 4)) + self.assertEqual(lut[0, 0], 255) + + def testFloatRGBA(self): + lut = numpy.array([[1.0, 0, 0, 128 / 256.0], [0.5, 0, 0, 1.0], [0.0, 0, 0, 1.0]], dtype="float") + colors.registerLUT("test_type", lut) + colormap = colors.Colormap(name="test_type") + lut = colormap.getNColors(3) + self.assertEqual(lut.shape, (3, 4)) + self.assertEqual(lut[0, 0], 255) + self.assertEqual(lut[0, 3], 128) def suite(): @@ -410,6 +481,7 @@ def suite(): test_suite.addTest(loadTests(TestDictAPI)) test_suite.addTest(loadTests(TestObjectAPI)) test_suite.addTest(loadTests(TestPreferredColormaps)) + test_suite.addTest(loadTests(TestRegisteredLut)) return test_suite diff --git a/silx/gui/utils/concurrent.py b/silx/gui/utils/concurrent.py index 48fff91..c27374f 100644 --- a/silx/gui/utils/concurrent.py +++ b/silx/gui/utils/concurrent.py @@ -25,12 +25,14 @@ """This module allows to run a function in Qt main thread from another thread """ +from __future__ import absolute_import + __authors__ = ["T. Vincent"] __license__ = "MIT" __date__ = "09/03/2018" -from silx.third_party.concurrent_futures import Future +from concurrent.futures import Future from .. import qt diff --git a/silx/gui/utils/projecturl.py b/silx/gui/utils/projecturl.py new file mode 100644 index 0000000..0832c2e --- /dev/null +++ b/silx/gui/utils/projecturl.py @@ -0,0 +1,77 @@ +# coding: utf-8 +# +# Project: Azimuthal integration +# https://github.com/silx-kit/silx +# +# Copyright (C) 2015-2019 European Synchrotron Radiation Facility, Grenoble, France +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from __future__ import absolute_import, print_function, division + +"""Provide convenient URL for silx-kit projects.""" + +__author__ = "Valentin Valls" +__contact__ = "valentin.valls@ESRF.eu" +__license__ = "MIT" +__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France" +__date__ = "15/01/2019" + + +from ... import _version as version + +BASE_DOC_URL = None +"""This could be patched by project packagers.""" + +_DEFAULT_BASE_DOC_URL = "http://www.silx.org/pub/doc/silx/{silx_doc_version}/{subpath}" +"""Identify the base URL of the project documentation. + +It supportes string replacement: + +- `{major}` the major version +- `{minor}` the minor version +- `{micro}` the micro version +- `{relev}` the status of the version (dev, final, rc). +- `{silx_doc_version}` is used to map the documentation stored at www.silx.org +- `{subpath}` is the subpart of the URL pointing to a specific page of the + documentation. It is mandatory. +""" + + +def getDocumentationUrl(subpath): + """Returns the URL to the documentation""" + + if version.RELEV == "final": + # Released verison will point to a specific documentation + silx_doc_version = "%d.%d.%d" % (version.MAJOR, version.MINOR, version.MICRO) + else: + # Dev versions will point to a single 'dev' documentation + silx_doc_version = "dev" + + keyworks = { + "silx_doc_version": silx_doc_version, + "major": version.MAJOR, + "minor": version.MINOR, + "micro": version.MICRO, + "relev": version.RELEV, + "subpath": subpath} + template = BASE_DOC_URL + if template is None: + template = _DEFAULT_BASE_DOC_URL + return template.format(**keyworks) diff --git a/silx/gui/utils/test/test_async.py b/silx/gui/utils/test/test_async.py index dabfb3c..dcfde1d 100644 --- a/silx/gui/utils/test/test_async.py +++ b/silx/gui/utils/test/test_async.py @@ -24,6 +24,8 @@ # ###########################################################################*/ """Test of async module.""" +from __future__ import absolute_import + __authors__ = ["T. Vincent"] __license__ = "MIT" __date__ = "09/03/2018" @@ -33,7 +35,7 @@ import threading import unittest -from silx.third_party.concurrent_futures import wait +from concurrent.futures import wait from silx.gui import qt from silx.gui.utils.testutils import TestCaseQt diff --git a/silx/gui/utils/testutils.py b/silx/gui/utils/testutils.py index 35085fc..6c54357 100644 --- a/silx/gui/utils/testutils.py +++ b/silx/gui/utils/testutils.py @@ -40,6 +40,8 @@ import os _logger = logging.getLogger(__name__) from silx.gui import qt +from silx.gui.qt import inspect as _inspect + if qt.BINDING == 'PySide': from PySide.QtTest import QTest @@ -139,11 +141,6 @@ class TestCaseQt(unittest.TestCase): # Makes sure a QApplication exists and do it once for all _qapp = qt.QApplication.instance() or qt.QApplication([]) - # Makes sure QDesktopWidget is init - # Otherwise it happens randomly during the tests - cls._desktopWidget = _qapp.desktop() - _qapp.processEvents() - @classmethod def tearDownClass(cls): sys.excepthook = cls._oldExceptionHook @@ -173,7 +170,8 @@ class TestCaseQt(unittest.TestCase): gc.collect() widgets = [widget for widget in self.qapp.allWidgets() - if widget not in self.__previousWidgets] + if (widget not in self.__previousWidgets and + _inspect.createdByPython(widget))] del self.__previousWidgets if qt.BINDING in ('PySide', 'PySide2'): diff --git a/silx/gui/widgets/PrintPreview.py b/silx/gui/widgets/PrintPreview.py index 94a8ed4..96af34b 100644 --- a/silx/gui/widgets/PrintPreview.py +++ b/silx/gui/widgets/PrintPreview.py @@ -285,7 +285,7 @@ class PrintPreviewDialog(qt.QDialog): def addSvgItem(self, item, title=None, comment=None, commentPosition=None, - viewBox=None): + viewBox=None, keepRatio=True): """Add a SVG item to the scene. :param QSvgRenderer item: SVG item to be added to the scene. @@ -295,6 +295,8 @@ class PrintPreviewDialog(qt.QDialog): :param QRectF viewBox: Bounding box for the item on the print page (xOffset, yOffset, width, height). If None, use original item size. + :param bool keepRatio: If True, resizing the item will preserve its + original aspect ratio. """ if not qt.HAS_SVG: raise RuntimeError("Missing QtSvg library.") @@ -331,35 +333,23 @@ class PrintPreviewDialog(qt.QDialog): svgItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True) svgItem.setFlag(qt.QGraphicsItem.ItemIsFocusable, False) - rectItemResizeRect = _GraphicsResizeRectItem(svgItem, self.scene) + rectItemResizeRect = _GraphicsResizeRectItem(svgItem, self.scene, + keepratio=keepRatio) rectItemResizeRect.setZValue(2) self._svgItems.append(item) - if qt.qVersion() < '5.0': - textItem = qt.QGraphicsTextItem(title, svgItem, self.scene) - else: - textItem = qt.QGraphicsTextItem(title, svgItem) - textItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction) - title_offset = 0.5 * textItem.boundingRect().width() - textItem.setZValue(1) - textItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True) - + # Comment / legend dummyComment = 80 * "1" if qt.qVersion() < '5.0': commentItem = qt.QGraphicsTextItem(dummyComment, svgItem, self.scene) else: commentItem = qt.QGraphicsTextItem(dummyComment, svgItem) commentItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction) + # we scale the text to have the legend box have the same width as the graph scaleCalculationRect = qt.QRectF(commentItem.boundingRect()) scale = svgItem.boundingRect().width() / scaleCalculationRect.width() - comment_offset = 0.5 * commentItem.boundingRect().width() - if commentPosition.upper() == "LEFT": - x = 1 - else: - x = 0.5 * svgItem.boundingRect().width() - comment_offset * scale # fixme: centering - commentItem.moveBy(svgItem.boundingRect().x() + x, - svgItem.boundingRect().y() + svgItem.boundingRect().height()) + commentItem.setPlainText(comment) commentItem.setZValue(1) @@ -367,17 +357,46 @@ class PrintPreviewDialog(qt.QDialog): if qt.qVersion() < "5.0": commentItem.scale(scale, scale) else: - # the correct equivalent would be: - # rectItem.setTransform(qt.QTransform.fromScale(scalex, scaley)) commentItem.setScale(scale) + + # align + if commentPosition.upper() == "CENTER": + alignment = qt.Qt.AlignCenter + elif commentPosition.upper() == "RIGHT": + alignment = qt.Qt.AlignRight + else: + alignment = qt.Qt.AlignLeft + commentItem.setTextWidth(commentItem.boundingRect().width()) + center_format = qt.QTextBlockFormat() + center_format.setAlignment(alignment) + cursor = commentItem.textCursor() + cursor.select(qt.QTextCursor.Document) + cursor.mergeBlockFormat(center_format) + cursor.clearSelection() + commentItem.setTextCursor(cursor) + if alignment == qt.Qt.AlignLeft: + deltax = 0 + else: + deltax = (svgItem.boundingRect().width() - commentItem.boundingRect().width()) / 2. + commentItem.moveBy(svgItem.boundingRect().x() + deltax, + svgItem.boundingRect().y() + svgItem.boundingRect().height()) + + # Title + if qt.qVersion() < '5.0': + textItem = qt.QGraphicsTextItem(title, svgItem, self.scene) + else: + textItem = qt.QGraphicsTextItem(title, svgItem) + textItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction) + textItem.setZValue(1) + textItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True) + + title_offset = 0.5 * textItem.boundingRect().width() textItem.moveBy(svgItem.boundingRect().x() + 0.5 * svgItem.boundingRect().width() - title_offset * scale, svgItem.boundingRect().y()) if qt.qVersion() < "5.0": textItem.scale(scale, scale) else: - # the correct equivalent would be: - # rectItem.setTransform(qt.QTransform.fromScale(scalex, scaley)) textItem.setScale(scale) def setup(self): @@ -601,7 +620,8 @@ class _GraphicsResizeRectItem(qt.QGraphicsRectItem): # following line prevents dragging along the previously selected # item when resizing another one scene.clearSelection() - rect = parent.rect() + + rect = parent.boundingRect() self._x = rect.x() self._y = rect.y() self._w = rect.width() @@ -655,12 +675,14 @@ class _GraphicsResizeRectItem(qt.QGraphicsRectItem): else: scalex = self._newRect.rect().width() / self._w scaley = self._newRect.rect().height() / self._h + if qt.qVersion() < "5.0": parent.scale(scalex, scaley) else: - # the correct equivalent would be: - # rectItem.setTransform(qt.QTransform.fromScale(scalex, scaley)) - parent.setScale(scalex) + # apply the scale to the previous transformation matrix + previousTransform = parent.transform() + parent.setTransform( + previousTransform.scale(scalex, scaley)) self.scene().removeItem(self._newRect) self._newRect = None diff --git a/silx/gui/widgets/RangeSlider.py b/silx/gui/widgets/RangeSlider.py index 0b72e71..0cf195c 100644 --- a/silx/gui/widgets/RangeSlider.py +++ b/silx/gui/widgets/RangeSlider.py @@ -31,7 +31,7 @@ from __future__ import absolute_import, division __authors__ = ["D. Naudet", "T. Vincent"] __license__ = "MIT" -__date__ = "02/08/2018" +__date__ = "26/11/2018" import numpy as numpy @@ -40,6 +40,17 @@ from silx.gui import qt, icons, colors from silx.gui.utils.image import convertArrayToQImage +class StyleOptionRangeSlider(qt.QStyleOption): + def __init__(self): + super(StyleOptionRangeSlider, self).__init__() + self.minimum = None + self.maximum = None + self.sliderPosition1 = None + self.sliderPosition2 = None + self.handlerRect1 = None + self.handlerRect2 = None + + class RangeSlider(qt.QWidget): """Range slider with 2 thumbs and an optional colored groove. @@ -86,6 +97,8 @@ class RangeSlider(qt.QWidget): self.__secondValue = 1. self.__minValue = 0. self.__maxValue = 1. + self.__hoverRect = qt.QRect() + self.__hoverControl = None self.__focus = None self.__moving = None @@ -100,6 +113,7 @@ class RangeSlider(qt.QWidget): super(RangeSlider, self).__init__(parent) self.setFocusPolicy(qt.Qt.ClickFocus) + self.setAttribute(qt.Qt.WA_Hover) self.setMinimumSize(qt.QSize(50, 20)) self.setMaximumHeight(20) @@ -107,6 +121,34 @@ class RangeSlider(qt.QWidget): # Broadcast value changed signal self.sigValueChanged.connect(self.__emitPositionChanged) + def event(self, event): + t = event.type() + if t == qt.QEvent.HoverEnter or t == qt.QEvent.HoverLeave or t == qt.QEvent.HoverMove: + return self.__updateHoverControl(event.pos()) + else: + return super(RangeSlider, self).event(event) + + def __updateHoverControl(self, pos): + hoverControl, hoverRect = self.__findHoverControl(pos) + if hoverControl != self.__hoverControl: + self.update(self.__hoverRect) + self.update(hoverRect) + self.__hoverControl = hoverControl + self.__hoverRect = hoverRect + return True + return hoverControl is not None + + def __findHoverControl(self, pos): + """Returns the control at the position and it's rect location""" + for name in ["first", "second"]: + rect = self.__sliderRect(name) + if rect.contains(pos): + return name, rect + rect = self.__drawArea() + if rect.contains(pos): + return "groove", rect + return None, qt.QRect() + # Position <-> Value conversion def __positionToValue(self, position): @@ -469,10 +511,12 @@ class RangeSlider(qt.QWidget): super(RangeSlider, self).mouseMoveEvent(event) if self.__moving is not None: - position = self.__xPixelToPosition(event.pos().x()) + delta = self._SLIDER_WIDTH // 2 if self.__moving == 'first': + position = self.__xPixelToPosition(event.pos().x() + delta) self.setFirstPosition(position) else: + position = self.__xPixelToPosition(event.pos().x() - delta) self.setSecondPosition(position) def mouseReleaseEvent(self, event): @@ -564,13 +608,13 @@ class RangeSlider(qt.QWidget): def __sliderAreaRect(self): return self.__drawArea().adjusted(self._SLIDER_WIDTH / 2., 0, - -self._SLIDER_WIDTH / 2., + -self._SLIDER_WIDTH / 2. + 1, 0) def __pixMapRect(self): return self.__sliderAreaRect().adjusted(0, self._PIXMAP_VOFFSET, - 0, + -1, -self._PIXMAP_VOFFSET) def paintEvent(self, event): @@ -579,33 +623,55 @@ class RangeSlider(qt.QWidget): style = qt.QApplication.style() area = self.__drawArea() - pixmapRect = self.__pixMapRect() - - option = qt.QStyleOptionProgressBar() - option.initFrom(self) - option.rect = area - option.state = ((self.isEnabled() and qt.QStyle.State_Enabled) - or qt.QStyle.State_None) - style.drawControl(qt.QStyle.CE_ProgressBarGroove, - option, - painter, - self) + if self.__pixmap is not None: + pixmapRect = self.__pixMapRect() - painter.save() - pen = painter.pen() - pen.setWidth(1) - pen.setColor(qt.Qt.black if self.isEnabled() else qt.Qt.gray) - painter.setPen(pen) - painter.drawRect(pixmapRect.adjusted(-1, -1, 1, 1)) - painter.restore() + option = qt.QStyleOptionProgressBar() + option.initFrom(self) + option.rect = area + option.state = (qt.QStyle.State_Enabled if self.isEnabled() + else qt.QStyle.State_None) + style.drawControl(qt.QStyle.CE_ProgressBarGroove, + option, + painter, + self) + + painter.save() + pen = painter.pen() + pen.setWidth(1) + pen.setColor(qt.Qt.black if self.isEnabled() else qt.Qt.gray) + painter.setPen(pen) + painter.drawRect(pixmapRect.adjusted(-1, -1, 0, 1)) + painter.restore() + + if self.isEnabled(): + rect = area.adjusted(self._SLIDER_WIDTH // 2, + self._PIXMAP_VOFFSET, + -self._SLIDER_WIDTH // 2, + -self._PIXMAP_VOFFSET + 1) + painter.drawPixmap(rect, + self.__pixmap, + self.__pixmap.rect()) + else: + option = StyleOptionRangeSlider() + option.initFrom(self) + option.rect = area + option.sliderPosition1 = self.__firstValue + option.sliderPosition2 = self.__secondValue + option.handlerRect1 = self.__sliderRect("first") + option.handlerRect2 = self.__sliderRect("second") + option.minimum = self.__minValue + option.maximum = self.__maxValue + option.state = (qt.QStyle.State_Enabled if self.isEnabled() + else qt.QStyle.State_None) + if self.__hoverControl == "groove": + option.state |= qt.QStyle.State_MouseOver + elif option.state & qt.QStyle.State_MouseOver: + option.state ^= qt.QStyle.State_MouseOver + self.drawRangeSliderBackground(painter, option, self) - if self.isEnabled() and self.__pixmap is not None: - painter.drawPixmap(area.adjusted(self._SLIDER_WIDTH / 2, - self._PIXMAP_VOFFSET, - -self._SLIDER_WIDTH / 2 + 1, - -self._PIXMAP_VOFFSET + 1), - self.__pixmap, - self.__pixmap.rect()) + # Avoid glitch when moving handles + hoverControl = self.__moving or self.__hoverControl for name in ('first', 'second'): rect = self.__sliderRect(name) @@ -613,7 +679,9 @@ class RangeSlider(qt.QWidget): option.initFrom(self) option.icon = self.__icons[name] option.iconSize = rect.size() * 0.7 - if option.state & qt.QStyle.State_MouseOver: + if hoverControl == name: + option.state |= qt.QStyle.State_MouseOver + elif option.state & qt.QStyle.State_MouseOver: option.state ^= qt.QStyle.State_MouseOver if self.__focus == name: option.state |= qt.QStyle.State_HasFocus @@ -625,3 +693,73 @@ class RangeSlider(qt.QWidget): def sizeHint(self): return qt.QSize(200, self.minimumHeight()) + + @classmethod + def drawRangeSliderBackground(cls, painter, option, widget): + """Draw the background of the RangeSlider widget into the painter. + + :param qt.QPainter painter: A painter + :param StyleOptionRangeSlider option: Options to draw the widget + :param qt.QWidget: The widget which have to be drawn + """ + painter.save() + painter.translate(0.5, 0.5) + + backgroundRect = qt.QRect(option.rect) + if backgroundRect.height() > 8: + center = backgroundRect.center() + backgroundRect.setHeight(8) + backgroundRect.moveCenter(center) + + selectedRangeRect = qt.QRect(backgroundRect) + selectedRangeRect.setLeft(option.handlerRect1.center().x()) + selectedRangeRect.setRight(option.handlerRect2.center().x()) + + highlight = option.palette.color(qt.QPalette.Highlight) + activeHighlight = highlight + selectedOutline = option.palette.color(qt.QPalette.Highlight) + + buttonColor = option.palette.button().color() + val = qt.qGray(buttonColor.rgb()) + buttonColor = buttonColor.lighter(100 + max(1, (180 - val) // 6)) + buttonColor.setHsv(buttonColor.hue(), buttonColor.saturation() * 0.75, buttonColor.value()) + + grooveColor = qt.QColor() + grooveColor.setHsv(buttonColor.hue(), + min(255, (int)(buttonColor.saturation())), + min(255, (int)(buttonColor.value() * 0.9))) + + selectedInnerContrastLine = qt.QColor(255, 255, 255, 30) + + outline = option.palette.color(qt.QPalette.Background).darker(140) + if (option.state & qt.QStyle.State_HasFocus and option.state & qt.QStyle.State_KeyboardFocusChange): + outline = highlight.darker(125) + if outline.value() > 160: + outline.setHsl(highlight.hue(), highlight.saturation(), 160) + + # Draw background groove + painter.setRenderHint(qt.QPainter.Antialiasing, True) + gradient = qt.QLinearGradient() + gradient.setStart(backgroundRect.center().x(), backgroundRect.top()) + gradient.setFinalStop(backgroundRect.center().x(), backgroundRect.bottom()) + painter.setPen(qt.QPen(outline)) + gradient.setColorAt(0, grooveColor.darker(110)) + gradient.setColorAt(1, grooveColor.lighter(110)) + painter.setBrush(gradient) + painter.drawRoundedRect(backgroundRect.adjusted(1, 1, -2, -2), 1, 1) + + # Draw slider background for the value + gradient = qt.QLinearGradient() + gradient.setStart(selectedRangeRect.center().x(), selectedRangeRect.top()) + gradient.setFinalStop(selectedRangeRect.center().x(), selectedRangeRect.bottom()) + painter.setRenderHint(qt.QPainter.Antialiasing, True) + painter.setPen(qt.QPen(selectedOutline)) + gradient.setColorAt(0, activeHighlight) + gradient.setColorAt(1, activeHighlight.lighter(130)) + painter.setBrush(gradient) + painter.drawRoundedRect(selectedRangeRect.adjusted(1, 1, -2, -2), 1, 1) + painter.setPen(selectedInnerContrastLine) + painter.setBrush(qt.Qt.NoBrush) + painter.drawRoundedRect(selectedRangeRect.adjusted(2, 2, -3, -3), 1, 1) + + painter.restore() diff --git a/silx/gui/widgets/UrlSelectionTable.py b/silx/gui/widgets/UrlSelectionTable.py new file mode 100644 index 0000000..4ac0381 --- /dev/null +++ b/silx/gui/widgets/UrlSelectionTable.py @@ -0,0 +1,164 @@ +# /*########################################################################## +# Copyright (C) 2017 European Synchrotron Radiation Facility +# +# This file is part of the PyMca X-ray Fluorescence Toolkit developed at +# the ESRF by the Software group. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +#############################################################################*/ +"""Some widget construction to check if a sample moved""" + +__author__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "19/03/2018" + +from silx.gui import qt +from collections import OrderedDict +from silx.gui.widgets.TableWidget import TableWidget +from silx.io.url import DataUrl +import functools +import logging +import os + +logger = logging.getLogger(__file__) + + +class UrlSelectionTable(TableWidget): + """Table used to select the color channel to be displayed for each""" + + COLUMS_INDEX = OrderedDict([ + ('url', 0), + ('img A', 1), + ('img B', 2), + ]) + + sigImageAChanged = qt.Signal(str) + """Signal emitted when the image A change. Param is the image url path""" + + sigImageBChanged = qt.Signal(str) + """Signal emitted when the image B change. Param is the image url path""" + + def __init__(self, parent=None): + TableWidget.__init__(self, parent) + self.clear() + + def clear(self): + qt.QTableWidget.clear(self) + self.setRowCount(0) + self.setColumnCount(len(self.COLUMS_INDEX)) + self.setHorizontalHeaderLabels(list(self.COLUMS_INDEX.keys())) + self.verticalHeader().hide() + if hasattr(self.horizontalHeader(), 'setSectionResizeMode'): # Qt5 + self.horizontalHeader().setSectionResizeMode(0, + qt.QHeaderView.Stretch) + else: # Qt4 + self.horizontalHeader().setResizeMode(0, qt.QHeaderView.Stretch) + + self.setSortingEnabled(True) + self._checkBoxes = {} + + def addUrl(self, url, **kwargs): + """ + + :param url: + :param args: + :return: index of the created items row + :rtype int + """ + assert isinstance(url, DataUrl) + row = self.rowCount() + self.setRowCount(row + 1) + + _item = qt.QTableWidgetItem() + _item.setText(os.path.basename(url.path())) + _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable) + self.setItem(row, self.COLUMS_INDEX['url'], _item) + + widgetImgA = qt.QRadioButton(parent=self) + widgetImgA.setAutoExclusive(False) + self.setCellWidget(row, self.COLUMS_INDEX['img A'], widgetImgA) + callbackImgA = functools.partial(self._activeImgAChanged, url.path()) + widgetImgA.toggled.connect(callbackImgA) + + widgetImgB = qt.QRadioButton(parent=self) + widgetImgA.setAutoExclusive(False) + self.setCellWidget(row, self.COLUMS_INDEX['img B'], widgetImgB) + callbackImgB = functools.partial(self._activeImgBChanged, url.path()) + widgetImgB.toggled.connect(callbackImgB) + + self._checkBoxes[url.path()] = {'img A': widgetImgA, + 'img B': widgetImgB} + self.resizeColumnsToContents() + return row + + def _activeImgAChanged(self, name): + self._updatecheckBoxes('img A', name) + self.sigImageAChanged.emit(name) + + def _activeImgBChanged(self, name): + self._updatecheckBoxes('img B', name) + self.sigImageBChanged.emit(name) + + def _updatecheckBoxes(self, whichImg, name): + assert name in self._checkBoxes + assert whichImg in self._checkBoxes[name] + if self._checkBoxes[name][whichImg].isChecked(): + for radioUrl in self._checkBoxes: + if radioUrl != name: + self._checkBoxes[radioUrl][whichImg].blockSignals(True) + self._checkBoxes[radioUrl][whichImg].setChecked(False) + self._checkBoxes[radioUrl][whichImg].blockSignals(False) + + def getSelection(self): + """ + + :return: url selected for img A and img B. + """ + imgA = imgB = None + for radioUrl in self._checkBoxes: + if self._checkBoxes[radioUrl]['img A'].isChecked(): + imgA = radioUrl + if self._checkBoxes[radioUrl]['img B'].isChecked(): + imgB = radioUrl + return imgA, imgB + + def setSelection(self, url_img_a, url_img_b): + """ + + :param ddict: key: image url, values: list of active channels + """ + for radioUrl in self._checkBoxes: + for img in ('img A', 'img B'): + self._checkBoxes[radioUrl][img].blockSignals(True) + self._checkBoxes[radioUrl][img].setChecked(False) + self._checkBoxes[radioUrl][img].blockSignals(False) + + self._checkBoxes[radioUrl][img].blockSignals(True) + self._checkBoxes[url_img_a]['img A'].setChecked(True) + self._checkBoxes[radioUrl][img].blockSignals(False) + + self._checkBoxes[radioUrl][img].blockSignals(True) + self._checkBoxes[url_img_b]['img B'].setChecked(True) + self._checkBoxes[radioUrl][img].blockSignals(False) + self.sigImageAChanged.emit(url_img_a) + self.sigImageBChanged.emit(url_img_b) + + def removeUrl(self, url): + raise NotImplementedError("") -- cgit v1.2.3