summaryrefslogtreecommitdiff
path: root/silx/gui
diff options
context:
space:
mode:
authorPicca Frédéric-Emmanuel <picca@synchrotron-soleil.fr>2019-05-28 08:16:16 +0200
committerPicca Frédéric-Emmanuel <picca@synchrotron-soleil.fr>2019-05-28 08:16:16 +0200
commita763e5d1b3921b3194f3d4e94ab9de3fbe08bbdd (patch)
tree45d462ed36a5522e9f3b9fde6c4ec4918c2ae8e3 /silx/gui
parentcebdc9244c019224846cb8d2668080fe386a6adc (diff)
New upstream version 0.10.1+dfsg
Diffstat (limited to 'silx/gui')
-rw-r--r--silx/gui/colors.py454
-rw-r--r--silx/gui/data/DataViewer.py92
-rw-r--r--silx/gui/data/DataViewerFrame.py5
-rw-r--r--silx/gui/data/DataViewerSelector.py6
-rw-r--r--silx/gui/data/DataViews.py499
-rw-r--r--silx/gui/data/Hdf5TableView.py13
-rw-r--r--silx/gui/data/HexaTableView.py8
-rw-r--r--silx/gui/data/NXdataWidgets.py439
-rw-r--r--silx/gui/data/TextFormatter.py50
-rw-r--r--silx/gui/data/test/test_arraywidget.py6
-rw-r--r--silx/gui/data/test/test_dataviewer.py14
-rw-r--r--silx/gui/data/test/test_numpyaxesselector.py7
-rw-r--r--silx/gui/data/test/test_textformatter.py12
-rw-r--r--silx/gui/dialog/AbstractDataFileDialog.py56
-rw-r--r--silx/gui/dialog/ColormapDialog.py355
-rw-r--r--silx/gui/dialog/DataFileDialog.py10
-rw-r--r--silx/gui/dialog/FileTypeComboBox.py39
-rw-r--r--silx/gui/dialog/ImageFileDialog.py5
-rw-r--r--silx/gui/dialog/SafeFileSystemModel.py6
-rw-r--r--silx/gui/dialog/test/test_colormapdialog.py36
-rw-r--r--silx/gui/dialog/test/test_datafiledialog.py115
-rw-r--r--silx/gui/dialog/test/test_imagefiledialog.py117
-rw-r--r--silx/gui/dialog/utils.py6
-rw-r--r--silx/gui/hdf5/Hdf5Formatter.py17
-rw-r--r--silx/gui/hdf5/Hdf5Item.py38
-rw-r--r--silx/gui/hdf5/Hdf5TreeModel.py41
-rw-r--r--silx/gui/hdf5/NexusSortFilterProxyModel.py4
-rw-r--r--silx/gui/hdf5/_utils.py74
-rw-r--r--silx/gui/hdf5/test/test_hdf5.py59
-rw-r--r--silx/gui/icons.py8
-rw-r--r--silx/gui/plot/ColorBar.py19
-rw-r--r--silx/gui/plot/CompareImages.py2
-rw-r--r--silx/gui/plot/ComplexImageView.py11
-rw-r--r--silx/gui/plot/CurvesROIWidget.py1834
-rw-r--r--silx/gui/plot/MaskToolsWidget.py108
-rw-r--r--silx/gui/plot/PlotInteraction.py148
-rw-r--r--silx/gui/plot/PlotToolButtons.py29
-rw-r--r--silx/gui/plot/PlotWidget.py492
-rw-r--r--silx/gui/plot/PlotWindow.py87
-rw-r--r--silx/gui/plot/PrintPreviewToolButton.py61
-rw-r--r--silx/gui/plot/Profile.py109
-rw-r--r--silx/gui/plot/ScatterMaskToolsWidget.py65
-rw-r--r--silx/gui/plot/ScatterView.py12
-rw-r--r--silx/gui/plot/StackView.py10
-rw-r--r--silx/gui/plot/StatsWidget.py1594
-rw-r--r--silx/gui/plot/_BaseMaskToolsWidget.py157
-rw-r--r--silx/gui/plot/_utils/dtime_ticklayout.py4
-rw-r--r--silx/gui/plot/actions/control.py11
-rw-r--r--silx/gui/plot/actions/io.py38
-rw-r--r--silx/gui/plot/backends/BackendBase.py37
-rw-r--r--silx/gui/plot/backends/BackendMatplotlib.py223
-rw-r--r--silx/gui/plot/backends/BackendOpenGL.py364
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotCurve.py119
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotFrame.py124
-rw-r--r--silx/gui/plot/backends/glutils/GLSupport.py63
-rw-r--r--silx/gui/plot/items/__init__.py2
-rw-r--r--silx/gui/plot/items/axis.py6
-rw-r--r--silx/gui/plot/items/complex.py8
-rw-r--r--silx/gui/plot/items/core.py37
-rw-r--r--silx/gui/plot/items/curve.py5
-rw-r--r--silx/gui/plot/items/histogram.py6
-rw-r--r--silx/gui/plot/items/roi.py72
-rw-r--r--silx/gui/plot/items/scatter.py5
-rw-r--r--silx/gui/plot/items/shape.py45
-rw-r--r--silx/gui/plot/matplotlib/Colormap.py16
-rw-r--r--silx/gui/plot/stats/stats.py400
-rw-r--r--silx/gui/plot/stats/statshandler.py124
-rw-r--r--silx/gui/plot/test/testCurvesROIWidget.py219
-rw-r--r--silx/gui/plot/test/testMaskToolsWidget.py7
-rw-r--r--silx/gui/plot/test/testPlotWidget.py61
-rw-r--r--silx/gui/plot/test/testSaveAction.py20
-rw-r--r--silx/gui/plot/test/testScatterMaskToolsWidget.py5
-rw-r--r--silx/gui/plot/test/testStats.py284
-rw-r--r--silx/gui/plot/test/testUtilsAxis.py49
-rw-r--r--silx/gui/plot/tools/roi.py20
-rw-r--r--silx/gui/plot/tools/test/testScatterProfileToolBar.py2
-rw-r--r--silx/gui/plot/utils/axis.py288
-rw-r--r--silx/gui/plot3d/ParamTreeView.py2
-rw-r--r--silx/gui/plot3d/ScalarFieldView.py21
-rw-r--r--silx/gui/plot3d/SceneWidget.py30
-rw-r--r--silx/gui/plot3d/_model/items.py67
-rw-r--r--silx/gui/plot3d/items/__init__.py4
-rw-r--r--silx/gui/plot3d/items/core.py4
-rw-r--r--silx/gui/plot3d/items/mesh.py281
-rw-r--r--silx/gui/plot3d/items/mixins.py21
-rw-r--r--silx/gui/plot3d/items/scatter.py39
-rw-r--r--silx/gui/plot3d/items/volume.py12
-rw-r--r--silx/gui/plot3d/scene/primitives.py8
-rw-r--r--silx/gui/plot3d/test/__init__.py4
-rw-r--r--silx/gui/plot3d/test/testSceneWidgetPicking.py53
-rw-r--r--silx/gui/plot3d/test/testStatsWidget.py213
-rw-r--r--silx/gui/qt/_pyside_dynamic.py54
-rw-r--r--silx/gui/test/test_colors.py110
-rw-r--r--silx/gui/utils/concurrent.py4
-rw-r--r--silx/gui/utils/projecturl.py77
-rw-r--r--silx/gui/utils/test/test_async.py4
-rw-r--r--silx/gui/utils/testutils.py10
-rw-r--r--silx/gui/widgets/PrintPreview.py74
-rw-r--r--silx/gui/widgets/RangeSlider.py198
-rw-r--r--silx/gui/widgets/UrlSelectionTable.py164
100 files changed, 7828 insertions, 3619 deletions
diff --git a/silx/gui/colors.py b/silx/gui/colors.py
index a51bcdc..f1f34c9 100644
--- a/silx/gui/colors.py
+++ b/silx/gui/colors.py
@@ -29,18 +29,28 @@ from __future__ import absolute_import
__authors__ = ["T. Vincent", "H.Payno"]
__license__ = "MIT"
-__date__ = "05/10/2018"
+__date__ = "29/01/2019"
-from silx.gui import qt
-import copy as copy_mdl
import numpy
import logging
+import collections
+from silx.gui import qt
+from silx import config
from silx.math.combo import min_max
from silx.math.colormap import cmap as _cmap
from silx.utils.exceptions import NotEditableError
+from silx.utils import deprecation
+from silx.resources import resource_filename as _resource_filename
+
_logger = logging.getLogger(__file__)
+try:
+ from matplotlib import cm as _matplotlib_cm
+except ImportError:
+ _logger.info("matplotlib not available, only embedded colormaps available")
+ _matplotlib_cm = None
+
_COLORDICT = {}
"""Dictionary of common colors."""
@@ -67,12 +77,44 @@ _COLORDICT['darkBrown'] = '#660000'
_COLORDICT['darkCyan'] = '#008080'
_COLORDICT['darkYellow'] = '#808000'
_COLORDICT['darkMagenta'] = '#800080'
+_COLORDICT['transparent'] = '#00000000'
# FIXME: It could be nice to expose a functional API instead of that attribute
COLORDICT = _COLORDICT
+_LUT_DESCRIPTION = collections.namedtuple("_LUT_DESCRIPTION", ["source", "cursor_color", "preferred"])
+"""Description of a LUT for internal purpose."""
+
+
+_AVAILABLE_LUTS = collections.OrderedDict([
+ ('gray', _LUT_DESCRIPTION('builtin', 'pink', True)),
+ ('reversed gray', _LUT_DESCRIPTION('builtin', 'pink', True)),
+ ('temperature', _LUT_DESCRIPTION('builtin', 'pink', True)),
+ ('red', _LUT_DESCRIPTION('builtin', 'green', True)),
+ ('green', _LUT_DESCRIPTION('builtin', 'pink', True)),
+ ('blue', _LUT_DESCRIPTION('builtin', 'yellow', True)),
+ ('jet', _LUT_DESCRIPTION('matplotlib', 'pink', True)),
+ ('viridis', _LUT_DESCRIPTION('resource', 'pink', True)),
+ ('magma', _LUT_DESCRIPTION('resource', 'green', True)),
+ ('inferno', _LUT_DESCRIPTION('resource', 'green', True)),
+ ('plasma', _LUT_DESCRIPTION('resource', 'green', True)),
+ ('hsv', _LUT_DESCRIPTION('matplotlib', 'black', True)),
+])
+"""Description for internal porpose of all the default LUT provided by the library."""
+
+
+DEFAULT_MIN_LIN = 0
+"""Default min value if in linear normalization"""
+DEFAULT_MAX_LIN = 1
+"""Default max value if in linear normalization"""
+DEFAULT_MIN_LOG = 1
+"""Default min value if in log normalization"""
+DEFAULT_MAX_LOG = 10
+"""Default max value if in log normalization"""
+
+
def rgba(color, colorDict=None):
"""Convert color code '#RRGGBB' and '#RRGGBBAA' to (R, G, B, A)
@@ -121,19 +163,21 @@ def rgba(color, colorDict=None):
return r, g, b, a
-_COLORMAP_CURSOR_COLORS = {
- 'gray': 'pink',
- 'reversed gray': 'pink',
- 'temperature': 'pink',
- 'red': 'green',
- 'green': 'pink',
- 'blue': 'yellow',
- 'jet': 'pink',
- 'viridis': 'pink',
- 'magma': 'green',
- 'inferno': 'green',
- 'plasma': 'green',
-}
+def greyed(color, colorDict=None):
+ """Convert color code '#RRGGBB' and '#RRGGBBAA' to a grey color
+ (R, G, B, A).
+
+ It also convert RGB(A) values from uint8 to float in [0, 1] and
+ accept a QColor as color argument.
+
+ :param str color: The color to convert
+ :param dict colorDict: A dictionary of color name conversion to color code
+ :returns: RGBA colors as floats in [0., 1.]
+ :rtype: tuple
+ """
+ r, g, b, a = rgba(color=color, colorDict=colorDict)
+ g = 0.21 * r + 0.72 * g + 0.07 * b
+ return g, g, g, a
def cursorColorForColormap(colormapName):
@@ -143,26 +187,140 @@ def cursorColorForColormap(colormapName):
:return: Name of the color.
:rtype: str
"""
- return _COLORMAP_CURSOR_COLORS.get(colormapName, 'black')
+ description = _AVAILABLE_LUTS.get(colormapName, None)
+ if description is not None:
+ color = description.cursor_color
+ if color is not None:
+ return color
+ return 'black'
-DEFAULT_COLORMAPS = (
- 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
-"""Tuple of supported colormap names."""
+# Colormap loader
-DEFAULT_MIN_LIN = 0
-"""Default min value if in linear normalization"""
-DEFAULT_MAX_LIN = 1
-"""Default max value if in linear normalization"""
-DEFAULT_MIN_LOG = 1
-"""Default min value if in log normalization"""
-DEFAULT_MAX_LOG = 10
-"""Default max value if in log normalization"""
+_COLORMAP_CACHE = {}
+"""Cache already used colormaps as name: color LUT"""
+
+
+def _arrayToRgba8888(colors):
+ """Convert colors from a numpy array using float (0..1) int or uint
+ (0..255) to uint8 RGBA.
+
+ :param numpy.ndarray colors: Array of float int or uint colors to convert
+ :return: colors as uint8
+ :rtype: numpy.ndarray
+ """
+ assert len(colors.shape) == 2
+ assert colors.shape[1] in (3, 4)
+
+ if colors.dtype == numpy.uint8:
+ pass
+ elif colors.dtype.kind == 'f':
+ # Each bin is [N, N+1[ except the last one: [255, 256]
+ colors = numpy.clip(colors.astype(numpy.float64) * 256, 0., 255.)
+ colors = colors.astype(numpy.uint8)
+ elif colors.dtype.kind in 'iu':
+ colors = numpy.clip(colors, 0, 255)
+ colors = colors.astype(numpy.uint8)
+
+ if colors.shape[1] == 3:
+ tmp = numpy.empty((len(colors), 4), dtype=numpy.uint8)
+ tmp[:, 0:3] = colors
+ tmp[:, 3] = 255
+ colors = tmp
+
+ return colors
+
+
+def _createColormapLut(name):
+ """Returns the color LUT corresponding to a colormap name
+
+ :param str name: Name of the colormap to load
+ :returns: Corresponding table of colors
+ :rtype: numpy.ndarray
+ :raise ValueError: If no colormap corresponds to name
+ """
+ description = _AVAILABLE_LUTS.get(name)
+ use_mpl = False
+ if description is not None:
+ if description.source == "builtin":
+ # Build colormap LUT
+ lut = numpy.zeros((256, 4), dtype=numpy.uint8)
+ lut[:, 3] = 255
+
+ if name == 'gray':
+ lut[:, :3] = numpy.arange(256, dtype=numpy.uint8).reshape(-1, 1)
+ elif name == 'reversed gray':
+ lut[:, :3] = numpy.arange(255, -1, -1, dtype=numpy.uint8).reshape(-1, 1)
+ elif name == 'red':
+ lut[:, 0] = numpy.arange(256, dtype=numpy.uint8)
+ elif name == 'green':
+ lut[:, 1] = numpy.arange(256, dtype=numpy.uint8)
+ elif name == 'blue':
+ lut[:, 2] = numpy.arange(256, dtype=numpy.uint8)
+ elif name == 'temperature':
+ # Red
+ lut[128:192, 0] = numpy.arange(2, 255, 4, dtype=numpy.uint8)
+ lut[192:, 0] = 255
+ # Green
+ lut[:64, 1] = numpy.arange(0, 255, 4, dtype=numpy.uint8)
+ lut[64:192, 1] = 255
+ lut[192:, 1] = numpy.arange(252, -1, -4, dtype=numpy.uint8)
+ # Blue
+ lut[:64, 2] = 255
+ lut[64:128, 2] = numpy.arange(254, 0, -4, dtype=numpy.uint8)
+ else:
+ raise RuntimeError("Built-in colormap not implemented")
+ return lut
+
+ elif description.source == "resource":
+ # Load colormap LUT
+ colors = numpy.load(_resource_filename("gui/colormaps/%s.npy" % name))
+ # Convert to uint8 and add alpha channel
+ lut = _arrayToRgba8888(colors)
+ return lut
+
+ elif description.source == "matplotlib":
+ use_mpl = True
+
+ else:
+ raise RuntimeError("Internal LUT source '%s' unsupported" % description.source)
+
+ # Here it expect a matplotlib LUTs
+
+ if use_mpl:
+ # matplotlib is mandatory
+ if _matplotlib_cm is None:
+ raise ValueError("The colormap '%s' expect matplotlib, but matplotlib is not installed" % name)
+
+ if _matplotlib_cm is not None: # Try to load with matplotlib
+ colormap = _matplotlib_cm.get_cmap(name)
+ lut = colormap(numpy.linspace(0, 1, colormap.N, endpoint=True))
+ lut = _arrayToRgba8888(lut)
+ return lut
+
+ raise ValueError("Unknown colormap '%s'" % name)
+
+
+def _getColormap(name):
+ """Returns the color LUT corresponding to a colormap name
+
+ :param str name: Name of the colormap to load
+ :returns: Corresponding table of colors
+ :rtype: numpy.ndarray
+ :raise ValueError: If no colormap corresponds to name
+ """
+ name = str(name)
+ if name not in _COLORMAP_CACHE:
+ lut = _createColormapLut(name)
+ _COLORMAP_CACHE[name] = lut
+ return _COLORMAP_CACHE[name]
class Colormap(qt.QObject):
"""Description of a colormap
+ If no `name` nor `colors` are provided, a default gray LUT is used.
+
:param str name: Name of the colormap
:param tuple colors: optional, custom colormap.
Nx3 or Nx4 numpy array of RGB(A) colors,
@@ -187,10 +345,11 @@ class Colormap(qt.QObject):
sigChanged = qt.Signal()
"""Signal emitted when the colormap has changed."""
- def __init__(self, name='gray', colors=None, normalization=LINEAR, vmin=None, vmax=None):
+ def __init__(self, name=None, colors=None, normalization=LINEAR, vmin=None, vmax=None):
qt.QObject.__init__(self)
+ self._editable = True
+
assert normalization in Colormap.NORMALIZATIONS
- assert not (name is None and colors is None)
if normalization is Colormap.LOGARITHM:
if (vmin is not None and vmin < 0) or (vmax is not None and vmax < 0):
m = "Unsuported vmin (%s) and/or vmax (%s) given for a log scale."
@@ -200,78 +359,76 @@ class Colormap(qt.QObject):
vmin = None
vmax = None
- self._name = str(name) if name is not None else None
- self._setColors(colors)
+ self._name = None
+ self._colors = None
+
+ if colors is not None and name is not None:
+ deprecation.deprecated_warning("Argument",
+ name="silx.gui.plot.Colors",
+ reason="name and colors can't be used at the same time",
+ since_version="0.10.0",
+ skip_backtrace_count=1)
+
+ colors = None
+
+ if name is not None:
+ self.setName(name) # And resets colormap LUT
+ elif colors is not None:
+ self.setColormapLUT(colors)
+ else:
+ # Default colormap is grey
+ self.setName("gray")
+
self._normalization = str(normalization)
self._vmin = float(vmin) if vmin is not None else None
self._vmax = float(vmax) if vmax is not None else None
- self._editable = True
-
- def isAutoscale(self):
- """Return True if both min and max are in autoscale mode"""
- return self._vmin is None and self._vmax is None
- def getName(self):
- """Return the name of the colormap
- :rtype: str
- """
- return self._name
-
- @staticmethod
- def _convertColorsFromFloatToUint8(colors):
- """Convert colors from float in [0, 1] to uint8
+ def setFromColormap(self, other):
+ """Set this colormap using information from the `other` colormap.
- :param numpy.ndarray colors: Array of float colors to convert
- :return: colors as uint8
- :rtype: numpy.ndarray
+ :param Colormap other: Colormap to use as reference.
"""
- # Each bin is [N, N+1[ except the last one: [255, 256]
- return numpy.clip(
- colors.astype(numpy.float64) * 256, 0., 255.).astype(numpy.uint8)
-
- def _setColors(self, colors):
- if colors is None:
- self._colors = None
+ if not self.isEditable():
+ raise NotEditableError('Colormap is not editable')
+ if self == other:
+ return
+ old = self.blockSignals(True)
+ name = other.getName()
+ if name is not None:
+ self.setName(name)
else:
- colors = numpy.array(colors, copy=False)
- if colors.shape == ():
- raise TypeError("An array is expected for 'colors' argument. '%s' was found." % type(colors))
- colors.shape = -1, colors.shape[-1]
- if colors.dtype.kind == 'f':
- colors = self._convertColorsFromFloatToUint8(colors)
-
- # Makes sure it is RGBA8888
- self._colors = numpy.zeros((len(colors), 4), dtype=numpy.uint8)
- self._colors[:, 3] = 255 # Alpha channel
- self._colors[:, :colors.shape[1]] = colors # Copy colors
+ self.setColormapLUT(other.getColormapLUT())
+ self.setNormalization(other.getNormalization())
+ self.setVRange(other.getVMin(), other.getVMax())
+ self.blockSignals(old)
+ self.sigChanged.emit()
def getNColors(self, nbColors=None):
"""Returns N colors computed by sampling the colormap regularly.
:param nbColors:
The number of colors in the returned array or None for the default value.
- The default value is 256 for colormap with a name (see :meth:`setName`) and
- it is the size of the LUT for colormap defined with :meth:`setColormapLUT`.
+ The default value is the size of the colormap LUT.
:type nbColors: int or None
:return: 2D array of uint8 of shape (nbColors, 4)
:rtype: numpy.ndarray
"""
# Handle default value for nbColors
if nbColors is None:
- lut = self.getColormapLUT()
- if lut is not None: # In this case uses LUT length
- nbColors = len(lut)
- else: # Default to 256
- nbColors = 256
-
- nbColors = int(nbColors)
+ return numpy.array(self._colors, copy=True)
+ else:
+ colormap = self.copy()
+ colormap.setNormalization(Colormap.LINEAR)
+ colormap.setVRange(vmin=None, vmax=None)
+ colors = colormap.applyToData(
+ numpy.arange(int(nbColors), dtype=numpy.int))
+ return colors
- colormap = self.copy()
- colormap.setNormalization(Colormap.LINEAR)
- colormap.setVRange(vmin=None, vmax=None)
- colors = colormap.applyToData(
- numpy.arange(nbColors, dtype=numpy.int))
- return colors
+ def getName(self):
+ """Return the name of the colormap
+ :rtype: str
+ """
+ return self._name
def setName(self, name):
"""Set the name of the colormap to use.
@@ -281,23 +438,31 @@ class Colormap(qt.QObject):
'reversed gray', 'temperature', 'red', 'green', 'blue', 'jet',
'viridis', 'magma', 'inferno', 'plasma'.
"""
+ name = str(name)
+ if self._name == name:
+ return
if self.isEditable() is False:
raise NotEditableError('Colormap is not editable')
- assert name in self.getSupportedColormaps()
- self._name = str(name)
- self._colors = None
+ if name not in self.getSupportedColormaps():
+ raise ValueError("Colormap name '%s' is not supported" % name)
+ self._name = name
+ self._colors = _getColormap(self._name)
self.sigChanged.emit()
- def getColormapLUT(self):
- """Return the list of colors for the colormap or None if not set
+ def getColormapLUT(self, copy=True):
+ """Return the list of colors for the colormap or None if not set.
+ This returns None if the colormap was set with :meth:`setName`.
+ Use :meth:`getNColors` to get the colormap LUT for any colormap.
+
+ :param bool copy: If true a copy of the numpy array is provided
:return: the list of colors for the colormap or None if not set
:rtype: numpy.ndarray or None
"""
- if self._colors is None:
- return None
+ if self._name is None:
+ return numpy.array(self._colors, copy=copy)
else:
- return numpy.array(self._colors, copy=True)
+ return None
def setColormapLUT(self, colors):
"""Set the colors of the colormap.
@@ -310,10 +475,15 @@ class Colormap(qt.QObject):
"""
if self.isEditable() is False:
raise NotEditableError('Colormap is not editable')
- self._setColors(colors)
- if len(colors) is 0:
- self._colors = None
-
+ assert colors is not None
+
+ colors = numpy.array(colors, copy=False)
+ if colors.shape == ():
+ raise TypeError("An array is expected for 'colors' argument. '%s' was found." % type(colors))
+ assert len(colors) != 0
+ assert colors.ndim >= 2
+ colors.shape = -1, colors.shape[-1]
+ self._colors = _arrayToRgba8888(colors)
self._name = None
self.sigChanged.emit()
@@ -335,6 +505,10 @@ class Colormap(qt.QObject):
self._normalization = str(norm)
self.sigChanged.emit()
+ def isAutoscale(self):
+ """Return True if both min and max are in autoscale mode"""
+ return self._vmin is None and self._vmax is None
+
def getVMin(self):
"""Return the lower bound of the colormap
@@ -504,7 +678,7 @@ class Colormap(qt.QObject):
"""
return {
'name': self._name,
- 'colors': copy_mdl.copy(self._colors),
+ 'colors': self.getColormapLUT(),
'vmin': self._vmin,
'vmax': self._vmax,
'autoscale': self.isAutoscale(),
@@ -546,8 +720,10 @@ class Colormap(qt.QObject):
if dic.get('autoscale', False):
vmin, vmax = None, None
- self._name = name
- self._colors = colors
+ if name is not None:
+ self.setName(name)
+ else:
+ self.setColormapLUT(colors)
self._vmin = vmin
self._vmax = vmax
self._autoscale = True if (vmin is None and vmax is None) else False
@@ -557,7 +733,7 @@ class Colormap(qt.QObject):
@staticmethod
def _fromDict(dic):
- colormap = Colormap(name="")
+ colormap = Colormap()
colormap._setFromDict(dic)
return colormap
@@ -567,7 +743,7 @@ class Colormap(qt.QObject):
:rtype: silx.gui.colors.Colormap
"""
return Colormap(name=self._name,
- colors=copy_mdl.copy(self._colors),
+ colors=self.getColormapLUT(),
vmin=self._vmin,
vmax=self._vmax,
normalization=self._normalization)
@@ -577,34 +753,30 @@ class Colormap(qt.QObject):
:param numpy.ndarray data: The data to convert.
"""
- name = self.getName()
- if name is not None: # Get colormap definition from matplotlib
- # FIXME: If possible remove dependency to the plot
- from .plot.matplotlib import Colormap as MPLColormap
- mplColormap = MPLColormap.getColormap(name)
- colors = mplColormap(numpy.linspace(0, 1, 256, endpoint=True))
- colors = self._convertColorsFromFloatToUint8(colors)
-
- else: # Use user defined LUT
- colors = self.getColormapLUT()
-
vmin, vmax = self.getColormapRange(data)
normalization = self.getNormalization()
-
- return _cmap(data, colors, vmin, vmax, normalization)
+ return _cmap(data, self._colors, vmin, vmax, normalization)
@staticmethod
def getSupportedColormaps():
"""Get the supported colormap names as a tuple of str.
The list should at least contain and start by:
- ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
+
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue',
+ 'viridis', 'magma', 'inferno', 'plasma')
+
:rtype: tuple
"""
- # FIXME: If possible remove dependency to the plot
- from .plot.matplotlib import Colormap as MPLColormap
- maps = MPLColormap.getSupportedColormaps()
- return DEFAULT_COLORMAPS + maps
+ colormaps = set()
+ if _matplotlib_cm is not None:
+ colormaps.update(_matplotlib_cm.cmap_d.keys())
+ colormaps.update(_AVAILABLE_LUTS.keys())
+
+ colormaps = tuple(cmap for cmap in sorted(colormaps)
+ if cmap not in _AVAILABLE_LUTS.keys())
+
+ return tuple(_AVAILABLE_LUTS.keys()) + colormaps
def __str__(self):
return str(self._toDict())
@@ -617,6 +789,10 @@ class Colormap(qt.QObject):
def __eq__(self, other):
"""Compare colormap values and not pointers"""
+ if other is None:
+ return False
+ if not isinstance(other, Colormap):
+ return False
return (self.getName() == other.getName() and
self.getNormalization() == other.getNormalization() and
self.getVMin() == other.getVMin() and
@@ -710,13 +886,10 @@ def preferredColormaps():
"""
global _PREFERRED_COLORMAPS
if _PREFERRED_COLORMAPS is None:
- _PREFERRED_COLORMAPS = DEFAULT_COLORMAPS
# Initialize preferred colormaps
- setPreferredColormaps(('gray', 'reversed gray',
- 'temperature', 'red', 'green', 'blue', 'jet',
- 'viridis', 'magma', 'inferno', 'plasma',
- 'hsv'))
- return _PREFERRED_COLORMAPS
+ default_preferred = [k for k in _AVAILABLE_LUTS.keys() if _AVAILABLE_LUTS[k].preferred]
+ setPreferredColormaps(default_preferred)
+ return tuple(_PREFERRED_COLORMAPS)
def setPreferredColormaps(colormaps):
@@ -730,10 +903,41 @@ def setPreferredColormaps(colormaps):
:raise ValueError: if the list of available preferred colormaps is empty.
"""
supportedColormaps = Colormap.getSupportedColormaps()
- colormaps = tuple(
- cmap for cmap in colormaps if cmap in supportedColormaps)
+ colormaps = [cmap for cmap in colormaps if cmap in supportedColormaps]
if len(colormaps) == 0:
raise ValueError("Cannot set preferred colormaps to an empty list")
global _PREFERRED_COLORMAPS
_PREFERRED_COLORMAPS = colormaps
+
+
+def registerLUT(name, colors, cursor_color='black', preferred=True):
+ """Register a custom LUT to be used with `Colormap` objects.
+
+ It can override existing LUT names.
+
+ :param str name: Name of the LUT as defined to configure colormaps
+ :param numpy.ndarray colors: The custom LUT to register.
+ Nx3 or Nx4 numpy array of RGB(A) colors,
+ either uint8 or float in [0, 1].
+ :param bool preferred: If true, this LUT will be displayed as part of the
+ preferred colormaps in dialogs.
+ :param str cursor_color: Color used to display overlay over images using
+ colormap with this LUT.
+ """
+ description = _LUT_DESCRIPTION('user', cursor_color, preferred=preferred)
+ colors = _arrayToRgba8888(colors)
+ _AVAILABLE_LUTS[name] = description
+
+ if preferred:
+ # Invalidate the preferred cache
+ global _PREFERRED_COLORMAPS
+ if _PREFERRED_COLORMAPS is not None:
+ if name not in _PREFERRED_COLORMAPS:
+ _PREFERRED_COLORMAPS.append(name)
+ else:
+ # The cache is not yet loaded, it's fine
+ pass
+
+ # Register the cache as the LUT was already loaded
+ _COLORMAP_CACHE[name] = colors
diff --git a/silx/gui/data/DataViewer.py b/silx/gui/data/DataViewer.py
index 4db2863..b33a931 100644
--- a/silx/gui/data/DataViewer.py
+++ b/silx/gui/data/DataViewer.py
@@ -32,12 +32,10 @@ from silx.gui.data.DataViews import _normalizeData
import logging
from silx.gui import qt
from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
-from silx.utils import deprecation
-from silx.utils.property import classproperty
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "12/02/2019"
_logger = logging.getLogger(__name__)
@@ -70,66 +68,6 @@ class DataViewer(qt.QFrame):
viewer.setVisible(True)
"""
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.EMPTY_MODE", since_version="0.7", skip_backtrace_count=2)
- def EMPTY_MODE(self):
- return DataViews.EMPTY_MODE
-
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.PLOT1D_MODE", since_version="0.7", skip_backtrace_count=2)
- def PLOT1D_MODE(self):
- return DataViews.PLOT1D_MODE
-
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.PLOT2D_MODE", since_version="0.7", skip_backtrace_count=2)
- def PLOT2D_MODE(self):
- return DataViews.PLOT2D_MODE
-
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.PLOT3D_MODE", since_version="0.7", skip_backtrace_count=2)
- def PLOT3D_MODE(self):
- return DataViews.PLOT3D_MODE
-
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.RAW_MODE", since_version="0.7", skip_backtrace_count=2)
- def RAW_MODE(self):
- return DataViews.RAW_MODE
-
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.RAW_ARRAY_MODE", since_version="0.7", skip_backtrace_count=2)
- def RAW_ARRAY_MODE(self):
- return DataViews.RAW_ARRAY_MODE
-
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.RAW_RECORD_MODE", since_version="0.7", skip_backtrace_count=2)
- def RAW_RECORD_MODE(self):
- return DataViews.RAW_RECORD_MODE
-
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.RAW_SCALAR_MODE", since_version="0.7", skip_backtrace_count=2)
- def RAW_SCALAR_MODE(self):
- return DataViews.RAW_SCALAR_MODE
-
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.STACK_MODE", since_version="0.7", skip_backtrace_count=2)
- def STACK_MODE(self):
- return DataViews.STACK_MODE
-
- # TODO: Can be removed for silx 0.8
- @classproperty
- @deprecation.deprecated(replacement="DataViews.HDF5_MODE", since_version="0.7", skip_backtrace_count=2)
- def HDF5_MODE(self):
- return DataViews.HDF5_MODE
-
displayedViewChanged = qt.Signal(object)
"""Emitted when the displayed view changes"""
@@ -288,6 +226,7 @@ class DataViewer(qt.QFrame):
else:
self.__displayedData = self.__data
+ # TODO: would be good to avoid that, it should be synchonous
qt.QTimer.singleShot(10, self.__setDataInView)
def __setDataInView(self):
@@ -405,18 +344,16 @@ class DataViewer(qt.QFrame):
data = self.__data
info = self._getInfo()
# sort available views according to priority
- priorities = [v.getDataPriority(data, info) for v in self.__views]
- views = zip(priorities, self.__views)
+ views = []
+ for v in self.__views:
+ views.extend(v.getMatchingViews(data, info))
+ views = [(v.getCachedDataPriority(data, info), v) for v in views]
views = filter(lambda t: t[0] > DataViews.DataView.UNSUPPORTED, views)
views = sorted(views, reverse=True)
+ views = [v[1] for v in views]
# store available views
- if len(views) == 0:
- self.__setCurrentAvailableViews([])
- available = []
- else:
- available = [v[1] for v in views]
- self.__setCurrentAvailableViews(available)
+ self.__setCurrentAvailableViews(views)
def __updateView(self):
"""Display the data using the widget which fit the best"""
@@ -447,7 +384,7 @@ class DataViewer(qt.QFrame):
priority to lowest.
:rtype: DataView
"""
- hdf5View = self.getViewFromModeId(DataViewer.HDF5_MODE)
+ hdf5View = self.getViewFromModeId(DataViews.HDF5_MODE)
if hdf5View in available:
return hdf5View
return self.getViewFromModeId(DataViews.EMPTY_MODE)
@@ -487,6 +424,17 @@ class DataViewer(qt.QFrame):
"""
return self.__currentAvailableViews
+ def getReachableViews(self):
+ """Returns the list of reachable views from the registred available
+ views.
+
+ :rtype: List[DataView]
+ """
+ views = []
+ for v in self.availableViews():
+ views.extend(v.getReachableViews())
+ return views
+
def availableViews(self):
"""Returns the list of registered views
diff --git a/silx/gui/data/DataViewerFrame.py b/silx/gui/data/DataViewerFrame.py
index 4e6d2e8..9bfb95b 100644
--- a/silx/gui/data/DataViewerFrame.py
+++ b/silx/gui/data/DataViewerFrame.py
@@ -27,7 +27,7 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "12/02/2019"
from silx.gui import qt
from .DataViewer import DataViewer
@@ -120,6 +120,9 @@ class DataViewerFrame(qt.QWidget):
"""
self.__dataViewer.setGlobalHooks(hooks)
+ def getReachableViews(self):
+ return self.__dataViewer.getReachableViews()
+
def availableViews(self):
"""Returns the list of registered views
diff --git a/silx/gui/data/DataViewerSelector.py b/silx/gui/data/DataViewerSelector.py
index 35bbe99..a1e9947 100644
--- a/silx/gui/data/DataViewerSelector.py
+++ b/silx/gui/data/DataViewerSelector.py
@@ -29,7 +29,7 @@ from __future__ import division
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "23/01/2018"
+__date__ = "12/02/2019"
import weakref
import functools
@@ -85,7 +85,7 @@ class DataViewerSelector(qt.QWidget):
iconSize = qt.QSize(16, 16)
- for view in self.__dataViewer.availableViews():
+ for view in self.__dataViewer.getReachableViews():
label = view.label()
icon = view.icon()
button = qt.QPushButton(label)
@@ -155,7 +155,7 @@ class DataViewerSelector(qt.QWidget):
self.__dataViewer.setDisplayedView(view)
def __checkAvailableButtons(self):
- views = set(self.__dataViewer.availableViews())
+ views = set(self.__dataViewer.getReachableViews())
if views == set(self.__buttons.keys()):
return
# Recreate all the buttons
diff --git a/silx/gui/data/DataViews.py b/silx/gui/data/DataViews.py
index 2291e87..6575d0d 100644
--- a/silx/gui/data/DataViews.py
+++ b/silx/gui/data/DataViews.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,6 +31,7 @@ import numbers
import numpy
import silx.io
+from silx.utils import deprecation
from silx.gui import qt, icons
from silx.gui.data.TextFormatter import TextFormatter
from silx.io import nxdata
@@ -41,7 +42,7 @@ from silx.gui.dialog.ColormapDialog import ColormapDialog
__authors__ = ["V. Valls", "P. Knobel"]
__license__ = "MIT"
-__date__ = "23/05/2018"
+__date__ = "19/02/2019"
_logger = logging.getLogger(__name__)
@@ -67,6 +68,8 @@ NXDATA_CURVE_MODE = 73
NXDATA_XYVSCATTER_MODE = 74
NXDATA_IMAGE_MODE = 75
NXDATA_STACK_MODE = 76
+NXDATA_VOLUME_MODE = 77
+NXDATA_VOLUME_AS_STACK_MODE = 78
def _normalizeData(data):
@@ -100,6 +103,7 @@ class DataInfo(object):
"""Store extracted information from a data"""
def __init__(self, data):
+ self.__priorities = {}
data = self.normalizeData(data)
self.isArray = False
self.interpretation = None
@@ -131,9 +135,6 @@ class DataInfo(object):
elif nx_class == "NXdata":
# group claiming to be NXdata could not be parsed
self.isInvalidNXdata = True
- elif nx_class == "NXentry" and "default" in data.attrs:
- # entry claiming to have a default NXdata could not be parsed
- self.isInvalidNXdata = True
elif nx_class == "NXroot" or silx.io.is_file(data):
# root claiming to have a default entry
if "default" in data.attrs:
@@ -141,6 +142,9 @@ class DataInfo(object):
if def_entry in data and "default" in data[def_entry].attrs:
# and entry claims to have default NXdata
self.isInvalidNXdata = True
+ elif "default" in data.attrs:
+ # group claiming to have a default NXdata could not be parsed
+ self.isInvalidNXdata = True
if isinstance(data, numpy.ndarray):
self.isArray = True
@@ -201,6 +205,12 @@ class DataInfo(object):
Else returns the data."""
return _normalizeData(data)
+ def cachePriority(self, view, priority):
+ self.__priorities[view] = priority
+
+ def getPriority(self, view):
+ return self.__priorities[view]
+
class DataViewHooks(object):
"""A set of hooks defined to custom the behaviour of the data views."""
@@ -357,6 +367,35 @@ class DataView(object):
"""
return []
+ def getReachableViews(self):
+ """Returns the views that can be returned by `getMatchingViews`.
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ :rtype: List[DataView]
+ """
+ return [self]
+
+ def getMatchingViews(self, data, info):
+ """Returns the views according to data and info from the data.
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ :rtype: List[DataView]
+ """
+ priority = self.getCachedDataPriority(data, info)
+ if priority == DataView.UNSUPPORTED:
+ return []
+ return [self]
+
+ def getCachedDataPriority(self, data, info):
+ try:
+ priority = info.getPriority(self)
+ except KeyError:
+ priority = self.getDataPriority(data, info)
+ info.cachePriority(self, priority)
+ return priority
+
def getDataPriority(self, data, info):
"""
Returns the priority of using this view according to a data.
@@ -377,7 +416,53 @@ class DataView(object):
return str(self) < str(other)
-class CompositeDataView(DataView):
+class _CompositeDataView(DataView):
+ """Contains sub views"""
+
+ def getViews(self):
+ """Returns the direct sub views registered in this view.
+
+ :rtype: List[DataView]
+ """
+ raise NotImplementedError()
+
+ def getReachableViews(self):
+ """Returns all views that can be reachable at on point.
+
+ This method return any sub view provided (recursivly).
+
+ :rtype: List[DataView]
+ """
+ raise NotImplementedError()
+
+ def getMatchingViews(self, data, info):
+ """Returns sub views matching this data and info.
+
+ This method return any sub view provided (recursivly).
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ :rtype: List[DataView]
+ """
+ raise NotImplementedError()
+
+ @deprecation.deprecated(replacement="getReachableViews", since_version="0.10")
+ def availableViews(self):
+ return self.getViews()
+
+ def isSupportedData(self, data, info):
+ """If true, the composite view allow sub views to access to this data.
+ Else this this data is considered as not supported by any of sub views
+ (incliding this composite view).
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ :rtype: bool
+ """
+ return True
+
+
+class SelectOneDataView(_CompositeDataView):
"""Data view which can display a data using different view according to
the kind of the data."""
@@ -386,7 +471,7 @@ class CompositeDataView(DataView):
:param qt.QWidget parent: Parent of the hold widget
"""
- super(CompositeDataView, self).__init__(parent, modeId, icon, label)
+ super(SelectOneDataView, self).__init__(parent, modeId, icon, label)
self.__views = OrderedDict()
self.__currentView = None
@@ -395,7 +480,7 @@ class CompositeDataView(DataView):
:param DataViewHooks hooks: The data view hooks to use
"""
- super(CompositeDataView, self).setHooks(hooks)
+ super(SelectOneDataView, self).setHooks(hooks)
if hooks is not None:
for v in self.__views:
v.setHooks(hooks)
@@ -407,16 +492,40 @@ class CompositeDataView(DataView):
dataView.setHooks(hooks)
self.__views[dataView] = None
- def availableViews(self):
+ def getReachableViews(self):
+ views = []
+ addSelf = False
+ for v in self.__views:
+ if isinstance(v, SelectManyDataView):
+ views.extend(v.getReachableViews())
+ else:
+ addSelf = True
+ if addSelf:
+ # Single views are hidden by this view
+ views.insert(0, self)
+ return views
+
+ def getMatchingViews(self, data, info):
+ if not self.isSupportedData(data, info):
+ return []
+ view = self.__getBestView(data, info)
+ if isinstance(view, SelectManyDataView):
+ return view.getMatchingViews(data, info)
+ else:
+ return [self]
+
+ def getViews(self):
"""Returns the list of registered views
:rtype: List[DataView]
"""
return list(self.__views.keys())
- def getBestView(self, data, info):
+ def __getBestView(self, data, info):
"""Returns the best view according to priorities."""
- views = [(v.getDataPriority(data, info), v) for v in self.__views.keys()]
+ if not self.isSupportedData(data, info):
+ return None
+ views = [(v.getCachedDataPriority(data, info), v) for v in self.__views.keys()]
views = filter(lambda t: t[0] > DataView.UNSUPPORTED, views)
views = sorted(views, key=lambda t: t[0], reverse=True)
@@ -471,17 +580,17 @@ class CompositeDataView(DataView):
self.__currentView.setData(data)
def axesNames(self, data, info):
- view = self.getBestView(data, info)
+ view = self.__getBestView(data, info)
self.__currentView = view
return view.axesNames(data, info)
def getDataPriority(self, data, info):
- view = self.getBestView(data, info)
+ view = self.__getBestView(data, info)
self.__currentView = view
if view is None:
return DataView.UNSUPPORTED
else:
- return view.getDataPriority(data, info)
+ return view.getCachedDataPriority(data, info)
def replaceView(self, modeId, newView):
"""Replace a data view with a custom view.
@@ -502,7 +611,7 @@ class CompositeDataView(DataView):
if view.modeId() == modeId:
oldView = view
break
- elif isinstance(view, CompositeDataView):
+ elif isinstance(view, _CompositeDataView):
# recurse
hooks = self.getHooks()
if hooks is not None:
@@ -519,6 +628,135 @@ class CompositeDataView(DataView):
return True
+# NOTE: SelectOneDataView was introduced with silx 0.10
+CompositeDataView = SelectOneDataView
+
+
+class SelectManyDataView(_CompositeDataView):
+ """Data view which can select a set of sub views according to
+ the kind of the data.
+
+ This view itself is abstract and is not exposed.
+ """
+
+ def __init__(self, parent, views=None):
+ """Constructor
+
+ :param qt.QWidget parent: Parent of the hold widget
+ """
+ super(SelectManyDataView, self).__init__(parent, modeId=None, icon=None, label=None)
+ if views is None:
+ views = []
+ self.__views = views
+
+ def setHooks(self, hooks):
+ """Set the data context to use with this view.
+
+ :param DataViewHooks hooks: The data view hooks to use
+ """
+ super(SelectManyDataView, self).setHooks(hooks)
+ if hooks is not None:
+ for v in self.__views:
+ v.setHooks(hooks)
+
+ def addView(self, dataView):
+ """Add a new dataview to the available list."""
+ hooks = self.getHooks()
+ if hooks is not None:
+ dataView.setHooks(hooks)
+ self.__views.append(dataView)
+
+ def getViews(self):
+ """Returns the list of registered views
+
+ :rtype: List[DataView]
+ """
+ return list(self.__views)
+
+ def getReachableViews(self):
+ views = []
+ for v in self.__views:
+ views.extend(v.getReachableViews())
+ return views
+
+ def getMatchingViews(self, data, info):
+ """Returns the views according to data and info from the data.
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ """
+ if not self.isSupportedData(data, info):
+ return []
+ views = [v for v in self.__views if v.getCachedDataPriority(data, info) != DataView.UNSUPPORTED]
+ return views
+
+ def customAxisNames(self):
+ raise RuntimeError("Abstract view")
+
+ def setCustomAxisValue(self, name, value):
+ raise RuntimeError("Abstract view")
+
+ def select(self):
+ raise RuntimeError("Abstract view")
+
+ def createWidget(self, parent):
+ raise RuntimeError("Abstract view")
+
+ def clear(self):
+ for v in self.__views:
+ v.clear()
+
+ def setData(self, data):
+ raise RuntimeError("Abstract view")
+
+ def axesNames(self, data, info):
+ raise RuntimeError("Abstract view")
+
+ def getDataPriority(self, data, info):
+ if not self.isSupportedData(data, info):
+ return DataView.UNSUPPORTED
+ priorities = [v.getCachedDataPriority(data, info) for v in self.__views]
+ priorities = [v for v in priorities if v != DataView.UNSUPPORTED]
+ priorities = sorted(priorities)
+ if len(priorities) == 0:
+ return DataView.UNSUPPORTED
+ return priorities[-1]
+
+ def replaceView(self, modeId, newView):
+ """Replace a data view with a custom view.
+ Return True in case of success, False in case of failure.
+
+ .. note::
+
+ This method must be called just after instantiation, before
+ the viewer is used.
+
+ :param int modeId: Unique mode ID identifying the DataView to
+ be replaced.
+ :param DataViews.DataView newView: New data view
+ :return: True if replacement was successful, else False
+ """
+ oldView = None
+ for iview, view in enumerate(self.__views):
+ if view.modeId() == modeId:
+ oldView = view
+ break
+ elif isinstance(view, CompositeDataView):
+ # recurse
+ hooks = self.getHooks()
+ if hooks is not None:
+ newView.setHooks(hooks)
+ if view.replaceView(modeId, newView):
+ return True
+
+ if oldView is None:
+ return False
+
+ # replace oldView with new view in dict
+ self.__views[iview] = newView
+ return True
+
+
class _EmptyView(DataView):
"""Dummy view to display nothing"""
@@ -1096,17 +1334,6 @@ class _InvalidNXdataView(DataView):
# invalid: could not even be parsed by NXdata
self._msg = "Group has @NX_class = NXdata, but could not be interpreted"
self._msg += " as valid NXdata."
- elif nx_class == "NXentry":
- self._msg = "NXentry group provides a @default attribute,"
- default_nxdata_name = data.attrs["default"]
- if default_nxdata_name not in data:
- self._msg += " but no corresponding NXdata group exists."
- elif get_attr_as_unicode(data[default_nxdata_name], "NX_class") != "NXdata":
- self._msg += " but the corresponding item is not a "
- self._msg += "NXdata group."
- else:
- self._msg += " but the corresponding NXdata seems to be"
- self._msg += " malformed."
elif nx_class == "NXroot" or silx.io.is_file(data):
default_entry = data[data.attrs["default"]]
default_nxdata_name = default_entry.attrs["default"]
@@ -1122,6 +1349,17 @@ class _InvalidNXdataView(DataView):
else:
self._msg += " but the corresponding NXdata seems to be"
self._msg += " malformed."
+ else:
+ self._msg = "Group provides a @default attribute,"
+ default_nxdata_name = data.attrs["default"]
+ if default_nxdata_name not in data:
+ self._msg += " but no corresponding NXdata group exists."
+ elif get_attr_as_unicode(data[default_nxdata_name], "NX_class") != "NXdata":
+ self._msg += " but the corresponding item is not a "
+ self._msg += "NXdata group."
+ else:
+ self._msg += " but the corresponding NXdata seems to be"
+ self._msg += " malformed."
return 100
@@ -1277,7 +1515,7 @@ class _NXdataXYVScatterView(DataView):
class _NXdataImageView(DataView):
"""DataView using a Plot2D for displaying NXdata images:
- 2-D signal or n-D signals with *@interpretation=spectrum*."""
+ 2-D signal or n-D signals with *@interpretation=image*."""
def __init__(self, parent):
DataView.__init__(self, parent,
modeId=NXDATA_IMAGE_MODE)
@@ -1323,6 +1561,53 @@ class _NXdataImageView(DataView):
return DataView.UNSUPPORTED
+class _NXdataComplexImageView(DataView):
+ """DataView using a ComplexImageView for displaying NXdata complex images:
+ 2-D signal or n-D signals with *@interpretation=image*."""
+ def __init__(self, parent):
+ DataView.__init__(self, parent,
+ modeId=NXDATA_IMAGE_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayComplexImagePlot
+ widget = ArrayComplexImagePlot(parent, colormap=self.defaultColormap())
+ widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ return widget
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+
+ # last two axes are Y & X
+ img_slicing = slice(-2, None)
+ y_axis, x_axis = nxd.axes[img_slicing]
+ y_label, x_label = nxd.axes_names[img_slicing]
+
+ self.getWidget().setImageData(
+ [nxd.signal] + nxd.auxiliary_signals,
+ x_axis=x_axis, y_axis=y_axis,
+ signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
+ xlabel=x_label, ylabel=y_label,
+ title=nxd.title)
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+
+ if info.hasNXdata and not info.isInvalidNXdata:
+ nxd = nxdata.get_default(data, validate=False)
+ if nxd.is_image and numpy.iscomplexobj(nxd.signal):
+ return 100
+
+ return DataView.UNSUPPORTED
+
+
class _NXdataStackView(DataView):
def __init__(self, parent):
DataView.__init__(self, parent,
@@ -1368,6 +1653,154 @@ class _NXdataStackView(DataView):
return DataView.UNSUPPORTED
+class _NXdataVolumeView(DataView):
+ def __init__(self, parent):
+ DataView.__init__(self, parent,
+ label="NXdata (3D)",
+ icon=icons.getQIcon("view-nexus"),
+ modeId=NXDATA_VOLUME_MODE)
+ try:
+ import silx.gui.plot3d # noqa
+ except ImportError:
+ _logger.warning("Plot3dView is not available")
+ _logger.debug("Backtrace", exc_info=True)
+ raise
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ data = _normalizeComplex(data)
+ return data
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayVolumePlot
+ widget = ArrayVolumePlot(parent)
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+ signal_name = nxd.signal_name
+ z_axis, y_axis, x_axis = nxd.axes[-3:]
+ z_label, y_label, x_label = nxd.axes_names[-3:]
+ title = nxd.title or signal_name
+
+ widget = self.getWidget()
+ widget.setData(
+ nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis,
+ signal_name=signal_name,
+ xlabel=x_label, ylabel=y_label, zlabel=z_label,
+ title=title)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_volume:
+ return 150
+
+ return DataView.UNSUPPORTED
+
+
+class _NXdataVolumeAsStackView(DataView):
+ def __init__(self, parent):
+ DataView.__init__(self, parent,
+ label="NXdata (2D)",
+ icon=icons.getQIcon("view-nexus"),
+ modeId=NXDATA_VOLUME_AS_STACK_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayStackPlot
+ widget = ArrayStackPlot(parent)
+ widget.getStackView().setColormap(self.defaultColormap())
+ widget.getStackView().getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+ signal_name = nxd.signal_name
+ z_axis, y_axis, x_axis = nxd.axes[-3:]
+ z_label, y_label, x_label = nxd.axes_names[-3:]
+ title = nxd.title or signal_name
+
+ widget = self.getWidget()
+ widget.setStackData(
+ nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis,
+ signal_name=signal_name,
+ xlabel=x_label, ylabel=y_label, zlabel=z_label,
+ title=title)
+ # Override the colormap, while setStack overwrite it
+ widget.getStackView().setColormap(self.defaultColormap())
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.isComplex:
+ return DataView.UNSUPPORTED
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_volume:
+ return 200
+
+ return DataView.UNSUPPORTED
+
+class _NXdataComplexVolumeAsStackView(DataView):
+ def __init__(self, parent):
+ DataView.__init__(self, parent,
+ label="NXdata (2D)",
+ icon=icons.getQIcon("view-nexus"),
+ modeId=NXDATA_VOLUME_AS_STACK_MODE)
+ self._is_complex_data = False
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayComplexImagePlot
+ widget = ArrayComplexImagePlot(parent, colormap=self.defaultColormap())
+ widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+ signal_name = nxd.signal_name
+ z_axis, y_axis, x_axis = nxd.axes[-3:]
+ z_label, y_label, x_label = nxd.axes_names[-3:]
+ title = nxd.title or signal_name
+
+ self.getWidget().setImageData(
+ [nxd.signal] + nxd.auxiliary_signals,
+ x_axis=x_axis, y_axis=y_axis,
+ signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
+ xlabel=x_label, ylabel=y_label, title=nxd.title)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if not info.isComplex:
+ return DataView.UNSUPPORTED
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_volume:
+ return 200
+
+ return DataView.UNSUPPORTED
+
+
class _NXdataView(CompositeDataView):
"""Composite view displaying NXdata groups using the most adequate
widget depending on the dimensionality."""
@@ -1382,5 +1815,17 @@ class _NXdataView(CompositeDataView):
self.addView(_NXdataScalarView(parent))
self.addView(_NXdataCurveView(parent))
self.addView(_NXdataXYVScatterView(parent))
+ self.addView(_NXdataComplexImageView(parent))
self.addView(_NXdataImageView(parent))
self.addView(_NXdataStackView(parent))
+
+ # The 3D view can be displayed using 2 ways
+ nx3dViews = SelectManyDataView(parent)
+ nx3dViews.addView(_NXdataVolumeAsStackView(parent))
+ nx3dViews.addView(_NXdataComplexVolumeAsStackView(parent))
+ try:
+ nx3dViews.addView(_NXdataVolumeView(parent))
+ except Exception:
+ _logger.warning("NXdataVolumeView is not available")
+ _logger.debug("Backtrace", exc_info=True)
+ self.addView(nx3dViews)
diff --git a/silx/gui/data/Hdf5TableView.py b/silx/gui/data/Hdf5TableView.py
index 9e28fbf..d7c33f3 100644
--- a/silx/gui/data/Hdf5TableView.py
+++ b/silx/gui/data/Hdf5TableView.py
@@ -30,12 +30,14 @@ from __future__ import division
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "05/07/2018"
+__date__ = "12/02/2019"
import collections
import functools
import os.path
import logging
+import h5py
+
from silx.gui import qt
import silx.io
from .TextFormatter import TextFormatter
@@ -44,11 +46,6 @@ from silx.gui.widgets import HierarchicalTableView
from ..hdf5.Hdf5Formatter import Hdf5Formatter
from ..hdf5._utils import htmlFromDict
-try:
- import h5py
-except ImportError:
- h5py = None
-
_logger = logging.getLogger(__name__)
@@ -198,11 +195,9 @@ class _CellFilterAvailableData(_CellData):
}
def __init__(self, filterId):
- import h5py.version
if h5py.version.hdf5_version_tuple >= (1, 10, 2):
# Previous versions only returns True if the filter was first used
# to decode a dataset
- import h5py.h5z
self.__availability = h5py.h5z.filter_avail(filterId)
else:
self.__availability = "na"
@@ -416,7 +411,7 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
self.__data.addHeaderRow(headerLabel="Data info")
- if h5py is not None and hasattr(obj, "id") and hasattr(obj.id, "get_type"):
+ if hasattr(obj, "id") and hasattr(obj.id, "get_type"):
# display the HDF5 type
self.__data.addHeaderValueRow("HDF5 type", self.__formatHdf5Type)
self.__data.addHeaderValueRow("dtype", self.__formatDType)
diff --git a/silx/gui/data/HexaTableView.py b/silx/gui/data/HexaTableView.py
index c86c0af..1617f0a 100644
--- a/silx/gui/data/HexaTableView.py
+++ b/silx/gui/data/HexaTableView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,11 +28,13 @@ hexadecimal viewer.
"""
from __future__ import division
-import numpy
import collections
+
+import numpy
+import six
+
from silx.gui import qt
import silx.io.utils
-from silx.third_party import six
from silx.gui.widgets.TableWidget import CopySelectedCellsAction
__authors__ = ["V. Valls"]
diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py
index f7c479d..e5a2550 100644
--- a/silx/gui/data/NXdataWidgets.py
+++ b/silx/gui/data/NXdataWidgets.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,19 +26,25 @@
"""
__authors__ = ["P. Knobel"]
__license__ = "MIT"
-__date__ = "10/10/2018"
+__date__ = "12/11/2018"
+import logging
+import numbers
import numpy
from silx.gui import qt
from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
from silx.gui.plot import Plot1D, Plot2D, StackView, ScatterView
+from silx.gui.plot.ComplexImageView import ComplexImageView
from silx.gui.colors import Colormap
from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration
+_logger = logging.getLogger(__name__)
+
+
class ArrayCurvePlot(qt.QWidget):
"""
Widget for plotting a curve from a multi-dimensional signal array
@@ -72,21 +78,16 @@ class ArrayCurvePlot(qt.QWidget):
self._plot = Plot1D(self)
- self.selectorDock = qt.QDockWidget("Data selector", self._plot)
- # not closable
- self.selectorDock.setFeatures(qt.QDockWidget.DockWidgetMovable |
- qt.QDockWidget.DockWidgetFloatable)
- self._selector = NumpyAxesSelector(self.selectorDock)
+ self._selector = NumpyAxesSelector(self)
self._selector.setNamedAxesSelectorVisibility(False)
self.__selector_is_connected = False
- self.selectorDock.setWidget(self._selector)
- self._plot.addTabbedDockWidget(self.selectorDock)
self._plot.sigActiveCurveChanged.connect(self._setYLabelFromActiveLegend)
- layout = qt.QGridLayout()
+ layout = qt.QVBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
- layout.addWidget(self._plot, 0, 0)
+ layout.addWidget(self._plot)
+ layout.addWidget(self._selector)
self.setLayout(layout)
@@ -130,9 +131,9 @@ class ArrayCurvePlot(qt.QWidget):
self._selector.setAxisNames(["Y"])
if len(ys[0].shape) < 2:
- self.selectorDock.hide()
+ self._selector.hide()
else:
- self.selectorDock.show()
+ self._selector.show()
self._plot.setGraphTitle(title or "")
self._updateCurve()
@@ -182,6 +183,9 @@ class ArrayCurvePlot(qt.QWidget):
break
def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
self._plot.clear()
@@ -339,11 +343,8 @@ class ArrayImagePlot(qt.QWidget):
normalization=Colormap.LINEAR))
self._plot.getIntensityHistogramAction().setVisible(True)
- self.selectorDock = qt.QDockWidget("Data selector", self._plot)
# not closable
- self.selectorDock.setFeatures(qt.QDockWidget.DockWidgetMovable |
- qt.QDockWidget.DockWidgetFloatable)
- self._selector = NumpyAxesSelector(self.selectorDock)
+ self._selector = NumpyAxesSelector(self)
self._selector.setNamedAxesSelectorVisibility(False)
self._selector.selectionChanged.connect(self._updateImage)
@@ -355,9 +356,8 @@ class ArrayImagePlot(qt.QWidget):
layout = qt.QVBoxLayout()
layout.addWidget(self._plot)
+ layout.addWidget(self._selector)
layout.addWidget(self._auxSigSlider)
- self.selectorDock.setWidget(self._selector)
- self._plot.addTabbedDockWidget(self.selectorDock)
self.setLayout(layout)
@@ -413,9 +413,9 @@ class ArrayImagePlot(qt.QWidget):
self._selector.setData(signals[0])
if len(signals[0].shape) <= img_ndim:
- self.selectorDock.hide()
+ self._selector.hide()
else:
- self.selectorDock.show()
+ self._selector.show()
self._auxSigSlider.setMaximum(len(signals) - 1)
if len(signals) > 1:
@@ -425,6 +425,7 @@ class ArrayImagePlot(qt.QWidget):
self._auxSigSlider.setValue(0)
self._updateImage()
+ self._plot.resetZoom()
self._selector.selectionChanged.connect(self._updateImage)
self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged)
@@ -492,12 +493,202 @@ class ArrayImagePlot(qt.QWidget):
self._plot.setGraphTitle(title)
self._plot.getXAxis().setLabel(self.__x_axis_name)
self._plot.getYAxis().setLabel(self.__y_axis_name)
- self._plot.resetZoom()
def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
self._plot.clear()
+class ArrayComplexImagePlot(qt.QWidget):
+ """
+ Widget for plotting an image of complex from a multi-dimensional signal array
+ and two 1D axes array.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last two dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 2) dimensions of
+ the signal array, and the plot is updated to show the image corresponding
+ to the selection.
+
+ If one or both of the axes does not have regularly spaced values, the
+ the image is plotted as a coloured scatter plot.
+ """
+ def __init__(self, parent=None, colormap=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayComplexImagePlot, self).__init__(parent)
+
+ self.__signals = None
+ self.__signals_names = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+
+ self._plot = ComplexImageView(self)
+ if colormap is not None:
+ for mode in (ComplexImageView.Mode.ABSOLUTE,
+ ComplexImageView.Mode.SQUARE_AMPLITUDE,
+ ComplexImageView.Mode.REAL,
+ ComplexImageView.Mode.IMAGINARY):
+ self._plot.setColormap(colormap, mode)
+
+ self._plot.getPlot().getIntensityHistogramAction().setVisible(True)
+ self._plot.setKeepDataAspectRatio(True)
+
+ # not closable
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self._selector.selectionChanged.connect(self._updateImage)
+
+ self._auxSigSlider = HorizontalSliderWithBrowser(parent=self)
+ self._auxSigSlider.setMinimum(0)
+ self._auxSigSlider.setValue(0)
+ self._auxSigSlider.valueChanged[int].connect(self._sliderIdxChanged)
+ self._auxSigSlider.setToolTip("Select auxiliary signals")
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._plot)
+ layout.addWidget(self._selector)
+ layout.addWidget(self._auxSigSlider)
+
+ self.setLayout(layout)
+
+ def _sliderIdxChanged(self, value):
+ self._updateImage()
+
+ def getPlot(self):
+ """Returns the plot used for the display
+
+ :rtype: PlotWidget
+ """
+ return self._plot.getPlot()
+
+ def setImageData(self, signals,
+ x_axis=None, y_axis=None,
+ signals_names=None,
+ xlabel=None, ylabel=None,
+ title=None):
+ """
+
+ :param signals: list of n-D datasets, whose last 2 dimensions are used as the
+ image's values, or list of 3D datasets interpreted as RGBA image.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param signals_names: Names for each image, used as subtitle and legend.
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param title: Graph title
+ """
+ self._selector.selectionChanged.disconnect(self._updateImage)
+ self._auxSigSlider.valueChanged.disconnect(self._sliderIdxChanged)
+
+ self.__signals = signals
+ self.__signals_names = signals_names
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__title = title
+
+ self._selector.clear()
+ self._selector.setAxisNames(["Y", "X"])
+ self._selector.setData(signals[0])
+
+ if len(signals[0].shape) <= 2:
+ self._selector.hide()
+ else:
+ self._selector.show()
+
+ self._auxSigSlider.setMaximum(len(signals) - 1)
+ if len(signals) > 1:
+ self._auxSigSlider.show()
+ else:
+ self._auxSigSlider.hide()
+ self._auxSigSlider.setValue(0)
+
+ self._updateImage()
+ self._plot.getPlot().resetZoom()
+
+ self._selector.selectionChanged.connect(self._updateImage)
+ self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged)
+
+ def _updateImage(self):
+ selection = self._selector.selection()
+ auxSigIdx = self._auxSigSlider.value()
+
+ images = [img[selection] for img in self.__signals]
+ image = images[auxSigIdx]
+
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+
+ if x_axis is None and y_axis is None:
+ xcalib = NoCalibration()
+ ycalib = NoCalibration()
+ else:
+ if x_axis is None:
+ # no calibration
+ x_axis = numpy.arange(image.shape[1])
+ elif numpy.isscalar(x_axis) or len(x_axis) == 1:
+ # constant axis
+ x_axis = x_axis * numpy.ones((image.shape[1], ))
+ elif len(x_axis) == 2:
+ # linear calibration
+ x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1]
+
+ if y_axis is None:
+ y_axis = numpy.arange(image.shape[0])
+ elif numpy.isscalar(y_axis) or len(y_axis) == 1:
+ y_axis = y_axis * numpy.ones((image.shape[0], ))
+ elif len(y_axis) == 2:
+ y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1]
+
+ xcalib = ArrayCalibration(x_axis)
+ ycalib = ArrayCalibration(y_axis)
+
+ self._plot.setData(image)
+ if xcalib.is_affine():
+ xorigin, xscale = xcalib(0), xcalib.get_slope()
+ else:
+ _logger.warning("Unsupported complex image X axis calibration")
+ xorigin, xscale = 0., 1.
+
+ if ycalib.is_affine():
+ yorigin, yscale = ycalib(0), ycalib.get_slope()
+ else:
+ _logger.warning("Unsupported complex image Y axis calibration")
+ yorigin, yscale = 0., 1.
+
+ self._plot.setOrigin((xorigin, yorigin))
+ self._plot.setScale((xscale, yscale))
+
+ title = ""
+ if self.__title:
+ title += self.__title
+ if not title.strip().endswith(self.__signals_names[auxSigIdx]):
+ title += "\n" + self.__signals_names[auxSigIdx]
+ self._plot.setGraphTitle(title)
+ self._plot.getXAxis().setLabel(self.__x_axis_name)
+ self._plot.getYAxis().setLabel(self.__y_axis_name)
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._plot.setData(None)
+
+
class ArrayStackPlot(qt.QWidget):
"""
Widget for plotting a n-D array (n >= 3) as a stack of images.
@@ -665,4 +856,208 @@ class ArrayStackPlot(qt.QWidget):
self.__x_axis_name])
def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
self._stack_view.clear()
+
+
+class ArrayVolumePlot(qt.QWidget):
+ """
+ Widget for plotting a n-D array (n >= 3) as a 3D scalar field.
+ Three axis arrays can be provided to calibrate the axes.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last 3 dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 3) dimensions of
+ the signal array, and the plot is updated to load the stack corresponding
+ to the selection.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayVolumePlot, self).__init__(parent)
+
+ self.__signal = None
+ self.__signal_name = None
+ # the Z, Y, X axes apply to the last three dimensions of the signal
+ # (in that order)
+ self.__z_axis = None
+ self.__z_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+
+ from silx.gui.plot3d.ScalarFieldView import ScalarFieldView
+ from silx.gui.plot3d import SFViewParamTree
+
+ self._view = ScalarFieldView(self)
+
+ def computeIsolevel(data):
+ data = data[numpy.isfinite(data)]
+ if len(data) == 0:
+ return 0
+ else:
+ return numpy.mean(data) + numpy.std(data)
+
+ self._view.addIsosurface(computeIsolevel, '#FF0000FF')
+
+ # Create a parameter tree for the scalar field view
+ options = SFViewParamTree.TreeView(self._view)
+ options.setSfView(self._view)
+
+ # Add the parameter tree to the main window in a dock widget
+ dock = qt.QDockWidget()
+ dock.setWidget(options)
+ self._view.addDockWidget(qt.Qt.RightDockWidgetArea, dock)
+
+ self._hline = qt.QFrame(self)
+ self._hline.setFrameStyle(qt.QFrame.HLine)
+ self._hline.setFrameShadow(qt.QFrame.Sunken)
+ self._legend = qt.QLabel(self)
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._view)
+ layout.addWidget(self._hline)
+ layout.addWidget(self._legend)
+ layout.addWidget(self._selector)
+
+ self.setLayout(layout)
+
+ def getVolumeView(self):
+ """Returns the plot used for the display
+
+ :rtype: ScalarFieldView
+ """
+ return self._view
+
+ def normalizeComplexData(self, data):
+ """
+ Converts a complex data array to its amplitude, if necessary.
+ :param data: the data to normalize
+ :return:
+ """
+ if hasattr(data, "dtype"):
+ isComplex = numpy.issubdtype(data.dtype, numpy.complexfloating)
+ else:
+ isComplex = isinstance(data, numbers.Complex)
+ if isComplex:
+ data = numpy.absolute(data)
+ return data
+
+ def setData(self, signal,
+ x_axis=None, y_axis=None, z_axis=None,
+ signal_name=None,
+ xlabel=None, ylabel=None, zlabel=None,
+ title=None):
+ """
+
+ :param signal: n-D dataset, whose last 3 dimensions are used as the
+ 3D stack values.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param z_axis: 1-D dataset used as the image's z. If provided,
+ its lengths must be equal to the length of the 3rd to last
+ dimension of ``signal``.
+ :param signal_name: Label used in the legend
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param zlabel: Label for Z axis
+ :param title: Graph title
+ """
+ signal = self.normalizeComplexData(signal)
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateVolume)
+ self.__selector_is_connected = False
+
+ self.__signal = signal
+ self.__signal_name = signal_name or ""
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__z_axis = z_axis
+ self.__z_axis_name = zlabel
+
+ self._selector.setData(signal)
+ self._selector.setAxisNames(["Y", "X", "Z"])
+
+ self._view.setAxesLabels(self.__x_axis_name or 'X',
+ self.__y_axis_name or 'Y',
+ self.__z_axis_name or 'Z')
+ self._updateVolume()
+
+ # the legend label shows the selection slice producing the volume
+ # (only interesting for ndim > 3)
+ if signal.ndim > 3:
+ self._selector.setVisible(True)
+ self._legend.setVisible(True)
+ self._hline.setVisible(True)
+ else:
+ self._selector.setVisible(False)
+ self._legend.setVisible(False)
+ self._hline.setVisible(False)
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateVolume)
+ self.__selector_is_connected = True
+
+ def _updateVolume(self):
+ """Update displayed stack according to the current axes selector
+ data."""
+ data = self._selector.selectedData()
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+ z_axis = self.__z_axis
+
+ offset = []
+ scale = []
+ for axis in [x_axis, y_axis, z_axis]:
+ if axis is None:
+ calibration = NoCalibration()
+ elif len(axis) == 2:
+ calibration = LinearCalibration(
+ y_intercept=axis[0], slope=axis[1])
+ else:
+ calibration = ArrayCalibration(axis)
+ if not calibration.is_affine():
+ _logger.warning("Axis has not linear values, ignored")
+ offset.append(0.)
+ scale.append(1.)
+ else:
+ offset.append(calibration(0))
+ scale.append(calibration.get_slope())
+
+ legend = self.__signal_name + "["
+ for sl in self._selector.selection():
+ if sl == slice(None):
+ legend += ":, "
+ else:
+ legend += str(sl) + ", "
+ legend = legend[:-2] + "]"
+ self._legend.setText("Displayed data: " + legend)
+
+ self._view.setData(data, copy=False)
+ self._view.setScale(*scale)
+ self._view.setTranslation(*offset)
+ self._view.setAxesLabels(self.__x_axis_name,
+ self.__y_axis_name,
+ self.__z_axis_name)
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._view.setData(None)
diff --git a/silx/gui/data/TextFormatter.py b/silx/gui/data/TextFormatter.py
index 1401634..98c37d7 100644
--- a/silx/gui/data/TextFormatter.py
+++ b/silx/gui/data/TextFormatter.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,16 +29,15 @@ __authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "24/07/2018"
-import numpy
+import logging
import numbers
-from silx.third_party import six
+
+import numpy
+import six
+
from silx.gui import qt
-import logging
-try:
- import h5py
-except ImportError:
- h5py = None
+import h5py
_logger = logging.getLogger(__name__)
@@ -322,10 +321,9 @@ class TextFormatter(qt.QObject):
if dtype.kind == 'S':
return self.__formatCharString(data)
elif dtype.kind == 'O':
- if h5py is not None:
- text = self.__formatH5pyObject(data, dtype)
- if text is not None:
- return text
+ text = self.__formatH5pyObject(data, dtype)
+ if text is not None:
+ return text
try:
# Try ascii/utf-8
text = "%s" % data.decode("utf-8")
@@ -339,15 +337,14 @@ class TextFormatter(qt.QObject):
elif isinstance(data, (numpy.integer)):
if dtype is None:
dtype = data.dtype
- if h5py is not None:
- enumType = h5py.check_dtype(enum=dtype)
- if enumType is not None:
- for key, value in enumType.items():
- if value == data:
- result = {}
- result["name"] = key
- result["value"] = data
- return self.__enumFormat % result
+ enumType = h5py.check_dtype(enum=dtype)
+ if enumType is not None:
+ for key, value in enumType.items():
+ if value == data:
+ result = {}
+ result["name"] = key
+ result["value"] = data
+ return self.__enumFormat % result
return self.__integerFormat % data
elif isinstance(data, (numbers.Integral)):
return self.__integerFormat % data
@@ -373,21 +370,20 @@ class TextFormatter(qt.QObject):
template = self.__floatFormat
params = (data.real)
return template % params
- elif h5py is not None and isinstance(data, h5py.h5r.Reference):
+ elif isinstance(data, h5py.h5r.Reference):
dtype = h5py.special_dtype(ref=h5py.Reference)
text = self.__formatH5pyObject(data, dtype)
return text
- elif h5py is not None and isinstance(data, h5py.h5r.RegionReference):
+ elif isinstance(data, h5py.h5r.RegionReference):
dtype = h5py.special_dtype(ref=h5py.RegionReference)
text = self.__formatH5pyObject(data, dtype)
return text
elif isinstance(data, numpy.object_) or dtype is not None:
if dtype is None:
dtype = data.dtype
- if h5py is not None:
- text = self.__formatH5pyObject(data, dtype)
- if text is not None:
- return text
+ text = self.__formatH5pyObject(data, dtype)
+ if text is not None:
+ return text
# That's a numpy object
return str(data)
return str(data)
diff --git a/silx/gui/data/test/test_arraywidget.py b/silx/gui/data/test/test_arraywidget.py
index 50ffc84..6bcbbd3 100644
--- a/silx/gui/data/test/test_arraywidget.py
+++ b/silx/gui/data/test/test_arraywidget.py
@@ -36,10 +36,7 @@ from silx.gui import qt
from silx.gui.data import ArrayTableWidget
from silx.gui.utils.testutils import TestCaseQt
-try:
- import h5py
-except ImportError:
- h5py = None
+import h5py
class TestArrayWidget(TestCaseQt):
@@ -190,7 +187,6 @@ class TestArrayWidget(TestCaseQt):
self.assertIs(b0, b1)
-@unittest.skipIf(h5py is None, "Could not import h5py")
class TestH5pyArrayWidget(TestCaseQt):
"""Basic test for ArrayTableWidget with a dataset.
diff --git a/silx/gui/data/test/test_dataviewer.py b/silx/gui/data/test/test_dataviewer.py
index a681f33..dc6fee8 100644
--- a/silx/gui/data/test/test_dataviewer.py
+++ b/silx/gui/data/test/test_dataviewer.py
@@ -24,7 +24,7 @@
# ###########################################################################*/
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "23/04/2018"
+__date__ = "19/02/2019"
import os
import tempfile
@@ -42,10 +42,7 @@ from silx.gui.data.DataViewerFrame import DataViewerFrame
from silx.gui.utils.testutils import SignalListener
from silx.gui.utils.testutils import TestCaseQt
-try:
- import h5py
-except ImportError:
- h5py = None
+import h5py
class _DataViewMock(DataView):
@@ -170,8 +167,6 @@ class AbstractDataViewerTests(TestCaseQt):
self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
def test_3d_h5_dataset(self):
- if h5py is None:
- self.skipTest("h5py library is not available")
with self.h5_temporary_file() as h5file:
dataset = h5file["data"]
widget = self.create_widget()
@@ -242,8 +237,9 @@ class AbstractDataViewerTests(TestCaseQt):
# replace a view that is a child of a composite view
widget = self.create_widget()
view = _DataViewMock(widget)
- widget.replaceView(DataViews.NXDATA_INVALID_MODE,
- view)
+ replaced = widget.replaceView(DataViews.NXDATA_INVALID_MODE,
+ view)
+ self.assertTrue(replaced)
nxdata_view = widget.getViewFromModeId(DataViews.NXDATA_MODE)
self.assertNotIn(DataViews.NXDATA_INVALID_MODE,
[v.modeId() for v in nxdata_view.availableViews()])
diff --git a/silx/gui/data/test/test_numpyaxesselector.py b/silx/gui/data/test/test_numpyaxesselector.py
index 6b7b58c..df11c1a 100644
--- a/silx/gui/data/test/test_numpyaxesselector.py
+++ b/silx/gui/data/test/test_numpyaxesselector.py
@@ -37,10 +37,7 @@ from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
from silx.gui.utils.testutils import SignalListener
from silx.gui.utils.testutils import TestCaseQt
-try:
- import h5py
-except ImportError:
- h5py = None
+import h5py
class TestNumpyAxesSelector(TestCaseQt):
@@ -121,8 +118,6 @@ class TestNumpyAxesSelector(TestCaseQt):
os.unlink(tmp_name)
def test_h5py_dataset(self):
- if h5py is None:
- self.skipTest("h5py library is not available")
with self.h5_temporary_file() as h5file:
dataset = h5file["data"]
expectedResult = dataset[0]
diff --git a/silx/gui/data/test/test_textformatter.py b/silx/gui/data/test/test_textformatter.py
index 850aa00..935344a 100644
--- a/silx/gui/data/test/test_textformatter.py
+++ b/silx/gui/data/test/test_textformatter.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,17 +29,15 @@ __date__ = "12/12/2017"
import unittest
import shutil
import tempfile
+
import numpy
+import six
from silx.gui.utils.testutils import TestCaseQt
from silx.gui.utils.testutils import SignalListener
from ..TextFormatter import TextFormatter
-from silx.third_party import six
-try:
- import h5py
-except ImportError:
- h5py = None
+import h5py
class TestTextFormatter(TestCaseQt):
@@ -108,8 +106,6 @@ class TestTextFormatterWithH5py(TestCaseQt):
@classmethod
def setUpClass(cls):
super(TestTextFormatterWithH5py, cls).setUpClass()
- if h5py is None:
- raise unittest.SkipTest("h5py is not available")
cls.tmpDirectory = tempfile.mkdtemp()
cls.h5File = h5py.File("%s/formatter.h5" % cls.tmpDirectory, mode="w")
diff --git a/silx/gui/dialog/AbstractDataFileDialog.py b/silx/gui/dialog/AbstractDataFileDialog.py
index 40045fe..c660cd7 100644
--- a/silx/gui/dialog/AbstractDataFileDialog.py
+++ b/silx/gui/dialog/AbstractDataFileDialog.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,29 +28,36 @@ This module contains an :class:`AbstractDataFileDialog`.
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "05/03/2018"
+__date__ = "03/12/2018"
import sys
import os
import logging
-import numpy
import functools
+from distutils.version import LooseVersion
+
+import numpy
+import six
+
import silx.io.url
from silx.gui import qt
from silx.gui.hdf5.Hdf5TreeModel import Hdf5TreeModel
from . import utils
-from silx.third_party import six
from .FileTypeComboBox import FileTypeComboBox
-try:
- import fabio
-except ImportError:
- fabio = None
+
+import fabio
_logger = logging.getLogger(__name__)
+DEFAULT_SIDEBAR_URL = True
+"""Set it to false to disable initilializing of the sidebar urls with the
+default Qt list. This could allow to disable a behaviour known to segfault on
+some version of PyQt."""
+
+
class _IconProvider(object):
FileDialogToParentDir = qt.QStyle.SP_CustomBase + 1
@@ -143,14 +150,22 @@ class _SideBar(qt.QListView):
:rtype: List[str]
"""
urls = []
- if qt.qVersion().startswith("5.") and sys.platform in ["linux", "linux2"]:
+ version = LooseVersion(qt.qVersion())
+ feed_sidebar = True
+
+ if not DEFAULT_SIDEBAR_URL:
+ _logger.debug("Skip default sidebar URLs (from setted variable)")
+ feed_sidebar = False
+ elif version.version[0] == 4 and sys.platform in ["win32"]:
+ # Avoid locking the GUI 5min in case of use of network driver
+ _logger.debug("Skip default sidebar URLs (avoid lock when using network drivers)")
+ feed_sidebar = False
+ elif version < LooseVersion("5.11.2") and qt.BINDING == "PyQt5" and sys.platform in ["linux", "linux2"]:
# Avoid segfault on PyQt5 + gtk
_logger.debug("Skip default sidebar URLs (avoid PyQt5 segfault)")
- pass
- elif qt.qVersion().startswith("4.") and sys.platform in ["win32"]:
- # Avoid 5min of locked GUI relative to network driver
- _logger.debug("Skip default sidebar URLs (avoid lock when using network drivers)")
- else:
+ feed_sidebar = False
+
+ if feed_sidebar:
# Get default shortcut
# There is no other way
d = qt.QFileDialog(self)
@@ -1061,8 +1076,6 @@ class AbstractDataFileDialog(qt.QDialog):
def __openFabioFile(self, filename):
self.__closeFile()
try:
- if fabio is None:
- raise ImportError("Fabio module is not available")
self.__fabio = fabio.open(filename)
self.__openedFiles.append(self.__fabio)
self.__selectedFile = filename
@@ -1108,10 +1121,10 @@ class AbstractDataFileDialog(qt.QDialog):
if codec.is_autodetect():
if self.__isSilxHavePriority(filename):
openners.append(self.__openSilxFile)
- if fabio is not None and self._isFabioFilesSupported():
+ if self._isFabioFilesSupported():
openners.append(self.__openFabioFile)
else:
- if fabio is not None and self._isFabioFilesSupported():
+ if self._isFabioFilesSupported():
openners.append(self.__openFabioFile)
openners.append(self.__openSilxFile)
elif codec.is_silx_codec():
@@ -1159,10 +1172,9 @@ class AbstractDataFileDialog(qt.QDialog):
is_fabio_have_priority = not codec.is_silx_codec() and not self.__isSilxHavePriority(path)
if is_fabio_decoder or is_fabio_have_priority:
# Then it's flat frame container
- if fabio is not None:
- self.__openFabioFile(path)
- if self.__fabio is not None:
- selectedData = _FabioData(self.__fabio)
+ self.__openFabioFile(path)
+ if self.__fabio is not None:
+ selectedData = _FabioData(self.__fabio)
else:
assert(False)
diff --git a/silx/gui/dialog/ColormapDialog.py b/silx/gui/dialog/ColormapDialog.py
index cbbfa5a..9950ad4 100644
--- a/silx/gui/dialog/ColormapDialog.py
+++ b/silx/gui/dialog/ColormapDialog.py
@@ -63,9 +63,10 @@ from __future__ import division
__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
__license__ = "MIT"
-__date__ = "23/05/2018"
+__date__ = "27/11/2018"
+import enum
import logging
import numpy
@@ -73,10 +74,10 @@ import numpy
from .. import qt
from ..colors import Colormap, preferredColormaps
from ..plot import PlotWidget
+from ..plot.items.axis import Axis
from silx.gui.widgets.FloatEdit import FloatEdit
import weakref
from silx.math.combo import min_max
-from silx.third_party import enum
from silx.gui import icons
from silx.math.histogram import Histogramnd
@@ -154,39 +155,59 @@ class _ColormapNameCombox(qt.QComboBox):
qt.QComboBox.__init__(self, parent)
self.__initItems()
- ORIGINAL_NAME = qt.Qt.UserRole + 1
+ LUT_NAME = qt.Qt.UserRole + 1
+ LUT_COLORS = qt.Qt.UserRole + 2
def __initItems(self):
for colormapName in preferredColormaps():
index = self.count()
self.addItem(str.title(colormapName))
- self.setItemIcon(index, self.getIconPreview(colormapName))
- self.setItemData(index, colormapName, role=self.ORIGINAL_NAME)
+ self.setItemIcon(index, self.getIconPreview(name=colormapName))
+ self.setItemData(index, colormapName, role=self.LUT_NAME)
- def getIconPreview(self, colormapName):
+ def getIconPreview(self, name=None, colors=None):
"""Return an icon preview from a LUT name.
This icons are cached into a global structure.
- :param str colormapName: str
+ :param str name: Name of the LUT
+ :param numpy.ndarray colors: Colors identify the LUT
:rtype: qt.QIcon
"""
- if colormapName not in _colormapIconPreview:
- icon = self.createIconPreview(colormapName)
- _colormapIconPreview[colormapName] = icon
- return _colormapIconPreview[colormapName]
-
- def createIconPreview(self, colormapName):
+ if name is not None:
+ iconKey = name
+ else:
+ iconKey = tuple(colors)
+ icon = _colormapIconPreview.get(iconKey, None)
+ if icon is None:
+ icon = self.createIconPreview(name, colors)
+ _colormapIconPreview[iconKey] = icon
+ return icon
+
+ def createIconPreview(self, name=None, colors=None):
"""Create and return an icon preview from a LUT name.
This icons are cached into a global structure.
- :param str colormapName: Name of the LUT
+ :param str name: Name of the LUT
+ :param numpy.ndarray colors: Colors identify the LUT
:rtype: qt.QIcon
"""
- colormap = Colormap(colormapName)
+ colormap = Colormap(name)
size = 32
- lut = colormap.getNColors(size)
+ if name is not None:
+ lut = colormap.getNColors(size)
+ else:
+ lut = colors
+ if len(lut) > size:
+ # Down sample
+ step = int(len(lut) / size)
+ lut = lut[::step]
+ elif len(lut) < size:
+ # Over sample
+ indexes = numpy.arange(size) / float(size) * (len(lut) - 1)
+ indexes = indexes.astype("int")
+ lut = lut[indexes]
if lut is None or len(lut) == 0:
return qt.QIcon()
@@ -204,18 +225,50 @@ class _ColormapNameCombox(qt.QComboBox):
return qt.QIcon(pixmap)
def getCurrentName(self):
- return self.itemData(self.currentIndex(), self.ORIGINAL_NAME)
+ return self.itemData(self.currentIndex(), self.LUT_NAME)
+
+ def getCurrentColors(self):
+ return self.itemData(self.currentIndex(), self.LUT_COLORS)
+
+ def findLutName(self, name):
+ return self.findData(name, role=self.LUT_NAME)
+
+ def findLutColors(self, lut):
+ for index in range(self.count()):
+ if self.itemData(index, role=self.LUT_NAME) is not None:
+ continue
+ colors = self.itemData(index, role=self.LUT_COLORS)
+ if colors is None:
+ continue
+ if numpy.array_equal(colors, lut):
+ return index
+ return -1
+
+ def setCurrentLut(self, colormap):
+ name = colormap.getName()
+ if name is not None:
+ self._setCurrentName(name)
+ else:
+ lut = colormap.getColormapLUT()
+ self._setCurrentLut(lut)
- def findColormap(self, name):
- return self.findData(name, role=self.ORIGINAL_NAME)
+ def _setCurrentLut(self, lut):
+ index = self.findLutColors(lut)
+ if index == -1:
+ index = self.count()
+ self.addItem("Custom")
+ self.setItemIcon(index, self.getIconPreview(colors=lut))
+ self.setItemData(index, None, role=self.LUT_NAME)
+ self.setItemData(index, lut, role=self.LUT_COLORS)
+ self.setCurrentIndex(index)
- def setCurrentName(self, name):
- index = self.findColormap(name)
+ def _setCurrentName(self, name):
+ index = self.findLutName(name)
if index < 0:
index = self.count()
self.addItem(str.title(name))
- self.setItemIcon(index, self.getIconPreview(name))
- self.setItemData(index, name, role=self.ORIGINAL_NAME)
+ self.setItemIcon(index, self.getIconPreview(name=name))
+ self.setItemData(index, name, role=self.LUT_NAME)
self.setCurrentIndex(index)
@@ -255,6 +308,7 @@ class ColormapDialog(qt.QDialog):
the self.setcolormap is a callback)
"""
+ self.__displayInvalidated = False
self._histogramData = None
self._minMaxWasEdited = False
self._initialRange = None
@@ -276,20 +330,19 @@ class ColormapDialog(qt.QDialog):
# Colormap row
self._comboBoxColormap = _ColormapNameCombox(parent=formWidget)
- self._comboBoxColormap.currentIndexChanged[int].connect(self._updateName)
+ self._comboBoxColormap.currentIndexChanged[int].connect(self._updateLut)
formLayout.addRow('Colormap:', self._comboBoxColormap)
# Normalization row
self._normButtonLinear = qt.QRadioButton('Linear')
self._normButtonLinear.setChecked(True)
self._normButtonLog = qt.QRadioButton('Log')
- self._normButtonLog.toggled.connect(self._activeLogNorm)
normButtonGroup = qt.QButtonGroup(self)
normButtonGroup.setExclusive(True)
normButtonGroup.addButton(self._normButtonLinear)
normButtonGroup.addButton(self._normButtonLog)
- self._normButtonLinear.toggled[bool].connect(self._updateLinearNorm)
+ normButtonGroup.buttonClicked[qt.QAbstractButton].connect(self._updateNormalization)
normLayout = qt.QHBoxLayout()
normLayout.setContentsMargins(0, 0, 0, 0)
@@ -388,9 +441,17 @@ class ColormapDialog(qt.QDialog):
self.setFixedSize(self.sizeHint())
self._applyColormap()
+ def _displayLater(self):
+ self.__displayInvalidated = True
+
def showEvent(self, event):
self.visibleChanged.emit(True)
super(ColormapDialog, self).showEvent(event)
+ if self.isVisible():
+ if self.__displayInvalidated:
+ self._applyColormap()
+ self._updateDataInPlot()
+ self.__displayInvalidated = False
def closeEvent(self, event):
if not self.isModal():
@@ -434,6 +495,54 @@ class ColormapDialog(qt.QDialog):
def sizeHint(self):
return self.layout().minimumSize()
+ def _computeView(self, dataMin, dataMax):
+ """Compute the location of the view according to the bound of the data
+
+ :rtype: Tuple(float, float)
+ """
+ marginRatio = 1.0 / 6.0
+ scale = self._plot.getXAxis().getScale()
+
+ if self._dataRange is not None:
+ if scale == Axis.LOGARITHMIC:
+ minRange = self._dataRange[1]
+ else:
+ minRange = self._dataRange[0]
+ maxRange = self._dataRange[2]
+ if minRange is not None:
+ dataMin = min(dataMin, minRange)
+ dataMax = max(dataMax, maxRange)
+
+ if self._histogramData is not None:
+ info = min_max(self._histogramData[1])
+ if scale == Axis.LOGARITHMIC:
+ minHisto = info.min_positive
+ else:
+ minHisto = info.minimum
+ maxHisto = info.maximum
+ if minHisto is not None:
+ dataMin = min(dataMin, minHisto)
+ dataMax = max(dataMax, maxHisto)
+
+ if scale == Axis.LOGARITHMIC:
+ epsilon = numpy.finfo(numpy.float32).eps
+ if dataMin == 0:
+ dataMin = epsilon
+ if dataMax < dataMin:
+ dataMax = dataMin + epsilon
+ marge = marginRatio * abs(numpy.log10(dataMax) - numpy.log10(dataMin))
+ viewMin = 10**(numpy.log10(dataMin) - marge)
+ viewMax = 10**(numpy.log10(dataMax) + marge)
+ else: # scale == Axis.LINEAR:
+ marge = marginRatio * abs(dataMax - dataMin)
+ if marge < 0.0001:
+ # Smaller that the QLineEdit precision
+ marge = 0.0001
+ viewMin = dataMin - marge
+ viewMax = dataMax + marge
+
+ return viewMin, viewMax
+
def _plotUpdate(self, updateMarkers=True):
"""Update the plot content
@@ -454,27 +563,8 @@ class ColormapDialog(qt.QDialog):
if minData > maxData:
# avoid a full collapse
minData, maxData = maxData, minData
- minimum = minData
- maximum = maxData
-
- if self._dataRange is not None:
- minRange = self._dataRange[0]
- maxRange = self._dataRange[2]
- minimum = min(minimum, minRange)
- maximum = max(maximum, maxRange)
- if self._histogramData is not None:
- minHisto = self._histogramData[1][0]
- maxHisto = self._histogramData[1][-1]
- minimum = min(minimum, minHisto)
- maximum = max(maximum, maxHisto)
-
- marge = abs(maximum - minimum) / 6.0
- if marge < 0.0001:
- # Smaller that the QLineEdit precision
- marge = 0.0001
-
- minView, maxView = minimum - marge, maximum + marge
+ minView, maxView = self._computeView(minData, maxData)
if updateMarkers:
# Save the state in we are not moving the markers
@@ -483,6 +573,9 @@ class ColormapDialog(qt.QDialog):
minView = min(minView, self._initialRange[0])
maxView = max(maxView, self._initialRange[1])
+ if minView > minData:
+ # Hide the min range
+ minData = minView
x = [minView, minData, maxData, maxView]
y = [0, 0, 1, 1]
@@ -493,26 +586,37 @@ class ColormapDialog(qt.QDialog):
linestyle='-',
resetzoom=False)
+ scale = self._plot.getXAxis().getScale()
+
if updateMarkers:
- minDraggable = (self._colormap().isEditable() and
- not self._minValue.isAutoChecked())
- self._plot.addXMarker(
- self._minValue.getFiniteValue(),
- legend='Min',
- text='Min',
- draggable=minDraggable,
- color='blue',
- constraint=self._plotMinMarkerConstraint)
-
- maxDraggable = (self._colormap().isEditable() and
- not self._maxValue.isAutoChecked())
- self._plot.addXMarker(
- self._maxValue.getFiniteValue(),
- legend='Max',
- text='Max',
- draggable=maxDraggable,
- color='blue',
- constraint=self._plotMaxMarkerConstraint)
+ posMin = self._minValue.getFiniteValue()
+ posMax = self._maxValue.getFiniteValue()
+
+ def isDisplayable(pos):
+ if scale == Axis.LOGARITHMIC:
+ return pos > 0.0
+ return True
+
+ if isDisplayable(posMin):
+ minDraggable = (self._colormap().isEditable() and
+ not self._minValue.isAutoChecked())
+ self._plot.addXMarker(
+ posMin,
+ legend='Min',
+ text='Min',
+ draggable=minDraggable,
+ color='blue',
+ constraint=self._plotMinMarkerConstraint)
+ if isDisplayable(posMax):
+ maxDraggable = (self._colormap().isEditable() and
+ not self._maxValue.isAutoChecked())
+ self._plot.addXMarker(
+ posMax,
+ legend='Max',
+ text='Max',
+ draggable=maxDraggable,
+ color='blue',
+ constraint=self._plotMaxMarkerConstraint)
self._plot.resetZoom()
@@ -546,7 +650,7 @@ class ColormapDialog(qt.QDialog):
"""Compute the data range as used by :meth:`setDataRange`.
:param data: The data to process
- :rtype: Tuple(float, float, float)
+ :rtype: List[Union[None,float]]
"""
if data is None or len(data) == 0:
return None, None, None
@@ -558,8 +662,6 @@ class ColormapDialog(qt.QDialog):
if dataRange is not None:
min_positive = dataRange.min_positive
- if min_positive is None:
- min_positive = float('nan')
dataRange = dataRange.minimum, min_positive, dataRange.maximum
if dataRange is None or len(dataRange) != 3:
@@ -571,7 +673,7 @@ class ColormapDialog(qt.QDialog):
return dataRange
@staticmethod
- def computeHistogram(data):
+ def computeHistogram(data, scale=Axis.LINEAR):
"""Compute the data histogram as used by :meth:`setHistogram`.
:param data: The data to process
@@ -588,7 +690,12 @@ class ColormapDialog(qt.QDialog):
if len(_data) == 0:
return None, None
+ if scale == Axis.LOGARITHMIC:
+ _data = numpy.log10(_data)
xmin, xmax = min_max(_data, min_positive=False, finite=True)
+ if xmin is None:
+ return None, None
+
nbins = min(256, int(numpy.sqrt(_data.size)))
data_range = xmin, xmax
@@ -601,7 +708,10 @@ class ColormapDialog(qt.QDialog):
_data = _data.ravel().astype(numpy.float32)
histogram = Histogramnd(_data, n_bins=nbins, histo_range=data_range)
- return histogram.histo, histogram.edges[0]
+ bins = histogram.edges[0]
+ if scale == Axis.LOGARITHMIC:
+ bins = 10**bins
+ return histogram.histo, bins
def _getData(self):
if self._data is None:
@@ -624,7 +734,10 @@ class ColormapDialog(qt.QDialog):
else:
self._data = weakref.ref(data, self._dataAboutToFinalize)
- self._updateDataInPlot()
+ if self.isVisible():
+ self._updateDataInPlot()
+ else:
+ self._displayLater()
def _setDataInPlotMode(self, mode):
if self._dataInPlotMode == mode:
@@ -660,10 +773,15 @@ class ColormapDialog(qt.QDialog):
self.setDataRange(*result)
elif mode == _DataInPlotMode.HISTOGRAM:
# The histogram should be done in a worker thread
- result = self.computeHistogram(data)
+ result = self.computeHistogram(data, scale=self._plot.getXAxis().getScale())
self.setHistogram(*result)
self.setDataRange()
+ def _invalidateHistogram(self):
+ """Recompute the histogram if it is displayed"""
+ if self._dataInPlotMode == _DataInPlotMode.HISTOGRAM:
+ self._updateDataInPlot()
+
def _colormapAboutToFinalize(self, weakrefColormap):
"""Callback when the data weakref is about to be finalized."""
if self._colormap is weakrefColormap:
@@ -727,9 +845,9 @@ class ColormapDialog(qt.QDialog):
"""
colormap = self.getColormap()
if colormap is not None and self._colormapStoredState is not None:
- if self._colormap()._toDict() != self._colormapStoredState:
+ if colormap != self._colormapStoredState:
self._ignoreColormapChange = True
- colormap._setFromDict(self._colormapStoredState)
+ colormap.setFromColormap(self._colormapStoredState)
self._ignoreColormapChange = False
self._applyColormap()
@@ -740,12 +858,18 @@ class ColormapDialog(qt.QDialog):
:param float positiveMin: The positive minimum of the data
:param float maximum: The maximum of the data
"""
- if minimum is None or positiveMin is None or maximum is None:
+ scale = self._plot.getXAxis().getScale()
+ if scale == Axis.LOGARITHMIC:
+ dataMin, dataMax = positiveMin, maximum
+ else:
+ dataMin, dataMax = minimum, maximum
+
+ if dataMin is None or dataMax is None:
self._dataRange = None
self._plot.remove(legend='Range', kind='histogram')
else:
hist = numpy.array([1])
- bin_edges = numpy.array([minimum, maximum])
+ bin_edges = numpy.array([dataMin, dataMax])
self._plot.addHistogram(hist,
bin_edges,
legend="Range",
@@ -801,7 +925,7 @@ class ColormapDialog(qt.QDialog):
"""
colormap = self.getColormap()
if colormap is not None:
- self._colormapStoredState = colormap._toDict()
+ self._colormapStoredState = colormap.copy()
else:
self._colormapStoredState = None
@@ -830,8 +954,11 @@ class ColormapDialog(qt.QDialog):
self._colormap = colormap
self.storeCurrentState()
- self._updateResetButton()
- self._applyColormap()
+ if self.isVisible():
+ self._applyColormap()
+ else:
+ self._updateResetButton()
+ self._displayLater()
def _updateResetButton(self):
resetButton = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
@@ -839,7 +966,7 @@ class ColormapDialog(qt.QDialog):
colormap = self.getColormap()
if colormap is not None and colormap.isEditable():
# can reset only in the case the colormap changed
- rStateEnabled = colormap._toDict() != self._colormapStoredState
+ rStateEnabled = colormap != self._colormapStoredState
resetButton.setEnabled(rStateEnabled)
def _applyColormap(self):
@@ -856,12 +983,8 @@ class ColormapDialog(qt.QDialog):
self._maxValue.setEnabled(False)
else:
self._ignoreColormapChange = True
-
- if colormap.getName() is not None:
- name = colormap.getName()
- self._comboBoxColormap.setCurrentName(name)
- self._comboBoxColormap.setEnabled(self._colormap().isEditable())
-
+ self._comboBoxColormap.setCurrentLut(colormap)
+ self._comboBoxColormap.setEnabled(colormap.isEditable())
assert colormap.getNormalization() in Colormap.NORMALIZATIONS
self._normButtonLinear.setChecked(
colormap.getNormalization() == Colormap.LINEAR)
@@ -870,12 +993,17 @@ class ColormapDialog(qt.QDialog):
vmin = colormap.getVMin()
vmax = colormap.getVMax()
dataRange = colormap.getColormapRange()
- self._normButtonLinear.setEnabled(self._colormap().isEditable())
- self._normButtonLog.setEnabled(self._colormap().isEditable())
+ self._normButtonLinear.setEnabled(colormap.isEditable())
+ self._normButtonLog.setEnabled(colormap.isEditable())
self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None)
self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None)
- self._minValue.setEnabled(self._colormap().isEditable())
- self._maxValue.setEnabled(self._colormap().isEditable())
+ self._minValue.setEnabled(colormap.isEditable())
+ self._maxValue.setEnabled(colormap.isEditable())
+
+ axis = self._plot.getXAxis()
+ scale = axis.LINEAR if colormap.getNormalization() == Colormap.LINEAR else axis.LOGARITHMIC
+ axis.setScale(scale)
+
self._ignoreColormapChange = False
self._plotUpdate()
@@ -908,26 +1036,47 @@ class ColormapDialog(qt.QDialog):
self._plotUpdate()
self._updateResetButton()
- def _updateName(self):
+ def _updateLut(self):
if self._ignoreColormapChange is True:
return
- if self._colormap():
+ colormap = self._colormap()
+ if colormap is not None:
self._ignoreColormapChange = True
- self._colormap().setName(
- self._comboBoxColormap.getCurrentName())
+ name = self._comboBoxColormap.getCurrentName()
+ if name is not None:
+ colormap.setName(name)
+ else:
+ lut = self._comboBoxColormap.getCurrentColors()
+ colormap.setColormapLUT(lut)
self._ignoreColormapChange = False
- def _updateLinearNorm(self, isNormLinear):
+ def _updateNormalization(self, button):
if self._ignoreColormapChange is True:
return
+ if not button.isChecked():
+ return
+
+ if button is self._normButtonLinear:
+ norm = Colormap.LINEAR
+ scale = Axis.LINEAR
+ elif button is self._normButtonLog:
+ norm = Colormap.LOGARITHM
+ scale = Axis.LOGARITHMIC
+ else:
+ assert(False)
- if self._colormap():
+ colormap = self.getColormap()
+ if colormap is not None:
self._ignoreColormapChange = True
- norm = Colormap.LINEAR if isNormLinear else Colormap.LOGARITHM
- self._colormap().setNormalization(norm)
+ colormap.setNormalization(norm)
+ axis = self._plot.getXAxis()
+ axis.setScale(scale)
self._ignoreColormapChange = False
+ self._invalidateHistogram()
+ self._updateMinMaxData()
+
def _minMaxTextEdited(self, text):
"""Handle _minValue and _maxValue textEdited signal"""
self._minMaxWasEdited = True
@@ -975,13 +1124,3 @@ class ColormapDialog(qt.QDialog):
else:
# Use QDialog keyPressEvent
super(ColormapDialog, self).keyPressEvent(event)
-
- def _activeLogNorm(self, isLog):
- if self._ignoreColormapChange is True:
- return
- if self._colormap():
- self._ignoreColormapChange = True
- norm = Colormap.LOGARITHM if isLog is True else Colormap.LINEAR
- self._colormap().setNormalization(norm)
- self._ignoreColormapChange = False
- self._updateMinMaxData()
diff --git a/silx/gui/dialog/DataFileDialog.py b/silx/gui/dialog/DataFileDialog.py
index 7ff1258..d2d76a3 100644
--- a/silx/gui/dialog/DataFileDialog.py
+++ b/silx/gui/dialog/DataFileDialog.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,16 +30,14 @@ __authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "14/02/2018"
+import enum
import logging
from silx.gui import qt
from silx.gui.hdf5.Hdf5Formatter import Hdf5Formatter
import silx.io
from .AbstractDataFileDialog import AbstractDataFileDialog
-from silx.third_party import enum
-try:
- import fabio
-except ImportError:
- fabio = None
+
+import fabio
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/dialog/FileTypeComboBox.py b/silx/gui/dialog/FileTypeComboBox.py
index 07b11cf..92529bc 100644
--- a/silx/gui/dialog/FileTypeComboBox.py
+++ b/silx/gui/dialog/FileTypeComboBox.py
@@ -28,12 +28,9 @@ This module contains utilitaries used by other dialog modules.
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "06/02/2018"
+__date__ = "17/01/2019"
-try:
- import fabio
-except ImportError:
- fabio = None
+import fabio
import silx.io
from silx.gui import qt
@@ -82,7 +79,7 @@ class FileTypeComboBox(qt.QComboBox):
def __initItems(self):
self.clear()
- if fabio is not None and self.__fabioUrlSupported:
+ if self.__fabioUrlSupported:
self.__insertFabioFormats()
self.__insertSilxFormats()
self.__insertAllSupported()
@@ -138,21 +135,36 @@ class FileTypeComboBox(qt.QComboBox):
def __insertFabioFormats(self):
formats = fabio.fabioformats.get_classes(reader=True)
+ from fabio import fabioutils
+ if hasattr(fabioutils, "COMPRESSED_EXTENSIONS"):
+ compressedExtensions = fabioutils.COMPRESSED_EXTENSIONS
+ else:
+ # Support for fabio < 0.9
+ compressedExtensions = set(["gz", "bz2"])
+
extensions = []
allExtensions = set([])
+ def extensionsIterator(reader):
+ for extension in reader.DEFAULT_EXTENSIONS:
+ yield "*.%s" % extension
+ for compressedExtension in compressedExtensions:
+ for extension in reader.DEFAULT_EXTENSIONS:
+ yield "*.%s.%s" % (extension, compressedExtension)
+
for reader in formats:
if not hasattr(reader, "DESCRIPTION"):
continue
if not hasattr(reader, "DEFAULT_EXTENSIONS"):
continue
- ext = reader.DEFAULT_EXTENSIONS
- ext = ["*.%s" % e for e in ext]
+ displayext = reader.DEFAULT_EXTENSIONS
+ displayext = ["*.%s" % e for e in displayext]
+ ext = list(extensionsIterator(reader))
allExtensions.update(ext)
if ext == []:
ext = ["*"]
- extensions.append((reader.DESCRIPTION, ext, reader.codec_name()))
+ extensions.append((reader.DESCRIPTION, displayext, ext, reader.codec_name()))
extensions = list(sorted(extensions))
allExtensions = list(sorted(list(allExtensions)))
@@ -162,13 +174,14 @@ class FileTypeComboBox(qt.QComboBox):
self.setItemData(index, Codec(any_fabio=True), role=self.CODEC_ROLE)
for e in extensions:
+ description, displayExt, allExt, _codecName = e
index = self.count()
if len(e[1]) < 10:
- self.addItem("%s%s (%s)" % (self.INDENTATION, e[0], " ".join(e[1])))
+ self.addItem("%s%s (%s)" % (self.INDENTATION, description, " ".join(displayExt)))
else:
- self.addItem(e[0])
- codec = Codec(fabio_codec=e[2])
- self.setItemData(index, e[1], role=self.EXTENSIONS_ROLE)
+ self.addItem("%s%s" % (self.INDENTATION, description))
+ codec = Codec(fabio_codec=_codecName)
+ self.setItemData(index, allExt, role=self.EXTENSIONS_ROLE)
self.setItemData(index, codec, role=self.CODEC_ROLE)
def itemExtensions(self, index):
diff --git a/silx/gui/dialog/ImageFileDialog.py b/silx/gui/dialog/ImageFileDialog.py
index c324071..ef6b472 100644
--- a/silx/gui/dialog/ImageFileDialog.py
+++ b/silx/gui/dialog/ImageFileDialog.py
@@ -36,10 +36,7 @@ from silx.gui import qt
from silx.gui.plot.PlotWidget import PlotWidget
from .AbstractDataFileDialog import AbstractDataFileDialog
import silx.io
-try:
- import fabio
-except ImportError:
- fabio = None
+import fabio
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/dialog/SafeFileSystemModel.py b/silx/gui/dialog/SafeFileSystemModel.py
index 198e089..26954e3 100644
--- a/silx/gui/dialog/SafeFileSystemModel.py
+++ b/silx/gui/dialog/SafeFileSystemModel.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,8 +34,10 @@ import sys
import os.path
import logging
import weakref
+
+import six
+
from silx.gui import qt
-from silx.third_party import six
from .SafeFileIconProvider import SafeFileIconProvider
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/dialog/test/test_colormapdialog.py b/silx/gui/dialog/test/test_colormapdialog.py
index 6e50193..cbc9de1 100644
--- a/silx/gui/dialog/test/test_colormapdialog.py
+++ b/silx/gui/dialog/test/test_colormapdialog.py
@@ -26,13 +26,11 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "23/05/2018"
+__date__ = "09/11/2018"
-import doctest
import unittest
-from silx.gui.utils.testutils import qWaitForWindowExposedAndActivate
from silx.gui import qt
from silx.gui.dialog import ColormapDialog
from silx.gui.utils.testutils import TestCaseQt
@@ -47,23 +45,6 @@ import numpy.random
_qapp = qt.QApplication.instance() or qt.QApplication([])
-def _tearDownQt(docTest):
- """Tear down to use for test from docstring.
-
- Checks that dialog widget is displayed
- """
- dialogWidget = docTest.globs['dialog']
- qWaitForWindowExposedAndActivate(dialogWidget)
- dialogWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
- dialogWidget.close()
- del dialogWidget
- _qapp.processEvents()
-
-
-cmapDocTestSuite = doctest.DocTestSuite(ColormapDialog, tearDown=_tearDownQt)
-"""Test suite of tests from the module's docstrings."""
-
-
class TestColormapDialog(TestCaseQt, ParametricTestCase):
"""Test the ColormapDialog."""
def setUp(self):
@@ -86,10 +67,12 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase):
editing the same colormap"""
colormapDiag2 = ColormapDialog.ColormapDialog()
colormapDiag2.setColormap(self.colormap)
+ colormapDiag2.show()
self.colormapDiag.setColormap(self.colormap)
+ self.colormapDiag.show()
- self.colormapDiag._comboBoxColormap.setCurrentName('red')
- self.colormapDiag._normButtonLog.setChecked(True)
+ self.colormapDiag._comboBoxColormap._setCurrentName('red')
+ self.colormapDiag._normButtonLog.click()
self.assertTrue(self.colormap.getName() == 'red')
self.assertTrue(self.colormapDiag.getColormap().getName() == 'red')
self.assertTrue(self.colormap.getNormalization() == 'log')
@@ -178,6 +161,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase):
def testSetColormapIsCorrect(self):
"""Make sure the interface fir the colormap when set a new colormap"""
self.colormap.setName('red')
+ self.colormapDiag.show()
for norm in (Colormap.NORMALIZATIONS):
for autoscale in (True, False):
if autoscale is True:
@@ -211,7 +195,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase):
self.colormapDiag.show()
del self.colormap
self.assertTrue(self.colormapDiag.getColormap() is None)
- self.colormapDiag._comboBoxColormap.setCurrentName('blue')
+ self.colormapDiag._comboBoxColormap._setCurrentName('blue')
def testColormapEditedOutside(self):
"""Make sure the GUI is still up to date if the colormap is modified
@@ -274,7 +258,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase):
cb = self.colormapDiag._comboBoxColormap
self.assertTrue(cb.getCurrentName() == colormapName)
cb.setCurrentIndex(0)
- index = cb.findColormap(colormapName)
+ index = cb.findLutName(colormapName)
assert index is not 0 # if 0 then the rest of the test has no sense
cb.setCurrentIndex(index)
self.assertTrue(cb.getCurrentName() == colormapName)
@@ -283,6 +267,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase):
"""Test that the colormapDialog is correctly updated when changing the
colormap editable status"""
colormap = Colormap(normalization='linear', vmin=1.0, vmax=10.0)
+ self.colormapDiag.show()
self.colormapDiag.setColormap(colormap)
for editable in (True, False):
with self.subTest(editable=editable):
@@ -302,7 +287,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase):
# False
self.colormapDiag.setModal(False)
colormap.setEditable(True)
- self.colormapDiag._normButtonLog.setChecked(True)
+ self.colormapDiag._normButtonLog.click()
resetButton = self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
self.assertTrue(resetButton.isEnabled())
colormap.setEditable(False)
@@ -387,7 +372,6 @@ class TestColormapAction(TestCaseQt):
def suite():
test_suite = unittest.TestSuite()
- test_suite.addTest(cmapDocTestSuite)
for testClass in (TestColormapDialog, TestColormapAction):
test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
testClass))
diff --git a/silx/gui/dialog/test/test_datafiledialog.py b/silx/gui/dialog/test/test_datafiledialog.py
index aff6bc4..06f8961 100644
--- a/silx/gui/dialog/test/test_datafiledialog.py
+++ b/silx/gui/dialog/test/test_datafiledialog.py
@@ -36,16 +36,8 @@ import shutil
import os
import io
import weakref
-
-try:
- import fabio
-except ImportError:
- fabio = None
-try:
- import h5py
-except ImportError:
- h5py = None
-
+import fabio
+import h5py
import silx.io.url
from silx.gui import qt
from silx.gui.utils import testutils
@@ -62,36 +54,33 @@ def setUpModule():
data = numpy.arange(100 * 100)
data.shape = 100, 100
- if fabio is not None:
- filename = _tmpDirectory + "/singleimage.edf"
- image = fabio.edfimage.EdfImage(data=data)
- image.write(filename)
-
- if h5py is not None:
- filename = _tmpDirectory + "/data.h5"
- f = h5py.File(filename, "w")
- f["scalar"] = 10
- f["image"] = data
- f["cube"] = [data, data + 1, data + 2]
- f["complex_image"] = data * 1j
- f["group/image"] = data
- f["nxdata/foo"] = 10
- f["nxdata"].attrs["NX_class"] = u"NXdata"
- f.close()
-
- if h5py is not None:
- directory = os.path.join(_tmpDirectory, "data")
- os.mkdir(directory)
- filename = os.path.join(directory, "data.h5")
- f = h5py.File(filename, "w")
- f["scalar"] = 10
- f["image"] = data
- f["cube"] = [data, data + 1, data + 2]
- f["complex_image"] = data * 1j
- f["group/image"] = data
- f["nxdata/foo"] = 10
- f["nxdata"].attrs["NX_class"] = u"NXdata"
- f.close()
+ filename = _tmpDirectory + "/singleimage.edf"
+ image = fabio.edfimage.EdfImage(data=data)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/data.h5"
+ f = h5py.File(filename, "w")
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+ f["nxdata/foo"] = 10
+ f["nxdata"].attrs["NX_class"] = u"NXdata"
+ f.close()
+
+ directory = os.path.join(_tmpDirectory, "data")
+ os.mkdir(directory)
+ filename = os.path.join(directory, "data.h5")
+ f = h5py.File(filename, "w")
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+ f["nxdata/foo"] = 10
+ f["nxdata"].attrs["NX_class"] = u"NXdata"
+ f.close()
filename = _tmpDirectory + "/badformat.h5"
with io.open(filename, "wb") as f:
@@ -185,8 +174,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertEqual(dialog.result(), qt.QDialog.Rejected)
def testSelectRoot_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -211,8 +198,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertEqual(dialog.result(), qt.QDialog.Accepted)
def testSelectGroup_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -243,8 +228,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertEqual(dialog.result(), qt.QDialog.Accepted)
def testSelectDataset_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -275,8 +258,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertEqual(dialog.result(), qt.QDialog.Accepted)
def testClickOnBackToParentTool(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -307,8 +288,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(url.text(), _tmpDirectory)
def testClickOnBackToRootTool(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -332,8 +311,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# self.assertFalse(button.isEnabled())
def testClickOnBackToDirectoryTool(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -361,8 +338,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.allowedLeakingWidgets = 1
def testClickOnHistoryTools(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -402,8 +377,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(url.text(), path3)
def testSelectImageFromEdf(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -417,8 +390,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.selectedUrl(), url.path())
def testSelectImage(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -433,8 +404,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.selectedUrl(), path)
def testSelectScalar(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -449,8 +418,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.selectedUrl(), path)
def testSelectGroup(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -467,8 +434,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(uri.data_path(), "/group")
def testSelectRoot(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -485,8 +450,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(uri.data_path(), "/")
def testSelectH5_Activate(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -533,10 +496,6 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
return selectable
def testFilterExtensions(self):
- if h5py is None:
- self.skipTest("h5py is missing")
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -558,8 +517,6 @@ class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin):
return dialog
def testSelectGroup_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -585,8 +542,6 @@ class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin):
self.assertFalse(button.isEnabled())
def testSelectDataset_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -632,8 +587,6 @@ class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin):
return dialog
def testSelectGroup_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -666,8 +619,6 @@ class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin):
self.assertRaises(Exception, dialog.selectedData)
def testSelectDataset_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -711,8 +662,6 @@ class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
return dialog
def testSelectGroupRefused_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -740,8 +689,6 @@ class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
self.assertRaises(Exception, dialog.selectedData)
def testSelectNXdataAccepted_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
dialog.show()
@@ -944,8 +891,6 @@ class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
self.assertIsNone(dialog._selectedData())
def testBadSubpath(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
self.qWaitForPendingActions(dialog)
@@ -965,8 +910,6 @@ class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
self.assertEqual(url.data_path(), "/group")
def testUnsupportedSlicingPath(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
self.qWaitForPendingActions(dialog)
dialog.selectUrl(_tmpDirectory + "/data.h5?path=/cube&slice=0")
diff --git a/silx/gui/dialog/test/test_imagefiledialog.py b/silx/gui/dialog/test/test_imagefiledialog.py
index 66469f3..068dcb9 100644
--- a/silx/gui/dialog/test/test_imagefiledialog.py
+++ b/silx/gui/dialog/test/test_imagefiledialog.py
@@ -36,16 +36,8 @@ import shutil
import os
import io
import weakref
-
-try:
- import fabio
-except ImportError:
- fabio = None
-try:
- import h5py
-except ImportError:
- h5py = None
-
+import fabio
+import h5py
import silx.io.url
from silx.gui import qt
from silx.gui.utils import testutils
@@ -63,42 +55,39 @@ def setUpModule():
data = numpy.arange(100 * 100)
data.shape = 100, 100
- if fabio is not None:
- filename = _tmpDirectory + "/singleimage.edf"
- image = fabio.edfimage.EdfImage(data=data)
- image.write(filename)
-
- filename = _tmpDirectory + "/multiframe.edf"
- image = fabio.edfimage.EdfImage(data=data)
- image.appendFrame(data=data + 1)
- image.appendFrame(data=data + 2)
- image.write(filename)
-
- filename = _tmpDirectory + "/singleimage.msk"
- image = fabio.fit2dmaskimage.Fit2dMaskImage(data=data % 2 == 1)
- image.write(filename)
-
- if h5py is not None:
- filename = _tmpDirectory + "/data.h5"
- f = h5py.File(filename, "w")
- f["scalar"] = 10
- f["image"] = data
- f["cube"] = [data, data + 1, data + 2]
- f["complex_image"] = data * 1j
- f["group/image"] = data
- f.close()
-
- if h5py is not None:
- directory = os.path.join(_tmpDirectory, "data")
- os.mkdir(directory)
- filename = os.path.join(directory, "data.h5")
- f = h5py.File(filename, "w")
- f["scalar"] = 10
- f["image"] = data
- f["cube"] = [data, data + 1, data + 2]
- f["complex_image"] = data * 1j
- f["group/image"] = data
- f.close()
+ filename = _tmpDirectory + "/singleimage.edf"
+ image = fabio.edfimage.EdfImage(data=data)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/multiframe.edf"
+ image = fabio.edfimage.EdfImage(data=data)
+ image.appendFrame(data=data + 1)
+ image.appendFrame(data=data + 2)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/singleimage.msk"
+ image = fabio.fit2dmaskimage.Fit2dMaskImage(data=data % 2 == 1)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/data.h5"
+ f = h5py.File(filename, "w")
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+ f.close()
+
+ directory = os.path.join(_tmpDirectory, "data")
+ os.mkdir(directory)
+ filename = os.path.join(directory, "data.h5")
+ f = h5py.File(filename, "w")
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+ f.close()
filename = _tmpDirectory + "/badformat.edf"
with io.open(filename, "wb") as f:
@@ -192,8 +181,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertEqual(dialog.result(), qt.QDialog.Rejected)
def testDisplayAndClickOpen(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -259,8 +246,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertEqual(dialog.viewMode(), qt.QFileDialog.List)
def testClickOnBackToParentTool(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -291,8 +276,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(url.text(), _tmpDirectory)
def testClickOnBackToRootTool(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -316,8 +299,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# self.assertFalse(button.isEnabled())
def testClickOnBackToDirectoryTool(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -345,8 +326,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.allowedLeakingWidgets = 1
def testClickOnHistoryTools(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -386,8 +365,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(url.text(), path3)
def testSelectImageFromEdf(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -402,8 +379,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.selectedUrl(), path)
def testSelectImageFromEdf_Activate(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -426,8 +401,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.selectedUrl(), path)
def testSelectFrameFromEdf(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -444,8 +417,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.selectedUrl(), path)
def testSelectImageFromMsk(self):
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -460,8 +431,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.selectedUrl(), path)
def testSelectImageFromH5(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -476,8 +445,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.selectedUrl(), path)
def testSelectH5_Activate(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -498,8 +465,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.selectedUrl(), path)
def testSelectFrameFromH5(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -541,10 +506,6 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
return selectable
def testFilterExtensions(self):
- if h5py is None:
- self.skipTest("h5py is missing")
- if fabio is None:
- self.skipTest("fabio is missing")
dialog = self.createDialog()
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
filters = testutils.findChildren(dialog, qt.QWidget, name="fileTypeCombo")[0]
@@ -745,16 +706,12 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
self.assertSamePath(dialog.directory(), _tmpDirectory)
def testBadDataType(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.selectUrl(_tmpDirectory + "/data.h5::/complex_image")
self.qWaitForPendingActions(dialog)
self.assertIsNone(dialog._selectedData())
def testBadDataShape(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
dialog.selectUrl(_tmpDirectory + "/data.h5::/unknown")
self.qWaitForPendingActions(dialog)
@@ -773,8 +730,6 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
self.assertIsNone(dialog._selectedData())
def testBadSubpath(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
self.qWaitForPendingActions(dialog)
@@ -794,8 +749,6 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
self.assertEqual(url.data_path(), "/group")
def testBadSlicingPath(self):
- if h5py is None:
- self.skipTest("h5py is missing")
dialog = self.createDialog()
self.qWaitForPendingActions(dialog)
dialog.selectUrl(_tmpDirectory + "/data.h5::/cube[a;45,-90]")
diff --git a/silx/gui/dialog/utils.py b/silx/gui/dialog/utils.py
index 1c16b44..e2334f9 100644
--- a/silx/gui/dialog/utils.py
+++ b/silx/gui/dialog/utils.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -33,8 +33,10 @@ __date__ = "25/10/2017"
import os
import sys
import types
+
+import six
+
from silx.gui import qt
-from silx.third_party import six
def samefile(path1, path2):
diff --git a/silx/gui/hdf5/Hdf5Formatter.py b/silx/gui/hdf5/Hdf5Formatter.py
index 6802142..5754fe8 100644
--- a/silx/gui/hdf5/Hdf5Formatter.py
+++ b/silx/gui/hdf5/Hdf5Formatter.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,14 +30,12 @@ __license__ = "MIT"
__date__ = "06/06/2018"
import numpy
-from silx.third_party import six
+import six
+
from silx.gui import qt
from silx.gui.data.TextFormatter import TextFormatter
-try:
- import h5py
-except ImportError:
- h5py = None
+import h5py
class Hdf5Formatter(qt.QObject):
@@ -162,10 +160,9 @@ class Hdf5Formatter(qt.QObject):
compound = [self.humanReadableDType(d) for d in compound]
return "compound(%s)" % ", ".join(compound)
elif numpy.issubdtype(dtype, numpy.integer):
- if h5py is not None:
- enumType = h5py.check_dtype(enum=dtype)
- if enumType is not None:
- return "enum"
+ enumType = h5py.check_dtype(enum=dtype)
+ if enumType is not None:
+ return "enum"
text = str(dtype.newbyteorder('N'))
if numpy.issubdtype(dtype, numpy.floating):
diff --git a/silx/gui/hdf5/Hdf5Item.py b/silx/gui/hdf5/Hdf5Item.py
index b3c313e..6ea870f 100644
--- a/silx/gui/hdf5/Hdf5Item.py
+++ b/silx/gui/hdf5/Hdf5Item.py
@@ -25,11 +25,12 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "03/09/2018"
+__date__ = "17/01/2019"
import logging
import collections
+
from .. import qt
from .. import icons
from . import _utils
@@ -37,7 +38,6 @@ from .Hdf5Node import Hdf5Node
import silx.io.utils
from silx.gui.data.TextFormatter import TextFormatter
from ..hdf5.Hdf5Formatter import Hdf5Formatter
-from ...third_party import six
_logger = logging.getLogger(__name__)
_formatter = TextFormatter()
_hdf5Formatter = Hdf5Formatter(textFormatter=_formatter)
@@ -217,14 +217,32 @@ class Hdf5Item(Hdf5Node):
def _populateChild(self, populateAll=False):
if self.isGroupObj():
- for name in self.obj:
+ keys = []
+ try:
+ for name in self.obj:
+ keys.append(name)
+ except Exception:
+ lib_name = self.obj.__class__.__module__.split(".")[0]
+ _logger.error("Internal %s error. The file is corrupted.", lib_name)
+ _logger.debug("Backtrace", exc_info=True)
+ if keys == []:
+ # If the file was open in READ_ONLY we still can reach something
+ # https://github.com/silx-kit/silx/issues/2262
+ try:
+ for name in self.obj:
+ keys.append(name)
+ except Exception:
+ lib_name = self.obj.__class__.__module__.split(".")[0]
+ _logger.error("Internal %s error (second time). The file is corrupted.", lib_name)
+ _logger.debug("Backtrace", exc_info=True)
+ for name in keys:
try:
class_ = self.obj.get(name, getclass=True)
link = self.obj.get(name, getclass=True, getlink=True)
link = silx.io.utils.get_h5_class(class_=link)
except Exception:
lib_name = self.obj.__class__.__module__.split(".")[0]
- _logger.warning("Internal %s error", lib_name, exc_info=True)
+ _logger.error("Internal %s error", lib_name)
_logger.debug("Backtrace", exc_info=True)
class_ = None
try:
@@ -344,14 +362,12 @@ class Hdf5Item(Hdf5Node):
def nexusClassName(self):
"""Returns the Nexus class name"""
if self.__nx_class is None:
- self.__nx_class = self.obj.attrs.get("NX_class", None)
- if self.__nx_class is None:
- self.__nx_class = ""
+ obj = self.obj.attrs.get("NX_class", None)
+ if obj is None:
+ text = ""
else:
- if six.PY2:
- self.__nx_class = self.__nx_class.decode()
- elif not isinstance(self.__nx_class, str):
- self.__nx_class = str(self.__nx_class, "UTF-8")
+ text = self._getFormatter().textFormatter().toString(obj)
+ self.__nx_class = text.strip('"')
return self.__nx_class
def dataName(self, role):
diff --git a/silx/gui/hdf5/Hdf5TreeModel.py b/silx/gui/hdf5/Hdf5TreeModel.py
index 438200b..152f3e5 100644
--- a/silx/gui/hdf5/Hdf5TreeModel.py
+++ b/silx/gui/hdf5/Hdf5TreeModel.py
@@ -25,7 +25,7 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "08/10/2018"
+__date__ = "12/03/2019"
import os
@@ -360,9 +360,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
def mimeTypes(self):
types = []
- if self.__fileMoveEnabled:
- types.append(_utils.Hdf5NodeMimeData.MIME_TYPE)
- if self.__datasetDragEnabled:
+ if self.__fileMoveEnabled or self.__datasetDragEnabled:
types.append(_utils.Hdf5DatasetMimeData.MIME_TYPE)
return types
@@ -386,7 +384,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
node = self.nodeFromIndex(indexes[0])
if self.__fileMoveEnabled and node.parent is self.__root:
- mimeData = _utils.Hdf5NodeMimeData(node=node)
+ mimeData = _utils.Hdf5DatasetMimeData(node=node, isRoot=True)
elif self.__datasetDragEnabled:
mimeData = _utils.Hdf5DatasetMimeData(node=node)
else:
@@ -413,23 +411,24 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
if action == qt.Qt.IgnoreAction:
return True
- if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5NodeMimeData.MIME_TYPE):
- dragNode = mimedata.node()
- parentNode = self.nodeFromIndex(parentIndex)
- if parentNode is not dragNode.parent:
- return False
+ if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5DatasetMimeData.MIME_TYPE):
+ if mimedata.isRoot():
+ dragNode = mimedata.node()
+ parentNode = self.nodeFromIndex(parentIndex)
+ if parentNode is not dragNode.parent:
+ return False
- if row == -1:
- # append to the parent
- row = parentNode.childCount()
- else:
- # insert at row
- pass
+ if row == -1:
+ # append to the parent
+ row = parentNode.childCount()
+ else:
+ # insert at row
+ pass
- dragNodeParent = dragNode.parent
- sourceRow = dragNodeParent.indexOfChild(dragNode)
- self.moveRow(parentIndex, sourceRow, parentIndex, row)
- return True
+ dragNodeParent = dragNode.parent
+ sourceRow = dragNodeParent.indexOfChild(dragNode)
+ self.moveRow(parentIndex, sourceRow, parentIndex, row)
+ return True
if self.__fileDropEnabled and mimedata.hasFormat("text/uri-list"):
@@ -571,7 +570,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
drag-and-drop"""
obj = node.obj
for f in self.__openedFiles:
- if f in obj:
+ if f is obj:
_logger.debug("Close file %s", obj.filename)
obj.close()
self.__openedFiles.remove(obj)
diff --git a/silx/gui/hdf5/NexusSortFilterProxyModel.py b/silx/gui/hdf5/NexusSortFilterProxyModel.py
index 216e992..9c3533f 100644
--- a/silx/gui/hdf5/NexusSortFilterProxyModel.py
+++ b/silx/gui/hdf5/NexusSortFilterProxyModel.py
@@ -25,7 +25,7 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "24/07/2018"
+__date__ = "29/11/2018"
import logging
@@ -108,6 +108,8 @@ class NexusSortFilterProxyModel(qt.QSortFilterProxyModel):
def __isNXnode(self, node):
"""Returns true if the node is an NX concept"""
+ if not hasattr(node, "h5Class"):
+ return False
class_ = node.h5Class
if class_ is None or class_ != silx.io.utils.H5Type.GROUP:
return False
diff --git a/silx/gui/hdf5/_utils.py b/silx/gui/hdf5/_utils.py
index 6a34933..aaab228 100644
--- a/silx/gui/hdf5/_utils.py
+++ b/silx/gui/hdf5/_utils.py
@@ -28,12 +28,15 @@ package `silx.gui.hdf5` package.
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "04/05/2018"
+__date__ = "17/01/2019"
import logging
-from .. import qt
+import os.path
+
import silx.io.utils
+import silx.io.url
+from .. import qt
from silx.utils.html import escape
_logger = logging.getLogger(__name__)
@@ -107,11 +110,22 @@ class Hdf5DatasetMimeData(qt.QMimeData):
MIME_TYPE = "application/x-internal-h5py-dataset"
- def __init__(self, node=None, dataset=None):
+ SILX_URI_TYPE = "application/x-silx-uri"
+
+ def __init__(self, node=None, dataset=None, isRoot=False):
qt.QMimeData.__init__(self)
self.__dataset = dataset
self.__node = node
+ self.__isRoot = isRoot
self.setData(self.MIME_TYPE, "".encode(encoding='utf-8'))
+ if node is not None:
+ h5Node = H5Node(node)
+ silxUrl = h5Node.url
+ self.setText(silxUrl)
+ self.setData(self.SILX_URI_TYPE, silxUrl.encode(encoding='utf-8'))
+
+ def isRoot(self):
+ return self.__isRoot
def node(self):
return self.__node
@@ -122,20 +136,6 @@ class Hdf5DatasetMimeData(qt.QMimeData):
return self.__dataset
-class Hdf5NodeMimeData(qt.QMimeData):
- """Mimedata class to identify an internal drag and drop of a Hdf5Node."""
-
- MIME_TYPE = "application/x-internal-h5py-node"
-
- def __init__(self, node=None):
- qt.QMimeData.__init__(self)
- self.__node = node
- self.setData(self.MIME_TYPE, "".encode(encoding='utf-8'))
-
- def node(self):
- return self.__node
-
-
class H5Node(object):
"""Adapter over an h5py object to provide missing informations from h5py
nodes, like internal node path and filename (which are not provided by
@@ -419,3 +419,43 @@ class H5Node(object):
:rtype: str
"""
return self.physical_name.split("/")[-1]
+
+ @property
+ def data_url(self):
+ """Returns a :class:`silx.io.url.DataUrl` object identify this node in the file
+ system.
+
+ :rtype: ~silx.io.url.DataUrl
+ """
+ absolute_filename = os.path.abspath(self.local_filename)
+ return silx.io.url.DataUrl(scheme="silx",
+ file_path=absolute_filename,
+ data_path=self.local_name)
+
+ @property
+ def url(self):
+ """Returns an URL object identifying this node in the file
+ system.
+
+ This URL can be used in different ways.
+
+ .. code-block:: python
+
+ # Parsing the URL
+ import silx.io.url
+ dataurl = silx.io.url.DataUrl(item.url)
+ # dataurl provides access to URL fields
+
+ # Open a numpy array
+ import silx.io
+ dataset = silx.io.get_data(item.url)
+
+ # Open an hdf5 object (URL targetting a file or a group)
+ import silx.io
+ with silx.io.open(item.url) as h5:
+ ...your stuff...
+
+ :rtype: str
+ """
+ data_url = self.data_url
+ return data_url.path()
diff --git a/silx/gui/hdf5/test/test_hdf5.py b/silx/gui/hdf5/test/test_hdf5.py
index 1751a21..f22d4ae 100644
--- a/silx/gui/hdf5/test/test_hdf5.py
+++ b/silx/gui/hdf5/test/test_hdf5.py
@@ -26,7 +26,7 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "03/05/2018"
+__date__ = "12/03/2019"
import time
@@ -43,10 +43,7 @@ from silx.gui.utils.testutils import SignalListener
from silx.io import commonh5
import weakref
-try:
- import h5py
-except ImportError:
- h5py = None
+import h5py
_tmpDirectory = None
@@ -56,14 +53,13 @@ def setUpModule():
global _tmpDirectory
_tmpDirectory = tempfile.mkdtemp(prefix=__name__)
- if h5py is not None:
- filename = _tmpDirectory + "/data.h5"
+ filename = _tmpDirectory + "/data.h5"
- # create h5 data
- f = h5py.File(filename, "w")
- g = f.create_group("arrays")
- g.create_dataset("scalar", data=10)
- f.close()
+ # create h5 data
+ f = h5py.File(filename, "w")
+ g = f.create_group("arrays")
+ g.create_dataset("scalar", data=10)
+ f.close()
def tearDownModule():
@@ -91,8 +87,6 @@ class TestHdf5TreeModel(TestCaseQt):
def setUp(self):
super(TestHdf5TreeModel, self).setUp()
- if h5py is None:
- self.skipTest("h5py is not available")
def waitForPendingOperations(self, model):
for _ in range(10):
@@ -127,8 +121,6 @@ class TestHdf5TreeModel(TestCaseQt):
model.appendFile(filename)
self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
# clean up
- index = model.index(0, 0, qt.QModelIndex())
- h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
ref = weakref.ref(model)
model = None
self.qWaitForDestroy(ref)
@@ -246,6 +238,37 @@ class TestHdf5TreeModel(TestCaseQt):
model.setFileDropEnabled(False)
self.assertNotEquals(model.supportedDropActions(), 0)
+ def testCloseFile(self):
+ """A file inserted as a filename is open and closed internally."""
+ filename = _tmpDirectory + "/data.h5"
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertFile(filename)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ index = model.index(0, 0)
+ h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ model.removeIndex(index)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ self.assertFalse(bool(h5File.id.valid), "The HDF5 file was not closed")
+
+ def testNotCloseFile(self):
+ """A file inserted as an h5py object is not open (then not closed)
+ internally."""
+ filename = _tmpDirectory + "/data.h5"
+ try:
+ h5File = h5py.File(filename)
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertH5pyObject(h5File)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ index = model.index(0, 0)
+ h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ model.removeIndex(index)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ self.assertTrue(bool(h5File.id.valid), "The HDF5 file was unexpetedly closed")
+ finally:
+ h5File.close()
+
def testDropExternalFile(self):
filename = _tmpDirectory + "/data.h5"
model = hdf5.Hdf5TreeModel()
@@ -571,8 +594,6 @@ class TestH5Node(TestCaseQt):
@classmethod
def setUpClass(cls):
super(TestH5Node, cls).setUpClass()
- if h5py is None:
- raise unittest.SkipTest("h5py is not available")
cls.tmpDirectory = tempfile.mkdtemp()
cls.h5Filename = cls.createResource(cls.tmpDirectory)
@@ -809,8 +830,6 @@ class TestHdf5TreeView(TestCaseQt):
def setUp(self):
super(TestHdf5TreeView, self).setUp()
- if h5py is None:
- self.skipTest("h5py is not available")
def testCreate(self):
view = hdf5.Hdf5TreeView()
diff --git a/silx/gui/icons.py b/silx/gui/icons.py
index ef99591..1493b92 100644
--- a/silx/gui/icons.py
+++ b/silx/gui/icons.py
@@ -29,7 +29,7 @@ Use :func:`getQIcon` to create Qt QIcon from the name identifying an icon.
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "05/10/2018"
+__date__ = "07/01/2019"
import os
@@ -213,11 +213,12 @@ class MultiImageAnimatedIcon(AbstractAnimatedIcon):
self.__frames = []
for i in range(100):
try:
- filename = getQFile("%s/%02d" % (filename, i))
+ frame_filename = os.sep.join((filename, ("%02d" %i)))
+ frame_file = getQFile(frame_filename)
except ValueError:
break
try:
- icon = qt.QIcon(filename.fileName())
+ icon = qt.QIcon(frame_file.fileName())
except ValueError:
break
self.__frames.append(icon)
@@ -420,4 +421,5 @@ def getQFile(name):
qfile = qt.QFile(filename)
if qfile.exists():
return qfile
+ _logger.debug("File '%s' not found.", filename)
raise ValueError('Not an icon name: %s' % name)
diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py
index fd4d34e..9798123 100644
--- a/silx/gui/plot/ColorBar.py
+++ b/silx/gui/plot/ColorBar.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -251,10 +251,13 @@ class ColorBarWidget(qt.QWidget):
def _defaultColormapChanged(self, event):
"""Handle plot default colormap changed"""
- if (event['event'] == 'defaultColormapChanged' and
- self.getPlot().getActiveImage() is None):
- # No active image, take default colormap update into account
- self._syncWithDefaultColormap()
+ if event['event'] == 'defaultColormapChanged':
+ plot = self.getPlot()
+ if (plot is not None and
+ plot.getActiveImage() is None and
+ plot._getActiveItem(kind='scatter') is None):
+ # No active item, take default colormap update into account
+ self._syncWithDefaultColormap()
def _syncWithDefaultColormap(self, data=None):
"""Update colorbar according to plot default colormap"""
@@ -801,7 +804,7 @@ class _TickBar(qt.QWidget):
if self._norm == colors.Colormap.LINEAR:
return 1 - (val - self._vmin) / (self._vmax - self._vmin)
elif self._norm == colors.Colormap.LOGARITHM:
- return 1 - (numpy.log10(val) - numpy.log10(self._vmin)) / (numpy.log10(self._vmax) - numpy.log(self._vmin))
+ return 1 - (numpy.log10(val) - numpy.log10(self._vmin)) / (numpy.log10(self._vmax) - numpy.log10(self._vmin))
else:
raise ValueError('Norm is not recognized')
@@ -864,7 +867,7 @@ class _TickBar(qt.QWidget):
def _guessType(self, font):
"""Try fo find the better format to display the tick's labels
- :param QFont font: the font we want want to use durint the painting
+ :param QFont font: the font we want to use during the painting
"""
form = self._getStandardFormat()
@@ -873,7 +876,7 @@ class _TickBar(qt.QWidget):
for tick in self.ticks:
width = max(fm.width(form.format(tick)), width)
- # if the length of the string are too long we are mooving to scientific
+ # if the length of the string are too long we are moving to scientific
# display
if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH:
return self._getScientificForm()
diff --git a/silx/gui/plot/CompareImages.py b/silx/gui/plot/CompareImages.py
index 88b257d..f7c4899 100644
--- a/silx/gui/plot/CompareImages.py
+++ b/silx/gui/plot/CompareImages.py
@@ -30,6 +30,7 @@ __license__ = "MIT"
__date__ = "23/07/2018"
+import enum
import logging
import numpy
import weakref
@@ -42,7 +43,6 @@ from silx.gui import plot
from silx.gui import icons
from silx.gui.colors import Colormap
from silx.gui.plot import tools
-from silx.third_party import enum
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py
index bbcb0a5..2523cde 100644
--- a/silx/gui/plot/ComplexImageView.py
+++ b/silx/gui/plot/ComplexImageView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -365,7 +365,7 @@ class ComplexImageView(qt.QWidget):
- log10_amplitude_phase:
Color-coded phase with log10(amplitude) as alpha.
- :rtype: tuple of str
+ :rtype: List[Mode]
"""
return tuple(ImageComplexData.Mode)
@@ -375,7 +375,12 @@ class ComplexImageView(qt.QWidget):
See :meth:`getSupportedVisualizationModes` for the list of
supported modes.
- :param str mode: The mode to use.
+ How-to change visualization mode::
+
+ widget = ComplexImageView()
+ widget.setVisualizationMode(ComplexImageView.Mode.PHASE)
+
+ :param Mode mode: The mode to use.
"""
self._plotImage.setVisualizationMode(mode)
diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py
index 81e684e..b426a23 100644
--- a/silx/gui/plot/CurvesROIWidget.py
+++ b/silx/gui/plot/CurvesROIWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -22,50 +22,43 @@
# THE SOFTWARE.
#
# ###########################################################################*/
-"""Widget to handle regions of interest (ROI) on curves displayed in a PlotWindow.
+"""
+Widget to handle regions of interest (:class:`ROI`) on curves displayed in a
+:class:`PlotWindow`.
This widget is meant to work with :class:`PlotWindow`.
-
-ROI are defined by :
-
-- A name (`ROI` column)
-- A type. The type is the label of the x axis.
- This can be used to apply or not some ROI to a curve and do some post processing.
-- The x coordinate of the left limit (`from` column)
-- The x coordinate of the right limit (`to` column)
-- Raw counts: Sum of the curve's values in the defined Region Of Intereset.
-
- .. image:: img/rawCounts.png
-
-- Net counts: Raw counts minus background
-
- .. image:: img/netCounts.png
"""
-__authors__ = ["V.A. Sole", "T. Vincent"]
+__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
__license__ = "MIT"
-__date__ = "13/11/2017"
+__date__ = "13/03/2018"
from collections import OrderedDict
-
import logging
import os
import sys
-import weakref
-
+import functools
import numpy
-
from silx.io import dictdump
from silx.utils import deprecation
-
+from silx.utils.weakref import WeakMethodProxy
from .. import icons, qt
+from silx.gui.plot.items.curve import Curve
+from silx.math.combo import min_max
+import weakref
+from silx.gui.widgets.TableWidget import TableWidget
_logger = logging.getLogger(__name__)
class CurvesROIWidget(qt.QWidget):
- """Widget displaying a table of ROI information.
+ """
+ Widget displaying a table of ROI information.
+
+ Implements also the following behavior:
+
+ * if the roiTable has no ROI when showing create the default ICR one
:param parent: See :class:`QWidget`
:param str name: The title of this widget
@@ -73,19 +66,18 @@ class CurvesROIWidget(qt.QWidget):
sigROIWidgetSignal = qt.Signal(object)
"""Signal of ROIs modifications.
-
- Modification information if given as a dict with an 'event' key
- providing the type of events.
-
- Type of events:
-
- - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict'
-
- - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader',
- 'rowheader'
+ Modification information if given as a dict with an 'event' key
+ providing the type of events.
+ Type of events:
+ - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict'
+ - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader',
+ 'rowheader'
"""
sigROISignal = qt.Signal(object)
+ """Deprecated signal for backward compatibility with silx < 0.7.
+ Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal`
+ """
def __init__(self, parent=None, name=None, plot=None):
super(CurvesROIWidget, self).__init__(parent)
@@ -93,6 +85,8 @@ class CurvesROIWidget(qt.QWidget):
self.setWindowTitle(name)
assert plot is not None
self._plotRef = weakref.ref(plot)
+ self._showAllMarkers = False
+ self.currentROI = None
layout = qt.QVBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
@@ -103,13 +97,22 @@ class CurvesROIWidget(qt.QWidget):
self.setHeader()
layout.addWidget(self.headerLabel)
##############
- self.roiTable = ROITable(self)
+ widgetAllCheckbox = qt.QWidget(parent=self)
+ self._showAllCheckBox = qt.QCheckBox("show all ROI",
+ parent=widgetAllCheckbox)
+ widgetAllCheckbox.setLayout(qt.QHBoxLayout())
+ spacer = qt.QWidget(parent=widgetAllCheckbox)
+ spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ widgetAllCheckbox.layout().addWidget(spacer)
+ widgetAllCheckbox.layout().addWidget(self._showAllCheckBox)
+ layout.addWidget(widgetAllCheckbox)
+ ##############
+ self.roiTable = ROITable(self, plot=plot)
rheight = self.roiTable.horizontalHeader().sizeHint().height()
self.roiTable.setMinimumHeight(4 * rheight)
- self.fillFromROIDict = self.roiTable.fillFromROIDict
- self.getROIListAndDict = self.roiTable.getROIListAndDict
layout.addWidget(self.roiTable)
self._roiFileDir = qt.QDir.home().absolutePath()
+ self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers)
#################
hbox = qt.QWidget(self)
@@ -127,7 +130,8 @@ class CurvesROIWidget(qt.QWidget):
self.addButton.setToolTip('Remove the selected ROI')
self.resetButton = qt.QPushButton(hbox)
self.resetButton.setText("Reset")
- self.addButton.setToolTip('Clear all created ROIs. We only let the default ROI')
+ self.addButton.setToolTip('Clear all created ROIs. We only let the '
+ 'default ROI')
hboxlayout.addWidget(self.addButton)
hboxlayout.addWidget(self.delButton)
@@ -149,19 +153,22 @@ class CurvesROIWidget(qt.QWidget):
layout.addWidget(hbox)
+ # Signal / Slot connections
self.addButton.clicked.connect(self._add)
self.delButton.clicked.connect(self._del)
self.resetButton.clicked.connect(self._reset)
self.loadButton.clicked.connect(self._load)
self.saveButton.clicked.connect(self._save)
- self.roiTable.sigROITableSignal.connect(self._forward)
- self.currentROI = None
- self._middleROIMarkerFlag = False
+ self.roiTable.activeROIChanged.connect(self._emitCurrentROISignal)
+
self._isConnected = False # True if connected to plot signals
self._isInit = False
+ # expose API
+ self.getROIListAndDict = self.roiTable.getROIListAndDict
+
def getPlotWidget(self):
"""Returns the associated PlotWidget or None
@@ -173,10 +180,6 @@ class CurvesROIWidget(qt.QWidget):
self._visibilityChangedHandler(visible=True)
qt.QWidget.showEvent(self, event)
- def hideEvent(self, event):
- self._visibilityChangedHandler(visible=False)
- qt.QWidget.hideEvent(self, event)
-
@property
def roiFileDir(self):
"""The directory from which to load/save ROI from/to files."""
@@ -188,135 +191,81 @@ class CurvesROIWidget(qt.QWidget):
def roiFileDir(self, roiFileDir):
self._roiFileDir = str(roiFileDir)
- def setRois(self, roidict, order=None):
- """Set the ROIs by providing a dictionary of ROI information.
-
- The dictionary keys are the ROI names.
- Each value is a sub-dictionary of ROI info with the following fields:
-
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
-
-
- :param roidict: Dictionary of ROIs
- :param str order: Field used for ordering the ROIs.
- One of "from", "to", "type".
- None (default) for no ordering, or same order as specified
- in parameter ``roidict`` if provided as an OrderedDict.
- """
- if order is None or order.lower() == "none":
- roilist = list(roidict.keys())
- else:
- assert order in ["from", "to", "type"]
- roilist = sorted(roidict.keys(),
- key=lambda roi_name: roidict[roi_name].get(order))
-
- return self.roiTable.fillFromROIDict(roilist, roidict)
+ def setRois(self, rois, order=None):
+ return self.roiTable.setRois(rois, order)
def getRois(self, order=None):
- """Return the currently defined ROIs, as an ordered dict.
+ return self.roiTable.getRois(order)
- The dictionary keys are the ROI names.
- Each value is a sub-dictionary of ROI info with the following fields:
+ def setMiddleROIMarkerFlag(self, flag=True):
+ return self.roiTable.setMiddleROIMarkerFlag(flag)
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+ def _add(self):
+ """Add button clicked handler"""
+ def getNextRoiName():
+ rois = self.roiTable.getRois(order=None)
+ roisNames = []
+ [roisNames.append(roiName) for roiName in rois]
+ nrois = len(rois)
+ if nrois == 0:
+ return "ICR"
+ else:
+ i = 1
+ newroi = "newroi %d" % i
+ while newroi in roisNames:
+ i += 1
+ newroi = "newroi %d" % i
+ return newroi
+ roi = ROI(name=getNextRoiName())
- :param order: Field used for ordering the ROIs.
- One of "from", "to", "type", "netcounts", "rawcounts".
- None (default) to get the same order as displayed in the widget.
- :return: Ordered dictionary of ROI information
- """
- roilist, roidict = self.roiTable.getROIListAndDict()
- if order is None or order.lower() == "none":
- ordered_roilist = roilist
+ if roi.getName() == "ICR":
+ roi.setType("Default")
else:
- assert order in ["from", "to", "type", "netcounts", "rawcounts"]
- ordered_roilist = sorted(roidict.keys(),
- key=lambda roi_name: roidict[roi_name].get(order))
-
- return OrderedDict([(name, roidict[name]) for name in ordered_roilist])
+ roi.setType(self.getPlotWidget().getXAxis().getLabel())
- def setMiddleROIMarkerFlag(self, flag=True):
- """Activate or deactivate middle marker.
+ xmin, xmax = self.getPlotWidget().getXAxis().getLimits()
+ fromdata = xmin + 0.25 * (xmax - xmin)
+ todata = xmin + 0.75 * (xmax - xmin)
+ if roi.isICR():
+ fromdata, dummy0, todata, dummy1 = self._getAllLimits()
+ roi.setFrom(fromdata)
+ roi.setTo(todata)
- This allows shifting both min and max limits at once, by dragging
- a marker located in the middle.
-
- :param bool flag: True to activate middle ROI marker
- """
- if flag:
- self._middleROIMarkerFlag = True
- else:
- self._middleROIMarkerFlag = False
+ self.roiTable.addRoi(roi)
- def _add(self):
- """Add button clicked handler"""
+ # back compatibility pymca roi signals
ddict = {}
ddict['event'] = "AddROI"
- roilist, roidict = self.roiTable.getROIListAndDict()
- ddict['roilist'] = roilist
- ddict['roidict'] = roidict
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
def _del(self):
"""Delete button clicked handler"""
- row = self.roiTable.currentRow()
- if row >= 0:
- index = self.roiTable.labels.index('Type')
- text = str(self.roiTable.item(row, index).text())
- if text.upper() != 'DEFAULT':
- index = self.roiTable.labels.index('ROI')
- key = str(self.roiTable.item(row, index).text())
- else:
- # This is to prevent deleting ICR ROI, that is
- # usually initialized as "Default" type.
- return
- roilist, roidict = self.roiTable.getROIListAndDict()
- row = roilist.index(key)
- del roilist[row]
- del roidict[key]
- if len(roilist) > 0:
- currentroi = roilist[0]
- else:
- currentroi = None
-
- self.roiTable.fillFromROIDict(roilist=roilist,
- roidict=roidict,
- currentroi=currentroi)
- ddict = {}
- ddict['event'] = "DelROI"
- ddict['roilist'] = roilist
- ddict['roidict'] = roidict
- self.sigROIWidgetSignal.emit(ddict)
-
- def _forward(self, ddict):
- """Broadcast events from ROITable signal"""
+ self.roiTable.deleteActiveRoi()
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "DelROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
def _reset(self):
"""Reset button clicked handler"""
+ self.roiTable.clear()
+ self._add()
+
+ # back compatibility pymca roi signals
ddict = {}
ddict['event'] = "ResetROI"
- roilist0, roidict0 = self.roiTable.getROIListAndDict()
- index = 0
- for key in roilist0:
- if roidict0[key]['type'].upper() == 'DEFAULT':
- index = roilist0.index(key)
- break
- roilist = []
- roidict = {}
- if len(roilist0):
- roilist.append(roilist0[index])
- roidict[roilist[0]] = {}
- roidict[roilist[0]].update(roidict0[roilist[0]])
- self.roiTable.fillFromROIDict(roilist=roilist, roidict=roidict)
- ddict['roilist'] = roilist
- ddict['roidict'] = roidict
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
def _load(self):
"""Load button clicked handler"""
@@ -334,32 +283,22 @@ class CurvesROIWidget(qt.QWidget):
dialog.close()
self.roiFileDir = os.path.dirname(outputFile)
- self.load(outputFile)
+ self.roiTable.load(outputFile)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "LoadROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
def load(self, filename):
"""Load ROI widget information from a file storing a dict of ROI.
:param str filename: The file from which to load ROI
"""
- rois = dictdump.load(filename)
- currentROI = None
- if self.roiTable.rowCount():
- item = self.roiTable.item(self.roiTable.currentRow(), 0)
- if item is not None:
- currentROI = str(item.text())
-
- # Remove rawcounts and netcounts from ROIs
- for roi in rois['ROI']['roidict'].values():
- roi.pop('rawcounts', None)
- roi.pop('netcounts', None)
-
- self.roiTable.fillFromROIDict(roilist=rois['ROI']['roilist'],
- roidict=rois['ROI']['roidict'],
- currentroi=currentROI)
-
- roilist, roidict = self.roiTable.getROIListAndDict()
- event = {'event': 'LoadROI', 'roilist': roilist, 'roidict': roidict}
- self.sigROIWidgetSignal.emit(event)
+ self.roiTable.load(filename)
def _save(self):
"""Save button clicked handler"""
@@ -396,142 +335,24 @@ class CurvesROIWidget(qt.QWidget):
:param str filename: The file to which to save the ROIs
"""
- roilist, roidict = self.roiTable.getROIListAndDict()
- datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}}
- dictdump.dump(datadict, filename)
+ self.roiTable.save(filename)
def setHeader(self, text='ROIs'):
"""Set the header text of this widget"""
self.headerLabel.setText("<b>%s<\b>" % text)
- def _roiSignal(self, ddict):
- """Handle ROI widget signal"""
- _logger.debug("CurvesROIWidget._roiSignal %s", str(ddict))
- plot = self.getPlotWidget()
- if plot is None:
- return
-
- if ddict['event'] == "AddROI":
- xmin, xmax = plot.getXAxis().getLimits()
- fromdata = xmin + 0.25 * (xmax - xmin)
- todata = xmin + 0.75 * (xmax - xmin)
- plot.remove('ROI min', kind='marker')
- plot.remove('ROI max', kind='marker')
- if self._middleROIMarkerFlag:
- plot.remove('ROI middle', kind='marker')
- roiList, roiDict = self.roiTable.getROIListAndDict()
- nrois = len(roiList)
- if nrois == 0:
- newroi = "ICR"
- fromdata, dummy0, todata, dummy1 = self._getAllLimits()
- draggable = False
- color = 'black'
- else:
- # find the next index free for newroi.
- for i in range(nrois):
- i += 1
- newroi = "newroi %d" % i
- if newroi not in roiList:
- break
- color = 'blue'
- draggable = True
- plot.addXMarker(fromdata,
- legend='ROI min',
- text='ROI min',
- color=color,
- draggable=draggable)
- plot.addXMarker(todata,
- legend='ROI max',
- text='ROI max',
- color=color,
- draggable=draggable)
- if draggable and self._middleROIMarkerFlag:
- pos = 0.5 * (fromdata + todata)
- plot.addXMarker(pos,
- legend='ROI middle',
- text="",
- color='yellow',
- draggable=draggable)
- roiList.append(newroi)
- roiDict[newroi] = {}
- if newroi == "ICR":
- roiDict[newroi]['type'] = "Default"
- else:
- roiDict[newroi]['type'] = plot.getXAxis().getLabel()
- roiDict[newroi]['from'] = fromdata
- roiDict[newroi]['to'] = todata
- self.roiTable.fillFromROIDict(roilist=roiList,
- roidict=roiDict,
- currentroi=newroi)
- self.currentROI = newroi
- self.calculateRois()
- elif ddict['event'] in ['DelROI', "ResetROI"]:
- plot.remove('ROI min', kind='marker')
- plot.remove('ROI max', kind='marker')
- if self._middleROIMarkerFlag:
- plot.remove('ROI middle', kind='marker')
- roiList, roiDict = self.roiTable.getROIListAndDict()
- roiDictKeys = list(roiDict.keys())
- if len(roiDictKeys):
- currentroi = roiDictKeys[0]
- else:
- # create again the ICR
- ddict = {"event": "AddROI"}
- return self._roiSignal(ddict)
-
- self.roiTable.fillFromROIDict(roilist=roiList,
- roidict=roiDict,
- currentroi=currentroi)
- self.currentROI = currentroi
-
- elif ddict['event'] == 'LoadROI':
- self.calculateRois()
+ @deprecation.deprecated(replacement="calculateRois",
+ reason="CamelCase convention",
+ since_version="0.7")
+ def calculateROIs(self, *args, **kw):
+ self.calculateRois(*args, **kw)
- elif ddict['event'] == 'selectionChanged':
- _logger.debug("Selection changed")
- self.roilist, self.roidict = self.roiTable.getROIListAndDict()
- fromdata = ddict['roi']['from']
- todata = ddict['roi']['to']
- plot.remove('ROI min', kind='marker')
- plot.remove('ROI max', kind='marker')
- if ddict['key'] == 'ICR':
- draggable = False
- color = 'black'
- else:
- draggable = True
- color = 'blue'
- plot.addXMarker(fromdata,
- legend='ROI min',
- text='ROI min',
- color=color,
- draggable=draggable)
- plot.addXMarker(todata,
- legend='ROI max',
- text='ROI max',
- color=color,
- draggable=draggable)
- if draggable and self._middleROIMarkerFlag:
- pos = 0.5 * (fromdata + todata)
- plot.addXMarker(pos,
- legend='ROI middle',
- text="",
- color='yellow',
- draggable=True)
- self.currentROI = ddict['key']
- if ddict['colheader'] in ['From', 'To']:
- dict0 = {}
- dict0['event'] = "SetActiveCurveEvent"
- dict0['legend'] = plot.getActiveCurve(just_legend=1)
- plot.setActiveCurve(dict0['legend'])
- elif ddict['colheader'] == 'Raw Counts':
- pass
- elif ddict['colheader'] == 'Net Counts':
- pass
- else:
- self._emitCurrentROISignal()
+ def calculateRois(self, roiList=None, roiDict=None):
+ """Compute ROI information"""
+ return self.roiTable.calculateRois()
- else:
- _logger.debug("Unknown or ignored event %s", ddict['event'])
+ def showAllMarkers(self, _show=True):
+ self.roiTable.showAllMarkers(_show)
def _getAllLimits(self):
"""Retrieve the limits based on the curves."""
@@ -565,429 +386,1121 @@ class CurvesROIWidget(qt.QWidget):
return xmin, ymin, xmax, ymax
- @deprecation.deprecated(replacement="calculateRois",
- reason="CamelCase convention")
- def calculateROIs(self, *args, **kw):
- self.calculateRois(*args, **kw)
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
- def calculateRois(self, roiList=None, roiDict=None):
- """Compute ROI information"""
- if roiList is None or roiDict is None:
- roiList, roiDict = self.roiTable.getROIListAndDict()
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
- plot = self.getPlotWidget()
- if plot is None:
- activeCurve = None
- else:
- activeCurve = plot.getActiveCurve(just_legend=False)
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
- if activeCurve is None:
- xproc = None
- yproc = None
- self.setHeader()
- else:
- x = activeCurve.getXData(copy=False)
- y = activeCurve.getYData(copy=False)
- legend = activeCurve.getLegend()
- idx = numpy.argsort(x, kind='mergesort')
- xproc = numpy.take(x, idx)
- yproc = numpy.take(y, idx)
- self.setHeader('ROIs of %s' % legend)
-
- for key in roiList:
- if key == 'ICR':
- if xproc is not None:
- roiDict[key]['from'] = xproc.min()
- roiDict[key]['to'] = xproc.max()
- else:
- roiDict[key]['from'] = 0
- roiDict[key]['to'] = -1
- fromData = roiDict[key]['from']
- toData = roiDict[key]['to']
- if xproc is not None:
- idx = numpy.nonzero((fromData <= xproc) &
- (xproc <= toData))[0]
- if len(idx):
- xw = xproc[idx]
- yw = yproc[idx]
- rawCounts = yw.sum(dtype=numpy.float)
- deltaX = xw[-1] - xw[0]
- deltaY = yw[-1] - yw[0]
- if deltaX > 0.0:
- slope = (deltaY / deltaX)
- background = yw[0] + slope * (xw - xw[0])
- netCounts = (rawCounts -
- background.sum(dtype=numpy.float))
- else:
- netCounts = 0.0
- else:
- rawCounts = 0.0
- netCounts = 0.0
- roiDict[key]['rawcounts'] = rawCounts
- roiDict[key]['netcounts'] = netCounts
- else:
- roiDict[key].pop('rawcounts', None)
- roiDict[key].pop('netcounts', None)
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ # if no ROI existing yet, add the default one
+ if self.roiTable.rowCount() is 0:
+ self._add()
+ self.calculateRois()
- self.roiTable.fillFromROIDict(
- roilist=roiList,
- roidict=roiDict,
- currentroi=self.currentROI if self.currentROI in roiList else None)
+ def fillFromROIDict(self, *args, **kwargs):
+ self.roiTable.fillFromROIDict(*args, **kwargs)
def _emitCurrentROISignal(self):
ddict = {}
ddict['event'] = "currentROISignal"
- _roiList, roiDict = self.roiTable.getROIListAndDict()
- if self.currentROI in roiDict:
- ddict['ROI'] = roiDict[self.currentROI]
+ if self.roiTable.activeRoi is not None:
+ ddict['ROI'] = self.roiTable.activeRoi.toDict()
+ ddict['current'] = self.roiTable.activeRoi.getName()
else:
- self.currentROI = None
- ddict['current'] = self.currentROI
+ ddict['current'] = None
self.sigROISignal.emit(ddict)
- def _handleROIMarkerEvent(self, ddict):
- """Handle plot signals related to marker events."""
- if ddict['event'] == 'markerMoved':
+ @property
+ def currentRoi(self):
+ return self.roiTable.activeRoi
- label = ddict['label']
- if label not in ['ROI min', 'ROI max', 'ROI middle']:
- return
- roiList, roiDict = self.roiTable.getROIListAndDict()
- if self.currentROI is None:
- return
- if self.currentROI not in roiDict:
- return
+class _FloatItem(qt.QTableWidgetItem):
+ """
+ Simple QTableWidgetItem overloading the < operator to deal with ordering
+ """
+ def __init__(self):
+ qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type)
- plot = self.getPlotWidget()
- if plot is None:
- return
+ def __lt__(self, other):
+ if self.text() in ('', ROITable.INFO_NOT_FOUND):
+ return False
+ if other.text() in ('', ROITable.INFO_NOT_FOUND):
+ return True
+ return float(self.text()) < float(other.text())
+
+
+class ROITable(TableWidget):
+ """Table widget displaying ROI information.
+
+ See :class:`QTableWidget` for constructor arguments.
- x = ddict['x']
-
- if label == 'ROI min':
- roiDict[self.currentROI]['from'] = x
- if self._middleROIMarkerFlag:
- pos = 0.5 * (roiDict[self.currentROI]['to'] +
- roiDict[self.currentROI]['from'])
- plot.addXMarker(pos,
- legend='ROI middle',
- text='',
- color='yellow',
- draggable=True)
- elif label == 'ROI max':
- roiDict[self.currentROI]['to'] = x
- if self._middleROIMarkerFlag:
- pos = 0.5 * (roiDict[self.currentROI]['to'] +
- roiDict[self.currentROI]['from'])
- plot.addXMarker(pos,
- legend='ROI middle',
- text='',
- color='yellow',
- draggable=True)
- elif label == 'ROI middle':
- delta = x - 0.5 * (roiDict[self.currentROI]['from'] +
- roiDict[self.currentROI]['to'])
- roiDict[self.currentROI]['from'] += delta
- roiDict[self.currentROI]['to'] += delta
- plot.addXMarker(roiDict[self.currentROI]['from'],
- legend='ROI min',
- text='ROI min',
- color='blue',
- draggable=True)
- plot.addXMarker(roiDict[self.currentROI]['to'],
- legend='ROI max',
- text='ROI max',
- color='blue',
- draggable=True)
+ Behavior: listen at the active curve changed only when the widget is
+ visible. Otherwise won't compute the row and net counts...
+ """
+
+ activeROIChanged = qt.Signal()
+ """Signal emitted when the active roi changed or when the value of the
+ active roi are changing"""
+
+ COLUMNS_INDEX = OrderedDict([
+ ('ID', 0),
+ ('ROI', 1),
+ ('Type', 2),
+ ('From', 3),
+ ('To', 4),
+ ('Raw Counts', 5),
+ ('Net Counts', 6),
+ ('Raw Area', 7),
+ ('Net Area', 8),
+ ])
+
+ COLUMNS = list(COLUMNS_INDEX.keys())
+
+ INFO_NOT_FOUND = '????????'
+
+ def __init__(self, parent=None, plot=None, rois=None):
+ super(ROITable, self).__init__(parent)
+ self._showAllMarkers = False
+ self._userIsEditingRoi = False
+ """bool used to avoid conflict when editing the ROI object"""
+ self._isConnected = False
+ self._roiToItems = {}
+ self._roiDict = {}
+ """dict of ROI object. Key is ROi id, value is the ROI object"""
+ self._markersHandler = _RoiMarkerManager()
+
+ """
+ Associate for each marker legend used when the `_showAllMarkers` option
+ is active a roi.
+ """
+ self.setColumnCount(len(self.COLUMNS))
+ self.setPlot(plot)
+ self.__setTooltip()
+ self.setSortingEnabled(True)
+ self.itemChanged.connect(self._itemChanged)
+
+ @property
+ def roidict(self):
+ return self._getRoiDict()
+
+ @property
+ def activeRoi(self):
+ return self._markersHandler._activeRoi
+
+ def _getRoiDict(self):
+ ddict = {}
+ for id in self._roiDict:
+ ddict[self._roiDict[id].getName()] = self._roiDict[id]
+ return ddict
+
+ def clear(self):
+ """
+ .. note:: clear the interface only. keep the roidict...
+ """
+ self._markersHandler.clear()
+ self._roiToItems = {}
+ self._roiDict = {}
+
+ qt.QTableWidget.clear(self)
+ self.setRowCount(0)
+ self.setHorizontalHeaderLabels(self.COLUMNS)
+ header = self.horizontalHeader()
+ if hasattr(header, 'setSectionResizeMode'): # Qt5
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+ else: # Qt4
+ header.setResizeMode(qt.QHeaderView.ResizeToContents)
+ self.sortByColumn(0, qt.Qt.AscendingOrder)
+ self.hideColumn(self.COLUMNS_INDEX['ID'])
+
+ def setPlot(self, plot):
+ self.clear()
+ self.plot = plot
+
+ def __setTooltip(self):
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['ROI']).setToolTip(
+ 'Region of interest identifier')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Type']).setToolTip(
+ 'Type of the ROI')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['From']).setToolTip(
+ 'X-value of the min point')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['To']).setToolTip(
+ 'X-value of the max point')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Raw Counts']).setToolTip(
+ 'Estimation of the integral between y=0 and the selected curve')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Net Counts']).setToolTip(
+ 'Estimation of the integral between the segment [maxPt, minPt] '
+ 'and the selected curve')
+
+ def setRois(self, rois, order=None):
+ """Set the ROIs by providing a dictionary of ROI information.
+
+ The dictionary keys are the ROI names.
+ Each value is a sub-dictionary of ROI info with the following fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+
+ :param roidict: Dictionary of ROIs
+ :param str order: Field used for ordering the ROIs.
+ One of "from", "to", "type".
+ None (default) for no ordering, or same order as specified
+ in parameter ``roidict`` if provided as an OrderedDict.
+ """
+ assert order in [None, "from", "to", "type"]
+ self.clear()
+
+ # backward compatibility since 0.10.0
+ if isinstance(rois, dict):
+ for roiName, roi in rois.items():
+ roi['name'] = roiName
+ _roi = ROI._fromDict(roi)
+ self.addRoi(_roi)
+ else:
+ for roi in rois:
+ assert isinstance(roi, ROI)
+ self.addRoi(roi)
+ self._updateMarkers()
+
+ def addRoi(self, roi):
+ """
+
+ :param :class:`ROI` roi: roi to add to the table
+ """
+ assert isinstance(roi, ROI)
+ self._getItem(name='ID', row=None, roi=roi)
+ self._roiDict[roi.getID()] = roi
+ self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot))
+ self._updateRoiInfo(roi.getID())
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
+ roi.getID())
+ roi.sigChanged.connect(callback)
+ # set it as the active one
+ self.setActiveRoi(roi)
+
+ def _getItem(self, name, row, roi):
+ if row:
+ item = self.item(row, self.COLUMNS_INDEX[name])
+ else:
+ item = None
+ if item:
+ return item
+ else:
+ if name == 'ID':
+ assert roi
+ if roi.getID() in self._roiToItems:
+ return self._roiToItems[roi.getID()]
+ else:
+ # create a new row
+ row = self.rowCount()
+ self.setRowCount(self.rowCount() + 1)
+ item = qt.QTableWidgetItem(str(roi.getID()),
+ type=qt.QTableWidgetItem.Type)
+ self._roiToItems[roi.getID()] = item
+ elif name == 'ROI':
+ item = qt.QTableWidgetItem(roi.getName() if roi else '',
+ type=qt.QTableWidgetItem.Type)
+ if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable)
+ elif name == 'Type':
+ item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type)
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
+ elif name in ('To', 'From'):
+ item = _FloatItem()
+ if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable)
+ elif name in ('Raw Counts', 'Net Counts', 'Raw Area', 'Net Area'):
+ item = _FloatItem()
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
else:
- return
- self.calculateRois(roiList, roiDict)
- self._emitCurrentROISignal()
+ raise ValueError('item type not recognized')
+
+ self.setItem(row, self.COLUMNS_INDEX[name], item)
+ return item
+
+ def _itemChanged(self, item):
+ def getRoi():
+ IDItem = self.item(item.row(), self.COLUMNS_INDEX['ID'])
+ assert IDItem
+ id = int(IDItem.text())
+ assert id in self._roiDict
+ roi = self._roiDict[id]
+ return roi
+
+ def signalChanged(roi):
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ self._userIsEditingRoi = True
+ if item.column() in (self.COLUMNS_INDEX['To'], self.COLUMNS_INDEX['From']):
+ roi = getRoi()
+
+ if item.text() not in ('', self.INFO_NOT_FOUND):
+ try:
+ value = float(item.text())
+ except ValueError:
+ value = 0
+ changed = False
+ if item.column() == self.COLUMNS_INDEX['To']:
+ if value != roi.getTo():
+ roi.setTo(value)
+ changed = True
+ else:
+ assert(item.column() == self.COLUMNS_INDEX['From'])
+ if value != roi.getFrom():
+ roi.setFrom(value)
+ changed = True
+ if changed:
+ self._updateMarker(roi.getName())
+ signalChanged(roi)
+
+ if item.column() is self.COLUMNS_INDEX['ROI']:
+ roi = getRoi()
+ if roi.getName() != item.text():
+ roi.setName(item.text())
+ self._markersHandler.getMarkerHandler(roi.getID()).updateTexts()
+ signalChanged(roi)
+
+ self._userIsEditingRoi = False
+
+ def deleteActiveRoi(self):
+ """
+ remove the current active roi
+ """
+ activeItems = self.selectedItems()
+ if len(activeItems) is 0:
+ return
+ roiToRm = set()
+ for item in activeItems:
+ row = item.row()
+ itemID = self.item(row, self.COLUMNS_INDEX['ID'])
+ roiToRm.add(self._roiDict[int(itemID.text())])
+ [self.removeROI(roi) for roi in roiToRm]
+ self.setActiveRoi(None)
+
+ def removeROI(self, roi):
+ """
+ remove the requested roi
- def _visibilityChangedHandler(self, visible):
- """Handle widget's visibility updates.
+ :param str name: the name of the roi to remove from the table
+ """
+ if roi and roi.getID() in self._roiToItems:
+ item = self._roiToItems[roi.getID()]
+ self.removeRow(item.row())
+ del self._roiToItems[roi.getID()]
- It is connected to plot signals only when visible.
+ assert roi.getID() in self._roiDict
+ del self._roiDict[roi.getID()]
+ self._markersHandler.remove(roi)
+
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
+ roi.getID())
+ roi.sigChanged.connect(callback)
+
+ def setActiveRoi(self, roi):
"""
- plot = self.getPlotWidget()
+ Define the given roi as the active one.
- if visible:
- if not self._isInit:
- # Deferred ROI widget init finalization
- self._finalizeInit()
-
- if not self._isConnected and plot is not None:
- plot.sigPlotSignal.connect(self._handleROIMarkerEvent)
- plot.sigActiveCurveChanged.connect(
- self._activeCurveChanged)
- self._isConnected = True
+ .. warning:: this roi should already be registred / added to the table
- self.calculateRois()
+ :param :class:`ROI` roi: the roi to defined as active
+ """
+ if roi is None:
+ self.clearSelection()
+ self._markersHandler.setActiveRoi(None)
+ self.activeROIChanged.emit()
else:
- if self._isConnected:
- if plot is not None:
- plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent)
- plot.sigActiveCurveChanged.disconnect(
- self._activeCurveChanged)
- self._isConnected = False
+ assert isinstance(roi, ROI)
+ if roi and roi.getID() in self._roiToItems.keys():
+ self.selectRow(self._roiToItems[roi.getID()].row())
+ self._markersHandler.setActiveRoi(roi)
+ self.activeROIChanged.emit()
+
+ def _updateRoiInfo(self, roiID):
+ if self._userIsEditingRoi is True:
+ return
+ if roiID not in self._roiDict:
+ return
+ roi = self._roiDict[roiID]
+ if roi.isICR():
+ activeCurve = self.plot.getActiveCurve()
+ if activeCurve:
+ xData = activeCurve.getXData()
+ if len(xData) > 0:
+ min, max = min_max(xData)
+ roi.blockSignals(True)
+ roi.setFrom(min)
+ roi.setTo(max)
+ roi.blockSignals(False)
+
+ itemID = self._getItem(name='ID', roi=roi, row=None)
+ itemName = self._getItem(name='ROI', row=itemID.row(), roi=roi)
+ itemName.setText(roi.getName())
+
+ itemType = self._getItem(name='Type', row=itemID.row(), roi=roi)
+ itemType.setText(roi.getType() or self.INFO_NOT_FOUND)
+
+ itemFrom = self._getItem(name='From', row=itemID.row(), roi=roi)
+ fromdata = str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND
+ itemFrom.setText(fromdata)
+
+ itemTo = self._getItem(name='To', row=itemID.row(), roi=roi)
+ todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND
+ itemTo.setText(todata)
+
+ rawCounts, netCounts = roi.computeRawAndNetCounts(
+ curve=self.plot.getActiveCurve(just_legend=False))
+ itemRawCounts = self._getItem(name='Raw Counts', row=itemID.row(),
+ roi=roi)
+ rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND
+ itemRawCounts.setText(rawCounts)
+
+ itemNetCounts = self._getItem(name='Net Counts', row=itemID.row(),
+ roi=roi)
+ netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND
+ itemNetCounts.setText(netCounts)
+
+ rawArea, netArea = roi.computeRawAndNetArea(
+ curve=self.plot.getActiveCurve(just_legend=False))
+ itemRawArea = self._getItem(name='Raw Area', row=itemID.row(),
+ roi=roi)
+ rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND
+ itemRawArea.setText(rawArea)
+
+ itemNetArea = self._getItem(name='Net Area', row=itemID.row(),
+ roi=roi)
+ netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND
+ itemNetArea.setText(netArea)
+
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ def currentChanged(self, current, previous):
+ if previous and current.row() != previous.row() and current.row() >= 0:
+ roiItem = self.item(current.row(),
+ self.COLUMNS_INDEX['ID'])
+
+ assert roiItem
+ self.setActiveRoi(self._roiDict[int(roiItem.text())])
+ self._markersHandler.updateAllMarkers()
+ qt.QTableWidget.currentChanged(self, current, previous)
+
+ @deprecation.deprecated(reason="Removed",
+ replacement="roidict and roidict.values()",
+ since_version="0.10.0")
+ def getROIListAndDict(self):
+ """
- def _activeCurveChanged(self, *args):
- """Recompute ROIs when active curve changed."""
- self.calculateRois()
+ :return: the list of roi objects and the dictionary of roi name to roi
+ object.
+ """
+ roidict = self._roiDict
+ return list(roidict.values()), roidict
- def _finalizeInit(self):
- self._isInit = True
- self.sigROIWidgetSignal.connect(self._roiSignal)
- # initialize with the ICR if no ROi existing yet
- if len(self.getRois()) is 0:
- self._roiSignal({'event': "AddROI"})
+ def calculateRois(self, roiList=None, roiDict=None):
+ """
+ Update values of all registred rois (raw and net counts in particular)
+ :param roiList: deprecated parameter
+ :param roiDict: deprecated parameter
+ """
+ if roiDict:
+ deprecation.deprecated_warning(name='roiDict', type_='Parameter',
+ reason='Unused parameter',
+ since_version="0.10.0")
+ if roiList:
+ deprecation.deprecated_warning(name='roiList', type_='Parameter',
+ reason='Unused parameter',
+ since_version="0.10.0")
+
+ for roiID in self._roiDict:
+ self._updateRoiInfo(roiID)
+
+ def _updateMarker(self, roiID):
+ """Make sure the marker of the given roi name is updated"""
+ if self._showAllMarkers or (self.activeRoi
+ and self.activeRoi.getName() == roiID):
+ self._updateMarkers()
+
+ def _updateMarkers(self):
+ if self._showAllMarkers is True:
+ self._markersHandler.updateMarkers()
+ else:
+ if not self.activeRoi or not self.plot:
+ return
+ assert isinstance(self.activeRoi, ROI)
+ markerHandler = self._markersHandler.getMarkerHandler(self.activeRoi.getID())
+ if markerHandler is not None:
+ markerHandler.updateMarkers()
-class ROITable(qt.QTableWidget):
- """Table widget displaying ROI information.
+ def getRois(self, order):
+ """
+ Return the currently defined ROIs, as an ordered dict.
- See :class:`QTableWidget` for constructor arguments.
- """
+ The dictionary keys are the ROI names.
+ Each value is a :class:`ROI` object..
- sigROITableSignal = qt.Signal(object)
- """Signal of ROI table modifications.
- """
+ :param order: Field used for ordering the ROIs.
+ One of "from", "to", "type", "netcounts", "rawcounts".
+ None (default) to get the same order as displayed in the widget.
+ :return: Ordered dictionary of ROI information
+ """
- def __init__(self, *args, **kwargs):
- super(ROITable, self).__init__(*args, **kwargs)
- self.setRowCount(1)
- self.labels = 'ROI', 'Type', 'From', 'To', 'Raw Counts', 'Net Counts'
- self.setColumnCount(len(self.labels))
- self.setSortingEnabled(False)
+ if order is None or order.lower() == "none":
+ ordered_roilist = list(self._roiDict.values())
+ res = OrderedDict([(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist])
+ else:
+ assert order in ["from", "to", "type", "netcounts", "rawcounts"]
+ ordered_roilist = sorted(self._roiDict.keys(),
+ key=lambda roi_id: self._roiDict[roi_id].get(order))
+ res = OrderedDict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist])
+
+ return res
+
+ def save(self, filename):
+ """
+ Save current ROIs of the widget as a dict of ROI to a file.
+
+ :param str filename: The file to which to save the ROIs
+ """
+ roilist = []
+ roidict = {}
+ for roiID, roi in self._roiDict.items():
+ roilist.append(roi.toDict())
+ roidict[roi.getName()] = roi.toDict()
+ datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}}
+ dictdump.dump(datadict, filename)
- for index, label in enumerate(self.labels):
- item = self.horizontalHeaderItem(index)
- if item is None:
- item = qt.QTableWidgetItem(label,
- qt.QTableWidgetItem.Type)
- item.setText(label)
- self.setHorizontalHeaderItem(index, item)
+ def load(self, filename):
+ """
+ Load ROI widget information from a file storing a dict of ROI.
- self.roidict = {}
- self.roilist = []
+ :param str filename: The file from which to load ROI
+ """
+ roisDict = dictdump.load(filename)
+ rois = []
- self.building = False
- self.fillFromROIDict(roilist=self.roilist, roidict=self.roidict)
+ # Remove rawcounts and netcounts from ROIs
+ for roiDict in roisDict['ROI']['roidict'].values():
+ roiDict.pop('rawcounts', None)
+ roiDict.pop('netcounts', None)
+ rois.append(ROI._fromDict(roiDict))
- self.cellClicked[(int, int)].connect(self._cellClickedSlot)
- self.cellChanged[(int, int)].connect(self._cellChangedSlot)
- verticalHeader = self.verticalHeader()
- verticalHeader.sectionClicked[int].connect(self._rowChangedSlot)
+ self.setRois(rois)
- self.__setTooltip()
+ def showAllMarkers(self, _show=True):
+ """
- def __setTooltip(self):
- assert(self.labels[0] == 'ROI')
- self.horizontalHeaderItem(0).setToolTip('Region of interest identifier')
- assert(self.labels[1] == 'Type')
- self.horizontalHeaderItem(1).setToolTip('Type of the ROI')
- assert(self.labels[2] == 'From')
- self.horizontalHeaderItem(2).setToolTip('X-value of the min point')
- assert(self.labels[3] == 'To')
- self.horizontalHeaderItem(3).setToolTip('X-value of the max point')
- assert(self.labels[4] == 'Raw Counts')
- self.horizontalHeaderItem(4).setToolTip('Estimation of the integral \
- between y=0 and the selected curve')
- assert(self.labels[5] == 'Net Counts')
- self.horizontalHeaderItem(5).setToolTip('Estimation of the integral \
- between the segment [maxPt, minPt] and the selected curve')
+ :param bool _show: if true show all the markers of all the ROIs
+ boundaries otherwise will only show the one of
+ the active ROI.
+ """
+ self._markersHandler.setShowAllMarkers(_show)
+
+ def setMiddleROIMarkerFlag(self, flag=True):
+ """
+ Activate or deactivate middle marker.
+
+ This allows shifting both min and max limits at once, by dragging
+ a marker located in the middle.
+
+ :param bool flag: True to activate middle ROI marker
+ """
+ self._markersHandler._middleROIMarkerFlag = flag
+
+ def _handleROIMarkerEvent(self, ddict):
+ """Handle plot signals related to marker events."""
+ if ddict['event'] == 'markerMoved':
+ label = ddict['label']
+ roiID = self._markersHandler.getRoiID(markerID=label)
+ if roiID:
+ self._markersHandler.changePosition(markerID=label,
+ x=ddict['x'])
+ self._updateRoiInfo(roiID)
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
+
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ assert self.plot
+ if self._isConnected is False:
+ self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ self._isConnected = True
+ self.calculateRois()
+ else:
+ if self._isConnected:
+ self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.disconnect(self._activeCurveChanged)
+ self._isConnected = False
+
+ def _activeCurveChanged(self, curve):
+ self.calculateRois()
+
+ def setCountsVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX['Raw Counts'])
+ self.showColumn(self.COLUMNS_INDEX['Net Counts'])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX['Raw Counts'])
+ self.hideColumn(self.COLUMNS_INDEX['Net Counts'])
+
+ def setAreaVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX['Raw Area'])
+ self.showColumn(self.COLUMNS_INDEX['Net Area'])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX['Raw Area'])
+ self.hideColumn(self.COLUMNS_INDEX['Net Area'])
def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None):
- """Set the ROIs by providing a list of ROI names and a dictionary
- of ROI information for each ROI.
+ """
+ This function API is kept for compatibility.
+ But `setRois` should be preferred.
+ Set the ROIs by providing a list of ROI names and a dictionary
+ of ROI information for each ROI.
The ROI names must match an existing dictionary key.
The name list is used to provide an order for the ROIs.
-
The dictionary's values are sub-dictionaries containing 3
mandatory fields:
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
:param roilist: List of ROI names (keys of roidict)
:type roilist: List
:param dict roidict: Dict of ROI information
:param currentroi: Name of the selected ROI or None (no selection)
"""
- if roidict is None:
- roidict = {}
-
- self.building = True
- line0 = 0
- self.roilist = []
- self.roidict = {}
- for key in roilist:
- if key in roidict.keys():
- roi = roidict[key]
- self.roilist.append(key)
- self.roidict[key] = {}
- self.roidict[key].update(roi)
- line0 = line0 + 1
- nlines = self.rowCount()
- if (line0 > nlines):
- self.setRowCount(line0)
- line = line0 - 1
- self.roidict[key]['line'] = line
- ROI = key
- roitype = "%s" % roi['type']
- fromdata = "%6g" % (roi['from'])
- todata = "%6g" % (roi['to'])
- if 'rawcounts' in roi:
- rawcounts = "%6g" % (roi['rawcounts'])
- else:
- rawcounts = " ?????? "
- if 'netcounts' in roi:
- netcounts = "%6g" % (roi['netcounts'])
- else:
- netcounts = " ?????? "
- fields = [ROI, roitype, fromdata, todata, rawcounts, netcounts]
- col = 0
- for field in fields:
- key2 = self.item(line, col)
- if key2 is None:
- key2 = qt.QTableWidgetItem(field,
- qt.QTableWidgetItem.Type)
- self.setItem(line, col, key2)
- else:
- key2.setText(field)
- if (ROI.upper() == 'ICR') or (ROI.upper() == 'DEFAULT'):
- key2.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled)
- else:
- if col in [0, 2, 3]:
- key2.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled |
- qt.Qt.ItemIsEditable)
- else:
- key2.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled)
- col = col + 1
- self.setRowCount(line0)
- i = 0
- for _label in self.labels:
- self.resizeColumnToContents(i)
- i = i + 1
- self.sortByColumn(2, qt.Qt.AscendingOrder)
- for i in range(len(self.roilist)):
- key = str(self.item(i, 0).text())
- self.roilist[i] = key
- self.roidict[key]['line'] = i
- if len(self.roilist) == 1:
- self.selectRow(0)
+ if roidict is not None:
+ self.setRois(roidict)
else:
- if currentroi in self.roidict.keys():
- self.selectRow(self.roidict[currentroi]['line'])
- _logger.debug("Qt4 ensureCellVisible to be implemented")
- self.building = False
+ self.setRois(roilist)
+ if currentroi:
+ self.setActiveRoi(currentroi)
- def getROIListAndDict(self):
- """Return the currently defined ROIs, as a 2-tuple
- ``(roiList, roiDict)``
- ``roiList`` is a list of ROI names.
- ``roiDict`` is a dictionary of ROI info.
+_indexNextROI = 0
- The ROI names must match an existing dictionary key.
- The name list is used to provide an order for the ROIs.
- The dictionary's values are sub-dictionaries containing 3
- fields:
+class ROI(qt.QObject):
+ """The Region Of Interest is defined by:
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+ - A name
+ - A type. The type is the label of the x axis. This can be used to apply or
+ not some ROI to a curve and do some post processing.
+ - The x coordinate of the left limit (fromdata)
+ - The x coordinate of the right limit (todata)
+ :param str: name of the ROI
+ :param fromdata: left limit of the roi
+ :param todata: right limit of the roi
+ :param type: type of the ROI
+ """
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the ROI is edited"""
+
+ def __init__(self, name, fromdata=None, todata=None, type_=None):
+ qt.QObject.__init__(self)
+ assert type(name) is str
+ global _indexNextROI
+ self._id = _indexNextROI
+ _indexNextROI += 1
+
+ self._name = name
+ self._fromdata = fromdata
+ self._todata = todata
+ self._type = type_ or 'Default'
- :return: ordered dict as a tuple of (list of ROI names, dict of info)
+ def getID(self):
"""
- return self.roilist, self.roidict
- def _cellClickedSlot(self, *var, **kw):
- # selection changed event, get the current selection
- row = self.currentRow()
- col = self.currentColumn()
- if row >= 0 and row < len(self.roilist):
- item = self.item(row, 0)
- text = '' if item is None else str(item.text())
- self.roilist[row] = text
- self._emitSelectionChangedSignal(row, col)
+ :return int: the unique ID of the ROI
+ """
+ return self._id
- def _rowChangedSlot(self, row):
- self._emitSelectionChangedSignal(row, 0)
+ def setType(self, type_):
+ """
- def _cellChangedSlot(self, row, col):
- _logger.debug("_cellChangedSlot(%d, %d)", row, col)
- if self.building:
- return
- if col == 0:
- self.nameSlot(row, col)
+ :param str type_:
+ """
+ if self._type != type_:
+ self._type = type_
+ self.sigChanged.emit()
+
+ def getType(self):
+ """
+
+ :return str: the type of the ROI.
+ """
+ return self._type
+
+ def setName(self, name):
+ """
+ Set the name of the :class:`ROI`
+
+ :param str name:
+ """
+ if self._name != name:
+ self._name = name
+ self.sigChanged.emit()
+
+ def getName(self):
+ """
+
+ :return str: name of the :class:`ROI`
+ """
+ return self._name
+
+ def setFrom(self, frm):
+ """
+
+ :param frm: set x coordinate of the left limit
+ """
+ if self._fromdata != frm:
+ self._fromdata = frm
+ self.sigChanged.emit()
+
+ def getFrom(self):
+ """
+
+ :return: x coordinate of the left limit
+ """
+ return self._fromdata
+
+ def setTo(self, to):
+ """
+
+ :param to: x coordinate of the right limit
+ """
+ if self._todata != to:
+ self._todata = to
+ self.sigChanged.emit()
+
+ def getTo(self):
+ """
+
+ :return: x coordinate of the right limit
+ """
+ return self._todata
+
+ def getMiddle(self):
+ """
+
+ :return: middle position between 'from' and 'to' values
+ """
+ return 0.5 * (self.getFrom() + self.getTo())
+
+ def toDict(self):
+ """
+
+ :return: dict containing the roi parameters
+ """
+ ddict = {
+ 'type': self._type,
+ 'name': self._name,
+ 'from': self._fromdata,
+ 'to': self._todata,
+ }
+ if hasattr(self, '_extraInfo'):
+ ddict.update(self._extraInfo)
+ return ddict
+
+ @staticmethod
+ def _fromDict(dic):
+ assert 'name' in dic
+ roi = ROI(name=dic['name'])
+ roi._extraInfo = {}
+ for key in dic:
+ if key == 'from':
+ roi.setFrom(dic['from'])
+ elif key == 'to':
+ roi.setTo(dic['to'])
+ elif key == 'type':
+ roi.setType(dic['type'])
+ else:
+ roi._extraInfo[key] = dic[key]
+
+ return roi
+
+ def isICR(self):
+ """
+
+ :return: True if the ROI is the `ICR`
+ """
+ return self._name == 'ICR'
+
+ def computeRawAndNetCounts(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw count: Points values sum of the curve in the defined Region Of
+ Interest.
+
+ .. image:: img/rawCounts.png
+
+ - Net count: Raw counts minus background
+
+ .. image:: img/netCounts.png
+
+ :param CurveItem curve:
+ :return tuple: rawCount, netCount
+ """
+ assert isinstance(curve, Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ idx = numpy.nonzero((self._fromdata <= x) &
+ (x <= self._todata))[0]
+ if len(idx):
+ xw = x[idx]
+ yw = y[idx]
+ rawCounts = yw.sum(dtype=numpy.float)
+ deltaX = xw[-1] - xw[0]
+ deltaY = yw[-1] - yw[0]
+ if deltaX > 0.0:
+ slope = (deltaY / deltaX)
+ background = yw[0] + slope * (xw - xw[0])
+ netCounts = (rawCounts -
+ background.sum(dtype=numpy.float))
+ else:
+ netCounts = 0.0
else:
- self._valueChanged(row, col)
+ rawCounts = 0.0
+ netCounts = 0.0
+ return rawCounts, netCounts
+
+ def computeRawAndNetArea(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw area: integral of the curve between the min ROI point and the
+ max ROI point to the y = 0 line.
- def _valueChanged(self, row, col):
- if col not in [2, 3]:
+ .. image:: img/rawArea.png
+
+ - Net area: Raw counts minus background
+
+ .. image:: img/netArea.png
+
+ :param CurveItem curve:
+ :return tuple: rawArea, netArea
+ """
+ assert isinstance(curve, Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ y = y[(x >= self._fromdata) & (x <= self._todata)]
+ x = x[(x >= self._fromdata) & (x <= self._todata)]
+
+ if x.size is 0:
+ return 0.0, 0.0
+
+ rawArea = numpy.trapz(y, x=x)
+ # to speed up and avoid an intersection calculation we are taking the
+ # closest index to the ROI
+ closestXLeftIndex = (numpy.abs(x - self.getFrom())).argmin()
+ closestXRightIndex = (numpy.abs(x - self.getTo())).argmin()
+ yBackground = y[closestXLeftIndex], y[closestXRightIndex]
+ background = numpy.trapz(yBackground, x=x)
+ netArea = rawArea - background
+ return rawArea, netArea
+
+
+class _RoiMarkerManager(object):
+ """
+ Deal with all the ROI markers
+ """
+ def __init__(self):
+ self._roiMarkerHandlers = {}
+ self._middleROIMarkerFlag = False
+ self._showAllMarkers = False
+ self._activeRoi = None
+
+ def setActiveRoi(self, roi):
+ self._activeRoi = roi
+ self.updateAllMarkers()
+
+ def setShowAllMarkers(self, show):
+ if show != self._showAllMarkers:
+ self._showAllMarkers = show
+ self.updateAllMarkers()
+
+ def add(self, roi, markersHandler):
+ assert isinstance(roi, ROI)
+ assert isinstance(markersHandler, _RoiMarkerHandler)
+ if roi.getID() in self._roiMarkerHandlers:
+ raise ValueError('roi with the same ID already existing')
+ else:
+ self._roiMarkerHandlers[roi.getID()] = markersHandler
+
+ def getMarkerHandler(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ return self._roiMarkerHandlers[roiID]
+ else:
+ return None
+
+ def clear(self):
+ roisHandler = list(self._roiMarkerHandlers.values())
+ for roiHandler in roisHandler:
+ self.remove(roiHandler.roi)
+
+ def remove(self, roi):
+ if roi is None:
return
- item = self.item(row, col)
- if item is None:
+ assert isinstance(roi, ROI)
+ if roi.getID() in self._roiMarkerHandlers:
+ self._roiMarkerHandlers[roi.getID()].clear()
+ del self._roiMarkerHandlers[roi.getID()]
+
+ def hasMarker(self, markerID):
+ assert type(markerID) is str
+ return self.getMarker(markerID) is not None
+
+ def changePosition(self, markerID, x):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError('Marker %s not register' % markerID)
+ markerHandler.changePosition(markerID=markerID, x=x)
+
+ def updateMarker(self, markerID):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError('Marker %s not register' % markerID)
+ roiID = self.getRoiID(markerID)
+ visible = (self._activeRoi and self._activeRoi.getID() == roiID) or self._showAllMarkers is True
+ markerHandler.setVisible(visible)
+ markerHandler.updateAllMarkers()
+
+ def updateRoiMarkers(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ visible = ((self._activeRoi and self._activeRoi.getID() == roiID)
+ or self._showAllMarkers is True)
+ _roi = self._roiMarkerHandlers[roiID]._roi()
+ if _roi and not _roi.isICR():
+ self._roiMarkerHandlers[roiID].showMiddleMarker(self._middleROIMarkerFlag)
+ self._roiMarkerHandlers[roiID].setVisible(visible)
+ self._roiMarkerHandlers[roiID].updateMarkers()
+
+ def getMarker(self, markerID):
+ assert type(markerID) is str
+ for marker in list(self._roiMarkerHandlers.values()):
+ if marker.hasMarker(markerID):
+ return marker
+
+ def updateMarkers(self):
+ for markerHandler in list(self._roiMarkerHandlers.values()):
+ markerHandler.updateMarkers()
+
+ def getRoiID(self, markerID):
+ for roiID, markerHandler in self._roiMarkerHandlers.items():
+ if markerHandler.hasMarker(markerID):
+ return roiID
+ return None
+
+ def setShowMiddleMarkers(self, show):
+ self._middleROIMarkerFlag = show
+ self._roiMarkerHandlers.updateAllMarkers()
+
+ def updateAllMarkers(self):
+ for roiID in self._roiMarkerHandlers:
+ self.updateRoiMarkers(roiID)
+
+ def getVisibleRois(self):
+ res = {}
+ for roiID, roiHandler in self._roiMarkerHandlers.items():
+ markers = (roiHandler.getMarker('min'), roiHandler.getMarker('max'),
+ roiHandler.getMarker('middle'))
+ for marker in markers:
+ if marker.isVisible():
+ if roiID not in res:
+ res[roiID] = []
+ res[roiID].append(marker)
+ return res
+
+
+class _RoiMarkerHandler(object):
+ """Used to deal with ROI markers used in ROITable"""
+ def __init__(self, roi, plot):
+ assert roi and isinstance(roi, ROI)
+ assert plot
+
+ self._roi = weakref.ref(roi)
+ self._plot = weakref.ref(plot)
+ self._draggable = False if roi.isICR() else True
+ self._color = 'black' if roi.isICR() else 'blue'
+ self._displayMidMarker = False
+ self._visible = True
+
+ @property
+ def draggable(self):
+ return self._draggable
+
+ @property
+ def plot(self):
+ return self._plot()
+
+ def clear(self):
+ if self.plot and self.roi:
+ self.plot.removeMarker(self._markerID('min'))
+ self.plot.removeMarker(self._markerID('max'))
+ self.plot.removeMarker(self._markerID('middle'))
+
+ @property
+ def roi(self):
+ return self._roi()
+
+ def setVisible(self, visible):
+ if visible != self._visible:
+ self._visible = visible
+ self.updateMarkers()
+
+ def showMiddleMarker(self, visible):
+ if self.draggable is False and visible is True:
+ _logger.warning("ROI is not draggable. Won't display middle marker")
return
- text = str(item.text())
- try:
- value = float(text)
- except:
+ self._displayMidMarker = visible
+ self.getMarker('middle').setVisible(self._displayMidMarker)
+
+ def updateMarkers(self):
+ if self.roi is None:
return
- if row >= len(self.roilist):
- _logger.debug("deleting???")
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+ self._updateMiddleMarkerPos()
+
+ def _updateMinMarkerPos(self):
+ self.getMarker('min').setPosition(x=self.roi.getFrom(), y=None)
+ self.getMarker('min').setVisible(self._visible)
+
+ def _updateMaxMarkerPos(self):
+ self.getMarker('max').setPosition(x=self.roi.getTo(), y=None)
+ self.getMarker('max').setVisible(self._visible)
+
+ def _updateMiddleMarkerPos(self):
+ self.getMarker('middle').setPosition(x=self.roi.getMiddle(), y=None)
+ self.getMarker('middle').setVisible(self._displayMidMarker and self._visible)
+
+ def getMarker(self, markerType):
+ if self.plot is None:
+ return None
+ assert markerType in ('min', 'max', 'middle')
+ if self.plot._getMarker(self._markerID(markerType)) is None:
+ assert self.roi
+ if markerType == 'min':
+ val = self.roi.getFrom()
+ elif markerType == 'max':
+ val = self.roi.getTo()
+ else:
+ val = self.roi.getMiddle()
+
+ _color = self._color
+ if markerType == 'middle':
+ _color = 'yellow'
+ self.plot.addXMarker(val,
+ legend=self._markerID(markerType),
+ text=self.getMarkerName(markerType),
+ color=_color,
+ draggable=self.draggable)
+ return self.plot._getMarker(self._markerID(markerType))
+
+ def _markerID(self, markerType):
+ assert markerType in ('min', 'max', 'middle')
+ assert self.roi
+ return '_'.join((str(self.roi.getID()), markerType))
+
+ def getMarkerName(self, markerType):
+ assert markerType in ('min', 'max', 'middle')
+ assert self.roi
+ return ' '.join((self.roi.getName(), markerType))
+
+ def updateTexts(self):
+ self.getMarker('min').setText(self.getMarkerName('min'))
+ self.getMarker('max').setText(self.getMarkerName('max'))
+ self.getMarker('middle').setText(self.getMarkerName('middle'))
+
+ def changePosition(self, markerID, x):
+ assert self.hasMarker(markerID)
+ markerType = self._getMarkerType(markerID)
+ assert markerType is not None
+ if self.roi is None:
return
- item = self.item(row, 0)
- if item is None:
- text = ""
+ if markerType == 'min':
+ self.roi.setFrom(x)
+ self._updateMiddleMarkerPos()
+ elif markerType == 'max':
+ self.roi.setTo(x)
+ self._updateMiddleMarkerPos()
else:
- text = str(item.text())
- if not len(text):
- return
- if col == 2:
- self.roidict[text]['from'] = value
- elif col == 3:
- self.roidict[text]['to'] = value
- self._emitSelectionChangedSignal(row, col)
-
- def nameSlot(self, row, col):
- if col != 0:
- return
- if row >= len(self.roilist):
- _logger.debug("deleting???")
- return
- item = self.item(row, col)
- if item is None:
- text = ""
+ delta = x - 0.5 * (self.roi.getFrom() + self.roi.getTo())
+ self.roi.setFrom(self.roi.getFrom() + delta)
+ self.roi.setTo(self.roi.getTo() + delta)
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+
+ def hasMarker(self, marker):
+ return marker in (self._markerID('min'),
+ self._markerID('max'),
+ self._markerID('middle'))
+
+ def _getMarkerType(self, markerID):
+ if markerID.endswith('_min'):
+ return 'min'
+ elif markerID.endswith('_max'):
+ return 'max'
+ elif markerID.endswith('_middle'):
+ return 'middle'
else:
- text = str(item.text())
- if len(text) and (text not in self.roilist):
- old = self.roilist[row]
- self.roilist[row] = text
- self.roidict[text] = {}
- self.roidict[text].update(self.roidict[old])
- del self.roidict[old]
- self._emitSelectionChangedSignal(row, col)
-
- def _emitSelectionChangedSignal(self, row, col):
- ddict = {}
- ddict['event'] = "selectionChanged"
- ddict['row'] = row
- ddict['col'] = col
- ddict['roi'] = self.roidict[self.roilist[row]]
- ddict['key'] = self.roilist[row]
- ddict['colheader'] = self.labels[col]
- ddict['rowheader'] = "%d" % row
- self.sigROITableSignal.emit(ddict)
+ return None
class CurvesROIDockWidget(qt.QDockWidget):
@@ -1007,6 +1520,8 @@ class CurvesROIDockWidget(qt.QDockWidget):
def __init__(self, parent=None, plot=None, name=None):
super(CurvesROIDockWidget, self).__init__(name, parent)
+ assert plot is not None
+ self.plot = plot
self.roiWidget = CurvesROIWidget(self, name, plot=plot)
"""Main widget of type :class:`CurvesROIWidget`"""
@@ -1016,12 +1531,15 @@ class CurvesROIDockWidget(qt.QDockWidget):
self.calculateROIs = self.calculateRois = self.roiWidget.calculateRois
self.setRois = self.roiWidget.setRois
self.getRois = self.roiWidget.getRois
+
self.roiWidget.sigROISignal.connect(self._forwardSigROISignal)
- self.currentROI = self.roiWidget.currentROI
self.layout().setContentsMargins(0, 0, 0, 0)
self.setWidget(self.roiWidget)
+ self.setAreaVisible = self.roiWidget.roiTable.setAreaVisible
+ self.setCountsVisible = self.roiWidget.roiTable.setCountsVisible
+
def _forwardSigROISignal(self, ddict):
# emit deprecated signal for backward compatibility (silx < 0.7)
self.sigROISignal.emit(ddict)
@@ -1042,3 +1560,7 @@ class CurvesROIDockWidget(qt.QDockWidget):
"""
self.raise_()
qt.QDockWidget.showEvent(self, event)
+
+ @property
+ def currentROI(self):
+ return self.roiWidget.currentRoi
diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py
index 990e479..9d727e7 100644
--- a/silx/gui/plot/MaskToolsWidget.py
+++ b/silx/gui/plot/MaskToolsWidget.py
@@ -35,7 +35,7 @@ from __future__ import division
__authors__ = ["T. Vincent", "P. Knobel"]
__license__ = "MIT"
-__date__ = "29/08/2018"
+__date__ = "15/02/2019"
import os
@@ -57,10 +57,7 @@ from .. import qt
from silx.third_party.EdfFile import EdfFile
from silx.third_party.TiffIO import TiffIO
-try:
- import fabio
-except ImportError:
- fabio = None
+import fabio
_logger = logging.getLogger(__name__)
@@ -135,8 +132,6 @@ class ImageMask(BaseMask):
self._saveToHdf5(filename, self.getMask(copy=False))
elif kind == 'msk':
- if fabio is None:
- raise ImportError("Fit2d mask files can't be written: Fabio module is not available")
try:
data = self.getMask(copy=False)
image = fabio.fabioimage.FabioImage(data=data)
@@ -250,6 +245,19 @@ class ImageMask(BaseMask):
rows, cols = shapes.circle_fill(crow, ccol, radius)
self.updatePoints(level, rows, cols, mask)
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.ellipse_fill(crow, ccol, radius_r, radius_c)
+ self.updatePoints(level, rows, cols, mask)
+
def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
"""Mask/Unmask a line of the given mask level.
@@ -300,6 +308,10 @@ class MaskToolsWidget(BaseMaskToolsWidget):
_logger.error('Not an image, shape: %d', len(mask.shape))
return None
+ # Handle mask with single level
+ if self.multipleMasks() == 'single':
+ mask = numpy.array(mask != 0, dtype=numpy.uint8)
+
# if mask has not changed, do nothing
if numpy.array_equal(mask, self.getSelectionMask()):
return mask.shape
@@ -501,8 +513,6 @@ class MaskToolsWidget(BaseMaskToolsWidget):
_logger.debug("Backtrace", exc_info=True)
raise e
elif extension == "msk":
- if fabio is None:
- raise ImportError("Fit2d mask files can't be read: Fabio module is not available")
try:
mask = fabio.open(filename).data
except Exception as e:
@@ -682,41 +692,51 @@ class MaskToolsWidget(BaseMaskToolsWidget):
level = self.levelSpinBox.value()
- if (self._drawingMode == 'rectangle' and
- event['event'] == 'drawingFinished'):
- # Convert from plot to array coords
- doMask = self._isMasking()
- ox, oy = self._origin
- sx, sy = self._scale
-
- height = int(abs(event['height'] / sy))
- width = int(abs(event['width'] / sx))
-
- row = int((event['y'] - oy) / sy)
- if sy < 0:
- row -= height
-
- col = int((event['x'] - ox) / sx)
- if sx < 0:
- col -= width
-
- self._mask.updateRectangle(
- level,
- row=row,
- col=col,
- height=height,
- width=width,
- mask=doMask)
- self._mask.commit()
+ if self._drawingMode == 'rectangle':
+ if event['event'] == 'drawingFinished':
+ # Convert from plot to array coords
+ doMask = self._isMasking()
+ ox, oy = self._origin
+ sx, sy = self._scale
+
+ height = int(abs(event['height'] / sy))
+ width = int(abs(event['width'] / sx))
+
+ row = int((event['y'] - oy) / sy)
+ if sy < 0:
+ row -= height
+
+ col = int((event['x'] - ox) / sx)
+ if sx < 0:
+ col -= width
+
+ self._mask.updateRectangle(
+ level,
+ row=row,
+ col=col,
+ height=height,
+ width=width,
+ mask=doMask)
+ self._mask.commit()
- elif (self._drawingMode == 'polygon' and
- event['event'] == 'drawingFinished'):
- doMask = self._isMasking()
- # Convert from plot to array coords
- vertices = (event['points'] - self._origin) / self._scale
- vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col)
- self._mask.updatePolygon(level, vertices, doMask)
- self._mask.commit()
+ elif self._drawingMode == 'ellipse':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ center = (event['points'][0] - self._origin) / self._scale
+ size = event['points'][1] / self._scale
+ center = center.astype(numpy.int) # (row, col)
+ self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'polygon':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ vertices = (event['points'] - self._origin) / self._scale
+ vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
elif self._drawingMode == 'pencil':
doMask = self._isMasking()
@@ -743,6 +763,8 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._lastPencilPos = None
else:
self._lastPencilPos = row, col
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
def _loadRangeFromColormapTriggered(self):
"""Set range from active image colormap range"""
diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py
index 356bda6..27abd10 100644
--- a/silx/gui/plot/PlotInteraction.py
+++ b/silx/gui/plot/PlotInteraction.py
@@ -26,7 +26,7 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "15/02/2019"
import math
@@ -96,10 +96,18 @@ class _PlotInteraction(object):
fill = fill != 'none' # TODO not very nice either
+ greyed = colors.greyed(color)[0]
+ if greyed < 0.5:
+ color2 = "white"
+ else:
+ color2 = "black"
+
self.plot.addItem(points[:, 0], points[:, 1], legend=legend,
replace=False,
- shape=shape, color=color, fill=fill,
+ shape=shape, fill=fill,
+ color=color, linebgcolor=color2, linestyle="--",
overlay=True)
+
self._selectionAreas.add(legend)
def resetSelectionArea(self):
@@ -274,6 +282,8 @@ class Zoom(_ZoomOnWheel):
and zoom on mouse wheel.
"""
+ SURFACE_THRESHOLD = 5
+
def __init__(self, plot, color):
self.color = color
@@ -347,35 +357,44 @@ class Zoom(_ZoomOnWheel):
self.setSelectionArea(corners, fill='none', color=self.color)
- def endDrag(self, startPos, endPos):
- x0, y0 = startPos
- x1, y1 = endPos
+ def _zoom(self, x0, y0, x1, y1):
+ """Zoom to the rectangle view x0,y0 x1,y1.
+ """
+ startPos = x0, y0
+ endPos = x1, y1
+
+ # Store current zoom state in stack
+ self.plot.getLimitsHistory().push()
- if x0 != x1 or y0 != y1: # Avoid empty zoom area
- # Store current zoom state in stack
- self.plot.getLimitsHistory().push()
+ if self.plot.isKeepDataAspectRatio():
+ x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1)
+
+ # Convert to data space and set limits
+ x0, y0 = self.plot.pixelToData(x0, y0, check=False)
- if self.plot.isKeepDataAspectRatio():
- x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1)
+ dataPos = self.plot.pixelToData(
+ startPos[0], startPos[1], axis="right", check=False)
+ y2_0 = dataPos[1]
- # Convert to data space and set limits
- x0, y0 = self.plot.pixelToData(x0, y0, check=False)
+ x1, y1 = self.plot.pixelToData(x1, y1, check=False)
- dataPos = self.plot.pixelToData(
- startPos[0], startPos[1], axis="right", check=False)
- y2_0 = dataPos[1]
+ dataPos = self.plot.pixelToData(
+ endPos[0], endPos[1], axis="right", check=False)
+ y2_1 = dataPos[1]
- x1, y1 = self.plot.pixelToData(x1, y1, check=False)
+ xMin, xMax = min(x0, x1), max(x0, x1)
+ yMin, yMax = min(y0, y1), max(y0, y1)
+ y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1)
- dataPos = self.plot.pixelToData(
- endPos[0], endPos[1], axis="right", check=False)
- y2_1 = dataPos[1]
+ self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
- xMin, xMax = min(x0, x1), max(x0, x1)
- yMin, yMax = min(y0, y1), max(y0, y1)
- y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1)
+ def endDrag(self, startPos, endPos):
+ x0, y0 = startPos
+ x1, y1 = endPos
- self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+ if abs(x0 - x1) * abs(y0 - y1) >= self.SURFACE_THRESHOLD:
+ # Avoid empty zoom area
+ self._zoom(x0, y0, x1, y1)
self.resetSelectionArea()
@@ -544,7 +563,6 @@ class SelectPolygon(Select):
return self.DRAG_THRESHOLD_DIST * ratio
-
class Select2Points(Select):
"""Base class for drawing selection based on 2 input points."""
class Idle(State):
@@ -603,6 +621,87 @@ class Select2Points(Select):
self.cancelSelect()
+class SelectEllipse(Select2Points):
+ """Drawing ellipse selection area state machine."""
+ def beginSelect(self, x, y):
+ self.center = self.plot.pixelToData(x, y)
+ assert self.center is not None
+
+ def _getEllipseSize(self, pointInEllipse):
+ """
+ Returns the size from the center to the bounding box of the ellipse.
+
+ :param Tuple[float,float] pointInEllipse: A point of the ellipse
+ :rtype: Tuple[float,float]
+ """
+ x = abs(self.center[0] - pointInEllipse[0])
+ y = abs(self.center[1] - pointInEllipse[1])
+ if x == 0 or y == 0:
+ return x, y
+ # Ellipse definitions
+ # e: eccentricity
+ # a: length fron center to bounding box width
+ # b: length fron center to bounding box height
+ # Equations
+ # (1) b < a
+ # (2) For x,y a point in the ellipse: x^2/a^2 + y^2/b^2 = 1
+ # (3) b = a * sqrt(1-e^2)
+ # (4) e = sqrt(a^2 - b^2) / a
+
+ # The eccentricity of the ellipse defined by a,b=x,y is the same
+ # as the one we are searching for.
+ swap = x < y
+ if swap:
+ x, y = y, x
+ e = math.sqrt(x**2 - y**2) / x
+ # From (2) using (3) to replace b
+ # a^2 = x^2 + y^2 / (1-e^2)
+ a = math.sqrt(x**2 + y**2 / (1.0 - e**2))
+ b = a * math.sqrt(1 - e**2)
+ if swap:
+ a, b = b, a
+ return a, b
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ # Circle used for circle preview
+ nbpoints = 27.
+ angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints
+ circleShape = numpy.array((numpy.cos(angles) * width,
+ numpy.sin(angles) * height)).T
+ circleShape += numpy.array(self.center)
+
+ self.setSelectionArea(circleShape,
+ shape="polygon",
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'ellipse',
+ (self.center, (width, height)),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'ellipse',
+ (self.center, (width, height)),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
class SelectRectangle(Select2Points):
"""Drawing rectangle selection area state machine."""
def beginSelect(self, x, y):
@@ -1488,6 +1587,7 @@ class PlotInteraction(object):
_DRAW_MODES = {
'polygon': SelectPolygon,
'rectangle': SelectRectangle,
+ 'ellipse': SelectEllipse,
'line': SelectLine,
'vline': SelectVLine,
'hline': SelectHLine,
diff --git a/silx/gui/plot/PlotToolButtons.py b/silx/gui/plot/PlotToolButtons.py
index f6291b5..bf6b8ce 100644
--- a/silx/gui/plot/PlotToolButtons.py
+++ b/silx/gui/plot/PlotToolButtons.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -45,6 +45,7 @@ import weakref
from .. import icons
from .. import qt
+from ... import config
from .items import SymbolMixIn
@@ -250,24 +251,24 @@ class ProfileOptionToolButton(PlotToolButton):
self.STATE = {}
# is down
self.STATE['sum', "icon"] = icons.getQIcon('math-sigma')
- self.STATE['sum', "state"] = "compute profile sum"
- self.STATE['sum', "action"] = "compute profile sum"
+ self.STATE['sum', "state"] = "Compute profile sum"
+ self.STATE['sum', "action"] = "Compute profile sum"
# keep ration
self.STATE['mean', "icon"] = icons.getQIcon('math-mean')
- self.STATE['mean', "state"] = "compute profile mean"
- self.STATE['mean', "action"] = "compute profile mean"
+ self.STATE['mean', "state"] = "Compute profile mean"
+ self.STATE['mean', "action"] = "Compute profile mean"
- sumAction = self._createAction('sum')
- sumAction.triggered.connect(self.setSum)
- sumAction.setIconVisibleInMenu(True)
+ self.sumAction = self._createAction('sum')
+ self.sumAction.triggered.connect(self.setSum)
+ self.sumAction.setIconVisibleInMenu(True)
- meanAction = self._createAction('mean')
- meanAction.triggered.connect(self.setMean)
- meanAction.setIconVisibleInMenu(True)
+ self.meanAction = self._createAction('mean')
+ self.meanAction.triggered.connect(self.setMean)
+ self.meanAction.setIconVisibleInMenu(True)
menu = qt.QMenu(self)
- menu.addAction(sumAction)
- menu.addAction(meanAction)
+ menu.addAction(self.sumAction)
+ menu.addAction(self.meanAction)
self.setMenu(menu)
self.setPopupMode(qt.QToolButton.InstantPopup)
self.setMean()
@@ -370,7 +371,7 @@ class SymbolToolButton(PlotToolButton):
slider = qt.QSlider(qt.Qt.Horizontal)
slider.setRange(1, 20)
- slider.setValue(SymbolMixIn._DEFAULT_SYMBOL_SIZE)
+ slider.setValue(config.DEFAULT_PLOT_SYMBOL_SIZE)
slider.setTracking(False)
slider.valueChanged.connect(self._sizeChanged)
widgetAction = qt.QWidgetAction(menu)
diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py
index e023a21..cfe39fa 100644
--- a/silx/gui/plot/PlotWidget.py
+++ b/silx/gui/plot/PlotWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,7 +31,7 @@ from __future__ import division
__authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
-__date__ = "12/10/2018"
+__date__ = "21/12/2018"
from collections import OrderedDict, namedtuple
@@ -44,7 +44,6 @@ import numpy
import silx
from silx.utils.weakref import WeakMethodProxy
-from silx.utils import deprecation
from silx.utils.property import classproperty
from silx.utils.deprecation import deprecated
# Import matplotlib backend here to init matplotlib our way
@@ -99,7 +98,7 @@ class PlotWidget(qt.QMainWindow):
# TODO: Can be removed for silx 0.10
@classproperty
- @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
+ @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
def DEFAULT_BACKEND(self):
"""Class attribute setting the default backend for all instances."""
return silx.config.DEFAULT_PLOT_BACKEND
@@ -193,21 +192,12 @@ class PlotWidget(qt.QMainWindow):
It provides the visible state.
"""
- def __init__(self, parent=None, backend=None,
- legends=False, callback=None, **kw):
+ def __init__(self, parent=None, backend=None):
self._autoreplot = False
self._dirty = False
self._cursorInPlot = False
self.__muteActiveItemChanged = False
- if kw:
- _logger.warning(
- 'deprecated: __init__ extra arguments: %s', str(kw))
- if legends:
- _logger.warning('deprecated: __init__ legend argument')
- if callback:
- _logger.warning('deprecated: __init__ callback argument')
-
self._panWithArrowKeys = True
self._viewConstrains = None
@@ -218,27 +208,8 @@ class PlotWidget(qt.QMainWindow):
else:
self.setWindowTitle('PlotWidget')
- if backend is None:
- backend = silx.config.DEFAULT_PLOT_BACKEND
-
- if hasattr(backend, "__call__"):
- self._backend = backend(self, parent)
-
- elif hasattr(backend, "lower"):
- lowerCaseString = backend.lower()
- if lowerCaseString in ("matplotlib", "mpl"):
- backendClass = BackendMatplotlibQt
- elif lowerCaseString in ('gl', 'opengl'):
- from .backends.BackendOpenGL import BackendOpenGL
- backendClass = BackendOpenGL
- elif lowerCaseString == 'none':
- from .backends.BackendBase import BackendBase as backendClass
- else:
- raise ValueError("Backend not supported %s" % backend)
- self._backend = backendClass(self, parent)
-
- else:
- raise ValueError("Backend not supported %s" % str(backend))
+ self._backend = None
+ self._setBackend(backend)
self.setCallback() # set _callback
@@ -258,6 +229,12 @@ class PlotWidget(qt.QMainWindow):
self._activeLegend = {'curve': None, 'image': None,
'scatter': None}
+ # plot colors (updated later to sync backend)
+ self._foregroundColor = 0., 0., 0., 1.
+ self._gridColor = .7, .7, .7, 1.
+ self._backgroundColor = 1., 1., 1., 1.
+ self._dataBackgroundColor = None
+
# default properties
self._cursorConfiguration = None
@@ -275,7 +252,7 @@ class PlotWidget(qt.QMainWindow):
self.setDefaultColormap() # Init default colormap
- self.setDefaultPlotPoints(False)
+ self.setDefaultPlotPoints(silx.config.DEFAULT_PLOT_CURVE_SYMBOL_MODE)
self.setDefaultPlotLines(True)
self._limitsHistory = LimitsHistory(self)
@@ -306,9 +283,41 @@ class PlotWidget(qt.QMainWindow):
self.setGraphYLimits(0., 100., axis='right')
self.setGraphYLimits(0., 100., axis='left')
+ # Sync backend colors with default ones
+ self._foregroundColorsUpdated()
+ self._backgroundColorsUpdated()
+
+ def _setBackend(self, backend):
+ """Setup a new backend"""
+ assert(self._backend is None)
+
+ if backend is None:
+ backend = silx.config.DEFAULT_PLOT_BACKEND
+
+ if hasattr(backend, "__call__"):
+ backend = backend(self, self)
+
+ elif hasattr(backend, "lower"):
+ lowerCaseString = backend.lower()
+ if lowerCaseString in ("matplotlib", "mpl"):
+ backendClass = BackendMatplotlibQt
+ elif lowerCaseString in ('gl', 'opengl'):
+ from .backends.BackendOpenGL import BackendOpenGL
+ backendClass = BackendOpenGL
+ elif lowerCaseString == 'none':
+ from .backends.BackendBase import BackendBase as backendClass
+ else:
+ raise ValueError("Backend not supported %s" % backend)
+ backend = backendClass(self, self)
+
+ else:
+ raise ValueError("Backend not supported %s" % str(backend))
+
+ self._backend = backend
+
# TODO: Can be removed for silx 0.10
@staticmethod
- @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
+ @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
def setDefaultBackend(backend):
"""Set system wide default plot backend.
@@ -349,6 +358,119 @@ class PlotWidget(qt.QMainWindow):
if self._autoreplot and not wasDirty and self.isVisible():
self._backend.postRedisplay()
+ def _foregroundColorsUpdated(self):
+ """Handle change of foreground/grid color"""
+ if self._gridColor is None:
+ gridColor = self._foregroundColor
+ else:
+ gridColor = self._gridColor
+ self._backend.setForegroundColors(
+ self._foregroundColor, gridColor)
+ self._setDirtyPlot()
+
+ def getForegroundColor(self):
+ """Returns the RGBA colors used to display the foreground of this widget
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._foregroundColor)
+
+ def setForegroundColor(self, color):
+ """Set the foreground color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._foregroundColorsUpdated()
+
+ def getGridColor(self):
+ """Returns the RGBA colors used to display the grid lines
+
+ An invalid QColor is returned if there is no grid color,
+ in which case the foreground color is used.
+
+ :rtype: qt.QColor
+ """
+ if self._gridColor is None:
+ return qt.QColor() # An invalid color
+ else:
+ return qt.QColor.fromRgbF(*self._gridColor)
+
+ def setGridColor(self, color):
+ """Set the grid lines color
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._gridColor != color:
+ self._gridColor = color
+ self._foregroundColorsUpdated()
+
+ def _backgroundColorsUpdated(self):
+ """Handle change of background/data background color"""
+ if self._dataBackgroundColor is None:
+ dataBGColor = self._backgroundColor
+ else:
+ dataBGColor = self._dataBackgroundColor
+ self._backend.setBackgroundColors(
+ self._backgroundColor, dataBGColor)
+ self._setDirtyPlot()
+
+ def getBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of this widget.
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._backgroundColor)
+
+ def setBackgroundColor(self, color):
+ """Set the background color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._backgroundColor != color:
+ self._backgroundColor = color
+ self._backgroundColorsUpdated()
+
+ def getDataBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of the plot
+ view displaying the data.
+
+ An invalid QColor is returned if there is no data background color.
+
+ :rtype: qt.QColor
+ """
+ if self._dataBackgroundColor is None:
+ # An invalid color
+ return qt.QColor()
+ else:
+ return qt.QColor.fromRgbF(*self._dataBackgroundColor)
+
+ def setDataBackgroundColor(self, color):
+ """Set the background color of this widget.
+
+ Set to None or an invalid QColor to use the background color.
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._dataBackgroundColor != color:
+ self._dataBackgroundColor = color
+ self._backgroundColorsUpdated()
+
def showEvent(self, event):
if self._autoreplot and self._dirty:
self._backend.postRedisplay()
@@ -528,13 +650,13 @@ class PlotWidget(qt.QMainWindow):
# This value is used when curve is updated either internally or by user.
def addCurve(self, x, y, legend=None, info=None,
- replace=False, replot=None,
+ replace=False,
color=None, symbol=None,
linewidth=None, linestyle=None,
xlabel=None, ylabel=None, yaxis=None,
xerror=None, yerror=None, z=None, selectable=None,
fill=None, resetzoom=True,
- histogram=None, copy=True, **kw):
+ histogram=None, copy=True):
"""Add a 1D curve given by x an y to the graph.
Curves are uniquely identified by their legend.
@@ -617,15 +739,6 @@ class PlotWidget(qt.QMainWindow):
False to use provided arrays.
:returns: The key string identify this curve
"""
- # Deprecation warnings
- if replot is not None:
- _logger.warning(
- 'addCurve deprecated replot argument, use resetzoom instead')
- resetzoom = replot and resetzoom
-
- if kw:
- _logger.warning('addCurve: deprecated extra arguments')
-
# This is an histogram, use addHistogram
if histogram is not None:
histoLegend = self.addHistogram(histogram=y,
@@ -825,13 +938,13 @@ class PlotWidget(qt.QMainWindow):
return legend
def addImage(self, data, legend=None, info=None,
- replace=False, replot=None,
- xScale=None, yScale=None, z=None,
+ replace=False,
+ z=None,
selectable=None, draggable=None,
colormap=None, pixmap=None,
xlabel=None, ylabel=None,
origin=None, scale=None,
- resetzoom=True, copy=True, **kw):
+ resetzoom=True, copy=True):
"""Add a 2D dataset or an image to the plot.
It displays either an array of data using a colormap or a RGB(A) image.
@@ -883,28 +996,6 @@ class PlotWidget(qt.QMainWindow):
False to use provided arrays.
:returns: The key string identify this image
"""
- # Deprecation warnings
- if xScale is not None or yScale is not None:
- _logger.warning(
- 'addImage deprecated xScale and yScale arguments,'
- 'use origin, scale arguments instead.')
- if origin is None and scale is None:
- origin = xScale[0], yScale[0]
- scale = xScale[1], yScale[1]
- else:
- _logger.warning(
- 'addCurve: xScale, yScale and origin, scale arguments'
- ' are conflicting. xScale and yScale are ignored.'
- ' Use only origin, scale arguments.')
-
- if replot is not None:
- _logger.warning(
- 'addImage deprecated replot argument, use resetzoom instead')
- resetzoom = replot and resetzoom
-
- if kw:
- _logger.warning('addImage: deprecated extra arguments')
-
legend = "Unnamed Image 1.1" if legend is None else str(legend)
# Check if image was previously active
@@ -1090,7 +1181,8 @@ class PlotWidget(qt.QMainWindow):
def addItem(self, xdata, ydata, legend=None, info=None,
replace=False,
shape="polygon", color='black', fill=True,
- overlay=False, z=None, **kw):
+ overlay=False, z=None, linestyle="-", linewidth=1.0,
+ linebgcolor=None):
"""Add an item (i.e. a shape) to the plot.
Items are uniquely identified by their legend.
@@ -1114,13 +1206,23 @@ class PlotWidget(qt.QMainWindow):
This allows for rendering optimization if this
item is changed often.
:param int z: Layer on which to draw the item (default: 2)
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
:returns: The key string identify this item
"""
# expected to receive the same parameters as the signal
- if kw:
- _logger.warning('addItem deprecated parameters: %s', str(kw))
-
legend = "Unnamed Item 1.1" if legend is None else str(legend)
z = int(z) if z is not None else 2
@@ -1138,6 +1240,9 @@ class PlotWidget(qt.QMainWindow):
item.setOverlay(overlay)
item.setZValue(z)
item.setPoints(numpy.array((xdata, ydata)).T)
+ item.setLineStyle(linestyle)
+ item.setLineWidth(linewidth)
+ item.setLineBgColor(linebgcolor)
self._add(item)
@@ -1148,8 +1253,7 @@ class PlotWidget(qt.QMainWindow):
color=None,
selectable=False,
draggable=False,
- constraint=None,
- **kw):
+ constraint=None):
"""Add a vertical line marker to the plot.
Markers are uniquely identified by their legend.
@@ -1177,10 +1281,6 @@ class PlotWidget(qt.QMainWindow):
and that returns the filtered coordinates.
:return: The key string identify this marker
"""
- if kw:
- _logger.warning(
- 'addXMarker deprecated extra parameters: %s', str(kw))
-
return self._addMarker(x=x, y=None, legend=legend,
text=text, color=color,
selectable=selectable, draggable=draggable,
@@ -1192,8 +1292,7 @@ class PlotWidget(qt.QMainWindow):
color=None,
selectable=False,
draggable=False,
- constraint=None,
- **kw):
+ constraint=None):
"""Add a horizontal line marker to the plot.
Markers are uniquely identified by their legend.
@@ -1221,10 +1320,6 @@ class PlotWidget(qt.QMainWindow):
and that returns the filtered coordinates.
:return: The key string identify this marker
"""
- if kw:
- _logger.warning(
- 'addYMarker deprecated extra parameters: %s', str(kw))
-
return self._addMarker(x=None, y=y, legend=legend,
text=text, color=color,
selectable=selectable, draggable=draggable,
@@ -1236,8 +1331,7 @@ class PlotWidget(qt.QMainWindow):
selectable=False,
draggable=False,
symbol='+',
- constraint=None,
- **kw):
+ constraint=None):
"""Add a point marker to the plot.
Markers are uniquely identified by their legend.
@@ -1277,10 +1371,6 @@ class PlotWidget(qt.QMainWindow):
and that returns the filtered coordinates.
:return: The key string identify this marker
"""
- if kw:
- _logger.warning(
- 'addMarker deprecated extra parameters: %s', str(kw))
-
if x is None:
xmin, xmax = self._xAxis.getLimits()
x = 0.5 * (xmax + xmin)
@@ -1368,7 +1458,7 @@ class PlotWidget(qt.QMainWindow):
curve = self._getItem('curve', legend)
return curve is not None and not curve.isVisible()
- def hideCurve(self, legend, flag=True, replot=None):
+ def hideCurve(self, legend, flag=True):
"""Show/Hide the curve associated to legend.
Even when hidden, the curve is kept in the list of curves.
@@ -1376,9 +1466,6 @@ class PlotWidget(qt.QMainWindow):
:param str legend: The legend associated to the curve to be hidden
:param bool flag: True (default) to hide the curve, False to show it
"""
- if replot is not None:
- _logger.warning('hideCurve deprecated replot parameter')
-
curve = self._getItem('curve', legend)
if curve is None:
_logger.warning('Curve not in plot: %s', legend)
@@ -1660,16 +1747,13 @@ class PlotWidget(qt.QMainWindow):
return self._getActiveItem(kind='curve', just_legend=just_legend)
- def setActiveCurve(self, legend, replot=None):
+ def setActiveCurve(self, legend):
"""Make the curve associated to legend the active curve.
:param legend: The legend associated to the curve
or None to have no active curve.
:type legend: str or None
"""
- if replot is not None:
- _logger.warning('setActiveCurve deprecated replot parameter')
-
if not self.isActiveCurveHandling():
return
if legend is None and self.getActiveCurveSelectionMode() == "legacy":
@@ -1723,15 +1807,12 @@ class PlotWidget(qt.QMainWindow):
"""
return self._getActiveItem(kind='image', just_legend=just_legend)
- def setActiveImage(self, legend, replot=None):
+ def setActiveImage(self, legend):
"""Make the image associated to legend the active image.
:param str legend: The legend associated to the image
or None to have no active image.
"""
- if replot is not None:
- _logger.warning('setActiveImage deprecated replot parameter')
-
return self._setActiveItem(kind='image', legend=legend)
def _getActiveItem(self, kind, just_legend=False):
@@ -2028,14 +2109,12 @@ class PlotWidget(qt.QMainWindow):
"""
return self._backend.getGraphXLimits()
- def setGraphXLimits(self, xmin, xmax, replot=None):
+ def setGraphXLimits(self, xmin, xmax):
"""Set the graph X (bottom) limits.
:param float xmin: minimum bottom axis value
:param float xmax: maximum bottom axis value
"""
- if replot is not None:
- _logger.warning('setGraphXLimits deprecated replot parameter')
self._xAxis.setLimits(xmin, xmax)
def getGraphYLimits(self, axis='left'):
@@ -2049,7 +2128,7 @@ class PlotWidget(qt.QMainWindow):
yAxis = self._yAxis if axis == 'left' else self._yRightAxis
return yAxis.getLimits()
- def setGraphYLimits(self, ymin, ymax, axis='left', replot=None):
+ def setGraphYLimits(self, ymin, ymax, axis='left'):
"""Set the graph Y limits.
:param float ymin: minimum bottom axis value
@@ -2057,8 +2136,6 @@ class PlotWidget(qt.QMainWindow):
:param str axis: The axis for which to get the limits:
Either 'left' or 'right'
"""
- if replot is not None:
- _logger.warning('setGraphYLimits deprecated replot parameter')
assert axis in ('left', 'right')
yAxis = self._yAxis if axis == 'left' else self._yRightAxis
return yAxis.setLimits(ymin, ymax)
@@ -2192,36 +2269,6 @@ class PlotWidget(qt.QMainWindow):
def _isAxesDisplayed(self):
return self._backend.isAxesDisplayed()
- @property
- @deprecated(since_version='0.6')
- def sigSetYAxisInverted(self):
- """Signal emitted when Y axis orientation has changed"""
- return self._yAxis.sigInvertedChanged
-
- @property
- @deprecated(since_version='0.6')
- def sigSetXAxisLogarithmic(self):
- """Signal emitted when X axis scale has changed"""
- return self._xAxis._sigLogarithmicChanged
-
- @property
- @deprecated(since_version='0.6')
- def sigSetYAxisLogarithmic(self):
- """Signal emitted when Y axis scale has changed"""
- return self._yAxis._sigLogarithmicChanged
-
- @property
- @deprecated(since_version='0.6')
- def sigSetXAxisAutoScale(self):
- """Signal emitted when X axis autoscale has changed"""
- return self._xAxis.sigAutoScaleChanged
-
- @property
- @deprecated(since_version='0.6')
- def sigSetYAxisAutoScale(self):
- """Signal emitted when Y axis autoscale has changed"""
- return self._yAxis.sigAutoScaleChanged
-
def setYAxisInverted(self, flag=True):
"""Set the Y axis orientation.
@@ -2290,6 +2337,8 @@ class PlotWidget(qt.QMainWindow):
:param bool flag: True to respect data aspect ratio
"""
flag = bool(flag)
+ if flag == self.isKeepDataAspectRatio():
+ return
self._backend.setKeepDataAspectRatio(flag=flag)
self._setDirtyPlot()
self._forceResetZoom()
@@ -2323,8 +2372,8 @@ class PlotWidget(qt.QMainWindow):
# Defaults
def isDefaultPlotPoints(self):
- """Return True if default Curve symbol is 'o', False for no symbol."""
- return self._defaultPlotPoints == 'o'
+ """Return True if the default Curve symbol is set and False if not."""
+ return self._defaultPlotPoints == silx.config.DEFAULT_PLOT_SYMBOL
def setDefaultPlotPoints(self, flag):
"""Set the default symbol of all curves.
@@ -2334,7 +2383,7 @@ class PlotWidget(qt.QMainWindow):
:param bool flag: True to use 'o' as the default curve symbol,
False to use no symbol.
"""
- self._defaultPlotPoints = 'o' if flag else ''
+ self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else ''
# Reset symbol of all curves
curves = self.getAllCurves(just_legend=False, withhidden=True)
@@ -2510,7 +2559,7 @@ class PlotWidget(qt.QMainWindow):
elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left':
self.setActiveCurve(None)
- def saveGraph(self, filename, fileFormat=None, dpi=None, **kw):
+ def saveGraph(self, filename, fileFormat=None, dpi=None):
"""Save a snapshot of the plot.
Supported file formats depends on the backend in use.
@@ -2523,9 +2572,6 @@ class PlotWidget(qt.QMainWindow):
:param str fileFormat: String specifying the format
:return: False if cannot save the plot, True otherwise
"""
- if kw:
- _logger.warning('Extra parameters ignored: %s', str(kw))
-
if fileFormat is None:
if not hasattr(filename, 'lower'):
_logger.warning(
@@ -3080,149 +3126,3 @@ class PlotWidget(qt.QMainWindow):
# Only call base class implementation when key is not handled.
# See QWidget.keyPressEvent for details.
super(PlotWidget, self).keyPressEvent(event)
-
- # Deprecated #
-
- def isDrawModeEnabled(self):
- """Deprecated, use :meth:`getInteractiveMode` instead.
-
- Return True if the current interactive state is drawing."""
- _logger.warning(
- 'isDrawModeEnabled deprecated, use getInteractiveMode instead')
- return self.getInteractiveMode()['mode'] == 'draw'
-
- def setDrawModeEnabled(self, flag=True, shape='polygon', label=None,
- color=None, **kwargs):
- """Deprecated, use :meth:`setInteractiveMode` instead.
-
- Set the drawing mode if flag is True and its parameters.
-
- If flag is False, only item selection is enabled.
-
- Warning: Zoom and drawing are not compatible and cannot be enabled
- simultaneously.
-
- :param bool flag: True to enable drawing and disable zoom and select.
- :param str shape: Type of item to be drawn in:
- hline, vline, rectangle, polygon (default)
- :param str label: Associated text for identifying draw signals
- :param color: The color to use to draw the selection area
- :type color: string ("#RRGGBB") or 4 column unsigned byte array or
- one of the predefined color names defined in colors.py
- """
- _logger.warning(
- 'setDrawModeEnabled deprecated, use setInteractiveMode instead')
-
- if kwargs:
- _logger.warning('setDrawModeEnabled ignores additional parameters')
-
- if color is None:
- color = 'black'
-
- if flag:
- self.setInteractiveMode('draw', shape=shape,
- label=label, color=color)
- elif self.getInteractiveMode()['mode'] == 'draw':
- self.setInteractiveMode('select')
-
- def getDrawMode(self):
- """Deprecated, use :meth:`getInteractiveMode` instead.
-
- Return the draw mode parameters as a dict of None.
-
- It returns None if the interactive mode is not a drawing mode,
- otherwise, it returns a dict containing the drawing mode parameters
- as provided to :meth:`setDrawModeEnabled`.
- """
- _logger.warning(
- 'getDrawMode deprecated, use getInteractiveMode instead')
- mode = self.getInteractiveMode()
- return mode if mode['mode'] == 'draw' else None
-
- def isZoomModeEnabled(self):
- """Deprecated, use :meth:`getInteractiveMode` instead.
-
- Return True if the current interactive state is zooming."""
- _logger.warning(
- 'isZoomModeEnabled deprecated, use getInteractiveMode instead')
- return self.getInteractiveMode()['mode'] == 'zoom'
-
- def setZoomModeEnabled(self, flag=True, color=None):
- """Deprecated, use :meth:`setInteractiveMode` instead.
-
- Set the zoom mode if flag is True, else item selection is enabled.
-
- Warning: Zoom and drawing are not compatible and cannot be enabled
- simultaneously
-
- :param bool flag: If True, enable zoom and select mode.
- :param color: The color to use to draw the selection area.
- (Default: 'black')
- :param color: The color to use to draw the selection area
- :type color: string ("#RRGGBB") or 4 column unsigned byte array or
- one of the predefined color names defined in colors.py
- """
- _logger.warning(
- 'setZoomModeEnabled deprecated, use setInteractiveMode instead')
- if color is None:
- color = 'black'
-
- if flag:
- self.setInteractiveMode('zoom', color=color)
- elif self.getInteractiveMode()['mode'] == 'zoom':
- self.setInteractiveMode('select')
-
- def insertMarker(self, *args, **kwargs):
- """Deprecated, use :meth:`addMarker` instead."""
- _logger.warning(
- 'insertMarker deprecated, use addMarker instead.')
- return self.addMarker(*args, **kwargs)
-
- def insertXMarker(self, *args, **kwargs):
- """Deprecated, use :meth:`addXMarker` instead."""
- _logger.warning(
- 'insertXMarker deprecated, use addXMarker instead.')
- return self.addXMarker(*args, **kwargs)
-
- def insertYMarker(self, *args, **kwargs):
- """Deprecated, use :meth:`addYMarker` instead."""
- _logger.warning(
- 'insertYMarker deprecated, use addYMarker instead.')
- return self.addYMarker(*args, **kwargs)
-
- def isActiveCurveHandlingEnabled(self):
- """Deprecated, use :meth:`isActiveCurveHandling` instead."""
- _logger.warning(
- 'isActiveCurveHandlingEnabled deprecated, '
- 'use isActiveCurveHandling instead.')
- return self.isActiveCurveHandling()
-
- def enableActiveCurveHandling(self, *args, **kwargs):
- """Deprecated, use :meth:`setActiveCurveHandling` instead."""
- _logger.warning(
- 'enableActiveCurveHandling deprecated, '
- 'use setActiveCurveHandling instead.')
- return self.setActiveCurveHandling(*args, **kwargs)
-
- def invertYAxis(self, *args, **kwargs):
- """Deprecated, use :meth:`Axis.setInverted` instead."""
- _logger.warning('invertYAxis deprecated, '
- 'use getYAxis().setInverted instead.')
- return self.getYAxis().setInverted(*args, **kwargs)
-
- def showGrid(self, flag=True):
- """Deprecated, use :meth:`setGraphGrid` instead."""
- _logger.warning("showGrid deprecated, use setGraphGrid instead")
- if flag in (0, False):
- flag = None
- elif flag in (1, True):
- flag = 'major'
- else:
- flag = 'both'
- return self.setGraphGrid(flag)
-
- def keepDataAspectRatio(self, *args, **kwargs):
- """Deprecated, use :meth:`setKeepDataAspectRatio`."""
- _logger.warning('keepDataAspectRatio deprecated,'
- 'use setKeepDataAspectRatio instead')
- return self.setKeepDataAspectRatio(*args, **kwargs)
diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py
index 23ea399..b44a512 100644
--- a/silx/gui/plot/PlotWindow.py
+++ b/silx/gui/plot/PlotWindow.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,7 +29,7 @@ The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`.
__authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
-__date__ = "12/10/2018"
+__date__ = "21/12/2018"
import collections
import logging
@@ -217,10 +217,8 @@ class PlotWindow(PlotWidget):
# Make colorbar background white
self._colorbar.setAutoFillBackground(True)
- palette = self._colorbar.palette()
- palette.setColor(qt.QPalette.Background, qt.Qt.white)
- palette.setColor(qt.QPalette.Window, qt.Qt.white)
- self._colorbar.setPalette(palette)
+ self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground)
+ self._updateColorBarBackground()
gridLayout = qt.QGridLayout()
gridLayout.setSpacing(0)
@@ -294,6 +292,43 @@ class PlotWindow(PlotWidget):
for action in toolbar.actions():
self.addAction(action)
+ def setBackgroundColor(self, color):
+ super(PlotWindow, self).setBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ setBackgroundColor.__doc__ = PlotWidget.setBackgroundColor.__doc__
+
+ def setDataBackgroundColor(self, color):
+ super(PlotWindow, self).setDataBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ setDataBackgroundColor.__doc__ = PlotWidget.setDataBackgroundColor.__doc__
+
+ def setForegroundColor(self, color):
+ super(PlotWindow, self).setForegroundColor(color)
+ self._updateColorBarBackground()
+
+ setForegroundColor.__doc__ = PlotWidget.setForegroundColor.__doc__
+
+ def _updateColorBarBackground(self):
+ """Update the colorbar background according to the state of the plot"""
+ if self._isAxesDisplayed():
+ color = self.getBackgroundColor()
+ else:
+ color = self.getDataBackgroundColor()
+ if not color.isValid():
+ # If no color defined, use the background one
+ color = self.getBackgroundColor()
+
+ foreground = self.getForegroundColor()
+
+ palette = self._colorbar.palette()
+ palette.setColor(qt.QPalette.Background, color)
+ palette.setColor(qt.QPalette.Window, color)
+ palette.setColor(qt.QPalette.WindowText, foreground)
+ palette.setColor(qt.QPalette.Text, foreground)
+ self._colorbar.setPalette(palette)
+
def getInteractiveModeToolBar(self):
"""Returns QToolBar controlling interactive mode.
@@ -457,10 +492,6 @@ class PlotWindow(PlotWidget):
return self._colorbar
# getters for dock widgets
- @property
- @deprecated(replacement="getLegendsDockWidget()", since_version="0.4.0")
- def legendsDockWidget(self):
- return self.getLegendsDockWidget()
def getLegendsDockWidget(self):
"""DockWidget with Legend panel"""
@@ -470,11 +501,6 @@ class PlotWindow(PlotWidget):
self.addTabbedDockWidget(self._legendsDockWidget)
return self._legendsDockWidget
- @property
- @deprecated(replacement="getCurvesRoiWidget()", since_version="0.4.0")
- def curvesROIDockWidget(self):
- return self.getCurvesRoiDockWidget()
-
def getCurvesRoiDockWidget(self):
# Undocumented for a "soft deprecation" in version 0.7.0
# (still used internally for lazy loading)
@@ -496,11 +522,6 @@ class PlotWindow(PlotWidget):
"""
return self.getCurvesRoiDockWidget().roiWidget
- @property
- @deprecated(replacement="getMaskToolsDockWidget()", since_version="0.4.0")
- def maskToolsDockWidget(self):
- return self.getMaskToolsDockWidget()
-
def getMaskToolsDockWidget(self):
"""DockWidget with image mask panel (lazy-loaded)."""
if self._maskToolsDockWidget is None:
@@ -539,11 +560,6 @@ class PlotWindow(PlotWidget):
def panModeAction(self):
return self.getInteractiveModeToolBar().getPanModeAction()
- @property
- @deprecated(replacement="getConsoleAction()", since_version="0.4.0")
- def consoleAction(self):
- return self.getConsoleAction()
-
def getConsoleAction(self):
"""QAction handling the IPython console activation.
@@ -563,11 +579,6 @@ class PlotWindow(PlotWidget):
self._consoleAction.setEnabled(False)
return self._consoleAction
- @property
- @deprecated(replacement="getCrosshairAction()", since_version="0.4.0")
- def crosshairAction(self):
- return self.getCrosshairAction()
-
def getCrosshairAction(self):
"""Action toggling crosshair cursor mode.
@@ -577,11 +588,6 @@ class PlotWindow(PlotWidget):
self._crosshairAction = actions.control.CrosshairAction(self, color='red')
return self._crosshairAction
- @property
- @deprecated(replacement="getMaskAction()", since_version="0.4.0")
- def maskAction(self):
- return self.getMaskAction()
-
def getMaskAction(self):
"""QAction toggling image mask dock widget
@@ -589,12 +595,6 @@ class PlotWindow(PlotWidget):
"""
return self.getMaskToolsDockWidget().toggleViewAction()
- @property
- @deprecated(replacement="getPanWithArrowKeysAction()",
- since_version="0.4.0")
- def panWithArrowKeysAction(self):
- return self.getPanWithArrowKeysAction()
-
def getPanWithArrowKeysAction(self):
"""Action toggling pan with arrow keys.
@@ -604,11 +604,6 @@ class PlotWindow(PlotWidget):
self._panWithArrowKeysAction = actions.control.PanWithArrowKeysAction(self)
return self._panWithArrowKeysAction
- @property
- @deprecated(replacement="getRoiAction()", since_version="0.4.0")
- def roiAction(self):
- return self.getRoiAction()
-
def getStatsAction(self):
if self._statsAction is None:
self._statsAction = qt.QAction('Curves stats', self)
diff --git a/silx/gui/plot/PrintPreviewToolButton.py b/silx/gui/plot/PrintPreviewToolButton.py
index b48505d..d857c18 100644
--- a/silx/gui/plot/PrintPreviewToolButton.py
+++ b/silx/gui/plot/PrintPreviewToolButton.py
@@ -111,10 +111,11 @@ from .. import icons
from . import PlotWidget
from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog
from ..widgets.PrintGeometryDialog import PrintGeometryDialog
+from silx.utils.deprecation import deprecated
__authors__ = ["P. Knobel"]
__license__ = "MIT"
-__date__ = "18/07/2017"
+__date__ = "20/12/2018"
_logger = logging.getLogger(__name__)
# _logger.setLevel(logging.DEBUG)
@@ -132,19 +133,19 @@ class PrintPreviewToolButton(qt.QToolButton):
if not isinstance(plot, PlotWidget):
raise TypeError("plot parameter must be a PlotWidget")
- self.plot = plot
+ self._plot = plot
self.setIcon(icons.getQIcon('document-print'))
printGeomAction = qt.QAction("Print geometry", self)
printGeomAction.setToolTip("Define a print geometry prior to sending "
"the plot to the print preview dialog")
- printGeomAction.setIcon(icons.getQIcon('shape-rectangle')) # fixme: icon not displayed in menu
+ printGeomAction.setIcon(icons.getQIcon('shape-rectangle'))
printGeomAction.triggered.connect(self._setPrintConfiguration)
printPreviewAction = qt.QAction("Print preview", self)
printPreviewAction.setToolTip("Send plot to the print preview dialog")
- printPreviewAction.setIcon(icons.getQIcon('document-print')) # fixme: icon not displayed
+ printPreviewAction.setIcon(icons.getQIcon('document-print'))
printPreviewAction.triggered.connect(self._plotToPrintPreview)
menu = qt.QMenu(self)
@@ -172,24 +173,64 @@ class PrintPreviewToolButton(qt.QToolButton):
self._printPreviewDialog = PrintPreviewDialog(self.parent())
return self._printPreviewDialog
+ def getTitle(self):
+ """Implement this method to fetch the title in the plot.
+
+ :return: Title to be printed above the plot, or None (no title added)
+ :rtype: str or None
+ """
+ return None
+
+ def getCommentAndPosition(self):
+ """Implement this method to fetch the legend to be printed below the
+ figure and its position.
+
+ :return: Legend to be printed below the figure and its position:
+ "CENTER", "LEFT" or "RIGHT"
+ :rtype: (str, str) or (None, None)
+ """
+ return None, None
+
+ @property
+ @deprecated(since_version="0.10",
+ replacement="getPlot()")
+ def plot(self):
+ return self._plot
+
+ def getPlot(self):
+ """Return the :class:`.PlotWidget` associated with this tool button.
+
+ :rtype: :class:`.PlotWidget`
+ """
+ return self._plot
+
def _plotToPrintPreview(self):
"""Grab the plot widget and send it to the print preview dialog.
Make sure the print preview dialog is shown and raised."""
if not self.printPreviewDialog.ensurePrinterIsSet():
return
+ comment, commentPosition = self.getCommentAndPosition()
+
if qt.HAS_SVG:
svgRenderer, viewBox = self._getSvgRendererAndViewbox()
self.printPreviewDialog.addSvgItem(svgRenderer,
- viewBox=viewBox)
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition,
+ viewBox=viewBox,
+ keepRatio=self._printGeometry["keepAspectRatio"])
else:
_logger.warning("Missing QtSvg library, using a raster image")
if qt.BINDING in ["PyQt4", "PySide"]:
- pixmap = qt.QPixmap.grabWidget(self.plot.centralWidget())
+ pixmap = qt.QPixmap.grabWidget(self._plot.centralWidget())
else:
# PyQt5 and hopefully PyQt6+
- pixmap = self.plot.centralWidget().grab()
- self.printPreviewDialog.addPixmap(pixmap)
+ pixmap = self._plot.centralWidget().grab()
+ self.printPreviewDialog.addPixmap(pixmap,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition)
self.printPreviewDialog.show()
self.printPreviewDialog.raise_()
@@ -201,7 +242,7 @@ class PrintPreviewToolButton(qt.QToolButton):
and to the geometry configuration (width, height, ratio) specified
by the user."""
imgData = StringIO()
- assert self.plot.saveGraph(imgData, fileFormat="svg"), \
+ assert self._plot.saveGraph(imgData, fileFormat="svg"), \
"Unable to save graph"
imgData.flush()
imgData.seek(0)
@@ -310,7 +351,7 @@ class PrintPreviewToolButton(qt.QToolButton):
self._printGeometry = self._printConfigurationDialog.getPrintGeometry()
def _getPlotAspectRatio(self):
- widget = self.plot.centralWidget()
+ widget = self._plot.centralWidget()
graphWidth = float(widget.width())
graphHeight = float(widget.height())
return graphHeight / graphWidth
diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py
index 182cf60..46e4523 100644
--- a/silx/gui/plot/Profile.py
+++ b/silx/gui/plot/Profile.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -180,7 +180,8 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
:type scale: 2-tuple of float
:param int lineWidth: width of the profile line
:param str method: method to compute the profile. Can be 'mean' or 'sum'
- :return: `profile, area, profileName, xLabel`, where:
+ :return: `coords, profile, area, profileName, xLabel`, where:
+ - coords is the X coordinate to use to display the profile
- profile is a 2D array of the profiles of the stack of images.
For a single image, the profile is a curve, so this parameter
has a shape *(1, len(curve))*
@@ -188,10 +189,9 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
the effective ROI area corners in plot coords.
- profileName is a string describing the ROI, meant to be used as
title of the profile plot
- - xLabel is a string describing the meaning of the X axis on the
- profile plot ("rows", "columns", "distance")
+ - xLabel the label for X in the profile window
- :rtype: tuple(ndarray, (ndarray, ndarray), str, str)
+ :rtype: tuple(ndarray,ndarray,(ndarray,ndarray),str)
"""
if currentData is None or roiInfo is None or lineWidth is None:
raise ValueError("createProfile called with invalide arguments")
@@ -212,12 +212,15 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
axis=0,
method=method)
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[0] + origin[0]
+
yMin, yMax = min(area[1]), max(area[1]) - 1
if roiWidth <= 1:
profileName = 'Y = %g' % yMin
else:
profileName = 'Y = [%g, %g]' % (yMin, yMax)
- xLabel = 'Columns'
+ xLabel = 'X'
elif lineProjectionMode == 'Y': # Vertical profile on the whole image
profile, area = _alignedFullProfile(currentData3D,
@@ -226,12 +229,15 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
axis=1,
method=method)
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[1] + origin[1]
+
xMin, xMax = min(area[0]), max(area[0]) - 1
if roiWidth <= 1:
profileName = 'X = %g' % xMin
else:
profileName = 'X = [%g, %g]' % (xMin, xMax)
- xLabel = 'Rows'
+ xLabel = 'Y'
else: # Free line profile
@@ -306,35 +312,52 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
dCol = (endPt[1] - startPt[1]) / length
# Extend ROI with half a pixel on each end
- startPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol
- endPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol
+ roiStartPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol
+ roiEndPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol
# Rotate deltas by 90 degrees to apply line width
dRow, dCol = dCol, -dRow
area = (
- numpy.array((startPt[1] - 0.5 * roiWidth * dCol,
- startPt[1] + 0.5 * roiWidth * dCol,
- endPt[1] + 0.5 * roiWidth * dCol,
- endPt[1] - 0.5 * roiWidth * dCol),
+ numpy.array((roiStartPt[1] - 0.5 * roiWidth * dCol,
+ roiStartPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] - 0.5 * roiWidth * dCol),
dtype=numpy.float32) * scale[0] + origin[0],
- numpy.array((startPt[0] - 0.5 * roiWidth * dRow,
- startPt[0] + 0.5 * roiWidth * dRow,
- endPt[0] + 0.5 * roiWidth * dRow,
- endPt[0] - 0.5 * roiWidth * dRow),
+ numpy.array((roiStartPt[0] - 0.5 * roiWidth * dRow,
+ roiStartPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] - 0.5 * roiWidth * dRow),
dtype=numpy.float32) * scale[1] + origin[1])
- y0, x0 = startPt
- y1, x1 = endPt
- if x1 == x0 or y1 == y0:
- profileName = 'From (%g, %g) to (%g, %g)' % (x0, y0, x1, y1)
+ # Convert start and end points back to plot coords
+ y0 = startPt[0] * scale[1] + origin[1]
+ x0 = startPt[1] * scale[0] + origin[0]
+ y1 = endPt[0] * scale[1] + origin[1]
+ x1 = endPt[1] * scale[0] + origin[0]
+
+ if startPt[1] == endPt[1]:
+ profileName = 'X = %g; Y = [%g, %g]' % (x0, y0, y1)
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[1] + y0
+ xLabel = 'Y'
+
+ elif startPt[0] == endPt[0]:
+ profileName = 'Y = %g; X = [%g, %g]' % (y0, x0, x1)
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[0] + x0
+ xLabel = 'X'
+
else:
m = (y1 - y0) / (x1 - x0)
b = y0 - m * x0
profileName = 'y = %g * x %+g ; width=%d' % (m, b, roiWidth)
- xLabel = 'Distance'
+ coords = numpy.linspace(x0, x1, len(profile[0]),
+ endpoint=True,
+ dtype=numpy.float32)
+ xLabel = 'X'
- return profile, area, profileName, xLabel
+ return coords, profile, area, profileName, xLabel
# ProfileToolBar ##############################################################
@@ -458,7 +481,7 @@ class ProfileToolBar(qt.QToolBar):
self.addWidget(self.lineWidthSpinBox)
self.methodsButton = ProfileOptionToolButton(parent=self, plot=self)
- self.addWidget(self.methodsButton)
+ self.__profileOptionToolAction = self.addWidget(self.methodsButton)
# TODO: add connection with the signal
self.methodsButton.sigMethodChanged.connect(self.setProfileMethod)
@@ -650,7 +673,7 @@ class ProfileToolBar(qt.QToolBar):
if self._roiInfo is None:
return
- profile, area, profileName, xLabel = createProfile(
+ coords, profile, area, profileName, xLabel = createProfile(
roiInfo=self._roiInfo,
currentData=currentData,
origin=origin,
@@ -658,28 +681,25 @@ class ProfileToolBar(qt.QToolBar):
lineWidth=self.lineWidthSpinBox.value(),
method=method)
- self.getProfilePlot().setGraphTitle(profileName)
+ profilePlot = self.getProfilePlot()
+
+ profilePlot.setGraphTitle(profileName)
+ profilePlot.getXAxis().setLabel(xLabel)
dataIs3D = len(currentData.shape) > 2
if dataIs3D:
- self.getProfilePlot().addImage(profile,
- legend=profileName,
- xlabel=xLabel,
- ylabel="Frame index (depth)",
- colormap=colormap)
+ profileScale = (coords[-1] - coords[0]) / profile.shape[1], 1
+ profilePlot.addImage(profile,
+ legend=profileName,
+ colormap=colormap,
+ origin=(coords[0], 0),
+ scale=profileScale)
+ profilePlot.getYAxis().setLabel("Frame index (depth)")
else:
- coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
- # Scale horizontal and vertical profile coordinates
- if self._roiInfo[2] == 'X':
- coords = coords * scale[0] + origin[0]
- elif self._roiInfo[2] == 'Y':
- coords = coords * scale[1] + origin[1]
-
- self.getProfilePlot().addCurve(coords,
- profile[0],
- legend=profileName,
- xlabel=xLabel,
- color=self.overlayColor)
+ profilePlot.addCurve(coords,
+ profile[0],
+ legend=profileName,
+ color=self.overlayColor)
self.plot.addItem(area[0], area[1],
legend=self._POLYGON_LEGEND,
@@ -732,6 +752,9 @@ class ProfileToolBar(qt.QToolBar):
def getProfileMethod(self):
return self._method
+ def getProfileOptionToolAction(self):
+ return self.__profileOptionToolAction
+
class Profile3DToolBar(ProfileToolBar):
def __init__(self, parent=None, stackview=None,
diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py
index de645be..0c6797f 100644
--- a/silx/gui/plot/ScatterMaskToolsWidget.py
+++ b/silx/gui/plot/ScatterMaskToolsWidget.py
@@ -35,7 +35,7 @@ from __future__ import division
__authors__ = ["P. Knobel"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "15/02/2019"
import math
@@ -152,6 +152,22 @@ class ScatterMask(BaseMask):
stencil = (y - cy)**2 + (x - cx)**2 < radius**2
self.updateStencil(level, stencil, mask)
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ def is_inside(px, py):
+ return (px - ccol)**2 / radius_c**2 + (py - crow)**2 / radius_r**2 <= 1.0
+ x, y = self._getXY()
+ indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])]
+ self.updatePoints(level, indices_inside, mask)
+
def updateLine(self, level, y0, x0, y1, x1, width, mask=True):
"""Mask/Unmask points inside a rectangle defined by a line (two
end points) and a width.
@@ -490,26 +506,35 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
level = self.levelSpinBox.value()
- if (self._drawingMode == 'rectangle' and
- event['event'] == 'drawingFinished'):
- doMask = self._isMasking()
+ if self._drawingMode == 'rectangle':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+
+ self._mask.updateRectangle(
+ level,
+ y=event['y'],
+ x=event['x'],
+ height=abs(event['height']),
+ width=abs(event['width']),
+ mask=doMask)
+ self._mask.commit()
- self._mask.updateRectangle(
- level,
- y=event['y'],
- x=event['x'],
- height=abs(event['height']),
- width=abs(event['width']),
- mask=doMask)
- self._mask.commit()
+ elif self._drawingMode == 'ellipse':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ center = event['points'][0]
+ size = event['points'][1]
+ self._mask.updateEllipse(level, center[1], center[0],
+ size[1], size[0], doMask)
+ self._mask.commit()
- elif (self._drawingMode == 'polygon' and
- event['event'] == 'drawingFinished'):
- doMask = self._isMasking()
- vertices = event['points']
- vertices = vertices[:, (1, 0)] # (y, x)
- self._mask.updatePolygon(level, vertices, doMask)
- self._mask.commit()
+ elif self._drawingMode == 'polygon':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ vertices = event['points']
+ vertices = vertices[:, (1, 0)] # (y, x)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
elif self._drawingMode == 'pencil':
doMask = self._isMasking()
@@ -536,6 +561,8 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
self._lastPencilPos = None
else:
self._lastPencilPos = y, x
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
def _loadRangeFromColormapTriggered(self):
"""Set range from active scatter colormap range"""
diff --git a/silx/gui/plot/ScatterView.py b/silx/gui/plot/ScatterView.py
index ae79cf9..5fc66ef 100644
--- a/silx/gui/plot/ScatterView.py
+++ b/silx/gui/plot/ScatterView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -353,3 +353,13 @@ class ScatterView(qt.QMainWindow):
return self.getPlotWidget().resetZoom(*args, **kwargs)
resetZoom.__doc__ = PlotWidget.resetZoom.__doc__
+
+ def getSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs)
+
+ getSelectionMask.__doc__ = ScatterMaskToolsWidget.getSelectionMask.__doc__
+
+ def setSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs)
+
+ setSelectionMask.__doc__ = ScatterMaskToolsWidget.setSelectionMask.__doc__
diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py
index 72b6cd4..2a3d7e8 100644
--- a/silx/gui/plot/StackView.py
+++ b/silx/gui/plot/StackView.py
@@ -89,14 +89,8 @@ from silx.utils.array_like import DatasetView, ListOfImages
from silx.math import calibration
from silx.utils.deprecation import deprecated_warning
-try:
- import h5py
-except ImportError:
- def is_dataset(obj):
- return False
- h5py = None
-else:
- from silx.io.utils import is_dataset
+import h5py
+from silx.io.utils import is_dataset
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py
index bb66613..4ba4fab 100644
--- a/silx/gui/plot/StatsWidget.py
+++ b/silx/gui/plot/StatsWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,552 +31,1266 @@ __license__ = "MIT"
__date__ = "24/07/2018"
-import functools
+from collections import OrderedDict
+from contextlib import contextmanager
import logging
+import weakref
+
import numpy
-from collections import OrderedDict
-import silx.utils.weakref
from silx.gui import qt
from silx.gui import icons
-from silx.gui.plot.items.curve import Curve as CurveItem
-from silx.gui.plot.items.histogram import Histogram as HistogramItem
-from silx.gui.plot.items.image import ImageBase as ImageItem
-from silx.gui.plot.items.scatter import Scatter as ScatterItem
from silx.gui.plot import stats as statsmdl
from silx.gui.widgets.TableWidget import TableWidget
from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter
+from silx.gui.plot.items.core import ItemChangedType
+from silx.gui.widgets.FlowLayout import FlowLayout
+from . import PlotWidget
+from . import items as plotitems
-logger = logging.getLogger(__name__)
+_logger = logging.getLogger(__name__)
-class StatsWidget(qt.QWidget):
+
+# Helper class to handle specific calls to PlotWidget and SceneWidget
+
+class _Wrapper(qt.QObject):
+ """Base class for connection with PlotWidget and SceneWidget.
+
+ This class is used when no PlotWidget or SceneWidget is connected.
+
+ :param plot: The plot to be used
"""
- Widget displaying a set of :class:`Stat` to be displayed on a
- :class:`StatsTable` and to be apply on items contained in the :class:`Plot`
- Also contains options to:
- * compute statistics on all the data or on visible data only
- * show statistics of all items or only the active one
+ sigItemAdded = qt.Signal(object)
+ """Signal emitted when a new item is added.
- :param parent: Qt parent
- :param plot: the plot containing items on which we want statistics.
+ It provides the added item.
"""
- sigVisibilityChanged = qt.Signal(bool)
+ sigItemRemoved = qt.Signal(object)
+ """Signal emitted when an item is (about to be) removed.
- NUMBER_FORMAT = '{0:.3f}'
+ It provides the removed item.
+ """
- class OptionsWidget(qt.QToolBar):
-
- def __init__(self, parent=None):
- qt.QToolBar.__init__(self, parent)
- self.setIconSize(qt.QSize(16, 16))
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-active-items"))
- action.setText("Active items only")
- action.setToolTip("Display stats for active items only.")
- action.setCheckable(True)
- action.setChecked(True)
- self.__displayActiveItems = action
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-whole-items"))
- action.setText("All items")
- action.setToolTip("Display stats for all available items.")
- action.setCheckable(True)
- self.__displayWholeItems = action
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-visible-data"))
- action.setText("Use the visible data range")
- action.setToolTip("Use the visible data range.<br/>"
- "If activated the data is filtered to only use"
- "visible data of the plot."
- "The filtering is a data sub-sampling."
- "No interpolation is made to fit data to"
- "boundaries.")
- action.setCheckable(True)
- self.__useVisibleData = action
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-whole-data"))
- action.setText("Use the full data range")
- action.setToolTip("Use the full data range.")
- action.setCheckable(True)
- action.setChecked(True)
- self.__useWholeData = action
-
- self.addAction(self.__displayWholeItems)
- self.addAction(self.__displayActiveItems)
- self.addSeparator()
- self.addAction(self.__useVisibleData)
- self.addAction(self.__useWholeData)
-
- self.itemSelection = qt.QActionGroup(self)
- self.itemSelection.setExclusive(True)
- self.itemSelection.addAction(self.__displayActiveItems)
- self.itemSelection.addAction(self.__displayWholeItems)
-
- self.dataRangeSelection = qt.QActionGroup(self)
- self.dataRangeSelection.setExclusive(True)
- self.dataRangeSelection.addAction(self.__useWholeData)
- self.dataRangeSelection.addAction(self.__useVisibleData)
-
- def isActiveItemMode(self):
- return self.itemSelection.checkedAction() is self.__displayActiveItems
-
- def isVisibleDataRangeMode(self):
- return self.dataRangeSelection.checkedAction() is self.__useVisibleData
+ sigCurrentChanged = qt.Signal(object)
+ """Signal emitted when the current item has changed.
- def __init__(self, parent=None, plot=None, stats=None):
- qt.QWidget.__init__(self, parent)
- self.setLayout(qt.QVBoxLayout())
- self.layout().setContentsMargins(0, 0, 0, 0)
- self._options = self.OptionsWidget(parent=self)
- self.layout().addWidget(self._options)
- self._statsTable = StatsTable(parent=self, plot=plot)
- self.setStats = self._statsTable.setStats
- self.setStats(stats)
+ It provides the current item.
+ """
- self.layout().addWidget(self._statsTable)
- self.setPlot = self._statsTable.setPlot
+ sigVisibleDataChanged = qt.Signal()
+ """Signal emitted when the visible data area has changed"""
- self._options.itemSelection.triggered.connect(
- self._optSelectionChanged)
- self._options.dataRangeSelection.triggered.connect(
- self._optDataRangeChanged)
- self._optSelectionChanged()
- self._optDataRangeChanged()
+ def __init__(self, plot=None):
+ super(_Wrapper, self).__init__(parent=None)
+ self._plotRef = None if plot is None else weakref.ref(plot)
- self.setDisplayOnlyActiveItem = self._statsTable.setDisplayOnlyActiveItem
- self.setStatsOnVisibleData = self._statsTable.setStatsOnVisibleData
+ def getPlot(self):
+ """Returns the plot attached to this widget"""
+ return None if self._plotRef is None else self._plotRef()
- def showEvent(self, event):
- self.sigVisibilityChanged.emit(True)
- qt.QWidget.showEvent(self, event)
+ def getItems(self):
+ """Returns the list of items in the plot
- def hideEvent(self, event):
- self.sigVisibilityChanged.emit(False)
- qt.QWidget.hideEvent(self, event)
+ :rtype: List[object]
+ """
+ return ()
- def _optSelectionChanged(self, action=None):
- self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode())
+ def getSelectedItems(self):
+ """Returns the list of selected items in the plot
- def _optDataRangeChanged(self, action=None):
- self._statsTable.setStatsOnVisibleData(self._options.isVisibleDataRangeMode())
+ :rtype: List[object]
+ """
+ return ()
+ def setCurrentItem(self, item):
+ """Set the current/active item in the plot
-class BasicStatsWidget(StatsWidget):
+ :param item: The plot item to set as active/current
+ """
+ pass
+
+ def getLabel(self, item):
+ """Returns the label of the given item.
+
+ :param item:
+ :rtype: str
+ """
+ return ''
+
+ def getKind(self, item):
+ """Returns the kind of an item or None if not supported
+
+ :param item:
+ :rtype: Union[str,None]
+ """
+ return None
+
+
+class _PlotWidgetWrapper(_Wrapper):
+ """Class handling PlotWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param PlotWidget plot:
"""
- Widget defining a simple set of :class:`Stat` to be displayed on a
- :class:`StatsWidget`.
- :param parent: Qt parent
- :param plot: the plot containing items on which we want statistics.
+ def __init__(self, plot):
+ assert isinstance(plot, PlotWidget)
+ super(_PlotWidgetWrapper, self).__init__(plot)
+ plot.sigItemAdded.connect(self.sigItemAdded.emit)
+ plot.sigItemAboutToBeRemoved.connect(self.sigItemRemoved.emit)
+ plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ plot.sigPlotSignal.connect(self._limitsChanged)
+
+ def _activeChanged(self, kind):
+ """Handle change of active curve/image/scatter"""
+ plot = self.getPlot()
+ if plot is not None:
+ item = plot._getActiveItem(kind=kind)
+ if item is None or self.getKind(item) is not None:
+ self.sigCurrentChanged.emit(item)
+
+ def _activeCurveChanged(self, previous, current):
+ self._activeChanged(kind='curve')
+
+ def _activeImageChanged(self, previous, current):
+ self._activeChanged(kind='image')
+
+ def _activeScatterChanged(self, previous, current):
+ self._activeChanged(kind='scatter')
+
+ def _limitsChanged(self, event):
+ """Handle change of plot area limits."""
+ if event['event'] == 'limitsChanged':
+ self.sigVisibleDataChanged.emit()
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else plot._getItems()
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ items = []
+ if plot is not None:
+ for kind in plot._ACTIVE_ITEM_KINDS:
+ item = plot._getActiveItem(kind=kind)
+ if item is not None:
+ items.append(item)
+ return tuple(items)
+
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ kind = self.getKind(item)
+ if kind in plot._ACTIVE_ITEM_KINDS:
+ if plot._getActiveItem(kind) != item:
+ plot._setActiveItem(kind, item.getLegend())
+
+ def getLabel(self, item):
+ return item.getLegend()
+
+ def getKind(self, item):
+ if isinstance(item, plotitems.Curve):
+ return 'curve'
+ elif isinstance(item, plotitems.ImageData):
+ return 'image'
+ elif isinstance(item, plotitems.Scatter):
+ return 'scatter'
+ elif isinstance(item, plotitems.Histogram):
+ return 'histogram'
+ else:
+ return None
+
+
+class _SceneWidgetWrapper(_Wrapper):
+ """Class handling SceneWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
"""
- STATS = StatsHandler((
- (statsmdl.StatMin(), StatFormatter()),
- statsmdl.StatCoordMin(),
- (statsmdl.StatMax(), StatFormatter()),
- statsmdl.StatCoordMax(),
- (('std', numpy.std), StatFormatter()),
- (('mean', numpy.mean), StatFormatter()),
- statsmdl.StatCOM()
- ))
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.SceneWidget import SceneWidget
- def __init__(self, parent=None, plot=None):
- StatsWidget.__init__(self, parent=parent, plot=plot, stats=self.STATS)
+ assert isinstance(plot, SceneWidget)
+ super(_SceneWidgetWrapper, self).__init__(plot)
+ plot.getSceneGroup().sigItemAdded.connect(self.sigItemAdded)
+ plot.getSceneGroup().sigItemRemoved.connect(self.sigItemRemoved)
+ plot.selection().sigCurrentChanged.connect(self._currentChanged)
+ # sigVisibleDataChanged is never emitted
+
+ def _currentChanged(self, current, previous):
+ self.sigCurrentChanged.emit(current)
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else tuple(plot.getSceneGroup().visit())
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (plot.selection().getCurrentItem(),)
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ plot.selection().setCurrentItem(item)
-class StatsTable(TableWidget):
+ def getLabel(self, item):
+ return item.getLabel()
+
+ def getKind(self, item):
+ from ..plot3d import items as plot3ditems
+
+ if isinstance(item, (plot3ditems.ImageData,
+ plot3ditems.ScalarField3D)):
+ return 'image'
+ elif isinstance(item, (plot3ditems.Scatter2D,
+ plot3ditems.Scatter3D)):
+ return 'scatter'
+ else:
+ return None
+
+
+class _ScalarFieldViewWrapper(_Wrapper):
+ """Class handling ScalarFieldView specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
"""
- TableWidget displaying for each curves contained by the Plot some
- information:
- * legend
- * minimal value
- * maximal value
- * standard deviation (std)
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.ScalarFieldView import ScalarFieldView
+ from ..plot3d.items import ScalarField3D
+
+ assert isinstance(plot, ScalarFieldView)
+ super(_ScalarFieldViewWrapper, self).__init__(plot)
+ self._item = ScalarField3D()
+ self._dataChanged()
+ plot.sigDataChanged.connect(self._dataChanged)
+ # sigItemAdded, sigItemRemoved, sigVisibleDataChanged are never emitted
+
+ def _dataChanged(self):
+ plot = self.getPlot()
+ if plot is not None:
+ self._item.setData(plot.getData(copy=False), copy=False)
+ self.sigCurrentChanged.emit(self._item)
- :param parent: The widget's parent.
- :param plot: :class:`.PlotWidget` instance on which to operate
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (self._item,)
+
+ def getSelectedItems(self):
+ return self.getItems()
+
+ def setCurrentItem(self, item):
+ pass
+
+ def getLabel(self, item):
+ return 'Data'
+
+ def getKind(self, item):
+ return 'image'
+
+
+class _Container(object):
+ """Class to contain a plot item.
+
+ This is apparently needed for compatibility with PySide2,
+
+ :param QObject obj:
"""
+ def __init__(self, obj):
+ self._obj = obj
- COMPATIBLE_KINDS = {
- 'curve': CurveItem,
- 'image': ImageItem,
- 'scatter': ScatterItem,
- 'histogram': HistogramItem
- }
+ def __call__(self):
+ return self._obj
- COMPATIBLE_ITEMS = tuple(COMPATIBLE_KINDS.values())
- def __init__(self, parent=None, plot=None):
- TableWidget.__init__(self, parent)
- """Next freeID for the curve"""
- self.plot = None
- self._displayOnlyActItem = False
- self._statsOnVisibleData = False
- self._lgdAndKindToItems = {}
- """Associate to a tuple(legend, kind) the items legend"""
- self.callbackImage = None
- self.callbackScatter = None
- self.callbackCurve = None
- """Associate the curve legend to his first item"""
+class _StatsWidgetBase(object):
+ """
+ Base class for all widgets which want to display statistics
+ """
+ def __init__(self, statsOnVisibleData, displayOnlyActItem):
+ self._displayOnlyActItem = displayOnlyActItem
+ self._statsOnVisibleData = statsOnVisibleData
self._statsHandler = None
- self._legendsSet = []
- """list of legends actually displayed"""
- self._resetColumns()
- self.setColumnCount(len(self._columns))
- self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
- self.setPlot(plot)
- self.setSortingEnabled(True)
+ self.__default_skipped_events = (
+ ItemChangedType.ALPHA,
+ ItemChangedType.COLOR,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL_SIZE,
+ ItemChangedType.LINE_WIDTH,
+ ItemChangedType.LINE_STYLE,
+ ItemChangedType.LINE_BG_COLOR,
+ ItemChangedType.FILL,
+ ItemChangedType.HIGHLIGHTED_COLOR,
+ ItemChangedType.HIGHLIGHTED_STYLE,
+ ItemChangedType.TEXT,
+ ItemChangedType.OVERLAY,
+ ItemChangedType.VISUALIZATION_MODE,
+ )
+
+ self._plotWrapper = _Wrapper()
+ self._dealWithPlotConnection(create=True)
- def _resetColumns(self):
- self._columns_index = OrderedDict([('legend', 0), ('kind', 1)])
- self._columns = self._columns_index.keys()
- self.setColumnCount(len(self._columns))
+ def setPlot(self, plot):
+ """Define the plot to interact with
- def setStats(self, statsHandler):
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
"""
+ try:
+ import OpenGL
+ except ImportError:
+ has_opengl = False
+ else:
+ has_opengl = True
+ from ..plot3d.SceneWidget import SceneWidget # Lazy import
+ self._dealWithPlotConnection(create=False)
+ self.clear()
+ if plot is None:
+ self._plotWrapper = _Wrapper()
+ elif isinstance(plot, PlotWidget):
+ self._plotWrapper = _PlotWidgetWrapper(plot)
+ else:
+ if has_opengl is True:
+ if isinstance(plot, SceneWidget):
+ self._plotWrapper = _SceneWidgetWrapper(plot)
+ else: # Expect a ScalarFieldView
+ self._plotWrapper = _ScalarFieldViewWrapper(plot)
+ else:
+ _logger.warning('OpenGL not installed, %s not managed' % ('SceneWidget qnd ScalarFieldView'))
+ self._dealWithPlotConnection(create=True)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
- :param statsHandler: Set the statistics to be displayed and how to
- format them using
- :rtype: :class:`StatsHandler`
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
"""
- _statsHandler = statsHandler
if statsHandler is None:
- _statsHandler = StatsHandler(statFormatters=())
- if isinstance(_statsHandler, (list, tuple)):
- _statsHandler = StatsHandler(_statsHandler)
- assert isinstance(_statsHandler, StatsHandler)
- self._resetColumns()
- self.clear()
-
- for statName, stat in list(_statsHandler.stats.items()):
- assert isinstance(stat, statsmdl.StatBase)
- self._columns_index[statName] = len(self._columns_index)
- self._statsHandler = _statsHandler
- self._columns = self._columns_index.keys()
- self.setColumnCount(len(self._columns))
+ statsHandler = StatsHandler(statFormatters=())
+ elif isinstance(statsHandler, (list, tuple)):
+ statsHandler = StatsHandler(statsHandler)
+ assert isinstance(statsHandler, StatsHandler)
- self._updateItemObserve()
- self._updateAllStats()
+ self._statsHandler = statsHandler
def getStatsHandler(self):
+ """Returns the :class:`StatsHandler` in use.
+
+ :rtype: StatsHandler
+ """
return self._statsHandler
- def _updateAllStats(self):
- for (legend, kind) in self._lgdAndKindToItems:
- self._updateStats(legend, kind)
+ def getPlot(self):
+ """Returns the plot attached to this widget
- @staticmethod
- def _getKind(myItem):
- if isinstance(myItem, CurveItem):
- return 'curve'
- elif isinstance(myItem, ImageItem):
- return 'image'
- elif isinstance(myItem, ScatterItem):
- return 'scatter'
- elif isinstance(myItem, HistogramItem):
- return 'histogram'
+ :rtype: Union[PlotWidget,SceneWidget,None]
+ """
+ return self._plotWrapper.getPlot()
+
+ def _dealWithPlotConnection(self, create=True):
+ """Manage connection to plot signals
+
+ Note: connection on Item are managed by _addItem and _removeItem methods
+ """
+ connections = [] # List of (signal, slot) to connect/disconnect
+ if self._statsOnVisibleData:
+ connections.append(
+ (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats))
+
+ if self._displayOnlyActItem:
+ connections.append(
+ (self._plotWrapper.sigCurrentChanged, self._updateItemObserve))
else:
- return None
+ connections += [
+ (self._plotWrapper.sigItemAdded, self._addItem),
+ (self._plotWrapper.sigItemRemoved, self._removeItem),
+ (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged)]
+
+ for signal, slot in connections:
+ if create:
+ signal.connect(slot)
+ else:
+ signal.disconnect(slot)
- def setPlot(self, plot):
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ raise NotImplementedError('Base class')
+
+ def _updateStats(self, item):
+ """Update displayed information for given plot item
+
+ :param item: The plot item
+ """
+ raise NotImplementedError('Base class')
+
+ def _updateAllStats(self):
+ """Update stats for all rows in the table"""
+ raise NotImplementedError('Base class')
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
"""
- Define the plot to interact with
+ self._displayOnlyActItem = displayOnlyActItem
+
+ def setStatsOnVisibleData(self, b):
+ """Toggle computation of statistics on whole data or only visible ones.
+
+ .. warning:: When visible data is activated we will process to a simple
+ filtering of visible data by the user. The filtering is a
+ simple data sub-sampling. No interpolation is made to fit
+ data to boundaries.
- :param plot: the plot containing the items on which statistics are
- applied
- :rtype: :class:`.PlotWidget`
+ :param bool b: True if we want to apply statistics only on visible data
"""
- if self.plot:
+ if self._statsOnVisibleData != b:
self._dealWithPlotConnection(create=False)
- self.plot = plot
- self.clear()
- if self.plot:
+ self._statsOnVisibleData = b
self._dealWithPlotConnection(create=True)
- self._updateItemObserve()
+ self._updateAllStats()
- def _updateItemObserve(self):
- if self.plot:
- self.clear()
- if self._displayOnlyActItem is True:
- activeCurve = self.plot.getActiveCurve(just_legend=False)
- activeScatter = self.plot._getActiveItem(kind='scatter',
- just_legend=False)
- activeImage = self.plot.getActiveImage(just_legend=False)
- if activeCurve:
- self._addItem(activeCurve)
- if activeImage:
- self._addItem(activeImage)
- if activeScatter:
- self._addItem(activeScatter)
- else:
- [self._addItem(curve) for curve in self.plot.getAllCurves()]
- [self._addItem(image) for image in self.plot.getAllImages()]
- scatters = self.plot._getItems(kind='scatter',
- just_legend=False,
- withhidden=True)
- [self._addItem(scatter) for scatter in scatters]
- histograms = self.plot._getItems(kind='histogram',
- just_legend=False,
- withhidden=True)
- [self._addItem(histogram) for histogram in histograms]
+ def _addItem(self, item):
+ """Add a plot item to the table
- def _dealWithPlotConnection(self, create=True):
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
"""
- Manage connection to plot signals
+ raise NotImplementedError('Base class')
- Note: connection on Item are managed by the _removeItem function
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
"""
- if self.plot is None:
- return
- if self._displayOnlyActItem:
- if create is True:
- if self.callbackImage is None:
- self.callbackImage = functools.partial(self._activeItemChanged, 'image')
- self.callbackScatter = functools.partial(self._activeItemChanged, 'scatter')
- self.callbackCurve = functools.partial(self._activeItemChanged, 'curve')
- self.plot.sigActiveImageChanged.connect(self.callbackImage)
- self.plot.sigActiveScatterChanged.connect(self.callbackScatter)
- self.plot.sigActiveCurveChanged.connect(self.callbackCurve)
- else:
- if self.callbackImage is not None:
- self.plot.sigActiveImageChanged.disconnect(self.callbackImage)
- self.plot.sigActiveScatterChanged.disconnect(self.callbackScatter)
- self.plot.sigActiveCurveChanged.disconnect(self.callbackCurve)
- self.callbackImage = None
- self.callbackScatter = None
- self.callbackCurve = None
- else:
- if create is True:
- self.plot.sigContentChanged.connect(self._plotContentChanged)
- else:
- self.plot.sigContentChanged.disconnect(self._plotContentChanged)
- if create is True:
- self.plot.sigPlotSignal.connect(self._zoomPlotChanged)
- else:
- self.plot.sigPlotSignal.disconnect(self._zoomPlotChanged)
+ raise NotImplementedError('Base class')
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
+
+ :param current:
+ """
+ raise NotImplementedError('Base class')
def clear(self):
+ """clear GUI"""
+ pass
+
+ def _skipPlotItemChangedEvent(self, event):
"""
- Clear all existing items
+
+ :param ItemChangedtype event: event to filter or not
+ :return: True if we want to ignore this ItemChangedtype
+ :rtype: bool
"""
- lgdsAndKinds = list(self._lgdAndKindToItems.keys())
- for lgdAndKind in lgdsAndKinds:
- self._removeItem(legend=lgdAndKind[0], kind=lgdAndKind[1])
- self._lgdAndKindToItems = {}
- qt.QTableWidget.clear(self)
+ return event in self.__default_skipped_events
+
+
+class StatsTable(_StatsWidgetBase, TableWidget):
+ """
+ TableWidget displaying for each curves contained by the Plot some
+ information:
+
+ * legend
+ * minimal value
+ * maximal value
+ * standard deviation (std)
+
+ :param QWidget parent: The widget's parent.
+ :param Union[PlotWidget,SceneWidget] plot:
+ :class:`PlotWidget` or :class:`SceneWidget` instance on which to operate
+ """
+
+ _LEGEND_HEADER_DATA = 'legend'
+ _KIND_HEADER_DATA = 'kind'
+
+ def __init__(self, parent=None, plot=None):
+ TableWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
+ displayOnlyActItem=False)
+
+ # Init for _displayOnlyActItem == False
+ assert self._displayOnlyActItem is False
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ self.currentItemChanged.connect(self._currentItemChanged)
+
self.setRowCount(0)
+ self.setColumnCount(2)
- # It have to called befor3e accessing to the header items
- self.setHorizontalHeaderLabels(list(self._columns))
-
- if self._statsHandler is not None:
- for columnId, name in enumerate(self._columns):
- item = self.horizontalHeaderItem(columnId)
- if name in self._statsHandler.stats:
- stat = self._statsHandler.stats[name]
- text = stat.name[0].upper() + stat.name[1:]
- if stat.description is not None:
- tooltip = stat.description
- else:
- tooltip = ""
- else:
- text = name[0].upper() + name[1:]
- tooltip = ""
- item.setToolTip(tooltip)
- item.setText(text)
+ # Init headers
+ headerItem = qt.QTableWidgetItem('Legend')
+ headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
+ self.setHorizontalHeaderItem(0, headerItem)
+ headerItem = qt.QTableWidgetItem('Kind')
+ headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
+ self.setHorizontalHeaderItem(1, headerItem)
+
+ self.setSortingEnabled(True)
+ self.setPlot(plot)
- if hasattr(self.horizontalHeader(), 'setSectionResizeMode'): # Qt5
- self.horizontalHeader().setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+ @contextmanager
+ def _disableSorting(self):
+ """Context manager that disables table sorting
+
+ Previous state is restored when leaving
+ """
+ sorting = self.isSortingEnabled()
+ if sorting:
+ self.setSortingEnabled(False)
+ yield
+ if sorting:
+ self.setSortingEnabled(sorting)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ self._removeAllItems()
+ _StatsWidgetBase.setStats(self, statsHandler)
+
+ self.setRowCount(0)
+ self.setColumnCount(len(self._statsHandler.stats) + 2) # + legend and kind
+
+ for index, stat in enumerate(self._statsHandler.stats.values()):
+ headerItem = qt.QTableWidgetItem(stat.name.capitalize())
+ headerItem.setData(qt.Qt.UserRole, stat.name)
+ if stat.description is not None:
+ headerItem.setToolTip(stat.description)
+ self.setHorizontalHeaderItem(2 + index, headerItem)
+
+ horizontalHeader = self.horizontalHeader()
+ if hasattr(horizontalHeader, 'setSectionResizeMode'): # Qt5
+ horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
else: # Qt4
- self.horizontalHeader().setResizeMode(qt.QHeaderView.ResizeToContents)
- self.setColumnHidden(self._columns_index['kind'], True)
+ horizontalHeader.setResizeMode(qt.QHeaderView.ResizeToContents)
- def _addItem(self, item):
- assert isinstance(item, self.COMPATIBLE_ITEMS)
- if (item.getLegend(), self._getKind(item)) in self._lgdAndKindToItems:
- self._updateStats(item.getLegend(), self._getKind(item))
- return
+ self._updateItemObserve()
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateItemObserve()
+
+ def clear(self):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ self._removeAllItems()
+
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ self._removeAllItems()
+
+ # Get selected or all items from the plot
+ if self._displayOnlyActItem: # Only selected
+ items = self._plotWrapper.getSelectedItems()
+ else: # All items
+ items = self._plotWrapper.getItems()
+
+ # Add items to the plot
+ for item in items:
+ self._addItem(item)
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
- self.setRowCount(self.rowCount() + 1)
- indexTable = self.rowCount() - 1
- kind = self._getKind(item)
-
- self._lgdAndKindToItems[(item.getLegend(), kind)] = {}
-
- # the get item will manage the item creation of not existing
- _createItem = self._getItem
- for itemName in self._columns:
- _createItem(name=itemName, legend=item.getLegend(), kind=kind,
- indexTable=indexTable)
-
- self._updateStats(legend=item.getLegend(), kind=kind)
-
- callback = functools.partial(
- silx.utils.weakref.WeakMethodProxy(self._updateStats),
- item.getLegend(), kind)
- item.sigItemChanged.connect(callback)
- self.setColumnHidden(self._columns_index['kind'],
- item.getLegend() not in self._legendsSet)
- self._legendsSet.append(item.getLegend())
-
- def _getItem(self, name, legend, kind, indexTable):
- if (legend, kind) not in self._lgdAndKindToItems:
- self._lgdAndKindToItems[(legend, kind)] = {}
- if not (name in self._lgdAndKindToItems[(legend, kind)] and
- self._lgdAndKindToItems[(legend, kind)]):
- if name in ('legend', 'kind'):
- _item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type)
- if name == 'legend':
- _item.setText(legend)
+ :param current:
+ """
+ row = self._itemToRow(current)
+ if row is None:
+ if self.currentRow() >= 0:
+ self.setCurrentCell(-1, -1)
+ elif row != self.currentRow():
+ self.setCurrentCell(row, 0)
+
+ def _tableItemToItem(self, tableItem):
+ """Find the plot item corresponding to a table item
+
+ :param QTableWidgetItem tableItem:
+ :rtype: QObject
+ """
+ container = tableItem.data(qt.Qt.UserRole)
+ return container()
+
+ def _itemToRow(self, item):
+ """Find the row corresponding to a plot item
+
+ :param item: The plot item
+ :return: The corresponding row index
+ :rtype: Union[int,None]
+ """
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ if self._tableItemToItem(tableItem) == item:
+ return row
+ return None
+
+ def _itemToTableItems(self, item):
+ """Find all table items corresponding to a plot item
+
+ :param item: The plot item
+ :return: An ordered dict of column name to QTableWidgetItem mapping
+ for the given plot item.
+ :rtype: OrderedDict
+ """
+ result = OrderedDict()
+ row = self._itemToRow(item)
+ if row is not None:
+ for column in range(self.columnCount()):
+ tableItem = self.item(row, column)
+ if self._tableItemToItem(tableItem) != item:
+ _logger.error("Table item/plot item mismatch")
else:
- assert name == 'kind'
- _item.setText(kind)
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+ result[name] = tableItem
+ return result
+
+ def _plotItemChanged(self, event):
+ """Handle modifications of the items.
+
+ :param event:
+ """
+ if self._skipPlotItemChangedEvent(event) is True:
+ return
+ else:
+ item = self.sender()
+ self._updateStats(item)
+
+ def _addItem(self, item):
+ """Add a plot item to the table
+
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
+ """
+ if self._itemToRow(item) is not None:
+ _logger.info("Item already present in the table")
+ self._updateStats(item)
+ return True
+
+ kind = self._plotWrapper.getKind(item)
+ if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.info("Item has not a supported type: %s", item)
+ return False
+
+ # Prepare table items
+ tableItems = [
+ qt.QTableWidgetItem(), # Legend
+ qt.QTableWidgetItem()] # Kind
+
+ for column in range(2, self.columnCount()):
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+
+ formatter = self._statsHandler.formatters[name]
+ if formatter:
+ tableItem = formatter.tabWidgetItemClass()
else:
- if self._statsHandler.formatters[name]:
- _item = self._statsHandler.formatters[name].tabWidgetItemClass()
- else:
- _item = qt.QTableWidgetItem()
- tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
- if tooltip is not None:
- _item.setToolTip(tooltip)
+ tableItem = qt.QTableWidgetItem()
- _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
- self.setItem(indexTable, self._columns_index[name], _item)
- self._lgdAndKindToItems[(legend, kind)][name] = _item
+ tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
+ if tooltip is not None:
+ tableItem.setToolTip(tooltip)
- return self._lgdAndKindToItems[(legend, kind)][name]
+ tableItems.append(tableItem)
- def _removeItem(self, legend, kind):
- if (legend, kind) not in self._lgdAndKindToItems or not self.plot:
- return
+ # Disable sorting while adding table items
+ with self._disableSorting():
+ # Add a row to the table
+ self.setRowCount(self.rowCount() + 1)
- self.firstItem = self._lgdAndKindToItems[(legend, kind)]['legend']
- del self._lgdAndKindToItems[(legend, kind)]
- self.removeRow(self.firstItem.row())
- self._legendsSet.remove(legend)
- self.setColumnHidden(self._columns_index['kind'],
- legend not in self._legendsSet)
+ # Add table items to the last row
+ row = self.rowCount() - 1
+ for column, tableItem in enumerate(tableItems):
+ tableItem.setData(qt.Qt.UserRole, _Container(item))
+ tableItem.setFlags(
+ qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, column, tableItem)
- def _updateCurrentStats(self):
- for lgdAndKind in self._lgdAndKindToItems:
- self._updateStats(lgdAndKind[0], lgdAndKind[1])
+ # Update table items content
+ self._updateStats(item)
- def _updateStats(self, legend, kind, event=None):
- if self._statsHandler is None:
+ # Listen for item changes
+ # Using queued connection to avoid issue with sender
+ # being that of the signal calling the signal
+ item.sigItemChanged.connect(self._plotItemChanged,
+ qt.Qt.QueuedConnection)
+
+ return True
+
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
+ """
+ row = self._itemToRow(item)
+ if row is None:
+ kind = self._plotWrapper.getKind(item)
+ if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.error("Removing item that is not in table: %s", str(item))
return
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.removeRow(row)
+
+ def _removeAllItems(self):
+ """Remove content of the table"""
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.clearContents()
+ self.setRowCount(0)
- assert kind in ('curve', 'image', 'scatter', 'histogram')
- if kind == 'curve':
- item = self.plot.getCurve(legend)
- elif kind == 'image':
- item = self.plot.getImage(legend)
- elif kind == 'scatter':
- item = self.plot.getScatter(legend)
- elif kind == 'histogram':
- item = self.plot.getHistogram(legend)
- else:
- raise ValueError('kind not managed')
+ def _updateStats(self, item):
+ """Update displayed information for given plot item
- if not item or (item.getLegend(), kind) not in self._lgdAndKindToItems:
+ :param item: The plot item
+ """
+ if item is None:
+ return
+ plot = self.getPlot()
+ if plot is None:
+ _logger.info("Plot not available")
return
- assert isinstance(item, self.COMPATIBLE_ITEMS)
-
- statsValDict = self._statsHandler.calculate(item, self.plot,
- self._statsOnVisibleData)
-
- lgdItem = self._lgdAndKindToItems[(item.getLegend(), kind)]['legend']
- assert lgdItem
- rowStat = lgdItem.row()
-
- for statName, statVal in list(statsValDict.items()):
- assert statName in self._lgdAndKindToItems[(item.getLegend(), kind)]
- tableItem = self._getItem(name=statName, legend=item.getLegend(),
- kind=kind, indexTable=rowStat)
- tableItem.setText(str(statVal))
-
- def currentChanged(self, current, previous):
- if current.row() >= 0:
- legendItem = self.item(current.row(), self._columns_index['legend'])
- assert legendItem
- kindItem = self.item(current.row(), self._columns_index['kind'])
- kind = kindItem.text()
- if kind == 'curve':
- self.plot.setActiveCurve(legendItem.text())
- elif kind == 'image':
- self.plot.setActiveImage(legendItem.text())
- elif kind == 'scatter':
- self.plot._setActiveItem('scatter', legendItem.text())
- elif kind == 'histogram':
- # active histogram not managed by the plot actually
- pass
- else:
- raise ValueError('kind not managed')
- qt.QTableWidget.currentChanged(self, current, previous)
+ row = self._itemToRow(item)
+ if row is None:
+ _logger.error("This item is not in the table: %s", str(item))
+ return
- def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ statsHandler = self.getStatsHandler()
+ if statsHandler is not None:
+ stats = statsHandler.calculate(
+ item, plot, self._statsOnVisibleData)
+ else:
+ stats = {}
+
+ with self._disableSorting():
+ for name, tableItem in self._itemToTableItems(item).items():
+ if name == self._LEGEND_HEADER_DATA:
+ text = self._plotWrapper.getLabel(item)
+ tableItem.setText(text)
+ elif name == self._KIND_HEADER_DATA:
+ tableItem.setText(self._plotWrapper.getKind(item))
+ else:
+ value = stats.get(name)
+ if value is None:
+ _logger.error("Value not found for: %s", name)
+ tableItem.setText('-')
+ else:
+ tableItem.setText(str(value))
+
+ def _updateAllStats(self):
+ """Update stats for all rows in the table"""
+ with self._disableSorting():
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ self._updateStats(item)
+
+ def _currentItemChanged(self, current, previous):
+ """Handle change of selection in table and sync plot selection
+
+ :param QTableWidgetItem current:
+ :param QTableWidgetItem previous:
"""
+ if current and current.row() >= 0:
+ item = self._tableItemToItem(current)
+ self._plotWrapper.setCurrentItem(item)
- :param bool displayOnlyActItem: True if we want to only show active
- item
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
"""
if self._displayOnlyActItem == displayOnlyActItem:
return
- self._displayOnlyActItem = displayOnlyActItem
self._dealWithPlotConnection(create=False)
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.disconnect(self._currentItemChanged)
+
+ _StatsWidgetBase.setDisplayOnlyActiveItem(self, displayOnlyActItem)
+
self._updateItemObserve()
self._dealWithPlotConnection(create=True)
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.connect(self._currentItemChanged)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ else:
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+
+class _OptionsWidget(qt.QToolBar):
+
+ def __init__(self, parent=None):
+ qt.QToolBar.__init__(self, parent)
+ self.setIconSize(qt.QSize(16, 16))
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-active-items"))
+ action.setText("Active items only")
+ action.setToolTip("Display stats for active items only.")
+ action.setCheckable(True)
+ action.setChecked(True)
+ self.__displayActiveItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-items"))
+ action.setText("All items")
+ action.setToolTip("Display stats for all available items.")
+ action.setCheckable(True)
+ self.__displayWholeItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-visible-data"))
+ action.setText("Use the visible data range")
+ action.setToolTip("Use the visible data range.<br/>"
+ "If activated the data is filtered to only use"
+ "visible data of the plot."
+ "The filtering is a data sub-sampling."
+ "No interpolation is made to fit data to"
+ "boundaries.")
+ action.setCheckable(True)
+ self.__useVisibleData = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-data"))
+ action.setText("Use the full data range")
+ action.setToolTip("Use the full data range.")
+ action.setCheckable(True)
+ action.setChecked(True)
+ self.__useWholeData = action
+
+ self.addAction(self.__displayWholeItems)
+ self.addAction(self.__displayActiveItems)
+ self.addSeparator()
+ self.addAction(self.__useVisibleData)
+ self.addAction(self.__useWholeData)
+
+ self.itemSelection = qt.QActionGroup(self)
+ self.itemSelection.setExclusive(True)
+ self.itemSelection.addAction(self.__displayActiveItems)
+ self.itemSelection.addAction(self.__displayWholeItems)
+
+ self.dataRangeSelection = qt.QActionGroup(self)
+ self.dataRangeSelection.setExclusive(True)
+ self.dataRangeSelection.addAction(self.__useWholeData)
+ self.dataRangeSelection.addAction(self.__useVisibleData)
+
+ def isActiveItemMode(self):
+ return self.itemSelection.checkedAction() is self.__displayActiveItems
+
+ def isVisibleDataRangeMode(self):
+ return self.dataRangeSelection.checkedAction() is self.__useVisibleData
+
+ def setVisibleDataRangeModeEnabled(self, enabled):
+ """Enable/Disable the visible data range mode
+
+ :param bool enabled: True to allow user to choose
+ stats on visible data
+ """
+ self.__useVisibleData.setEnabled(enabled)
+ if not enabled:
+ self.__useWholeData.setChecked(True)
+
+
+class StatsWidget(qt.QWidget):
+ """
+ Widget displaying a set of :class:`Stat` to be displayed on a
+ :class:`StatsTable` and to be apply on items contained in the :class:`Plot`
+ Also contains options to:
+
+ * compute statistics on all the data or on visible data only
+ * show statistics of all items or only the active one
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the visibility of this widget changes.
+
+ It Provides the visibility of the widget.
+ """
+
+ NUMBER_FORMAT = '{0:.3f}'
+
+ def __init__(self, parent=None, plot=None, stats=None):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._options = _OptionsWidget(parent=self)
+ self.layout().addWidget(self._options)
+ self._statsTable = StatsTable(parent=self, plot=plot)
+ self.setStats(stats)
+
+ self.layout().addWidget(self._statsTable)
+
+ self._options.itemSelection.triggered.connect(
+ self._optSelectionChanged)
+ self._options.dataRangeSelection.triggered.connect(
+ self._optDataRangeChanged)
+ self._optSelectionChanged()
+ self._optDataRangeChanged()
+
+ def _getStatsTable(self):
+ """Returns the :class:`StatsTable` used by this widget.
+
+ :rtype: StatsTable
+ """
+ return self._statsTable
+
+ def showEvent(self, event):
+ self.sigVisibilityChanged.emit(True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self.sigVisibilityChanged.emit(False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _optSelectionChanged(self, action=None):
+ self._getStatsTable().setDisplayOnlyActiveItem(
+ self._options.isActiveItemMode())
+
+ def _optDataRangeChanged(self, action=None):
+ self._getStatsTable().setStatsOnVisibleData(
+ self._options.isVisibleDataRangeMode())
+
+ # Proxy methods
+
+ def setStats(self, statsHandler):
+ return self._getStatsTable().setStats(statsHandler=statsHandler)
+
+ setStats.__doc__ = StatsTable.setStats.__doc__
+
+ def setPlot(self, plot):
+ self._options.setVisibleDataRangeModeEnabled(
+ plot is None or isinstance(plot, PlotWidget))
+ return self._getStatsTable().setPlot(plot=plot)
+
+ setPlot.__doc__ = StatsTable.setPlot.__doc__
+
+ def getPlot(self):
+ return self._getStatsTable().getPlot()
+
+ getPlot.__doc__ = StatsTable.getPlot.__doc__
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ return self._getStatsTable().setDisplayOnlyActiveItem(
+ displayOnlyActItem=displayOnlyActItem)
+
+ setDisplayOnlyActiveItem.__doc__ = StatsTable.setDisplayOnlyActiveItem.__doc__
+
def setStatsOnVisibleData(self, b):
+ return self._getStatsTable().setStatsOnVisibleData(b=b)
+
+ setStatsOnVisibleData.__doc__ = StatsTable.setStatsOnVisibleData.__doc__
+
+
+DEFAULT_STATS = StatsHandler((
+ (statsmdl.StatMin(), StatFormatter()),
+ statsmdl.StatCoordMin(),
+ (statsmdl.StatMax(), StatFormatter()),
+ statsmdl.StatCoordMax(),
+ statsmdl.StatCOM(),
+ (('mean', numpy.mean), StatFormatter()),
+ (('std', numpy.std), StatFormatter()),
+))
+
+
+class BasicStatsWidget(StatsWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`StatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param PlotWidget plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+
+ .. snapshotqt:: img/BasicStatsWidget.png
+ :width: 300px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicStatsWidget(plot=plot)
+ widget.show()
+ """
+ def __init__(self, parent=None, plot=None):
+ StatsWidget.__init__(self, parent=parent, plot=plot,
+ stats=DEFAULT_STATS)
+
+
+class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
+ """
+ Widget made to display stats into a QLayout with for all stat a couple
+ (QLabel, QLineEdit) created.
+ The the layout can be defined prior of adding any statistic.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+
+ def __init__(self, parent=None, plot=None, kind='curve', stats=None,
+ statsOnVisibleData=False):
+ self._item_kind = kind
+ """The item displayed"""
+ self._statQlineEdit = {}
+ """list of legends actually displayed"""
+ self._n_statistics_per_line = 4
+ """number of statistics displayed per line in the grid layout"""
+ qt.QWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self,
+ statsOnVisibleData=statsOnVisibleData,
+ displayOnlyActItem=True)
+ self.setLayout(self._createLayout())
+ self.setPlot(plot)
+ if stats is not None:
+ self.setStats(stats)
+
+ def _addItemForStatistic(self, statistic):
+ assert isinstance(statistic, statsmdl.StatBase)
+ assert statistic.name in self._statsHandler.stats
+
+ self.layout().setSpacing(2)
+ self.layout().setContentsMargins(2, 2, 2, 2)
+
+ if isinstance(self.layout(), qt.QGridLayout):
+ parent = self
+ else:
+ widget = qt.QWidget(parent=self)
+ parent = widget
+
+ qLabel = qt.QLabel(statistic.name + ':', parent=parent)
+ qLineEdit = qt.QLineEdit('', parent=parent)
+ qLineEdit.setReadOnly(True)
+
+ self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit)
+ self._statQlineEdit[statistic.name] = qLineEdit
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
"""
- .. warning:: When visible data is activated we will process to a simple
- filtering of visible data by the user. The filtering is a
- simple data sub-sampling. No interpolation is made to fit
- data to boundaries.
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateAllStats()
- :param bool b: True if we want to apply statistics only on visible data
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ raise NotImplementedError('Base class')
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
"""
- if self._statsOnVisibleData != b:
- self._statsOnVisibleData = b
- self._updateCurrentStats()
+ _StatsWidgetBase.setStats(self, statsHandler)
+ for statName, stat in list(self._statsHandler.stats.items()):
+ self._addItemForStatistic(stat)
+ self._updateAllStats()
def _activeItemChanged(self, kind, previous, current):
- """Callback used when plotting only the active item"""
- assert kind in ('curve', 'image', 'scatter', 'histogram')
- self._updateItemObserve()
+ if kind == self._item_kind:
+ self._updateAllStats()
- def _plotContentChanged(self, action, kind, legend):
- """Callback used when plotting all the plot items"""
- if kind not in ('curve', 'image', 'scatter', 'histogram'):
- return
- if kind == 'curve':
- item = self.plot.getCurve(legend)
- elif kind == 'image':
- item = self.plot.getImage(legend)
- elif kind == 'scatter':
- item = self.plot.getScatter(legend)
- elif kind == 'histogram':
- item = self.plot.getHistogram(legend)
- else:
- raise ValueError('kind not managed')
+ def _updateAllStats(self):
+ plot = self.getPlot()
+ if plot is not None:
+ _items = self._plotWrapper.getSelectedItems()
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ if len(items) is 1:
+ self._setItem(items[0])
+
+ def setKind(self, kind):
+ """Change the kind of active item to display
+ :param str kind: kind of item to display information for ('curve' ...)
+ """
+ if self._item_kind != kind:
+ self._item_kind = kind
+ self._updateItemObserve()
- if action == 'add':
- if item is None:
- raise ValueError('Item from legend "%s" do not exists' % legend)
- self._addItem(item)
- elif action == 'remove':
- self._removeItem(legend, kind)
+ def getKind(self):
+ """
+ :return: kind of item we want to compute statistic for
+ :rtype: str
+ """
+ return self._item_kind
+
+ def _setItem(self, item):
+ if item is None:
+ for stat_name, stat_widget in self._statQlineEdit.items():
+ stat_widget.setText('')
+ elif (self._statsHandler is not None and len(
+ self._statsHandler.stats) > 0):
+ plot = self.getPlot()
+ if plot is not None:
+ statsValDict = self._statsHandler.calculate(item,
+ plot,
+ self._statsOnVisibleData)
+ for statName, statVal in list(statsValDict.items()):
+ self._statQlineEdit[statName].setText(statVal)
+
+ def _updateItemObserve(self, *argv):
+ assert self._displayOnlyActItem
+ _items = self._plotWrapper.getSelectedItems()
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ _item = items[0] if len(items) is 1 else None
+ self._setItem(_item)
+
+ def _createLayout(self):
+ """create an instance of the main QLayout"""
+ raise NotImplementedError('Base class')
+
+ def _addItem(self, item):
+ raise NotImplementedError('Display only the active item')
+
+ def _removeItem(self, item):
+ raise NotImplementedError('Display only the active item')
+
+ def _plotCurrentChanged(selfself, current):
+ raise NotImplementedError('Display only the active item')
+
+
+class BasicLineStatsWidget(_BaseLineStatsWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`LineStatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False):
+ _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
+ plot=plot, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+
+ def _createLayout(self):
+ return FlowLayout()
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ # create a mother widget to make sure both qLabel & qLineEdit will
+ # always be displayed side by side
+ widget = qt.QWidget(parent=self)
+ widget.setLayout(qt.QHBoxLayout())
+ widget.layout().setSpacing(0)
+ widget.layout().setContentsMargins(0, 0, 0, 0)
+
+ widget.layout().addWidget(qLabel)
+ widget.layout().addWidget(qLineEdit)
+
+ self.layout().addWidget(widget)
+
+
+class BasicGridStatsWidget(_BaseLineStatsWidget):
+ """
+ pymca design like widget
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param str kind: the kind of plotitems we want to display
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ :param int statsPerLine: number of statistic to be displayed per line
+
+ .. snapshotqt:: img/BasicGridStatsWidget.png
+ :width: 600px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicGridStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicGridStatsWidget(plot=plot, kind='curve')
+ widget.show()
+ """
- def _zoomPlotChanged(self, event):
- if self._statsOnVisibleData is True:
- if 'event' in event and event['event'] == 'limitsChanged':
- self._updateCurrentStats()
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False,
+ statsPerLine=4):
+ _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
+ plot=plot, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+ self._n_statistics_per_line = statsPerLine
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ column = len(self._statQlineEdit) % self._n_statistics_per_line
+ row = len(self._statQlineEdit) // self._n_statistics_per_line
+ self.layout().addWidget(qLabel, row, column * 2)
+ self.layout().addWidget(qLineEdit, row, column * 2 + 1)
+
+ def _createLayout(self):
+ return qt.QGridLayout()
diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py
index e087354..0d11f17 100644
--- a/silx/gui/plot/_BaseMaskToolsWidget.py
+++ b/silx/gui/plot/_BaseMaskToolsWidget.py
@@ -29,7 +29,7 @@ from __future__ import division
__authors__ = ["T. Vincent", "P. Knobel"]
__license__ = "MIT"
-__date__ = "29/08/2018"
+__date__ = "15/02/2019"
import os
import weakref
@@ -141,7 +141,7 @@ class BaseMask(qt.QObject):
def commit(self):
"""Append the current mask to history if changed"""
if (not self._history or self._redo or
- not numpy.all(numpy.equal(self._mask, self._history[-1]))):
+ not numpy.array_equal(self._mask, self._history[-1])):
if self._redo:
self._redo = [] # Reset redo as a new action as been performed
self.sigRedoable[bool].emit(False)
@@ -325,7 +325,7 @@ class BaseMask(qt.QObject):
raise NotImplementedError("To be implemented in subclass")
def updateDisk(self, level, crow, ccol, radius, mask=True):
- """Mask/Unmask data located inside a disk of the given mask level.
+ """Mask/Unmask data located inside a dick of the given mask level.
:param int level: Mask level to update.
:param crow: Disk center row/ordinate (y).
@@ -335,6 +335,18 @@ class BaseMask(qt.QObject):
"""
raise NotImplementedError("To be implemented in subclass")
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
"""Mask/Unmask a line of the given mask level.
@@ -376,13 +388,11 @@ class BaseMaskToolsWidget(qt.QWidget):
self._plotRef = weakref.ref(plot)
self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask
- self._colormap = Colormap(name="",
- normalization='linear',
+ self._colormap = Colormap(normalization='linear',
vmin=0,
- vmax=self._maxLevelNumber,
- colors=None)
+ vmax=self._maxLevelNumber)
self._defaultOverlayColor = rgba('gray') # Color of the mask
- self._setMaskColors(1, 0.5)
+ self._setMaskColors(1, 0.5) # Set the colormap LUT
if not isinstance(mask, BaseMask):
raise TypeError("mask is not an instance of BaseMask")
@@ -482,6 +492,7 @@ class BaseMaskToolsWidget(qt.QWidget):
layout.addWidget(self._initMaskGroupBox())
layout.addWidget(self._initDrawGroupBox())
layout.addWidget(self._initThresholdGroupBox())
+ layout.addWidget(self._initOtherToolsGroupBox())
layout.addStretch(1)
self.setLayout(layout)
@@ -617,6 +628,15 @@ class BaseMaskToolsWidget(qt.QWidget):
self.rectAction.triggered.connect(self._activeRectMode)
self.addAction(self.rectAction)
+ self.ellipseAction = qt.QAction(
+ icons.getQIcon('shape-ellipse'), 'Circle selection', None)
+ self.ellipseAction.setToolTip(
+ 'Rectangle selection tool: (Un)Mask a circle region <b>R</b>')
+ self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R))
+ self.ellipseAction.setCheckable(True)
+ self.ellipseAction.triggered.connect(self._activeEllipseMode)
+ self.addAction(self.ellipseAction)
+
self.polygonAction = qt.QAction(
icons.getQIcon('shape-polygon'), 'Polygon selection', None)
self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S))
@@ -640,10 +660,11 @@ class BaseMaskToolsWidget(qt.QWidget):
self.drawActionGroup = qt.QActionGroup(self)
self.drawActionGroup.setExclusive(True)
self.drawActionGroup.addAction(self.rectAction)
+ self.drawActionGroup.addAction(self.ellipseAction)
self.drawActionGroup.addAction(self.polygonAction)
self.drawActionGroup.addAction(self.pencilAction)
- actions = (self.browseAction, self.rectAction,
+ actions = (self.browseAction, self.rectAction, self.ellipseAction,
self.polygonAction, self.pencilAction)
drawButtons = []
for action in actions:
@@ -711,36 +732,28 @@ class BaseMaskToolsWidget(qt.QWidget):
def _initThresholdGroupBox(self):
"""Init thresholding widgets"""
- layout = qt.QVBoxLayout()
-
- # Thresholing
self.belowThresholdAction = qt.QAction(
icons.getQIcon('plot-roi-below'), 'Mask below threshold', None)
self.belowThresholdAction.setToolTip(
'Mask image where values are below given threshold')
self.belowThresholdAction.setCheckable(True)
- self.belowThresholdAction.triggered[bool].connect(
- self._belowThresholdActionTriggered)
+ self.belowThresholdAction.setChecked(True)
self.betweenThresholdAction = qt.QAction(
icons.getQIcon('plot-roi-between'), 'Mask within range', None)
self.betweenThresholdAction.setToolTip(
'Mask image where values are within given range')
self.betweenThresholdAction.setCheckable(True)
- self.betweenThresholdAction.triggered[bool].connect(
- self._betweenThresholdActionTriggered)
self.aboveThresholdAction = qt.QAction(
icons.getQIcon('plot-roi-above'), 'Mask above threshold', None)
self.aboveThresholdAction.setToolTip(
'Mask image where values are above given threshold')
self.aboveThresholdAction.setCheckable(True)
- self.aboveThresholdAction.triggered[bool].connect(
- self._aboveThresholdActionTriggered)
self.thresholdActionGroup = qt.QActionGroup(self)
- self.thresholdActionGroup.setExclusive(False)
+ self.thresholdActionGroup.setExclusive(True)
self.thresholdActionGroup.addAction(self.belowThresholdAction)
self.thresholdActionGroup.addAction(self.betweenThresholdAction)
self.thresholdActionGroup.addAction(self.aboveThresholdAction)
@@ -770,41 +783,50 @@ class BaseMaskToolsWidget(qt.QWidget):
loadColormapRangeBtn.setDefaultAction(self.loadColormapRangeAction)
widgets.append(loadColormapRangeBtn)
- container = self._hboxWidget(*widgets, stretch=False)
- layout.addWidget(container)
+ toolBar = self._hboxWidget(*widgets, stretch=False)
- form = qt.QFormLayout()
+ config = qt.QGridLayout()
+ config.setContentsMargins(0, 0, 0, 0)
+ self.minLineLabel = qt.QLabel("Min:", self)
self.minLineEdit = FloatEdit(self, value=0)
- self.minLineEdit.setEnabled(False)
- form.addRow('Min:', self.minLineEdit)
+ config.addWidget(self.minLineLabel, 0, 0)
+ config.addWidget(self.minLineEdit, 0, 1)
+ self.maxLineLabel = qt.QLabel("Max:", self)
self.maxLineEdit = FloatEdit(self, value=0)
- self.maxLineEdit.setEnabled(False)
- form.addRow('Max:', self.maxLineEdit)
+ config.addWidget(self.maxLineLabel, 1, 0)
+ config.addWidget(self.maxLineEdit, 1, 1)
self.applyMaskBtn = qt.QPushButton('Apply mask')
self.applyMaskBtn.clicked.connect(self._maskBtnClicked)
- self.applyMaskBtn.setEnabled(False)
- form.addRow(self.applyMaskBtn)
-
- self.maskNanBtn = qt.QPushButton('Mask not finite values')
- self.maskNanBtn.setToolTip('Mask Not a Number and infinite values')
- self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked)
- form.addRow(self.maskNanBtn)
- thresholdWidget = qt.QWidget()
- thresholdWidget.setLayout(form)
- layout.addWidget(thresholdWidget)
-
- layout.addStretch(1)
+ layout = qt.QVBoxLayout()
+ layout.addWidget(toolBar)
+ layout.addLayout(config)
+ layout.addWidget(self.applyMaskBtn)
self.thresholdGroup = qt.QGroupBox('Threshold')
self.thresholdGroup.setLayout(layout)
+
+ # Init widget state
+ self._thresholdActionGroupTriggered(self.belowThresholdAction)
return self.thresholdGroup
# track widget visibility and plot active image changes
+ def _initOtherToolsGroupBox(self):
+ layout = qt.QVBoxLayout()
+
+ self.maskNanBtn = qt.QPushButton('Mask not finite values')
+ self.maskNanBtn.setToolTip('Mask Not a Number and infinite values')
+ self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked)
+ layout.addWidget(self.maskNanBtn)
+
+ self.otherToolGroup = qt.QGroupBox('Other tools')
+ self.otherToolGroup.setLayout(layout)
+ return self.otherToolGroup
+
def changeEvent(self, event):
"""Reset drawing action when disabling widget"""
if (event.type() == qt.QEvent.EnabledChange and
@@ -883,6 +905,7 @@ class BaseMaskToolsWidget(qt.QWidget):
The index of the mask for which we want to change the color.
If none set this color for all the masks
"""
+ rgb = rgba(rgb)[0:3]
if level is None:
self._overlayColors[:] = rgb
self._defaultColors[:] = False
@@ -925,6 +948,8 @@ class BaseMaskToolsWidget(qt.QWidget):
"""
if self._drawingMode == 'rectangle':
self._activeRectMode()
+ elif self._drawingMode == 'ellipse':
+ self._activeEllipseMode()
elif self._drawingMode == 'polygon':
self._activePolygonMode()
elif self._drawingMode == 'pencil':
@@ -971,6 +996,16 @@ class BaseMaskToolsWidget(qt.QWidget):
'draw', shape='rectangle', source=self, color=color)
self._updateDrawingModeWidgets()
+ def _activeEllipseMode(self):
+ """Handle circle action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = 'ellipse'
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode(
+ 'draw', shape='ellipse', source=self, color=color)
+ self._updateDrawingModeWidgets()
+
def _activePolygonMode(self):
"""Handle polygon action mode triggering"""
self._releaseDrawingMode()
@@ -1016,36 +1051,28 @@ class BaseMaskToolsWidget(qt.QWidget):
return doMask
# Handle threshold UI events
- def _belowThresholdActionTriggered(self, triggered):
- if triggered:
- self.minLineEdit.setEnabled(True)
- self.maxLineEdit.setEnabled(False)
- self.applyMaskBtn.setEnabled(True)
-
- def _betweenThresholdActionTriggered(self, triggered):
- if triggered:
- self.minLineEdit.setEnabled(True)
- self.maxLineEdit.setEnabled(True)
- self.applyMaskBtn.setEnabled(True)
-
- def _aboveThresholdActionTriggered(self, triggered):
- if triggered:
- self.minLineEdit.setEnabled(False)
- self.maxLineEdit.setEnabled(True)
- self.applyMaskBtn.setEnabled(True)
def _thresholdActionGroupTriggered(self, triggeredAction):
"""Threshold action group listener."""
- if triggeredAction.isChecked():
- # Uncheck other actions
- for action in self.thresholdActionGroup.actions():
- if action is not triggeredAction and action.isChecked():
- action.setChecked(False)
- else:
- # Disable min/max edit
- self.minLineEdit.setEnabled(False)
- self.maxLineEdit.setEnabled(False)
- self.applyMaskBtn.setEnabled(False)
+ if triggeredAction is self.belowThresholdAction:
+ self.minLineLabel.setVisible(True)
+ self.maxLineLabel.setVisible(False)
+ self.minLineEdit.setVisible(True)
+ self.maxLineEdit.setVisible(False)
+ self.applyMaskBtn.setText("Mask bellow")
+ elif triggeredAction is self.betweenThresholdAction:
+ self.minLineLabel.setVisible(True)
+ self.maxLineLabel.setVisible(True)
+ self.minLineEdit.setVisible(True)
+ self.maxLineEdit.setVisible(True)
+ self.applyMaskBtn.setText("Mask between")
+ elif triggeredAction is self.aboveThresholdAction:
+ self.minLineLabel.setVisible(False)
+ self.maxLineLabel.setVisible(True)
+ self.minLineEdit.setVisible(False)
+ self.maxLineEdit.setVisible(True)
+ self.applyMaskBtn.setText("Mask above")
+ self.applyMaskBtn.setToolTip(triggeredAction.toolTip())
def _maskBtnClicked(self):
if self.belowThresholdAction.isChecked():
diff --git a/silx/gui/plot/_utils/dtime_ticklayout.py b/silx/gui/plot/_utils/dtime_ticklayout.py
index 95fc235..23c9dce 100644
--- a/silx/gui/plot/_utils/dtime_ticklayout.py
+++ b/silx/gui/plot/_utils/dtime_ticklayout.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,6 +32,7 @@ __date__ = "04/04/2018"
import datetime as dt
+import enum
import logging
import math
import time
@@ -40,7 +41,6 @@ import dateutil.tz
from dateutil.relativedelta import relativedelta
-from silx.third_party import enum
from .ticklayout import niceNumGeneric
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py
index 10df130..2d01ef1 100644
--- a/silx/gui/plot/actions/control.py
+++ b/silx/gui/plot/actions/control.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -303,9 +303,12 @@ class CurveStyleAction(PlotAction):
currentState = (self.plot.isDefaultPlotLines(),
self.plot.isDefaultPlotPoints())
- # line only, line and symbol, symbol only
- states = (True, False), (True, True), (False, True)
- newState = states[(states.index(currentState) + 1) % 3]
+ if currentState == (False, False):
+ newState = True, False
+ else:
+ # line only, line and symbol, symbol only
+ states = (True, False), (True, True), (False, True)
+ newState = states[(states.index(currentState) + 1) % 3]
self.plot.setDefaultPlotLines(newState[0])
self.plot.setDefaultPlotPoints(newState[1])
diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py
index 97de527..09e4a99 100644
--- a/silx/gui/plot/actions/io.py
+++ b/silx/gui/plot/actions/io.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -502,7 +502,7 @@ class SaveAction(PlotAction):
axes_errors=[xerror, yerror],
title=plot.getGraphTitle())
- def setFileFilter(self, dataKind, nameFilter, func):
+ def setFileFilter(self, dataKind, nameFilter, func, index=None):
"""Set a name filter to add/replace a file format support
:param str dataKind:
@@ -513,10 +513,44 @@ class SaveAction(PlotAction):
:param callable func: The function to call to perform saving.
Expected signature is:
bool func(PlotWidget plot, str filename, str nameFilter)
+ :param integer index: Index of the filter in the final list (or None)
"""
assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
+ # first append or replace the new filter to prevent colissions
self._filters[dataKind][nameFilter] = func
+ if index is None:
+ # we are already done
+ return
+
+ # get the current ordered list of keys
+ keyList = list(self._filters[dataKind].keys())
+
+ # deal with negative indices
+ if index < 0:
+ index = len(keyList) + index
+ if index < 0:
+ index = 0
+
+ if index >= len(keyList):
+ # nothing to be done, already at the end
+ txt = 'Requested index %d impossible, already at the end' % index
+ _logger.info(txt)
+ return
+
+ # get the new ordered list
+ oldIndex = keyList.index(nameFilter)
+ del keyList[oldIndex]
+ keyList.insert(index, nameFilter)
+
+ # build the new filters
+ newFilters = OrderedDict()
+ for key in keyList:
+ newFilters[key] = self._filters[dataKind][key]
+
+ # and update the filters
+ self._filters[dataKind] = newFilters
+ return
def getFileFilters(self, dataKind):
"""Returns the nameFilter and associated function for a kind of data.
diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py
index 7fb8be0..0514c85 100644
--- a/silx/gui/plot/backends/BackendBase.py
+++ b/silx/gui/plot/backends/BackendBase.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,7 +31,7 @@ This API is a simplified version of PyMca PlotBackend API.
__authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "21/12/2018"
import weakref
from ... import qt
@@ -170,7 +170,8 @@ class BackendBase(object):
"""
return legend
- def addItem(self, x, y, legend, shape, color, fill, overlay, z):
+ def addItem(self, x, y, legend, shape, color, fill, overlay, z,
+ linestyle, linewidth, linebgcolor):
"""Add an item (i.e. a shape) to the plot.
:param numpy.ndarray x: The X coords of the points of the shape
@@ -182,6 +183,19 @@ class BackendBase(object):
:param bool fill: True to fill the shape
:param bool overlay: True if item is an overlay, False otherwise
:param int z: Layer on which to draw the item
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
:returns: The handle used by the backend to univocally access the item
"""
return legend
@@ -546,3 +560,20 @@ class BackendBase(object):
This only check status set to axes from the public API
"""
return self._axesDisplayed
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ """Set foreground and grid colors used to display this widget.
+
+ :param List[float] foregroundColor: RGBA foreground color of the widget
+ :param List[float] gridColor: RGBA grid color of the data view
+ """
+ pass
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ """Set background colors used to display this widget.
+
+ :param List[float] backgroundColor: RGBA background color of the widget
+ :param Union[Tuple[float],None] dataBackgroundColor:
+ RGBA background color of the data view
+ """
+ pass
diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py
index 3b1d6dd..726a839 100644
--- a/silx/gui/plot/backends/BackendMatplotlib.py
+++ b/silx/gui/plot/backends/BackendMatplotlib.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,7 +28,7 @@ from __future__ import division
__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"]
__license__ = "MIT"
-__date__ = "01/08/2018"
+__date__ = "21/12/2018"
import logging
@@ -56,12 +56,26 @@ from matplotlib.collections import PathCollection, LineCollection
from matplotlib.ticker import Formatter, ScalarFormatter, Locator
-from ....third_party.modest_image import ModestImage
from . import BackendBase
from .._utils import FLOAT32_MINPOS
from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp
+_PATCH_LINESTYLE = {
+ "-": 'solid',
+ "--": 'dashed',
+ '-.': 'dashdot',
+ ':': 'dotted',
+ '': "solid",
+ None: "solid",
+}
+"""Patches do not uses the same matplotlib syntax"""
+
+
+def normalize_linestyle(linestyle):
+ """Normalize known old-style linestyle, else return the provided value."""
+ return _PATCH_LINESTYLE.get(linestyle, linestyle)
+
class NiceDateLocator(Locator):
"""
@@ -115,7 +129,6 @@ class NiceDateLocator(Locator):
return ticks
-
class NiceAutoDateFormatter(Formatter):
"""
Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the
@@ -139,7 +152,6 @@ class NiceAutoDateFormatter(Formatter):
else:
return bestFormatString(self.locator.spacing, self.locator.unit)
-
def __call__(self, x, pos=None):
"""Return the format for tick val *x* at position *pos*
Expects x to be a POSIX timestamp (seconds since 1 Jan 1970)
@@ -149,8 +161,6 @@ class NiceAutoDateFormatter(Formatter):
return tickStr
-
-
class _MarkerContainer(Container):
"""Marker artists container supporting draw/remove and text position update
@@ -204,6 +214,57 @@ class _MarkerContainer(Container):
self.text.set_x(xmax)
+class _DoubleColoredLinePatch(matplotlib.patches.Patch):
+ """Matplotlib patch to display any patch using double color."""
+
+ def __init__(self, patch):
+ super(_DoubleColoredLinePatch, self).__init__()
+ self.__patch = patch
+ self.linebgcolor = None
+
+ def __getattr__(self, name):
+ return getattr(self.__patch, name)
+
+ def draw(self, renderer):
+ oldLineStype = self.__patch.get_linestyle()
+ if self.linebgcolor is not None and oldLineStype != "solid":
+ oldLineColor = self.__patch.get_edgecolor()
+ oldHatch = self.__patch.get_hatch()
+ self.__patch.set_linestyle("solid")
+ self.__patch.set_edgecolor(self.linebgcolor)
+ self.__patch.set_hatch(None)
+ self.__patch.draw(renderer)
+ self.__patch.set_linestyle(oldLineStype)
+ self.__patch.set_edgecolor(oldLineColor)
+ self.__patch.set_hatch(oldHatch)
+ self.__patch.draw(renderer)
+
+ def set_transform(self, transform):
+ self.__patch.set_transform(transform)
+
+ def get_path(self):
+ return self.__patch.get_path()
+
+ def contains(self, mouseevent, radius=None):
+ return self.__patch.contains(mouseevent, radius)
+
+ def contains_point(self, point, radius=None):
+ return self.__patch.contains_point(point, radius)
+
+
+class Image(AxesImage):
+ """An AxesImage with a fast path for uint8 RGBA images"""
+
+ def set_data(self, A):
+ A = numpy.array(A, copy=False)
+ if A.ndim != 3 or A.shape[2] != 4 or A.dtype != numpy.uint8:
+ super(Image, self).set_data(A)
+ else:
+ # Call AxesImage.set_data with small data to set attributes
+ super(Image, self).set_data(numpy.zeros((2, 2, 4), dtype=A.dtype))
+ self._A = A # Override stored data
+
+
class BackendMatplotlib(BackendBase.BackendBase):
"""Base class for Matplotlib backend without a FigureCanvas.
@@ -231,6 +292,8 @@ class BackendMatplotlib(BackendBase.BackendBase):
self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
self.ax2 = self.ax.twinx()
self.ax2.set_label("right")
+ # Make sure background of Axes is displayed
+ self.ax2.patch.set_visible(True)
# disable the use of offsets
try:
@@ -239,9 +302,9 @@ class BackendMatplotlib(BackendBase.BackendBase):
self.ax2.get_yaxis().get_major_formatter().set_useOffset(False)
self.ax2.get_xaxis().get_major_formatter().set_useOffset(False)
except:
- _logger.warning('Cannot disabled axes offsets in %s ' \
+ _logger.warning('Cannot disabled axes offsets in %s '
% matplotlib.__version__)
-
+
# critical for picking!!!!
self.ax2.set_zorder(0)
self.ax2.set_autoscaley_on(True)
@@ -376,44 +439,13 @@ class BackendMatplotlib(BackendBase.BackendBase):
picker = (selectable or draggable)
- # Debian 7 specific support
- # No transparent colormap with matplotlib < 1.2.0
- # Add support for transparent colormap for uint8 data with
- # colormap with 256 colors, linear norm, [0, 255] range
- if self._matplotlibVersion < _parse_version('1.2.0'):
- if (len(data.shape) == 2 and colormap.getName() is None and
- colormap.getColormapLUT() is not None):
- colors = colormap.getColormapLUT()
- if (colors.shape[-1] == 4 and
- not numpy.all(numpy.equal(colors[3], 255))):
- # This is a transparent colormap
- if (colors.shape == (256, 4) and
- colormap.getNormalization() == 'linear' and
- not colormap.isAutoscale() and
- colormap.getVMin() == 0 and
- colormap.getVMax() == 255 and
- data.dtype == numpy.uint8):
- # Supported case, convert data to RGBA
- data = colors[data.reshape(-1)].reshape(
- data.shape + (4,))
- else:
- _logger.warning(
- 'matplotlib %s does not support transparent '
- 'colormap.', matplotlib.__version__)
-
- if ((height * width) > 5.0e5 and
- origin == (0., 0.) and scale == (1., 1.)):
- imageClass = ModestImage
- else:
- imageClass = AxesImage
-
# All image are shown as RGBA image
- image = imageClass(self.ax,
- label="__IMAGE__" + legend,
- interpolation='nearest',
- picker=picker,
- zorder=z,
- origin='lower')
+ image = Image(self.ax,
+ label="__IMAGE__" + legend,
+ interpolation='nearest',
+ picker=picker,
+ zorder=z,
+ origin='lower')
if alpha < 1:
image.set_alpha(alpha)
@@ -438,40 +470,41 @@ class BackendMatplotlib(BackendBase.BackendBase):
ystep = 1 if scale[1] >= 0. else -1
data = data[::ystep, ::xstep]
- if self._matplotlibVersion < _parse_version('2.1'):
- # matplotlib 1.4.2 do not support float128
- dtype = data.dtype
- if dtype.kind == "f" and dtype.itemsize >= 16:
- _logger.warning("Your matplotlib version do not support "
- "float128. Data converted to float64.")
- data = data.astype(numpy.float64)
-
if data.ndim == 2: # Data image, convert to RGBA image
data = colormap.applyToData(data)
image.set_data(data)
-
self.ax.add_artist(image)
-
return image
- def addItem(self, x, y, legend, shape, color, fill, overlay, z):
+ def addItem(self, x, y, legend, shape, color, fill, overlay, z,
+ linestyle, linewidth, linebgcolor):
+ if (linebgcolor is not None and
+ shape not in ('rectangle', 'polygon', 'polylines')):
+ _logger.warning(
+ 'linebgcolor not implemented for %s with matplotlib backend',
+ shape)
xView = numpy.array(x, copy=False)
yView = numpy.array(y, copy=False)
+ linestyle = normalize_linestyle(linestyle)
+
if shape == "line":
item = self.ax.plot(x, y, label=legend, color=color,
- linestyle='-', marker=None)[0]
+ linestyle=linestyle, linewidth=linewidth,
+ marker=None)[0]
elif shape == "hline":
if hasattr(y, "__len__"):
y = y[-1]
- item = self.ax.axhline(y, label=legend, color=color)
+ item = self.ax.axhline(y, label=legend, color=color,
+ linestyle=linestyle, linewidth=linewidth)
elif shape == "vline":
if hasattr(x, "__len__"):
x = x[-1]
- item = self.ax.axvline(x, label=legend, color=color)
+ item = self.ax.axvline(x, label=legend, color=color,
+ linestyle=linestyle, linewidth=linewidth)
elif shape == 'rectangle':
xMin = numpy.nanmin(xView)
@@ -484,10 +517,16 @@ class BackendMatplotlib(BackendBase.BackendBase):
width=w,
height=h,
fill=False,
- color=color)
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth)
if fill:
item.set_hatch('.')
+ if linestyle != "solid" and linebgcolor is not None:
+ item = _DoubleColoredLinePatch(item)
+ item.linebgcolor = linebgcolor
+
self.ax.add_patch(item)
elif shape in ('polygon', 'polylines'):
@@ -500,10 +539,16 @@ class BackendMatplotlib(BackendBase.BackendBase):
closed=closed,
fill=False,
label=legend,
- color=color)
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth)
if fill and shape == 'polygon':
item.set_hatch('/')
+ if linestyle != "solid" and linebgcolor is not None:
+ item = _DoubleColoredLinePatch(item)
+ item.linebgcolor = linebgcolor
+
self.ax.add_patch(item)
else:
@@ -908,8 +953,56 @@ class BackendMatplotlib(BackendBase.BackendBase):
# remove external margins
self.ax.set_position([0, 0, 1, 1])
self.ax2.set_position([0, 0, 1, 1])
+ self._synchronizeBackgroundColors()
+ self._synchronizeForegroundColors()
self._plot._setDirtyPlot()
+ def _synchronizeBackgroundColors(self):
+ backgroundColor = self._plot.getBackgroundColor().getRgbF()
+
+ dataBackgroundColor = self._plot.getDataBackgroundColor()
+ if dataBackgroundColor.isValid():
+ dataBackgroundColor = dataBackgroundColor.getRgbF()
+ else:
+ dataBackgroundColor = backgroundColor
+
+ if self.ax2.axison:
+ self.fig.patch.set_facecolor(backgroundColor)
+ if self._matplotlibVersion < _parse_version('2'):
+ self.ax2.set_axis_bgcolor(dataBackgroundColor)
+ else:
+ self.ax2.set_facecolor(dataBackgroundColor)
+ else:
+ self.fig.patch.set_facecolor(dataBackgroundColor)
+
+ def _synchronizeForegroundColors(self):
+ foregroundColor = self._plot.getForegroundColor().getRgbF()
+
+ gridColor = self._plot.getGridColor()
+ if gridColor.isValid():
+ gridColor = gridColor.getRgbF()
+ else:
+ gridColor = foregroundColor
+
+ for axes in (self.ax, self.ax2):
+ if axes.axison:
+ axes.spines['bottom'].set_color(foregroundColor)
+ axes.spines['top'].set_color(foregroundColor)
+ axes.spines['right'].set_color(foregroundColor)
+ axes.spines['left'].set_color(foregroundColor)
+ axes.tick_params(axis='x', colors=foregroundColor)
+ axes.tick_params(axis='y', colors=foregroundColor)
+ axes.yaxis.label.set_color(foregroundColor)
+ axes.xaxis.label.set_color(foregroundColor)
+ axes.title.set_color(foregroundColor)
+
+ for line in axes.get_xgridlines():
+ line.set_color(gridColor)
+
+ for line in axes.get_ygridlines():
+ line.set_color(gridColor)
+ # axes.grid().set_markeredgecolor(gridColor)
+
class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
"""QWidget matplotlib backend using a QtAgg canvas.
@@ -1137,3 +1230,9 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
else:
cursor = self._QT_CURSORS[cursor]
FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor))
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._synchronizeBackgroundColors()
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._synchronizeForegroundColors()
diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py
index 9e2cb73..e33d03c 100644
--- a/silx/gui/plot/backends/BackendOpenGL.py
+++ b/silx/gui/plot/backends/BackendOpenGL.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,7 +28,7 @@ from __future__ import division
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "01/08/2018"
+__date__ = "21/12/2018"
from collections import OrderedDict, namedtuple
from ctypes import c_void_p
@@ -44,10 +44,11 @@ from ... import qt
from ..._glutils import gl
from ... import _glutils as glu
from .glutils import (
+ GLLines2D,
GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D,
mat4Ortho, mat4Identity,
LEFT, RIGHT, BOTTOM, TOP,
- Text2D, Shape2D)
+ Text2D, FilledShape2D)
from .glutils.PlotImageFile import saveImageToFile
_logger = logging.getLogger(__name__)
@@ -338,6 +339,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
f=f)
BackendBase.BackendBase.__init__(self, plot, parent)
+ self._backgroundColor = 1., 1., 1., 1.
+ self._dataBackgroundColor = 1., 1., 1., 1.
+
self.matScreenProj = mat4Identity()
self._progBase = glu.Program(
@@ -357,6 +361,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._glGarbageCollector = []
self._plotFrame = GLPlotFrame2D(
+ foregroundColor=(0., 0., 0., 1.),
+ gridColor=(.7, .7, .7, 1.),
margins={'left': 100, 'right': 50, 'top': 50, 'bottom': 50})
# Make postRedisplay asynchronous using Qt signal
@@ -432,7 +438,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def initializeGL(self):
gl.testGL()
- gl.glClearColor(1., 1., 1., 1.)
gl.glClearStencil(0)
gl.glEnable(gl.GL_BLEND)
@@ -482,6 +487,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._plotFBOs[context] = plotFBOTex
with plotFBOTex:
+ gl.glClearColor(*self._backgroundColor)
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
self._renderPlotAreaGL()
self._plotFrame.render()
@@ -530,6 +536,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
item.discard()
self._glGarbageCollector = []
+ gl.glClearColor(*self._backgroundColor)
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
# Check if window is large enough
@@ -543,100 +550,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
glu.setGLContextGetter()
_current_context = None
- def _nonOrthoAxesLineMarkerPrimitives(self, marker, pixelOffset):
- """Generates the vertices and label for a line marker.
-
- :param dict marker: Description of a line marker
- :param int pixelOffset: Offset of text from borders in pixels
- :return: Line vertices and Text label or None
- :rtype: 2-tuple (2x2 numpy.array of float, Text2D)
- """
- label, vertices = None, None
-
- xCoord, yCoord = marker['x'], marker['y']
- assert xCoord is None or yCoord is None # Specific to line markers
-
- # Get plot corners in data coords
- plotLeft, plotTop, plotWidth, plotHeight = self.getPlotBoundsInPixels()
-
- corners = [(plotLeft, plotTop),
- (plotLeft, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop)]
- corners = numpy.array([self.pixelToData(x, y, axis='left', check=False)
- for (x, y) in corners])
-
- borders = {
- 'right': (corners[3], corners[2]),
- 'top': (corners[0], corners[3]),
- 'bottom': (corners[2], corners[1]),
- 'left': (corners[1], corners[0])
- }
-
- textLayouts = { # align, valign, offsets
- 'right': (RIGHT, BOTTOM, (-1., -1.)),
- 'top': (LEFT, TOP, (1., 1.)),
- 'bottom': (LEFT, BOTTOM, (1., -1.)),
- 'left': (LEFT, BOTTOM, (1., -1.))
- }
-
- if xCoord is None: # Horizontal line in data space
- if marker['text'] is not None:
- # Find intersection of hline with borders in data
- # Order is important as it stops at first intersection
- for border_name in ('right', 'top', 'bottom', 'left'):
- (x0, y0), (x1, y1) = borders[border_name]
-
- if min(y0, y1) <= yCoord < max(y0, y1):
- xIntersect = (yCoord - y0) * (x1 - x0) / (y1 - y0) + x0
-
- # Add text label
- pixelPos = self.dataToPixel(
- xIntersect, yCoord, axis='left', check=False)
-
- align, valign, offsets = textLayouts[border_name]
-
- x = pixelPos[0] + offsets[0] * pixelOffset
- y = pixelPos[1] + offsets[1] * pixelOffset
- label = Text2D(marker['text'], x, y,
- color=marker['color'],
- bgColor=(1., 1., 1., 0.5),
- align=align, valign=valign)
- break # Stop at first intersection
-
- xMin, xMax = corners[:, 0].min(), corners[:, 0].max()
- vertices = numpy.array(
- ((xMin, yCoord), (xMax, yCoord)), dtype=numpy.float32)
-
- else: # yCoord is None: vertical line in data space
- if marker['text'] is not None:
- # Find intersection of hline with borders in data
- # Order is important as it stops at first intersection
- for border_name in ('top', 'bottom', 'right', 'left'):
- (x0, y0), (x1, y1) = borders[border_name]
- if min(x0, x1) <= xCoord < max(x0, x1):
- yIntersect = (xCoord - x0) * (y1 - y0) / (x1 - x0) + y0
-
- # Add text label
- pixelPos = self.dataToPixel(
- xCoord, yIntersect, axis='left', check=False)
-
- align, valign, offsets = textLayouts[border_name]
-
- x = pixelPos[0] + offsets[0] * pixelOffset
- y = pixelPos[1] + offsets[1] * pixelOffset
- label = Text2D(marker['text'], x, y,
- color=marker['color'],
- bgColor=(1., 1., 1., 0.5),
- align=align, valign=valign)
- break # Stop at first intersection
-
- yMin, yMax = corners[:, 1].min(), corners[:, 1].max()
- vertices = numpy.array(
- ((xCoord, yMin), (xCoord, yMax)), dtype=numpy.float32)
-
- return vertices, label
-
def _renderMarkersGL(self):
if len(self._markers) == 0:
return
@@ -651,16 +564,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
- # Prepare vertical and horizontal markers rendering
- self._progBase.use()
- gl.glUniformMatrix4fv(
- self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
- self.matScreenProj.astype(numpy.float32))
- gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
- gl.glUniform1i(self._progBase.uniforms['hatchStep'], 0)
- gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
- posAttrib = self._progBase.attributes['position']
-
labels = []
pixelOffset = 3
@@ -677,59 +580,43 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
continue
if xCoord is None or yCoord is None:
- if not self.isDefaultBaseVectors(): # Non-orthogonal axes
- vertices, label = self._nonOrthoAxesLineMarkerPrimitives(
- marker, pixelOffset)
- if label is not None:
- labels.append(label)
+ pixelPos = self.dataToPixel(
+ xCoord, yCoord, axis='left', check=False)
- else: # Orthogonal axes
- pixelPos = self.dataToPixel(
- xCoord, yCoord, axis='left', check=False)
-
- if xCoord is None: # Horizontal line in data space
- if marker['text'] is not None:
- x = self._plotFrame.size[0] - \
- self._plotFrame.margins.right - pixelOffset
- y = pixelPos[1] - pixelOffset
- label = Text2D(marker['text'], x, y,
- color=marker['color'],
- bgColor=(1., 1., 1., 0.5),
- align=RIGHT, valign=BOTTOM)
- labels.append(label)
-
- width = self._plotFrame.size[0]
- vertices = numpy.array(((0, pixelPos[1]),
- (width, pixelPos[1])),
- dtype=numpy.float32)
-
- else: # yCoord is None: vertical line in data space
- if marker['text'] is not None:
- x = pixelPos[0] + pixelOffset
- y = self._plotFrame.margins.top + pixelOffset
- label = Text2D(marker['text'], x, y,
- color=marker['color'],
- bgColor=(1., 1., 1., 0.5),
- align=LEFT, valign=TOP)
- labels.append(label)
-
- height = self._plotFrame.size[1]
- vertices = numpy.array(((pixelPos[0], 0),
- (pixelPos[0], height)),
- dtype=numpy.float32)
+ if xCoord is None: # Horizontal line in data space
+ if marker['text'] is not None:
+ x = self._plotFrame.size[0] - \
+ self._plotFrame.margins.right - pixelOffset
+ y = pixelPos[1] - pixelOffset
+ label = Text2D(marker['text'], x, y,
+ color=marker['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=RIGHT, valign=BOTTOM)
+ labels.append(label)
- self._progBase.use()
- gl.glUniform4f(self._progBase.uniforms['color'],
- *marker['color'])
+ width = self._plotFrame.size[0]
+ lines = GLLines2D((0, width), (pixelPos[1], pixelPos[1]),
+ style=marker['linestyle'],
+ color=marker['color'],
+ width=marker['linewidth'])
+ lines.render(self.matScreenProj)
+
+ else: # yCoord is None: vertical line in data space
+ if marker['text'] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = self._plotFrame.margins.top + pixelOffset
+ label = Text2D(marker['text'], x, y,
+ color=marker['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=LEFT, valign=TOP)
+ labels.append(label)
- gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, vertices)
- gl.glLineWidth(1)
- gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+ height = self._plotFrame.size[1]
+ lines = GLLines2D((pixelPos[0], pixelPos[0]), (0, height),
+ style=marker['linestyle'],
+ color=marker['color'],
+ width=marker['linewidth'])
+ lines.render(self.matScreenProj)
else:
pixelPos = self.dataToPixel(
@@ -820,13 +707,17 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def _renderPlotAreaGL(self):
plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
- self._plotFrame.renderGrid()
-
gl.glScissor(self._plotFrame.margins.left,
self._plotFrame.margins.bottom,
plotWidth, plotHeight)
gl.glEnable(gl.GL_SCISSOR_TEST)
+ if self._dataBackgroundColor != self._backgroundColor:
+ gl.glClearColor(*self._dataBackgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+ self._plotFrame.renderGrid()
+
# Matrix
trBounds = self._plotFrame.transformedDataRanges
if trBounds.x[0] == trBounds.x[1] or \
@@ -853,32 +744,61 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# Render Items
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
- self._progBase.use()
- gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
- self.matScreenProj.astype(numpy.float32))
- gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
- gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
-
for item in self._items.values():
if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or
(isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)):
# Ignore items <= 0. on log axes
continue
- closed = item['shape'] != 'polylines'
- points = [self.dataToPixel(x, y, axis='left', check=False)
- for (x, y) in zip(item['x'], item['y'])]
- shape2D = Shape2D(points,
- fill=item['fill'],
- fillColor=item['color'],
- stroke=True,
- strokeColor=item['color'],
- strokeClosed=closed)
+ if item['shape'] == 'hline':
+ width = self._plotFrame.size[0]
+ _, yPixel = self.dataToPixel(
+ None, item['y'], axis='left', check=False)
+ points = numpy.array(((0., yPixel), (width, yPixel)),
+ dtype=numpy.float32)
- posAttrib = self._progBase.attributes['position']
- colorUnif = self._progBase.uniforms['color']
- hatchStepUnif = self._progBase.uniforms['hatchStep']
- shape2D.render(posAttrib, colorUnif, hatchStepUnif)
+ elif item['shape'] == 'vline':
+ xPixel, _ = self.dataToPixel(
+ item['x'], None, axis='left', check=False)
+ height = self._plotFrame.size[1]
+ points = numpy.array(((xPixel, 0), (xPixel, height)),
+ dtype=numpy.float32)
+
+ else:
+ points = numpy.array([
+ self.dataToPixel(x, y, axis='left', check=False)
+ for (x, y) in zip(item['x'], item['y'])])
+
+ # Draw the fill
+ if (item['fill'] is not None and
+ item['shape'] not in ('hline', 'vline')):
+ self._progBase.use()
+ gl.glUniformMatrix4fv(
+ self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
+ self.matScreenProj.astype(numpy.float32))
+ gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+
+ shape2D = FilledShape2D(
+ points, style=item['fill'], color=item['color'])
+ shape2D.render(
+ posAttrib=self._progBase.attributes['position'],
+ colorUnif=self._progBase.uniforms['color'],
+ hatchStepUnif=self._progBase.uniforms['hatchStep'])
+
+ # Draw the stroke
+ if item['linestyle'] not in ('', ' ', None):
+ if item['shape'] != 'polylines':
+ # close the polyline
+ points = numpy.append(points,
+ numpy.atleast_2d(points[0]), axis=0)
+
+ lines = GLLines2D(points[:, 0], points[:, 1],
+ style=item['linestyle'],
+ color=item['color'],
+ dash2ndColor=item['linebgcolor'],
+ width=item['linewidth'])
+ lines.render(self.matScreenProj)
gl.glDisable(gl.GL_SCISSOR_TEST)
@@ -1123,7 +1043,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
return legend, 'image'
- def addItem(self, x, y, legend, shape, color, fill, overlay, z):
+ def addItem(self, x, y, legend, shape, color, fill, overlay, z,
+ linestyle, linewidth, linebgcolor):
# TODO handle overlay
if shape not in ('polygon', 'rectangle', 'line',
'vline', 'hline', 'polylines'):
@@ -1154,7 +1075,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
'color': colors.rgba(color),
'fill': 'hatch' if fill else None,
'x': x,
- 'y': y
+ 'y': y,
+ 'linestyle': linestyle,
+ 'linewidth': linewidth,
+ 'linebgcolor': linebgcolor,
}
return legend, 'item'
@@ -1166,10 +1090,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if symbol is None:
symbol = '+'
- if linestyle != '-' or linewidth != 1:
- _logger.warning(
- 'OpenGL backend does not support marker line style and width.')
-
behaviors = set()
if selectable:
behaviors.add('selectable')
@@ -1191,6 +1111,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
'behaviors': behaviors,
'constraint': constraint if isConstraint else None,
'symbol': symbol,
+ 'linestyle': linestyle,
+ 'linewidth': linewidth,
}
return legend, 'marker'
@@ -1441,37 +1363,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if label:
_logger.warning('Right axis label not implemented')
- # Non orthogonal axes
-
- def setBaseVectors(self, x=(1., 0.), y=(0., 1.)):
- """Set base vectors.
-
- Useful for non-orthogonal axes.
- If an axis is in log scale, skew is applied to log transformed values.
-
- Base vector does not work well with log axes, to investi
- """
- if x != (1., 0.) and y != (0., 1.):
- if self._plotFrame.xAxis.isLog:
- _logger.warning("setBaseVectors disables X axis logarithmic.")
- self.setXAxisLogarithmic(False)
- if self._plotFrame.yAxis.isLog:
- _logger.warning("setBaseVectors disables Y axis logarithmic.")
- self.setYAxisLogarithmic(False)
-
- if self.isKeepDataAspectRatio():
- _logger.warning("setBaseVectors disables keepDataAspectRatio.")
- self.keepDataAspectRatio(False)
-
- self._plotFrame.baseVectors = x, y
-
- def getBaseVectors(self):
- return self._plotFrame.baseVectors
-
- def isDefaultBaseVectors(self):
- return self._plotFrame.baseVectors == \
- self._plotFrame.DEFAULT_BASE_VECTORS
-
# Graph limits
def _setDataRanges(self, xlim=None, ylim=None, y2lim=None):
@@ -1486,26 +1377,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# Update axes range with a clipped range if too wide
self._plotFrame.setDataRanges(xlim, ylim, y2lim)
- if not self.isDefaultBaseVectors():
- # Update axes range with axes bounds in data coords
- plotLeft, plotTop, plotWidth, plotHeight = \
- self.getPlotBoundsInPixels()
-
- self._plotFrame.xAxis.dataRange = sorted([
- self.pixelToData(x, y, axis='left', check=False)[0]
- for (x, y) in ((plotLeft, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop + plotHeight))])
-
- self._plotFrame.yAxis.dataRange = sorted([
- self.pixelToData(x, y, axis='left', check=False)[1]
- for (x, y) in ((plotLeft, plotTop + plotHeight),
- (plotLeft, plotTop))])
-
- self._plotFrame.y2Axis.dataRange = sorted([
- self.pixelToData(x, y, axis='right', check=False)[1]
- for (x, y) in ((plotLeft + plotWidth, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop))])
-
def _ensureAspectRatio(self, keepDim=None):
"""Update plot bounds in order to keep aspect ratio.
@@ -1619,11 +1490,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
_logger.warning(
"KeepDataAspectRatio is ignored with log axes")
- if flag and not self.isDefaultBaseVectors():
- _logger.warning(
- "setXAxisLogarithmic ignored because baseVectors are set")
- return
-
self._plotFrame.xAxis.isLog = flag
def setYAxisLogarithmic(self, flag):
@@ -1633,11 +1499,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
_logger.warning(
"KeepDataAspectRatio is ignored with log axes")
- if flag and not self.isDefaultBaseVectors():
- _logger.warning(
- "setYAxisLogarithmic ignored because baseVectors are set")
- return
-
self._plotFrame.yAxis.isLog = flag
self._plotFrame.y2Axis.isLog = flag
@@ -1658,9 +1519,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if flag and (self._plotFrame.xAxis.isLog or
self._plotFrame.yAxis.isLog):
_logger.warning("KeepDataAspectRatio is ignored with log axes")
- if flag and not self.isDefaultBaseVectors():
- _logger.warning(
- "keepDataAspectRatio ignored because baseVectors are set")
self._keepDataAspectRatio = flag
@@ -1723,3 +1581,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def setAxesDisplayed(self, displayed):
BackendBase.BackendBase.setAxesDisplayed(self, displayed)
self._plotFrame.displayed = displayed
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._plotFrame.foregroundColor = foregroundColor
+ self._plotFrame.gridColor = gridColor
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._backgroundColor = backgroundColor
+ self._dataBackgroundColor = dataBackgroundColor
diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py
index 12b6bbe..5f8d652 100644
--- a/silx/gui/plot/backends/glutils/GLPlotCurve.py
+++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -42,7 +42,7 @@ import numpy
from silx.math.combo import min_max
from ...._glutils import gl
-from ...._glutils import Program, vertexBuffer
+from ...._glutils import Program, vertexBuffer, VertexBufferAttrib
from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate
@@ -245,7 +245,7 @@ class _Fill2D(object):
SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':'
-class _Lines2D(object):
+class GLLines2D(object):
"""Object rendering curve as a polyline
:param xVboData: X coordinates VBO
@@ -323,6 +323,7 @@ class _Lines2D(object):
/* Dashes: [0, x], [y, z]
Dash period: w */
uniform vec4 dash;
+ uniform vec4 dash2ndColor;
varying float vDist;
varying vec4 vColor;
@@ -330,25 +331,52 @@ class _Lines2D(object):
void main(void) {
float dist = mod(vDist, dash.w);
if ((dist > dash.x && dist < dash.y) || dist > dash.z) {
- discard;
+ if (dash2ndColor.a == 0.) {
+ discard; // Discard full transparent bg color
+ } else {
+ gl_FragColor = dash2ndColor;
+ }
+ } else {
+ gl_FragColor = vColor;
}
- gl_FragColor = vColor;
}
""",
attrib0='xPos')
def __init__(self, xVboData=None, yVboData=None,
colorVboData=None, distVboData=None,
- style=SOLID, color=(0., 0., 0., 1.),
- width=1, dashPeriod=20, drawMode=None,
+ style=SOLID, color=(0., 0., 0., 1.), dash2ndColor=None,
+ width=1, dashPeriod=10., drawMode=None,
offset=(0., 0.)):
+ if (xVboData is not None and
+ not isinstance(xVboData, VertexBufferAttrib)):
+ xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32)
self.xVboData = xVboData
+
+ if (yVboData is not None and
+ not isinstance(yVboData, VertexBufferAttrib)):
+ yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32)
self.yVboData = yVboData
+
+ # Compute distances if not given while providing numpy array coordinates
+ if (isinstance(self.xVboData, numpy.ndarray) and
+ isinstance(self.yVboData, numpy.ndarray) and
+ distVboData is None):
+ distVboData = distancesFromArrays(self.xVboData, self.yVboData)
+
+ if (distVboData is not None and
+ not isinstance(distVboData, VertexBufferAttrib)):
+ distVboData = numpy.array(
+ distVboData, copy=False, dtype=numpy.float32)
self.distVboData = distVboData
+
+ if colorVboData is not None:
+ assert isinstance(colorVboData, VertexBufferAttrib)
self.colorVboData = colorVboData
self.useColorVboData = colorVboData is not None
self.color = color
+ self.dash2ndColor = dash2ndColor
self.width = width
self._style = None
self.style = style
@@ -396,29 +424,46 @@ class _Lines2D(object):
gl.glUniform2f(program.uniforms['halfViewportSize'],
0.5 * viewWidth, 0.5 * viewHeight)
+ dashPeriod = self.dashPeriod * self.width
if self.style == DOTTED:
- dash = (0.1 * self.dashPeriod,
- 0.6 * self.dashPeriod,
- 0.7 * self.dashPeriod,
- self.dashPeriod)
+ dash = (0.2 * dashPeriod,
+ 0.5 * dashPeriod,
+ 0.7 * dashPeriod,
+ dashPeriod)
elif self.style == DASHDOT:
- dash = (0.3 * self.dashPeriod,
- 0.5 * self.dashPeriod,
- 0.6 * self.dashPeriod,
- self.dashPeriod)
+ dash = (0.3 * dashPeriod,
+ 0.5 * dashPeriod,
+ 0.6 * dashPeriod,
+ dashPeriod)
else:
- dash = (0.5 * self.dashPeriod,
- self.dashPeriod,
- self.dashPeriod,
- self.dashPeriod)
+ dash = (0.5 * dashPeriod,
+ dashPeriod,
+ dashPeriod,
+ dashPeriod)
gl.glUniform4f(program.uniforms['dash'], *dash)
+ if self.dash2ndColor is None:
+ # Use fully transparent color which gets discarded in shader
+ dash2ndColor = (0., 0., 0., 0.)
+ else:
+ dash2ndColor = self.dash2ndColor
+ gl.glUniform4f(program.uniforms['dash2ndColor'], *dash2ndColor)
+
distAttrib = program.attributes['distance']
gl.glEnableVertexAttribArray(distAttrib)
- self.distVboData.setVertexAttrib(distAttrib)
+ if isinstance(self.distVboData, VertexBufferAttrib):
+ self.distVboData.setVertexAttrib(distAttrib)
+ else:
+ gl.glVertexAttribPointer(distAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.distVboData)
- gl.glEnable(gl.GL_LINE_SMOOTH)
+ if self.width != 1:
+ gl.glEnable(gl.GL_LINE_SMOOTH)
matrix = numpy.dot(matrix,
mat4Translate(*self.offset)).astype(numpy.float32)
@@ -435,11 +480,27 @@ class _Lines2D(object):
xPosAttrib = program.attributes['xPos']
gl.glEnableVertexAttribArray(xPosAttrib)
- self.xVboData.setVertexAttrib(xPosAttrib)
+ if isinstance(self.xVboData, VertexBufferAttrib):
+ self.xVboData.setVertexAttrib(xPosAttrib)
+ else:
+ gl.glVertexAttribPointer(xPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.xVboData)
yPosAttrib = program.attributes['yPos']
gl.glEnableVertexAttribArray(yPosAttrib)
- self.yVboData.setVertexAttrib(yPosAttrib)
+ if isinstance(self.yVboData, VertexBufferAttrib):
+ self.yVboData.setVertexAttrib(yPosAttrib)
+ else:
+ gl.glVertexAttribPointer(yPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.yVboData)
gl.glLineWidth(self.width)
gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
@@ -447,7 +508,7 @@ class _Lines2D(object):
gl.glDisable(gl.GL_LINE_SMOOTH)
-def _distancesFromArrays(xData, yData):
+def distancesFromArrays(xData, yData):
"""Returns distances between each points
:param numpy.ndarray xData: X coordinate of points
@@ -711,7 +772,7 @@ class _ErrorBars(object):
This is using its own VBO as opposed to fill/points/lines.
There is no picking on error bars.
- It uses 2 vertices per error bars and uses :class:`_Lines2D` to
+ It uses 2 vertices per error bars and uses :class:`GLLines2D` to
render error bars and :class:`_Points2D` to render the ends.
:param numpy.ndarray xData: X coordinates of the data.
@@ -753,7 +814,7 @@ class _ErrorBars(object):
self._xData, self._yData = None, None
self._xError, self._yError = None, None
- self._lines = _Lines2D(
+ self._lines = GLLines2D(
None, None, color=color, drawMode=gl.GL_LINES, offset=offset)
self._xErrPoints = _Points2D(
None, None, color=color, marker=V_LINE, offset=offset)
@@ -957,7 +1018,7 @@ class GLPlotCurve2D(object):
self.xMin, self.yMin,
offset=self.offset)
- self.lines = _Lines2D()
+ self.lines = GLLines2D()
self.lines.style = lineStyle
self.lines.color = lineColor
self.lines.width = lineWidth
@@ -999,7 +1060,7 @@ class GLPlotCurve2D(object):
@classmethod
def init(cls):
"""OpenGL context initialization"""
- _Lines2D.init()
+ GLLines2D.init()
_Points2D.init()
def prepare(self):
@@ -1007,7 +1068,7 @@ class GLPlotCurve2D(object):
if self.xVboData is None:
xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None
if self.lineStyle in (DASHED, DASHDOT, DOTTED):
- dists = _distancesFromArrays(self.xData, self.yData)
+ dists = distancesFromArrays(self.xData, self.yData)
if self.colorData is None:
xAttrib, yAttrib, dAttrib = vertexBuffer(
(self.xData, self.yData, dists))
diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py
index 4ad1547..43f6e10 100644
--- a/silx/gui/plot/backends/glutils/GLPlotFrame.py
+++ b/silx/gui/plot/backends/glutils/GLPlotFrame.py
@@ -63,6 +63,7 @@ class PlotAxis(object):
def __init__(self, plot,
tickLength=(0., 0.),
+ foregroundColor=(0., 0., 0., 1.0),
labelAlign=CENTER, labelVAlign=CENTER,
titleAlign=CENTER, titleVAlign=CENTER,
titleRotate=0, titleOffset=(0., 0.)):
@@ -78,6 +79,7 @@ class PlotAxis(object):
self._title = ''
self._tickLength = tickLength
+ self._foregroundColor = foregroundColor
self._labelAlign = labelAlign
self._labelVAlign = labelVAlign
self._titleAlign = titleAlign
@@ -169,6 +171,20 @@ class PlotAxis(object):
plot._dirty()
@property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._dirtyTicks()
+
+ @property
def ticks(self):
"""Ticks as tuples: ((x, y) in display, dataPos, textLabel)."""
if self._ticks is None:
@@ -192,6 +208,7 @@ class PlotAxis(object):
tickScale = 1.
label = Text2D(text=text,
+ color=self._foregroundColor,
x=xPixel - xTickLength,
y=yPixel - yTickLength,
align=self._labelAlign,
@@ -223,6 +240,7 @@ class PlotAxis(object):
# yOffset -= 3 * yTickLength
axisTitle = Text2D(text=self.title,
+ color=self._foregroundColor,
x=xAxisCenter + xOffset,
y=yAxisCenter + yOffset,
align=self._titleAlign,
@@ -373,15 +391,21 @@ class GLPlotFrame(object):
# Margins used when plot frame is not displayed
_NoDisplayMargins = _Margins(0, 0, 0, 0)
- def __init__(self, margins):
+ def __init__(self, margins, foregroundColor, gridColor):
"""
:param margins: The margins around plot area for axis and labels.
:type margins: dict with 'left', 'right', 'top', 'bottom' keys and
values as ints.
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
"""
self._renderResources = None
self._margins = self._Margins(**margins)
+ self._foregroundColor = foregroundColor
+ self._gridColor = gridColor
self.axes = [] # List of PlotAxis to be updated by subclasses
@@ -401,6 +425,36 @@ class GLPlotFrame(object):
GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS)
@property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ for axis in self.axes:
+ axis.foregroundColor = color
+ self._dirty()
+
+ @property
+ def gridColor(self):
+ """Color used for frame and labels"""
+ return self._gridColor
+
+ @gridColor.setter
+ def gridColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "gridColor must have length 4, got {}".format(len(self._gridColor))
+ if self._gridColor != color:
+ self._gridColor = color
+ self._dirty()
+
+ @property
def displayed(self):
"""Whether axes and their labels are displayed or not (bool)"""
return self._displayed
@@ -522,6 +576,7 @@ class GLPlotFrame(object):
self.margins.right) // 2
yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS
labels.append(Text2D(text=self.title,
+ color=self._foregroundColor,
x=xTitle,
y=yTitle,
align=CENTER,
@@ -556,7 +611,7 @@ class GLPlotFrame(object):
gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
matProj.astype(numpy.float32))
- gl.glUniform4f(prog.uniforms['color'], 0., 0., 0., 1.)
+ gl.glUniform4f(prog.uniforms['color'], *self._foregroundColor)
gl.glUniform1f(prog.uniforms['tickFactor'], 0.)
gl.glEnableVertexAttribArray(prog.attributes['position'])
@@ -590,7 +645,7 @@ class GLPlotFrame(object):
gl.glLineWidth(self._LINE_WIDTH)
gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
matProj.astype(numpy.float32))
- gl.glUniform4f(prog.uniforms['color'], 0.7, 0.7, 0.7, 1.)
+ gl.glUniform4f(prog.uniforms['color'], *self._gridColor)
gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen
gl.glEnableVertexAttribArray(prog.attributes['position'])
@@ -606,15 +661,21 @@ class GLPlotFrame(object):
# GLPlotFrame2D ###############################################################
class GLPlotFrame2D(GLPlotFrame):
- def __init__(self, margins):
+ def __init__(self, margins, foregroundColor, gridColor):
"""
:param margins: The margins around plot area for axis and labels.
:type margins: dict with 'left', 'right', 'top', 'bottom' keys and
values as ints.
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
+
"""
- super(GLPlotFrame2D, self).__init__(margins)
+ super(GLPlotFrame2D, self).__init__(margins, foregroundColor, gridColor)
self.axes.append(PlotAxis(self,
tickLength=(0., -5.),
+ foregroundColor=self._foregroundColor,
labelAlign=CENTER, labelVAlign=TOP,
titleAlign=CENTER, titleVAlign=TOP,
titleRotate=0,
@@ -624,6 +685,7 @@ class GLPlotFrame2D(GLPlotFrame):
self.axes.append(PlotAxis(self,
tickLength=(5., 0.),
+ foregroundColor=self._foregroundColor,
labelAlign=RIGHT, labelVAlign=CENTER,
titleAlign=CENTER, titleVAlign=BOTTOM,
titleRotate=ROTATE_270,
@@ -632,6 +694,7 @@ class GLPlotFrame2D(GLPlotFrame):
self._y2Axis = PlotAxis(self,
tickLength=(-5., 0.),
+ foregroundColor=self._foregroundColor,
labelAlign=LEFT, labelVAlign=CENTER,
titleAlign=CENTER, titleVAlign=TOP,
titleRotate=ROTATE_270,
@@ -825,23 +888,6 @@ class GLPlotFrame2D(GLPlotFrame):
_logger.info('yMax: warning log10(%f)', y2Max)
y2Max = 0.
- # Non-orthogonal axes
- if self.baseVectors != self.DEFAULT_BASE_VECTORS:
- (xx, xy), (yx, yy) = self.baseVectors
- skew_mat = numpy.array(((xx, yx), (xy, yy)))
-
- corners = [(xMin, yMin), (xMin, yMax),
- (xMax, yMin), (xMax, yMax),
- (xMin, y2Min), (xMin, y2Max),
- (xMax, y2Min), (xMax, y2Max)]
-
- corners = numpy.array(
- [numpy.dot(skew_mat, corner) for corner in corners],
- dtype=numpy.float32)
- xMin, xMax = corners[:, 0].min(), corners[:, 0].max()
- yMin, yMax = corners[0:4, 1].min(), corners[0:4, 1].max()
- y2Min, y2Max = corners[4:, 1].min(), corners[4:, 1].max()
-
self._transformedDataRanges = self._DataRanges(
(xMin, xMax), (yMin, yMax), (y2Min, y2Max))
@@ -861,16 +907,6 @@ class GLPlotFrame2D(GLPlotFrame):
mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1)
else:
mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1)
-
- # Non-orthogonal axes
- if self.baseVectors != self.DEFAULT_BASE_VECTORS:
- (xx, xy), (yx, yy) = self.baseVectors
- mat = numpy.dot(mat, numpy.array((
- (xx, yx, 0., 0.),
- (xy, yy, 0., 0.),
- (0., 0., 1., 0.),
- (0., 0., 0., 1.)), dtype=numpy.float64))
-
self._transformedDataProjMat = mat
return self._transformedDataProjMat
@@ -890,16 +926,6 @@ class GLPlotFrame2D(GLPlotFrame):
mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1)
else:
mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1)
-
- # Non-orthogonal axes
- if self.baseVectors != self.DEFAULT_BASE_VECTORS:
- (xx, xy), (yx, yy) = self.baseVectors
- mat = numpy.dot(mat, numpy.matrix((
- (xx, yx, 0., 0.),
- (xy, yy, 0., 0.),
- (0., 0., 1., 0.),
- (0., 0., 0., 1.)), dtype=numpy.float64))
-
self._transformedDataY2ProjMat = mat
return self._transformedDataY2ProjMat
@@ -1114,3 +1140,17 @@ class GLPlotFrame2D(GLPlotFrame):
vertices = numpy.append(vertices, extraVertices, axis=0)
self._renderResources = (vertices, gridVertices, labels)
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._y2Axis.foregroundColor = color
+ GLPlotFrame.foregroundColor.fset(self, color) # call parent property
diff --git a/silx/gui/plot/backends/glutils/GLSupport.py b/silx/gui/plot/backends/glutils/GLSupport.py
index 18c5eb7..da6dffa 100644
--- a/silx/gui/plot/backends/glutils/GLSupport.py
+++ b/silx/gui/plot/backends/glutils/GLSupport.py
@@ -60,16 +60,12 @@ def buildFillMaskIndices(nIndices, dtype=None):
return indices
-class Shape2D(object):
+class FilledShape2D(object):
_NO_HATCH = 0
_HATCH_STEP = 20
- def __init__(self, points, fill='solid', stroke=True,
- fillColor=(0., 0., 0., 1.), strokeColor=(0., 0., 0., 1.),
- strokeClosed=True):
+ def __init__(self, points, style='solid', color=(0., 0., 0., 1.)):
self.vertices = numpy.array(points, dtype=numpy.float32, copy=False)
- self.strokeClosed = strokeClosed
-
self._indices = buildFillMaskIndices(len(self.vertices))
tVertex = numpy.transpose(self.vertices)
@@ -81,28 +77,16 @@ class Shape2D(object):
self._xMin, self._xMax = xMin, xMax
self._yMin, self._yMax = yMin, yMax
- self.fill = fill
- self.fillColor = fillColor
- self.stroke = stroke
- self.strokeColor = strokeColor
-
- @property
- def xMin(self):
- return self._xMin
-
- @property
- def xMax(self):
- return self._xMax
-
- @property
- def yMin(self):
- return self._yMin
+ self.style = style
+ self.color = color
- @property
- def yMax(self):
- return self._yMax
+ def render(self, posAttrib, colorUnif, hatchStepUnif):
+ assert self.style in ('hatch', 'solid')
+ gl.glUniform4f(colorUnif, *self.color)
+ step = self._HATCH_STEP if self.style == 'hatch' else self._NO_HATCH
+ gl.glUniform1i(hatchStepUnif, step)
- def prepareFillMask(self, posAttrib):
+ # Prepare fill mask
gl.glEnableVertexAttribArray(posAttrib)
gl.glVertexAttribPointer(posAttrib,
2,
@@ -126,9 +110,6 @@ class Shape2D(object):
gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
gl.glDepthMask(gl.GL_TRUE)
- def renderFill(self, posAttrib):
- self.prepareFillMask(posAttrib)
-
gl.glVertexAttribPointer(posAttrib,
2,
gl.GL_FLOAT,
@@ -138,30 +119,6 @@ class Shape2D(object):
gl.glDisable(gl.GL_STENCIL_TEST)
- def renderStroke(self, posAttrib):
- gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, self.vertices)
- gl.glLineWidth(1)
- drawMode = gl.GL_LINE_LOOP if self.strokeClosed else gl.GL_LINE_STRIP
- gl.glDrawArrays(drawMode, 0, len(self.vertices))
-
- def render(self, posAttrib, colorUnif, hatchStepUnif):
- assert self.fill in ['hatch', 'solid', None]
- if self.fill is not None:
- gl.glUniform4f(colorUnif, *self.fillColor)
- step = self._HATCH_STEP if self.fill == 'hatch' else self._NO_HATCH
- gl.glUniform1i(hatchStepUnif, step)
- self.renderFill(posAttrib)
-
- if self.stroke:
- gl.glUniform4f(colorUnif, *self.strokeColor)
- gl.glUniform1i(hatchStepUnif, self._NO_HATCH)
- self.renderStroke(posAttrib)
-
# matrix ######################################################################
diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py
index e7957ac..f829f78 100644
--- a/silx/gui/plot/items/__init__.py
+++ b/silx/gui/plot/items/__init__.py
@@ -36,7 +36,7 @@ from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa
SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa
AlphaMixIn, LineMixIn, ItemChangedType) # noqa
from .complex import ImageComplexData # noqa
-from .curve import Curve # noqa
+from .curve import Curve, CurveStyle # noqa
from .histogram import Histogram # noqa
from .image import ImageBase, ImageData, ImageRgba, MaskImageData # noqa
from .shape import Shape # noqa
diff --git a/silx/gui/plot/items/axis.py b/silx/gui/plot/items/axis.py
index 3d9fe14..8ea5c7a 100644
--- a/silx/gui/plot/items/axis.py
+++ b/silx/gui/plot/items/axis.py
@@ -27,16 +27,16 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "06/12/2017"
+__date__ = "22/11/2018"
import datetime as dt
+import enum
import logging
import dateutil.tz
from ... import qt
-from silx.third_party import enum
_logger = logging.getLogger(__name__)
@@ -448,6 +448,8 @@ class YAxis(Axis):
False for Y axis going from bottom to top
"""
flag = bool(flag)
+ if self.isInverted() == flag:
+ return
self._getBackend().setYAxisInverted(flag)
self._getPlot()._setDirtyPlot()
self.sigInvertedChanged.emit(flag)
diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py
index 535b0a9..7fffd77 100644
--- a/silx/gui/plot/items/complex.py
+++ b/silx/gui/plot/items/complex.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -33,9 +33,9 @@ __date__ = "14/06/2018"
import logging
-import numpy
+import enum
-from silx.third_party import enum
+import numpy
from ...colors import Colormap
from .core import ColormapMixIn, ItemChangedType
@@ -137,7 +137,6 @@ class ImageComplexData(ImageBase, ColormapMixIn):
name='hsv',
vmin=-numpy.pi,
vmax=numpy.pi)
- phaseColormap.setEditable(False)
self._colormaps = { # Default colormaps for all modes
self.Mode.ABSOLUTE: colormap,
@@ -180,7 +179,6 @@ class ImageComplexData(ImageBase, ColormapMixIn):
colormap=colormap,
alpha=self.getAlpha())
-
def setVisualizationMode(self, mode):
"""Set the visualization mode to use.
diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py
index e000751..bf3b719 100644
--- a/silx/gui/plot/items/core.py
+++ b/silx/gui/plot/items/core.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,20 +27,23 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "14/06/2018"
+__date__ = "29/01/2019"
import collections
from copy import deepcopy
import logging
+import enum
import warnings
import weakref
+
import numpy
-from silx.third_party import six, enum
+import six
from ... import qt
from ... import colors
from ...colors import Colormap
+from silx import config
_logger = logging.getLogger(__name__)
@@ -82,6 +85,9 @@ class ItemChangedType(enum.Enum):
COLOR = 'colorChanged'
"""Item's color changed flag."""
+ LINE_BG_COLOR = 'lineBgColorChanged'
+ """Item's line background color changed flag."""
+
YAXIS = 'yAxisChanged'
"""Item's Y axis binding changed flag."""
@@ -411,10 +417,12 @@ class ColormapMixIn(ItemMixInBase):
return self._colormap
def setColormap(self, colormap):
- """Set the colormap of this image
+ """Set the colormap of this item
:param silx.gui.colors.Colormap colormap: colormap description
"""
+ if self._colormap is colormap:
+ return
if isinstance(colormap, dict):
colormap = Colormap._fromDict(colormap)
@@ -433,10 +441,10 @@ class ColormapMixIn(ItemMixInBase):
class SymbolMixIn(ItemMixInBase):
"""Mix-in class for items with symbol type"""
- _DEFAULT_SYMBOL = ''
+ _DEFAULT_SYMBOL = None
"""Default marker of the item"""
- _DEFAULT_SYMBOL_SIZE = 6.0
+ _DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE
"""Default marker size of the item"""
_SUPPORTED_SYMBOLS = collections.OrderedDict((
@@ -451,8 +459,15 @@ class SymbolMixIn(ItemMixInBase):
"""Dict of supported symbols"""
def __init__(self):
- self._symbol = self._DEFAULT_SYMBOL
- self._symbol_size = self._DEFAULT_SYMBOL_SIZE
+ if self._DEFAULT_SYMBOL is None: # Use default from config
+ self._symbol = config.DEFAULT_PLOT_SYMBOL
+ else:
+ self._symbol = self._DEFAULT_SYMBOL
+
+ if self._DEFAULT_SYMBOL_SIZE is None: # Use default from config
+ self._symbol_size = config.DEFAULT_PLOT_SYMBOL_SIZE
+ else:
+ self._symbol_size = self._DEFAULT_SYMBOL_SIZE
@classmethod
def getSupportedSymbols(cls):
@@ -892,14 +907,14 @@ class Points(Item, SymbolMixIn, AlphaMixIn):
# use the getData class method because instance method can be
# overloaded to return additional arrays
data = Points.getData(self, copy=False,
- displayed=True)
+ displayed=True)
if len(data) == 5:
# hack to avoid duplicating caching mechanism in Scatter
# (happens when cached data is used, caching done using
# Scatter._logFilterData)
- x, y, xerror, yerror = data[0], data[1], data[3], data[4]
+ x, y, _xerror, _yerror = data[0], data[1], data[3], data[4]
else:
- x, y, xerror, yerror = data
+ x, y, _xerror, _yerror = data
self._boundsCache[(xPositive, yPositive)] = (
numpy.nanmin(x),
diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py
index 80d9dea..79def55 100644
--- a/silx/gui/plot/items/curve.py
+++ b/silx/gui/plot/items/curve.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,9 +31,10 @@ __date__ = "24/04/2018"
import logging
+
import numpy
+import six
-from silx.third_party import six
from ....utils.deprecation import deprecated
from ... import colors
from .core import (Points, LabelsMixIn, ColorMixIn, YAxisMixIn,
diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py
index 389e8a6..a1d6586 100644
--- a/silx/gui/plot/items/histogram.py
+++ b/silx/gui/plot/items/histogram.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -197,13 +197,13 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn,
values[clipped_values] = numpy.nan
- if xPositive or yPositive:
+ if yPositive:
return (numpy.nanmin(edges),
numpy.nanmax(edges),
numpy.nanmin(values),
numpy.nanmax(values))
- else: # No log scale, include 0 in bounds
+ else: # No log scale on y axis, include 0 in bounds
return (numpy.nanmin(edges),
numpy.nanmax(edges),
min(0, numpy.nanmin(values)),
diff --git a/silx/gui/plot/items/roi.py b/silx/gui/plot/items/roi.py
index f55ef91..0169439 100644
--- a/silx/gui/plot/items/roi.py
+++ b/silx/gui/plot/items/roi.py
@@ -65,7 +65,7 @@ class RegionOfInterest(qt.QObject):
# Avoid circular dependancy
from ..tools import roi as roi_tools
assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager)
- super(RegionOfInterest, self).__init__(parent)
+ qt.QObject.__init__(self, parent)
self._color = rgba('red')
self._items = WeakList()
self._editAnchors = WeakList()
@@ -108,7 +108,7 @@ class RegionOfInterest(qt.QObject):
return qt.QColor.fromRgbF(*self._color)
def _getAnchorColor(self, color):
- """Returns the anchor color from the base ROI color
+ """Returns the anchor color from the base ROI color
:param Union[numpy.array,Tuple,List]: color
:rtype: Union[numpy.array,Tuple,List]
@@ -209,7 +209,7 @@ class RegionOfInterest(qt.QObject):
def setFirstShapePoints(self, points):
""""Initialize the ROI using the points from the first interaction.
- This interaction is constains by the plot API and only supports few
+ This interaction is constrained by the plot API and only supports few
shapes.
"""
points = self._createControlPointsFromFirstShape(points)
@@ -410,6 +410,13 @@ class RegionOfInterest(qt.QObject):
plot._remove(item)
self._labelItem = None
+ def _updated(self, event=None, checkVisibility=True):
+ """Implement Item mix-in update method by updating the plot items
+
+ See :class:`~silx.gui.plot.items.Item._updated`
+ """
+ self._createPlotItems()
+
def __str__(self):
"""Returns parameters of the ROI as a string."""
points = self._getControlPoints()
@@ -417,7 +424,7 @@ class RegionOfInterest(qt.QObject):
return "%s(%s)" % (self.__class__.__name__, params)
-class PointROI(RegionOfInterest):
+class PointROI(RegionOfInterest, items.SymbolMixIn):
"""A ROI identifying a point in a 2D plot."""
_kind = "Point"
@@ -426,6 +433,10 @@ class PointROI(RegionOfInterest):
_plotShape = "point"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.SymbolMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def getPosition(self):
"""Returns the position of this ROI
@@ -458,6 +469,8 @@ class PointROI(RegionOfInterest):
marker.setPosition(points[0][0], points[0][1])
marker.setText(self.getLabel())
marker.setColor(rgba(self.getColor()))
+ marker.setSymbol(self.getSymbol())
+ marker.setSymbolSize(self.getSymbolSize())
marker._setDraggable(False)
return [marker]
@@ -466,6 +479,8 @@ class PointROI(RegionOfInterest):
marker.setPosition(points[0][0], points[0][1])
marker.setText(self.getLabel())
marker._setDraggable(self.isEditable())
+ marker.setSymbol(self.getSymbol())
+ marker.setSymbolSize(self.getSymbolSize())
return [marker]
def __str__(self):
@@ -474,7 +489,7 @@ class PointROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class LineROI(RegionOfInterest):
+class LineROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying a line in a 2D plot.
This ROI provides 1 anchor for each boundary of the line, plus an center
@@ -487,6 +502,10 @@ class LineROI(RegionOfInterest):
_plotShape = "line"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def _createControlPointsFromFirstShape(self, points):
center = numpy.mean(points, axis=0)
controlPoints = numpy.array([points[0], points[1], center])
@@ -535,6 +554,8 @@ class LineROI(RegionOfInterest):
item.setColor(rgba(self.getColor()))
item.setFill(False)
item.setOverlay(True)
+ item.setLineStyle(self.getLineStyle())
+ item.setLineWidth(self.getLineWidth())
return [item]
def _createAnchorItems(self, points):
@@ -582,7 +603,7 @@ class LineROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class HorizontalLineROI(RegionOfInterest):
+class HorizontalLineROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying an horizontal line in a 2D plot."""
_kind = "HLine"
@@ -591,6 +612,10 @@ class HorizontalLineROI(RegionOfInterest):
_plotShape = "hline"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def _createControlPointsFromFirstShape(self, points):
points = numpy.array([(float('nan'), points[0, 1])],
dtype=numpy.float64)
@@ -636,6 +661,8 @@ class HorizontalLineROI(RegionOfInterest):
marker.setText(self.getLabel())
marker.setColor(rgba(self.getColor()))
marker._setDraggable(False)
+ marker.setLineWidth(self.getLineWidth())
+ marker.setLineStyle(self.getLineStyle())
return [marker]
def _createAnchorItems(self, points):
@@ -643,6 +670,8 @@ class HorizontalLineROI(RegionOfInterest):
marker.setPosition(points[0][0], points[0][1])
marker.setText(self.getLabel())
marker._setDraggable(self.isEditable())
+ marker.setLineWidth(self.getLineWidth())
+ marker.setLineStyle(self.getLineStyle())
return [marker]
def __str__(self):
@@ -651,7 +680,7 @@ class HorizontalLineROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class VerticalLineROI(RegionOfInterest):
+class VerticalLineROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying a vertical line in a 2D plot."""
_kind = "VLine"
@@ -660,6 +689,10 @@ class VerticalLineROI(RegionOfInterest):
_plotShape = "vline"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def _createControlPointsFromFirstShape(self, points):
points = numpy.array([(points[0, 0], float('nan'))],
dtype=numpy.float64)
@@ -705,6 +738,8 @@ class VerticalLineROI(RegionOfInterest):
marker.setText(self.getLabel())
marker.setColor(rgba(self.getColor()))
marker._setDraggable(False)
+ marker.setLineWidth(self.getLineWidth())
+ marker.setLineStyle(self.getLineStyle())
return [marker]
def _createAnchorItems(self, points):
@@ -712,6 +747,8 @@ class VerticalLineROI(RegionOfInterest):
marker.setPosition(points[0][0], points[0][1])
marker.setText(self.getLabel())
marker._setDraggable(self.isEditable())
+ marker.setLineWidth(self.getLineWidth())
+ marker.setLineStyle(self.getLineStyle())
return [marker]
def __str__(self):
@@ -720,7 +757,7 @@ class VerticalLineROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class RectangleROI(RegionOfInterest):
+class RectangleROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying a rectangle in a 2D plot.
This ROI provides 1 anchor for each corner, plus an anchor in the
@@ -733,6 +770,10 @@ class RectangleROI(RegionOfInterest):
_plotShape = "rectangle"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def _createControlPointsFromFirstShape(self, points):
point0 = points[0]
point1 = points[1]
@@ -838,6 +879,8 @@ class RectangleROI(RegionOfInterest):
item.setColor(rgba(self.getColor()))
item.setFill(False)
item.setOverlay(True)
+ item.setLineStyle(self.getLineStyle())
+ item.setLineWidth(self.getLineWidth())
return [item]
def _createAnchorItems(self, points):
@@ -894,7 +937,7 @@ class RectangleROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class PolygonROI(RegionOfInterest):
+class PolygonROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying a closed polygon in a 2D plot.
This ROI provides 1 anchor for each point of the polygon.
@@ -906,6 +949,10 @@ class PolygonROI(RegionOfInterest):
_plotShape = "polygon"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def getPoints(self):
"""Returns the list of the points of this polygon.
@@ -948,6 +995,8 @@ class PolygonROI(RegionOfInterest):
item.setColor(rgba(self.getColor()))
item.setFill(False)
item.setOverlay(True)
+ item.setLineStyle(self.getLineStyle())
+ item.setLineWidth(self.getLineWidth())
return [item]
def _createAnchorItems(self, points):
@@ -967,7 +1016,7 @@ class PolygonROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class ArcROI(RegionOfInterest):
+class ArcROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying an arc of a circle with a width.
This ROI provides 3 anchors to control the curvature, 1 anchor to control
@@ -986,6 +1035,7 @@ class ArcROI(RegionOfInterest):
'startAngle', 'endAngle'])
def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
RegionOfInterest.__init__(self, parent=parent)
self._geometry = None
@@ -1357,6 +1407,8 @@ class ArcROI(RegionOfInterest):
item.setColor(rgba(self.getColor()))
item.setFill(False)
item.setOverlay(True)
+ item.setLineStyle(self.getLineStyle())
+ item.setLineWidth(self.getLineWidth())
return [item]
def _createAnchorItems(self, points):
diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py
index acc74b4..707dd3d 100644
--- a/silx/gui/plot/items/scatter.py
+++ b/silx/gui/plot/items/scatter.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -46,9 +46,6 @@ class Scatter(Points, ColormapMixIn):
_DEFAULT_SELECTABLE = True
"""Default selectable state for scatter plots"""
- _DEFAULT_SYMBOL = 'o'
- """Default symbol of the scatter plots"""
-
def __init__(self):
Points.__init__(self)
ColormapMixIn.__init__(self)
diff --git a/silx/gui/plot/items/shape.py b/silx/gui/plot/items/shape.py
index 65b26a1..9fc1306 100644
--- a/silx/gui/plot/items/shape.py
+++ b/silx/gui/plot/items/shape.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,14 +27,16 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "17/05/2017"
+__date__ = "21/12/2018"
import logging
import numpy
+import six
-from .core import (Item, ColorMixIn, FillMixIn, ItemChangedType)
+from ... import colors
+from .core import Item, ColorMixIn, FillMixIn, ItemChangedType, LineMixIn
_logger = logging.getLogger(__name__)
@@ -42,7 +44,7 @@ _logger = logging.getLogger(__name__)
# TODO probably make one class for each kind of shape
# TODO check fill:polygon/polyline + fill = duplicated
-class Shape(Item, ColorMixIn, FillMixIn):
+class Shape(Item, ColorMixIn, FillMixIn, LineMixIn):
"""Description of a shape item
:param str type_: The type of shape in:
@@ -53,10 +55,12 @@ class Shape(Item, ColorMixIn, FillMixIn):
Item.__init__(self)
ColorMixIn.__init__(self)
FillMixIn.__init__(self)
+ LineMixIn.__init__(self)
self._overlay = False
assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polylines')
self._type = type_
self._points = ()
+ self._lineBgColor = None
self._handle = None
@@ -71,7 +75,10 @@ class Shape(Item, ColorMixIn, FillMixIn):
color=self.getColor(),
fill=self.isFill(),
overlay=self.isOverlay(),
- z=self.getZValue())
+ z=self.getZValue(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ linebgcolor=self.getLineBgColor())
def isOverlay(self):
"""Return true if shape is drawn as an overlay
@@ -119,3 +126,31 @@ class Shape(Item, ColorMixIn, FillMixIn):
"""
self._points = numpy.array(points, copy=copy)
self._updated(ItemChangedType.DATA)
+
+ def getLineBgColor(self):
+ """Returns the RGBA color of the item
+ :rtype: 4-tuple of float in [0, 1] or array of colors
+ """
+ return self._lineBgColor
+
+ def setLineBgColor(self, color, copy=True):
+ """Set item color
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if color is not None:
+ if isinstance(color, six.string_types):
+ color = colors.rgba(color)
+ else:
+ color = numpy.array(color, copy=copy)
+ # TODO more checks + improve color array support
+ if color.ndim == 1: # Single RGBA color
+ color = colors.rgba(color)
+ else: # Array of colors
+ assert color.ndim == 2
+
+ self._lineBgColor = color
+ self._updated(ItemChangedType.LINE_BG_COLOR)
diff --git a/silx/gui/plot/matplotlib/Colormap.py b/silx/gui/plot/matplotlib/Colormap.py
index 772a473..38f3b55 100644
--- a/silx/gui/plot/matplotlib/Colormap.py
+++ b/silx/gui/plot/matplotlib/Colormap.py
@@ -29,7 +29,13 @@ from matplotlib.colors import ListedColormap
import matplotlib.colors
import matplotlib.cm
import silx.resources
-from silx.utils.deprecation import deprecated
+from silx.utils.deprecation import deprecated, deprecated_warning
+
+
+deprecated_warning(type_='module',
+ name=__file__,
+ replacement='silx.gui.colors.Colormap',
+ since_version='0.10.0')
_logger = logging.getLogger(__name__)
@@ -46,25 +52,30 @@ _CMAPS = {}
@property
+@deprecated(since_version='0.10.0')
def magma():
return getColormap('magma')
@property
+@deprecated(since_version='0.10.0')
def inferno():
return getColormap('inferno')
@property
+@deprecated(since_version='0.10.0')
def plasma():
return getColormap('plasma')
@property
+@deprecated(since_version='0.10.0')
def viridis():
return getColormap('viridis')
+@deprecated(since_version='0.10.0')
def getColormap(name):
"""Returns matplotlib colormap corresponding to given name
@@ -143,6 +154,7 @@ def getColormap(name):
return matplotlib.cm.get_cmap(name)
+@deprecated(since_version='0.10.0')
def getScalarMappable(colormap, data=None):
"""Returns matplotlib ScalarMappable corresponding to colormap
@@ -223,6 +235,8 @@ def applyColormapToData(data, colormap):
return rgbaImage
+@deprecated(replacement='silx.colors.Colormap.getSupportedColormaps',
+ since_version='0.10.0')
def getSupportedColormaps():
"""Get the supported colormap names as a tuple of str.
"""
diff --git a/silx/gui/plot/stats/stats.py b/silx/gui/plot/stats/stats.py
index a753989..ad61536 100644
--- a/silx/gui/plot/stats/stats.py
+++ b/silx/gui/plot/stats/stats.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,15 +30,15 @@ __license__ = "MIT"
__date__ = "06/06/2018"
-import numpy
-from silx.gui.plot.items.curve import Curve as CurveItem
-from silx.gui.plot.items.image import ImageBase as ImageItem
-from silx.gui.plot.items.scatter import Scatter as ScatterItem
-from silx.gui.plot.items.histogram import Histogram as HistogramItem
-from silx.math.combo import min_max
from collections import OrderedDict
import logging
+import numpy
+
+from .. import items
+from ....math.combo import min_max
+
+
logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class Stats(OrderedDict):
def calculate(self, item, plot, onlimits):
"""
- Call all :class:`Stat` object registred and return the result of the
+ Call all :class:`Stat` object registered and return the result of the
computation.
:param item: the item for which we want statistics
@@ -72,17 +72,29 @@ class Stats(OrderedDict):
:return dict: dictionary with :class:`Stat` name as ket and result
of the calculation as value
"""
- res = {}
- if isinstance(item, CurveItem):
+ context = None
+ # Check for PlotWidget items
+ if isinstance(item, items.Curve):
context = _CurveContext(item, plot, onlimits)
- elif isinstance(item, ImageItem):
+ elif isinstance(item, items.ImageData):
context = _ImageContext(item, plot, onlimits)
- elif isinstance(item, ScatterItem):
+ elif isinstance(item, items.Scatter):
context = _ScatterContext(item, plot, onlimits)
- elif isinstance(item, HistogramItem):
+ elif isinstance(item, items.Histogram):
context = _HistogramContext(item, plot, onlimits)
else:
- raise ValueError('Item type not managed')
+ # Check for SceneWidget items
+ from ...plot3d import items as items3d # Lazy import
+
+ if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)):
+ context = _plot3DScatterContext(item, plot, onlimits)
+ elif isinstance(item, (items3d.ImageData, items3d.ScalarField3D)):
+ context = _plot3DArrayContext(item, plot, onlimits)
+
+ if context is None:
+ raise ValueError('Item type not managed')
+
+ res = {}
for statName, stat in list(self.items()):
if context.kind not in stat.compatibleKinds:
logger.debug('kind %s not managed by statistic %s'
@@ -124,12 +136,54 @@ class _StatsContext(object):
self.min = None
self.max = None
self.data = None
+
self.values = None
+ """The array of data"""
+
+ self.axes = None
+ """A list of array of position on each axis.
+
+ If the signal is an array,
+ then each axis has the length of that dimension,
+ and the order is (z, y, x) (i.e., as the array shape).
+ If the signal is not an array,
+ then each axis has the same length as the signal,
+ and the order is (x, y, z).
+ """
+
self.createContext(item, plot, onlimits)
def createContext(self, item, plot, onlimits):
raise NotImplementedError("Base class")
+ def isStructuredData(self):
+ """Returns True if data as an array-like structure.
+
+ :rtype: bool
+ """
+ if self.values is None or self.axes is None:
+ return False
+
+ if numpy.prod([len(axis) for axis in self.axes]) == self.values.size:
+ return True
+ else:
+ # Make sure there is the right number of value in axes
+ for axis in self.axes:
+ assert len(axis) == self.values.size
+ return False
+
+ def isScalarData(self):
+ """Returns True if data is a scalar.
+
+ :rtype: bool
+ """
+ if self.values is None or self.axes is None:
+ return False
+ if self.isStructuredData():
+ return len(self.axes) == self.values.ndim
+ else:
+ return self.values.ndim == 1
+
class _CurveContext(_StatsContext):
"""
@@ -149,8 +203,9 @@ class _CurveContext(_StatsContext):
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
- yData = yData[(minX <= xData) & (xData <= maxX)]
- xData = xData[(minX <= xData) & (xData <= maxX)]
+ mask = (minX <= xData) & (xData <= maxX)
+ yData = yData[mask]
+ xData = xData[mask]
self.xData = xData
self.yData = yData
@@ -160,11 +215,12 @@ class _CurveContext(_StatsContext):
self.min, self.max = None, None
self.data = (xData, yData)
self.values = yData
+ self.axes = (xData,)
class _HistogramContext(_StatsContext):
"""
- StatsContext for :class:`Curve`
+ StatsContext for :class:`Histogram`
:param item: the item for which we want to compute the context
:param plot: the plot containing the item
@@ -176,12 +232,13 @@ class _HistogramContext(_StatsContext):
plot=plot, onlimits=onlimits)
def createContext(self, item, plot, onlimits):
- xData, edges = item.getData(copy=True)[0:2]
- yData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment())
+ yData, edges = item.getData(copy=True)[0:2]
+ xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment())
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
- yData = yData[(minX <= xData) & (xData <= maxX)]
- xData = xData[(minX <= xData) & (xData <= maxX)]
+ mask = (minX <= xData) & (xData <= maxX)
+ yData = yData[mask]
+ xData = xData[mask]
self.xData = xData
self.yData = yData
@@ -191,11 +248,13 @@ class _HistogramContext(_StatsContext):
self.min, self.max = None, None
self.data = (xData, yData)
self.values = yData
+ self.axes = (xData,)
class _ScatterContext(_StatsContext):
- """
- StatsContext for :class:`Scatter`
+ """StatsContext scatter plots.
+
+ It supports :class:`~silx.gui.plot.items.Scatter`.
:param item: the item for which we want to compute the context
:param plot: the plot containing the item
@@ -207,11 +266,14 @@ class _ScatterContext(_StatsContext):
onlimits=onlimits)
def createContext(self, item, plot, onlimits):
- xData, yData, valueData, xerror, yerror = item.getData(copy=True)
- assert plot
+ valueData = item.getValueData(copy=True)
+ xData = item.getXData(copy=True)
+ yData = item.getYData(copy=True)
+
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
minY, maxY = plot.getYAxis().getLimits()
+
# filter on X axis
valueData = valueData[(minX <= xData) & (xData <= maxX)]
yData = yData[(minX <= xData) & (xData <= maxX)]
@@ -220,17 +282,20 @@ class _ScatterContext(_StatsContext):
valueData = valueData[(minY <= yData) & (yData <= maxY)]
xData = xData[(minY <= yData) & (yData <= maxY)]
yData = yData[(minY <= yData) & (yData <= maxY)]
+
if len(valueData) > 0:
self.min, self.max = min_max(valueData)
else:
self.min, self.max = None, None
self.data = (xData, yData, valueData)
self.values = valueData
+ self.axes = (xData, yData)
class _ImageContext(_StatsContext):
- """
- StatsContext for :class:`ImageBase`
+ """StatsContext for images.
+
+ It supports :class:`~silx.gui.plot.items.ImageData`.
:param item: the item for which we want to compute the context
:param plot: the plot containing the item
@@ -244,7 +309,8 @@ class _ImageContext(_StatsContext):
def createContext(self, item, plot, onlimits):
self.origin = item.getOrigin()
self.scale = item.getScale()
- self.data = item.getData()
+
+ self.data = item.getData(copy=True)
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
@@ -259,25 +325,88 @@ class _ImageContext(_StatsContext):
YMinBound = max(YMinBound, 0)
if XMaxBound <= XMinBound or YMaxBound <= YMinBound:
- return self.noDataSelected()
- data = item.getData()
- self.data = data[YMinBound:YMaxBound + 1, XMinBound:XMaxBound + 1]
- else:
- self.data = item.getData()
-
+ self.data = None
+ else:
+ self.data = self.data[YMinBound:YMaxBound + 1,
+ XMinBound:XMaxBound + 1]
if self.data.size > 0:
self.min, self.max = min_max(self.data)
else:
self.min, self.max = None, None
self.values = self.data
+ if self.values is not None:
+ self.axes = (self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]),
+ self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1]))
+
+
+class _plot3DScatterContext(_StatsContext):
+ """StatsContext for 3D scatter plots.
+
+ It supports :class:`~silx.gui.plot3d.items.Scatter2D` and
+ :class:`~silx.gui.plot3d.items.Scatter3D`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ """
+ def __init__(self, item, plot, onlimits):
+ _StatsContext.__init__(self, kind='scatter', item=item, plot=plot,
+ onlimits=onlimits)
+
+ def createContext(self, item, plot, onlimits):
+ if onlimits:
+ raise RuntimeError("Unsupported plot %s" % str(plot))
+
+ values = item.getValueData(copy=False)
+
+ if values is not None and len(values) > 0:
+ self.values = values
+ axes = [item.getXData(copy=False), item.getYData(copy=False)]
+ if self.values.ndim == 3:
+ axes.append(item.getZData(copy=False))
+ self.axes = tuple(axes)
+
+ self.min, self.max = min_max(self.values)
+ else:
+ self.values = None
+ self.axes = None
+ self.min, self.max = None, None
+
+
+class _plot3DArrayContext(_StatsContext):
+ """StatsContext for 3D scalar field and data image.
+
+ It supports :class:`~silx.gui.plot3d.items.ScalarField3D` and
+ :class:`~silx.gui.plot3d.items.ImageData`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ """
+ def __init__(self, item, plot, onlimits):
+ _StatsContext.__init__(self, kind='image', item=item, plot=plot,
+ onlimits=onlimits)
+
+ def createContext(self, item, plot, onlimits):
+ if onlimits:
+ raise RuntimeError("Unsupported plot %s" % str(plot))
+
+ values = item.getData(copy=False)
+
+ if values is not None and len(values) > 0:
+ self.values = values
+ self.axes = tuple([numpy.arange(size) for size in self.values.shape])
+ self.min, self.max = min_max(self.values)
+ else:
+ self.values = None
+ self.axes = None
+ self.min, self.max = None, None
+
-BASIC_COMPATIBLE_KINDS = {
- 'curve': CurveItem,
- 'image': ImageItem,
- 'scatter': ScatterItem,
- 'histogram': HistogramItem,
-}
+BASIC_COMPATIBLE_KINDS = 'curve', 'image', 'scatter', 'histogram'
class StatBase(object):
@@ -285,9 +414,8 @@ class StatBase(object):
Base class for defining a statistic.
:param str name: the name of the statistic. Must be unique.
- :param compatibleKinds: the kind of items (curve, scatter...) for which
- the statistic apply.
- :rtype: List or tuple
+ :param List[str] compatibleKinds:
+ The kind of items (curve, scatter...) for which the statistic apply.
"""
def __init__(self, name, compatibleKinds=BASIC_COMPATIBLE_KINDS, description=None):
self.name = name
@@ -298,7 +426,7 @@ class StatBase(object):
"""
compute the statistic for the given :class:`StatsContext`
- :param context:
+ :param _StatsContext context:
:return dict: key is stat name, statistic computed is the dict value
"""
raise NotImplementedError('Base class')
@@ -307,7 +435,7 @@ class StatBase(object):
"""
If necessary add a tooltip for a stat kind
- :param str kinf: the kind of item the statistic is compute for.
+ :param str kind: the kind of item the statistic is compute for.
:return: tooltip or None if no tooltip
"""
return None
@@ -329,17 +457,18 @@ class Stat(StatBase):
self._fct = fct
def calculate(self, context):
- if context.kind in self.compatibleKinds:
- return self._fct(context.values)
+ if context.values is not None:
+ if context.kind in self.compatibleKinds:
+ return self._fct(context.values)
+ else:
+ raise ValueError('Kind %s not managed by %s'
+ '' % (context.kind, self.name))
else:
- raise ValueError('Kind %s not managed by %s'
- '' % (context.kind, self.name))
+ return None
class StatMin(StatBase):
- """
- Compute the minimal value on data
- """
+ """Compute the minimal value on data"""
def __init__(self):
StatBase.__init__(self, name='min')
@@ -348,9 +477,7 @@ class StatMin(StatBase):
class StatMax(StatBase):
- """
- Compute the maximal value on data
- """
+ """Compute the maximal value on data"""
def __init__(self):
StatBase.__init__(self, name='max')
@@ -359,9 +486,7 @@ class StatMax(StatBase):
class StatDelta(StatBase):
- """
- Compute the delta between minimal and maximal on data
- """
+ """Compute the delta between minimal and maximal on data"""
def __init__(self):
StatBase.__init__(self, name='delta')
@@ -369,123 +494,84 @@ class StatDelta(StatBase):
return context.max - context.min
-class StatCoordMin(StatBase):
- """
- Compute the first coordinates of the data minimal value
- """
+class _StatCoord(StatBase):
+ """Base class for argmin and argmax stats"""
+
+ def _indexToCoordinates(self, context, index):
+ """Returns the coordinates of data point at given index
+
+ If data is an array, coordinates are in reverse order from data shape.
+
+ :param _StatsContext context:
+ :param int index: Index in the flattened data array
+ :rtype: List[int]
+ """
+ if context.isStructuredData():
+ coordinates = []
+ for axis in reversed(context.axes):
+ coordinates.append(axis[index % len(axis)])
+ index = index // len(axis)
+ return tuple(coordinates)
+ else:
+ return tuple(axis[index] for axis in context.axes)
+
+
+class StatCoordMin(_StatCoord):
+ """Compute the coordinates of the first minimum value of the data"""
def __init__(self):
- StatBase.__init__(self, name='coords min')
+ _StatCoord.__init__(self, name='coords min')
def calculate(self, context):
- if context.kind in ('curve', 'histogram'):
- return context.xData[numpy.argmin(context.yData)]
- elif context.kind == 'scatter':
- xData, yData, valueData = context.data
- return (xData[numpy.argmin(valueData)],
- yData[numpy.argmin(valueData)])
- elif context.kind == 'image':
- scaleX, scaleY = context.scale
- originX, originY = context.origin
- index1D = numpy.argmin(context.data)
- ySize = (context.data.shape[1])
- x = index1D % context.data.shape[1]
- y = (index1D - x) / ySize
- x = x * scaleX + originX
- y = y * scaleY + originY
- return (x, y)
- else:
- raise ValueError('kind not managed')
+ if context.values is None or not context.isScalarData():
+ return None
+
+ index = numpy.argmin(context.values)
+ return self._indexToCoordinates(context, index)
def getToolTip(self, kind):
- if kind in ('scatter', 'image'):
- return '(x, y)'
- else:
- return None
+ return "Coordinates of the first minimum value of the data"
-class StatCoordMax(StatBase):
- """
- Compute the first coordinates of the data minimal value
- """
+
+class StatCoordMax(_StatCoord):
+ """Compute the coordinates of the first maximum value of the data"""
def __init__(self):
- StatBase.__init__(self, name='coords max')
+ _StatCoord.__init__(self, name='coords max')
def calculate(self, context):
- if context.kind in ('curve', 'histogram'):
- return context.xData[numpy.argmax(context.yData)]
- elif context.kind == 'scatter':
- xData, yData, valueData = context.data
- return (xData[numpy.argmax(valueData)],
- yData[numpy.argmax(valueData)])
- elif context.kind == 'image':
- scaleX, scaleY = context.scale
- originX, originY = context.origin
- index1D = numpy.argmax(context.data)
- ySize = (context.data.shape[1])
- x = index1D % context.data.shape[1]
- y = (index1D - x) / ySize
- x = x * scaleX + originX
- y = y * scaleY + originY
- return (x, y)
- else:
- raise ValueError('kind not managed')
+ if context.values is None or not context.isScalarData():
+ return None
+
+ index = numpy.argmax(context.values)
+ return self._indexToCoordinates(context, index)
def getToolTip(self, kind):
- if kind in ('scatter', 'image'):
- return '(x, y)'
- else:
- return None
+ return "Coordinates of the first maximum value of the data"
+
class StatCOM(StatBase):
- """
- Compute data center of mass
- """
+ """Compute data center of mass"""
def __init__(self):
StatBase.__init__(self, name='COM', description='Center of mass')
def calculate(self, context):
- if context.kind in ('curve', 'histogram'):
- xData, yData = context.data
- deno = numpy.sum(yData).astype(numpy.float32)
- if deno == 0.:
- return numpy.nan
- else:
- return numpy.sum(xData * yData).astype(numpy.float32) / deno
- elif context.kind == 'scatter':
- xData, yData, values = context.data
- deno = numpy.sum(values).astype(numpy.float32)
- if deno == 0.:
- return numpy.nan, numpy.nan
- else:
- xcom = numpy.sum(xData * values).astype(numpy.float32) / deno
- ycom = numpy.sum(yData * values).astype(numpy.float32) / deno
- return (xcom, ycom)
- elif context.kind == 'image':
- yData = numpy.sum(context.data, axis=1)
- xData = numpy.sum(context.data, axis=0)
- dataXRange = range(context.data.shape[1])
- dataYRange = range(context.data.shape[0])
- xScale, yScale = context.scale
- xOrigin, yOrigin = context.origin
-
- denoY = numpy.sum(yData)
- if denoY == 0.:
- ycom = numpy.nan
- else:
- ycom = numpy.sum(yData * dataYRange) / denoY
- ycom = ycom * yScale + yOrigin
+ if context.values is None or not context.isScalarData():
+ return None
- denoX = numpy.sum(xData)
- if denoX == 0.:
- xcom = numpy.nan
- else:
- xcom = numpy.sum(xData * dataXRange) / denoX
- xcom = xcom * xScale + xOrigin
- return (xcom, ycom)
+ values = numpy.array(context.values, dtype=numpy.float64)
+ sum_ = numpy.sum(values)
+ if sum_ == 0.:
+ return (numpy.nan,) * len(context.axes)
+
+ if context.isStructuredData():
+ centerofmass = []
+ for index, axis in enumerate(context.axes):
+ axes = tuple([i for i in range(len(context.axes)) if i != index])
+ centerofmass.append(
+ numpy.sum(axis * numpy.sum(values, axis=axes)) / sum_)
+ return tuple(reversed(centerofmass))
else:
- raise ValueError('kind not managed')
+ return tuple(
+ numpy.sum(axis * values) / sum_ for axis in context.axes)
def getToolTip(self, kind):
- if kind in ('scatter', 'image'):
- return '(x, y)'
- else:
- return None
+ return "Compute the center of mass of the dataset"
diff --git a/silx/gui/plot/stats/statshandler.py b/silx/gui/plot/stats/statshandler.py
index 0a62b31..f69daff 100644
--- a/silx/gui/plot/stats/statshandler.py
+++ b/silx/gui/plot/stats/statshandler.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -45,7 +45,14 @@ class _FloatItem(qt.QTableWidgetItem):
qt.QTableWidgetItem.__init__(self, type=type)
def __lt__(self, other):
- return float(self.text()) < float(other.text())
+ self_values = self.text().lstrip('(').rstrip(')').split(',')
+ other_values = other.text().lstrip('(').rstrip(')').split(',')
+ for self_value, other_value in zip(self_values, other_values):
+ f_self_value = float(self_value)
+ f_other_value = float(other_value)
+ if f_self_value != f_other_value:
+ return f_self_value < f_other_value
+ return False
class StatFormatter(object):
@@ -89,10 +96,60 @@ class StatsHandler(object):
self.stats = statsmdl.Stats()
self.formatters = {}
for elmt in statFormatters:
- helper = _StatHelper(elmt)
- self.add(stat=helper.stat, formatter=helper.statFormatter)
+ stat, formatter = self._processStatArgument(elmt)
+ self.add(stat=stat, formatter=formatter)
+
+ @staticmethod
+ def _processStatArgument(arg):
+ """Process an element of the init arguments
+
+ :param arg: The argument to process
+ :return: Corresponding (StatBase, StatFormatter)
+ """
+ stat, formatter = None, None
+
+ if isinstance(arg, statsmdl.StatBase):
+ stat = arg
+ else:
+ assert len(arg) > 0
+ if isinstance(arg[0], statsmdl.StatBase):
+ stat = arg[0]
+ if len(arg) > 2:
+ raise ValueError('To many argument with %s. At most one '
+ 'argument can be associated with the '
+ 'BaseStat (the `StatFormatter`')
+ if len(arg) == 2:
+ assert arg[1] is None or isinstance(arg[1], (StatFormatter, str))
+ formatter = arg[1]
+ else:
+ if isinstance(arg[0], tuple):
+ if len(arg) > 1:
+ formatter = arg[1]
+ arg = arg[0]
+
+ if type(arg[0]) is not str:
+ raise ValueError('first element of the tuple should be a string'
+ ' or a StatBase instance')
+ if len(arg) == 1:
+ raise ValueError('A function should be associated with the'
+ 'stat name')
+ if len(arg) > 3:
+ raise ValueError('Two much argument given for defining statistic.'
+ 'Take at most three arguments (name, function, '
+ 'kinds)')
+ if len(arg) == 2:
+ stat = statsmdl.Stat(name=arg[0], fct=arg[1])
+ else:
+ stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2])
+
+ return stat, formatter
def add(self, stat, formatter=None):
+ """Add a stat to the list.
+
+ :param StatBase stat:
+ :param Union[None,StatFormatter] formatter:
+ """
assert isinstance(stat, statsmdl.StatBase)
self.stats.add(stat)
_formatter = formatter
@@ -101,9 +158,9 @@ class StatsHandler(object):
self.formatters[stat.name] = _formatter
def format(self, name, val):
- """
- Apply the format for the `name` statistic and the given value
- :param name: the name of the associated statistic
+ """Apply the format for the `name` statistic and the given value
+
+ :param str name: the name of the associated statistic
:param val: value before formatting
:return: formatted value
"""
@@ -123,7 +180,7 @@ class StatsHandler(object):
def calculate(self, item, plot, onlimits):
"""
- compute all statistic registred and return the list of formatted
+ compute all statistic registered and return the list of formatted
statistics result.
:param item: item for which we want to compute statistics
@@ -137,54 +194,3 @@ class StatsHandler(object):
for resName, resValue in list(res.items()):
res[resName] = self.format(resName, res[resName])
return res
-
-
-class _StatHelper(object):
- """
- Helper class to generated the requested StatBase instance and the
- associated StatFormatter
- """
- def __init__(self, arg):
- self.statFormatter = None
- self.stat = None
-
- if isinstance(arg, statsmdl.StatBase):
- self.stat = arg
- else:
- assert len(arg) > 0
- if isinstance(arg[0], statsmdl.StatBase):
- self.dealWithStatAndFormatter(arg)
- else:
- _arg = arg
- if isinstance(arg[0], tuple):
- _arg = arg[0]
- if len(arg) > 1:
- self.statFormatter = arg[1]
- self.createStatInstanceAndFormatter(_arg)
-
- def dealWithStatAndFormatter(self, arg):
- assert isinstance(arg[0], statsmdl.StatBase)
- self.stat = arg[0]
- if len(arg) > 2:
- raise ValueError('To many argument with %s. At most one '
- 'argument can be associated with the '
- 'BaseStat (the `StatFormatter`')
- if len(arg) is 2:
- assert isinstance(arg[1], (StatFormatter, type(None), str))
- self.statFormatter = arg[1]
-
- def createStatInstanceAndFormatter(self, arg):
- if type(arg[0]) is not str:
- raise ValueError('first element of the tuple should be a string'
- ' or a StatBase instance')
- if len(arg) is 1:
- raise ValueError('A function should be associated with the'
- 'stat name')
- if len(arg) > 3:
- raise ValueError('Two much argument given for defining statistic.'
- 'Take at most three arguments (name, function, '
- 'kinds)')
- if len(arg) is 2:
- self.stat = statsmdl.Stat(name=arg[0], fct=arg[1])
- else:
- self.stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2])
diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py
index 0704779..5bcabd8 100644
--- a/silx/gui/plot/test/testCurvesROIWidget.py
+++ b/silx/gui/plot/test/testCurvesROIWidget.py
@@ -36,7 +36,7 @@ from collections import OrderedDict
import numpy
from silx.gui import qt
from silx.test.utils import temp_dir
-from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
from silx.gui.plot import PlotWindow, CurvesROIWidget
@@ -52,7 +52,8 @@ class TestCurvesROIWidget(TestCaseQt):
self.plot.show()
self.qWaitForWindowExposed(self.plot)
- self.widget = CurvesROIWidget.CurvesROIDockWidget(plot=self.plot, name='TEST')
+ self.widget = self.plot.getCurvesRoiDockWidget()
+
self.widget.show()
self.qWaitForWindowExposed(self.widget)
@@ -67,10 +68,6 @@ class TestCurvesROIWidget(TestCaseQt):
super(TestCurvesROIWidget, self).tearDown()
- def testEmptyPlot(self):
- """Empty plot, display ROI widget"""
- pass
-
def testWithCurves(self):
"""Plot with curves: test all ROI widget buttons"""
for offset in range(2):
@@ -80,13 +77,16 @@ class TestCurvesROIWidget(TestCaseQt):
# Add two ROI
self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
# Change active curve
self.plot.setActiveCurve(str(1))
# Delete a ROI
self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton)
+ self.qWait(200)
with temp_dir() as tmpDir:
self.tmpFile = os.path.join(tmpDir, 'test.ini')
@@ -94,30 +94,42 @@ class TestCurvesROIWidget(TestCaseQt):
# Save ROIs
self.widget.roiWidget.save(self.tmpFile)
self.assertTrue(os.path.isfile(self.tmpFile))
+ self.assertTrue(len(self.widget.getRois()) is 2)
# Reset ROIs
self.mouseClick(self.widget.roiWidget.resetButton,
qt.Qt.LeftButton)
+ self.qWait(200)
+ rois = self.widget.getRois()
+ self.assertTrue(len(rois) is 1)
+ print(rois)
+ roiID = list(rois.keys())[0]
+ self.assertTrue(rois[roiID].getName() == 'ICR')
# Load ROIs
self.widget.roiWidget.load(self.tmpFile)
+ self.assertTrue(len(self.widget.getRois()) is 2)
del self.tmpFile
def testMiddleMarker(self):
"""Test with middle marker enabled"""
- self.widget.roiWidget.setMiddleROIMarkerFlag(True)
+ self.widget.roiWidget.roiTable.setMiddleROIMarkerFlag(True)
# Add a ROI
self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
- xleftMarker = self.plot._getMarker(legend='ROI min').getXPosition()
- xMiddleMarker = self.plot._getMarker(legend='ROI middle').getXPosition()
- xRightMarker = self.plot._getMarker(legend='ROI max').getXPosition()
- self.assertAlmostEqual(xMiddleMarker,
- xleftMarker + (xRightMarker - xleftMarker) / 2.)
-
- def testCalculation(self):
+ for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers:
+ handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[roiID]
+ assert handler.getMarker('min')
+ xleftMarker = handler.getMarker('min').getXPosition()
+ xMiddleMarker = handler.getMarker('middle').getXPosition()
+ xRightMarker = handler.getMarker('max').getXPosition()
+ thValue = xleftMarker + (xRightMarker - xleftMarker) / 2.
+ self.assertAlmostEqual(xMiddleMarker, thValue)
+
+ def testAreaCalculation(self):
+ """Test result of area calculation"""
x = numpy.arange(100.)
y = numpy.arange(100.)
@@ -129,30 +141,60 @@ class TestCurvesROIWidget(TestCaseQt):
self.plot.setActiveCurve("positive")
# Add two ROIs
- ddict = {}
- ddict["positive"] = {"from": 10, "to": 20, "type":"X"}
- ddict["negative"] = {"from": -20, "to": -10, "type":"X"}
- self.widget.roiWidget.setRois(ddict)
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ posCurve = self.plot.getCurve('positive')
+ negCurve = self.plot.getCurve('negative')
+
+ self.assertEqual(roi_pos.computeRawAndNetArea(posCurve),
+ (numpy.trapz(y=[10, 20], x=[10, 20]),
+ 0.0))
+ self.assertEqual(roi_pos.computeRawAndNetArea(negCurve),
+ (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(posCurve),
+ ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(negCurve),
+ ((-150.0), 0.0))
+
+ def testCountsCalculation(self):
+ """Test result of count calculation"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
- # And calculate the expected output
- self.widget.calculateROIs()
+ # Add two curves
+ self.plot.addCurve(x, y, legend="positive")
+ self.plot.addCurve(-x, y, legend="negative")
+
+ # Make sure there is an active curve and it is the positive one
+ self.plot.setActiveCurve("positive")
- output = self.widget.roiWidget.getRois()
- self.assertEqual(output["positive"]["rawcounts"],
- y[ddict["positive"]["from"]:ddict["positive"]["to"]+1].sum(),
- "Calculation failed on positive X coordinates")
+ # Add two ROIs
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
- # Set the curve with negative X coordinates as active
- self.plot.setActiveCurve("negative")
+ posCurve = self.plot.getCurve('positive')
+ negCurve = self.plot.getCurve('negative')
- # the ROIs should have been automatically updated
- output = self.widget.roiWidget.getRois()
- selection = numpy.nonzero((-x >= output["negative"]["from"]) & \
- (-x <= output["negative"]["to"]))[0]
- self.assertEqual(output["negative"]["rawcounts"],
- y[selection].sum(), "Calculation failed on negative X coordinates")
+ self.assertEqual(roi_pos.computeRawAndNetCounts(posCurve),
+ (y[10:21].sum(), 0.0))
+ self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve),
+ (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve),
+ ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(negCurve),
+ (y[10:21].sum(), 0.0))
def testDeferedInit(self):
+ """Test behavior of the deferedInit"""
x = numpy.arange(100.)
y = numpy.arange(100.)
self.plot.addCurve(x=x, y=y, legend="name", replace="True")
@@ -164,12 +206,123 @@ class TestCurvesROIWidget(TestCaseQt):
])
roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
- self.assertFalse(roiWidget._isInit)
self.plot.getCurvesRoiDockWidget().setRois(roisDefs)
self.assertTrue(len(roiWidget.getRois()) is len(roisDefs))
self.plot.getCurvesRoiDockWidget().setVisible(True)
self.assertTrue(len(roiWidget.getRois()) is len(roisDefs))
+ def testDictCompatibility(self):
+ """Test that ROI api is valid with dict and not information is lost"""
+ roiDict = {'from': 20, 'to': 200, 'type': 'energy', 'comment': 'no',
+ 'name': 'myROI', 'calibration': [1, 2, 3]}
+ roi = CurvesROIWidget.ROI._fromDict(roiDict)
+ self.assertTrue(roi.toDict() == roiDict)
+
+ def testShowAllROI(self):
+ """Test the show allROI action"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+ self.plot.addCurve(x=x, y=y, legend="name", replace="True")
+
+ roisDefsDict = {
+ "range1": {"from": 20, "to": 200,"type": "energy"},
+ "range2": {"from": 300, "to": 500, "type": "energy"}
+ }
+
+ roisDefsObj = (
+ CurvesROIWidget.ROI(name='range3', fromdata=20, todata=200,
+ type_='energy'),
+ CurvesROIWidget.ROI(name='range4', fromdata=300, todata=500,
+ type_='energy')
+ )
+ self.widget.roiWidget.showAllMarkers(True)
+ roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
+ roiWidget.setRois(roisDefsDict)
+ self.assertTrue(len(self.plot._getAllMarkers()) is 2*3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertTrue(len(ICRROI) is 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertTrue(len(ICRROI) is 1)
+
+ roiWidget.setRois(roisDefsObj)
+ self.qapp.processEvents()
+ self.assertTrue(len(self.plot._getAllMarkers()) is 2*3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertTrue(len(ICRROI) is 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertTrue(len(ICRROI) is 1)
+
+ def testRoiEdition(self):
+ """Make sure if the ROI object is edited the ROITable will be updated
+ """
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi, ))
+
+ x = (0, 1, 1, 2, 2, 3)
+ y = (1, 1, 2, 2, 1, 1)
+ self.plot.addCurve(x=x, y=y, legend='linearCurve')
+ self.plot.setActiveCurve(legend='linearCurve')
+ self.widget.calculateROIs()
+
+ roiTable = self.widget.roiWidget.roiTable
+ indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX
+ itemRawCounts = roiTable.item(0, indexesColumns['Raw Counts'])
+ itemNetCounts = roiTable.item(0, indexesColumns['Net Counts'])
+
+ self.assertTrue(itemRawCounts.text() == '8.0')
+ self.assertTrue(itemNetCounts.text() == '2.0')
+
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ itemNetArea = roiTable.item(0, indexesColumns['Net Area'])
+
+ self.assertTrue(itemRawArea.text() == '4.0')
+ self.assertTrue(itemNetArea.text() == '1.0')
+
+ roi.setTo(2)
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ self.assertTrue(itemRawArea.text() == '3.0')
+ roi.setFrom(1)
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ self.assertTrue(itemRawArea.text() == '2.0')
+
+ def testRemoveActiveROI(self):
+ """Test widget behavior when removing the active ROI"""
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+
+ self.widget.roiWidget.roiTable.setActiveRoi(None)
+ self.assertTrue(len(self.widget.roiWidget.roiTable.selectedItems()) is 0)
+ self.widget.roiWidget.setRois((roi,))
+ self.plot.setActiveCurve(legend='linearCurve')
+ self.widget.calculateROIs()
+
+ def testEmitCurrentROI(self):
+ """Test behavior of the CurvesROIWidget.sigROISignal"""
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+ signalListener = SignalListener()
+ self.widget.roiWidget.sigROISignal.connect(signalListener.partial())
+ self.widget.show()
+ self.qapp.processEvents()
+ self.assertTrue(signalListener.callCount() is 0)
+ self.assertTrue(self.widget.roiWidget.roiTable.activeRoi is roi)
+ roi.setFrom(0.0)
+ self.qapp.processEvents()
+ self.assertTrue(signalListener.callCount() is 0)
+ roi.setFrom(0.3)
+ self.qapp.processEvents()
+ self.assertTrue(signalListener.callCount() is 1)
+
def suite():
test_suite = unittest.TestSuite()
diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py
index 6912ea3..a05c1be 100644
--- a/silx/gui/plot/test/testMaskToolsWidget.py
+++ b/silx/gui/plot/test/testMaskToolsWidget.py
@@ -42,10 +42,7 @@ from silx.gui.utils.testutils import getQToolButtonFromAction
from silx.gui.plot import PlotWindow, MaskToolsWidget
from .utils import PlotWidgetTestCase
-try:
- import fabio
-except ImportError:
- fabio = None
+import fabio
_logger = logging.getLogger(__name__)
@@ -254,8 +251,6 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.__loadSave("npy")
def testLoadSaveFit2D(self):
- if fabio is None:
- self.skipTest("Fabio is missing")
self.__loadSave("msk")
def testSigMaskChangedEmitted(self):
diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py
index 857b9bc..9d7c093 100644
--- a/silx/gui/plot/test/testPlotWidget.py
+++ b/silx/gui/plot/test/testPlotWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,7 +26,7 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "21/09/2018"
+__date__ = "03/01/2019"
import unittest
@@ -36,8 +36,6 @@ import numpy
from silx.utils.testutils import ParametricTestCase, parameterize
from silx.gui.utils.testutils import SignalListener
from silx.gui.utils.testutils import TestCaseQt
-from silx.utils import testutils
-from silx.utils import deprecation
from silx.test.utils import test_options
@@ -184,6 +182,39 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x)))
self.assertEqual(items[5].getType(), 'rectangle')
+ def testBackGroundColors(self):
+ self.plot.setVisible(True)
+ self.qWaitForWindowExposed(self.plot)
+ self.qapp.processEvents()
+
+ # Custom the full background
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ self.plot.setBackgroundColor("red")
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Custom the data background
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.plot.setDataBackgroundColor("red")
+ color = self.plot.getDataBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Back to default
+ self.plot.setBackgroundColor('white')
+ self.plot.setDataBackgroundColor(None)
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.qapp.processEvents()
+
+
class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
"""Basic tests for addImage"""
@@ -881,17 +912,12 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
if getter is not None:
self.assertEqual(getter(), expected)
- @testutils.test_logging(deprecation.depreclog.name)
def testOldPlotAxis_Logarithmic(self):
"""Test silx API prior to silx 0.6"""
x = self.plot.getXAxis()
y = self.plot.getYAxis()
yright = self.plot.getYAxis(axis="right")
- listener = SignalListener()
- self.plot.sigSetXAxisLogarithmic.connect(listener.partial("x"))
- self.plot.sigSetYAxisLogarithmic.connect(listener.partial("y"))
-
self.assertEqual(x.getScale(), x.LINEAR)
self.assertEqual(y.getScale(), x.LINEAR)
self.assertEqual(yright.getScale(), x.LINEAR)
@@ -902,7 +928,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.getScale(), x.LINEAR)
self.assertEqual(self.plot.isXAxisLogarithmic(), True)
self.assertEqual(self.plot.isYAxisLogarithmic(), False)
- self.assertEqual(listener.arguments(callIndex=-1), ("x", True))
self.plot.setYAxisLogarithmic(True)
self.assertEqual(x.getScale(), x.LOGARITHMIC)
@@ -910,7 +935,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.getScale(), x.LOGARITHMIC)
self.assertEqual(self.plot.isXAxisLogarithmic(), True)
self.assertEqual(self.plot.isYAxisLogarithmic(), True)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", True))
yright.setScale(yright.LINEAR)
self.assertEqual(x.getScale(), x.LOGARITHMIC)
@@ -918,19 +942,13 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.getScale(), x.LINEAR)
self.assertEqual(self.plot.isXAxisLogarithmic(), True)
self.assertEqual(self.plot.isYAxisLogarithmic(), False)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", False))
- @testutils.test_logging(deprecation.depreclog.name)
def testOldPlotAxis_AutoScale(self):
"""Test silx API prior to silx 0.6"""
x = self.plot.getXAxis()
y = self.plot.getYAxis()
yright = self.plot.getYAxis(axis="right")
- listener = SignalListener()
- self.plot.sigSetXAxisAutoScale.connect(listener.partial("x"))
- self.plot.sigSetYAxisAutoScale.connect(listener.partial("y"))
-
self.assertEqual(x.isAutoScale(), True)
self.assertEqual(y.isAutoScale(), True)
self.assertEqual(yright.isAutoScale(), True)
@@ -941,7 +959,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.isAutoScale(), True)
self.assertEqual(self.plot.isXAxisAutoScale(), False)
self.assertEqual(self.plot.isYAxisAutoScale(), True)
- self.assertEqual(listener.arguments(callIndex=-1), ("x", False))
self.plot.setYAxisAutoScale(False)
self.assertEqual(x.isAutoScale(), False)
@@ -949,7 +966,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.isAutoScale(), False)
self.assertEqual(self.plot.isXAxisAutoScale(), False)
self.assertEqual(self.plot.isYAxisAutoScale(), False)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", False))
yright.setAutoScale(True)
self.assertEqual(x.isAutoScale(), False)
@@ -957,18 +973,13 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.isAutoScale(), True)
self.assertEqual(self.plot.isXAxisAutoScale(), False)
self.assertEqual(self.plot.isYAxisAutoScale(), True)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", True))
- @testutils.test_logging(deprecation.depreclog.name)
def testOldPlotAxis_Inverted(self):
"""Test silx API prior to silx 0.6"""
x = self.plot.getXAxis()
y = self.plot.getYAxis()
yright = self.plot.getYAxis(axis="right")
- listener = SignalListener()
- self.plot.sigSetYAxisInverted.connect(listener.partial("y"))
-
self.assertEqual(x.isInverted(), False)
self.assertEqual(y.isInverted(), False)
self.assertEqual(yright.isInverted(), False)
@@ -978,14 +989,12 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(y.isInverted(), True)
self.assertEqual(yright.isInverted(), True)
self.assertEqual(self.plot.isYAxisInverted(), True)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", True))
yright.setInverted(False)
self.assertEqual(x.isInverted(), False)
self.assertEqual(y.isInverted(), False)
self.assertEqual(yright.isInverted(), False)
self.assertEqual(self.plot.isYAxisInverted(), False)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", False))
def testLogXWithData(self):
self.plot.setGraphTitle('Curve X: Log Y: Linear')
diff --git a/silx/gui/plot/test/testSaveAction.py b/silx/gui/plot/test/testSaveAction.py
index 85669bf..0eb129d 100644
--- a/silx/gui/plot/test/testSaveAction.py
+++ b/silx/gui/plot/test/testSaveAction.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -106,12 +106,30 @@ class TestSaveActionExtension(PlotWidgetTestCase):
self.assertEqual(saveAction.getFileFilters('all')[nameFilter],
self._dummySaveFunction)
+ # Add a new file filter at a particular position
+ nameFilter = 'Dummy file2 (*.dummy)'
+ saveAction.setFileFilter('all', nameFilter,
+ self._dummySaveFunction, index=3)
+ self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
+ filters = saveAction.getFileFilters('all')
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter),3)
+
# Update an existing file filter
nameFilter = SaveAction.IMAGE_FILTER_EDF
saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction)
self.assertEqual(saveAction.getFileFilters('image')[nameFilter],
self._dummySaveFunction)
+ # Change the position of an existing file filter
+ nameFilter = 'Dummy file2 (*.dummy)'
+ oldIndex = list(saveAction.getFileFilters('all')).index(nameFilter)
+ newIndex = oldIndex - 1
+ saveAction.setFileFilter('all', nameFilter,
+ self._dummySaveFunction, index=newIndex)
+ filters = saveAction.getFileFilters('all')
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter), newIndex)
def suite():
test_suite = unittest.TestSuite()
diff --git a/silx/gui/plot/test/testScatterMaskToolsWidget.py b/silx/gui/plot/test/testScatterMaskToolsWidget.py
index a446911..171ec42 100644
--- a/silx/gui/plot/test/testScatterMaskToolsWidget.py
+++ b/silx/gui/plot/test/testScatterMaskToolsWidget.py
@@ -42,10 +42,7 @@ from silx.gui.utils.testutils import getQToolButtonFromAction
from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget
from .utils import PlotWidgetTestCase
-try:
- import fabio
-except ImportError:
- fabio = None
+import fabio
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py
index faedcff..7fbc247 100644
--- a/silx/gui/plot/test/testStats.py
+++ b/silx/gui/plot/test/testStats.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -112,34 +112,34 @@ class TestStats(TestCaseQt):
"""Test result for simple stats on a curve"""
_stats = self.getBasicStats()
xData = yData = numpy.array(range(20))
- self.assertTrue(_stats['min'].calculate(self.curveContext) == 0)
- self.assertTrue(_stats['max'].calculate(self.curveContext) == 19)
- self.assertTrue(_stats['minCoords'].calculate(self.curveContext) == [0])
- self.assertTrue(_stats['maxCoords'].calculate(self.curveContext) == [19])
- self.assertTrue(_stats['std'].calculate(self.curveContext) == numpy.std(yData))
- self.assertTrue(_stats['mean'].calculate(self.curveContext) == numpy.mean(yData))
+ self.assertEqual(_stats['min'].calculate(self.curveContext), 0)
+ self.assertEqual(_stats['max'].calculate(self.curveContext), 19)
+ self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,))
+ self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData))
+ self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData))
com = numpy.sum(xData * yData) / numpy.sum(yData)
- self.assertTrue(_stats['com'].calculate(self.curveContext) == com)
+ self.assertEqual(_stats['com'].calculate(self.curveContext), com)
def testBasicStatsImage(self):
"""Test result for simple stats on an image"""
_stats = self.getBasicStats()
- self.assertTrue(_stats['min'].calculate(self.imageContext) == 0)
- self.assertTrue(_stats['max'].calculate(self.imageContext) == 128 * 32 - 1)
- self.assertTrue(_stats['minCoords'].calculate(self.imageContext) == (0, 0))
- self.assertTrue(_stats['maxCoords'].calculate(self.imageContext) == (127, 31))
- self.assertTrue(_stats['std'].calculate(self.imageContext) == numpy.std(self.imageData))
- self.assertTrue(_stats['mean'].calculate(self.imageContext) == numpy.mean(self.imageData))
-
- yData = numpy.sum(self.imageData, axis=1)
- xData = numpy.sum(self.imageData, axis=0)
+ self.assertEqual(_stats['min'].calculate(self.imageContext), 0)
+ self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0))
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31))
+ self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData))
+ self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData))
+
+ yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1)
+ xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0)
dataXRange = range(self.imageData.shape[1])
dataYRange = range(self.imageData.shape[0])
ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
- self.assertTrue(_stats['com'].calculate(self.imageContext) == (xcom, ycom))
+ self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
def testStatsImageAdv(self):
"""Test that scale and origin are taking into account for images"""
@@ -153,52 +153,46 @@ class TestStats(TestCaseQt):
onlimits=False
)
_stats = self.getBasicStats()
- self.assertTrue(_stats['min'].calculate(image2Context) == 0)
- self.assertTrue(
- _stats['max'].calculate(image2Context) == 128 * 32 - 1)
- self.assertTrue(
- _stats['minCoords'].calculate(image2Context) == (100, 10))
- self.assertTrue(
- _stats['maxCoords'].calculate(image2Context) == (127*2. + 100,
- 31 * 0.5 + 10)
- )
- self.assertTrue(
- _stats['std'].calculate(image2Context) == numpy.std(
- self.imageData))
- self.assertTrue(
- _stats['mean'].calculate(image2Context) == numpy.mean(
- self.imageData))
+ self.assertEqual(_stats['min'].calculate(image2Context), 0)
+ self.assertEqual(
+ _stats['max'].calculate(image2Context), 128 * 32 - 1)
+ self.assertEqual(
+ _stats['minCoords'].calculate(image2Context), (100, 10))
+ self.assertEqual(
+ _stats['maxCoords'].calculate(image2Context), (127*2. + 100,
+ 31 * 0.5 + 10))
+ self.assertEqual(_stats['std'].calculate(image2Context),
+ numpy.std(self.imageData))
+ self.assertEqual(_stats['mean'].calculate(image2Context),
+ numpy.mean(self.imageData))
yData = numpy.sum(self.imageData, axis=1)
xData = numpy.sum(self.imageData, axis=0)
- dataXRange = range(self.imageData.shape[1])
- dataYRange = range(self.imageData.shape[0])
+ dataXRange = numpy.arange(self.imageData.shape[1], dtype=numpy.float64)
+ dataYRange = numpy.arange(self.imageData.shape[0], dtype=numpy.float64)
ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
ycom = (ycom * 0.5) + 10
xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
xcom = (xcom * 2.) + 100
- self.assertTrue(
- _stats['com'].calculate(image2Context) == (xcom, ycom))
+ self.assertTrue(numpy.allclose(
+ _stats['com'].calculate(image2Context), (xcom, ycom)))
def testBasicStatsScatter(self):
"""Test result for simple stats on a scatter"""
_stats = self.getBasicStats()
- self.assertTrue(_stats['min'].calculate(self.scatterContext) == 5)
- self.assertTrue(_stats['max'].calculate(self.scatterContext) == 90)
- self.assertTrue(_stats['minCoords'].calculate(self.scatterContext) == (0, 2))
- self.assertTrue(_stats['maxCoords'].calculate(self.scatterContext) == (50, 69))
- self.assertTrue(_stats['std'].calculate(self.scatterContext) == numpy.std(self.valuesScatterData))
- self.assertTrue(_stats['mean'].calculate(self.scatterContext) == numpy.mean(self.valuesScatterData))
-
- comx = numpy.sum(self.xScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum(
- self.valuesScatterData).astype(numpy.float32)
- comy = numpy.sum(self.yScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum(
- self.valuesScatterData).astype(numpy.float32)
- self.assertTrue(numpy.all(
- numpy.equal(_stats['com'].calculate(self.scatterContext),
- (comx, comy)))
- )
+ self.assertEqual(_stats['min'].calculate(self.scatterContext), 5)
+ self.assertEqual(_stats['max'].calculate(self.scatterContext), 90)
+ self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2))
+ self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69))
+ self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData))
+ self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData))
+
+ data = self.valuesScatterData.astype(numpy.float64)
+ comx = numpy.sum(self.xScatterData * data) / numpy.sum(data)
+ comy = numpy.sum(self.yScatterData * data) / numpy.sum(data)
+ self.assertEqual(_stats['com'].calculate(self.scatterContext),
+ (comx, comy))
def testKindNotManagedByStat(self):
"""Make sure an exception is raised if we try to execute calculate
@@ -227,21 +221,21 @@ class TestStats(TestCaseQt):
item=self.plot1d.getCurve('curve0'),
plot=self.plot1d,
onlimits=True)
- self.assertTrue(stat.calculate(curveContextOnLimits) == 2)
+ self.assertEqual(stat.calculate(curveContextOnLimits), 2)
self.plot2d.getXAxis().setLimitsConstraints(minPos=32)
imageContextOnLimits = stats._ImageContext(
item=self.plot2d.getImage('test image'),
plot=self.plot2d,
onlimits=True)
- self.assertTrue(stat.calculate(imageContextOnLimits) == 32)
+ self.assertEqual(stat.calculate(imageContextOnLimits), 32)
self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40)
scatterContextOnLimits = stats._ScatterContext(
item=self.scatterPlot.getScatter('scatter plot'),
plot=self.scatterPlot,
onlimits=True)
- self.assertTrue(stat.calculate(scatterContextOnLimits) == 20)
+ self.assertEqual(stat.calculate(scatterContextOnLimits), 20)
class TestStatsFormatter(TestCaseQt):
@@ -267,15 +261,15 @@ class TestStatsFormatter(TestCaseQt):
"""Make sure a formatter with no formatter definition will return a
simple cast to str"""
emptyFormatter = statshandler.StatFormatter()
- self.assertTrue(
- emptyFormatter.format(self.stat.calculate(self.curveContext)) == '0.000')
+ self.assertEqual(
+ emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000')
def testSettedFormatter(self):
"""Make sure a formatter with no formatter definition will return a
simple cast to str"""
formatter= statshandler.StatFormatter(formatter='{0:.3f}')
- self.assertTrue(
- formatter.format(self.stat.calculate(self.curveContext)) == '0.000')
+ self.assertEqual(
+ formatter.format(self.stat.calculate(self.curveContext)), '0.000')
class TestStatsHandler(unittest.TestCase):
@@ -309,9 +303,9 @@ class TestStatsHandler(unittest.TestCase):
res = handler0.calculate(item=self.curveItem, plot=self.plot1d,
onlimits=False)
self.assertTrue('min' in res)
- self.assertTrue(res['min'] == '0')
+ self.assertEqual(res['min'], '0')
self.assertTrue('max' in res)
- self.assertTrue(res['max'] == '19')
+ self.assertEqual(res['max'], '19')
handler1 = statshandler.StatsHandler(
(
@@ -323,9 +317,9 @@ class TestStatsHandler(unittest.TestCase):
res = handler1.calculate(item=self.curveItem, plot=self.plot1d,
onlimits=False)
self.assertTrue('min' in res)
- self.assertTrue(res['min'] == '0')
+ self.assertEqual(res['min'], '0')
self.assertTrue('max' in res)
- self.assertTrue(res['max'] == '19.000')
+ self.assertEqual(res['max'], '19.000')
handler2 = statshandler.StatsHandler(
(
@@ -336,9 +330,9 @@ class TestStatsHandler(unittest.TestCase):
res = handler2.calculate(item=self.curveItem, plot=self.plot1d,
onlimits=False)
self.assertTrue('min' in res)
- self.assertTrue(res['min'] == '0')
+ self.assertEqual(res['min'], '0')
self.assertTrue('max' in res)
- self.assertTrue(res['max'] == '19.000')
+ self.assertEqual(res['max'], '19.000')
handler3 = statshandler.StatsHandler((
(('amin', numpy.argmin), statshandler.StatFormatter()),
@@ -348,9 +342,9 @@ class TestStatsHandler(unittest.TestCase):
res = handler3.calculate(item=self.curveItem, plot=self.plot1d,
onlimits=False)
self.assertTrue('amin' in res)
- self.assertTrue(res['amin'] == '0.000')
+ self.assertEqual(res['amin'], '0.000')
self.assertTrue('amax' in res)
- self.assertTrue(res['amax'] == '19')
+ self.assertEqual(res['amax'], '19')
with self.assertRaises(ValueError):
statshandler.StatsHandler(('name'))
@@ -395,47 +389,49 @@ class TestStatsWidgetWithCurves(TestCaseQt):
def testInit(self):
"""Make sure all the curves are registred on initialization"""
- self.assertTrue(self.widget.rowCount() is 3)
+ self.assertEqual(self.widget.rowCount(), 3)
def testRemoveCurve(self):
"""Make sure the Curves stats take into account the curve removal from
plot"""
self.plot.removeCurve('curve2')
- self.assertTrue(self.widget.rowCount() is 2)
+ self.assertEqual(self.widget.rowCount(), 2)
for iRow in range(2):
self.assertTrue(self.widget.item(iRow, 0).text() in ('curve0', 'curve1'))
self.plot.removeCurve('curve0')
- self.assertTrue(self.widget.rowCount() is 1)
+ self.assertEqual(self.widget.rowCount(), 1)
self.plot.removeCurve('curve1')
- self.assertTrue(self.widget.rowCount() is 0)
+ self.assertEqual(self.widget.rowCount(), 0)
def testAddCurve(self):
"""Make sure the Curves stats take into account the add curve action"""
self.plot.addCurve(legend='curve3', x=range(10), y=range(10))
- self.assertTrue(self.widget.rowCount() is 4)
+ self.assertEqual(self.widget.rowCount(), 4)
- def testUpdateCurveFrmAddCurve(self):
+ def testUpdateCurveFromAddCurve(self):
"""Make sure the stats of the cuve will be removed after updating a
curve"""
self.plot.addCurve(legend='curve0', x=range(10), y=range(10))
- self.assertTrue(self.widget.rowCount() is 3)
- itemMax = self.widget._getItem(name='max', legend='curve0',
- kind='curve', indexTable=None)
- self.assertTrue(itemMax.text() == '9')
+ self.qapp.processEvents()
+ self.assertEqual(self.widget.rowCount(), 3)
+ curve = self.plot._getItem(kind='curve', legend='curve0')
+ tableItems = self.widget._itemToTableItems(curve)
+ self.assertEqual(tableItems['max'].text(), '9')
- def testUpdateCurveFrmCurveObj(self):
+ def testUpdateCurveFromCurveObj(self):
self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
- self.assertTrue(self.widget.rowCount() is 3)
- itemMax = self.widget._getItem(name='max', legend='curve0',
- kind='curve', indexTable=None)
- self.assertTrue(itemMax.text() == '3')
+ self.qapp.processEvents()
+ self.assertEqual(self.widget.rowCount(), 3)
+ curve = self.plot._getItem(kind='curve', legend='curve0')
+ tableItems = self.widget._itemToTableItems(curve)
+ self.assertEqual(tableItems['max'].text(), '3')
def testSetAnotherPlot(self):
plot2 = Plot1D()
plot2.addCurve(x=range(26), y=range(26), legend='new curve')
self.widget.setPlot(plot2)
- self.assertTrue(self.widget.rowCount() is 1)
+ self.assertEqual(self.widget.rowCount(), 1)
self.qapp.processEvents()
plot2.setAttribute(qt.Qt.WA_DeleteOnClose)
plot2.close()
@@ -444,12 +440,15 @@ class TestStatsWidgetWithCurves(TestCaseQt):
class TestStatsWidgetWithImages(TestCaseQt):
"""Basic test for StatsWidget with images"""
+
+ IMAGE_LEGEND = 'test image'
+
def setUp(self):
TestCaseQt.setUp(self)
self.plot = Plot2D()
self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128),
- legend='test image', replace=False)
+ legend=self.IMAGE_LEGEND, replace=False)
self.widget = StatsWidget.StatsTable(plot=self.plot)
@@ -476,31 +475,30 @@ class TestStatsWidgetWithImages(TestCaseQt):
TestCaseQt.tearDown(self)
def test(self):
- columnsIndex = self.widget._columns_index
- itemLegend = self.widget._lgdAndKindToItems[('test image', 'image')]['legend']
- itemMin = self.widget.item(itemLegend.row(), columnsIndex['min'])
- itemMax = self.widget.item(itemLegend.row(), columnsIndex['max'])
- itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta'])
- itemCoordsMin = self.widget.item(itemLegend.row(),
- columnsIndex['coords min'])
- itemCoordsMax = self.widget.item(itemLegend.row(),
- columnsIndex['coords max'])
- max = (128 * 128) - 1
- self.assertTrue(itemMin.text() == '0.000')
- self.assertTrue(itemMax.text() == '{0:.3f}'.format(max))
- self.assertTrue(itemDelta.text() == '{0:.3f}'.format(max))
- self.assertTrue(itemCoordsMin.text() == '0.0, 0.0')
- self.assertTrue(itemCoordsMax.text() == '127.0, 127.0')
+ image = self.plot._getItem(
+ kind='image', legend=self.IMAGE_LEGEND)
+ tableItems = self.widget._itemToTableItems(image)
+
+ maxText = '{0:.3f}'.format((128 * 128) - 1)
+ self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND)
+ self.assertEqual(tableItems['min'].text(), '0.000')
+ self.assertEqual(tableItems['max'].text(), maxText)
+ self.assertEqual(tableItems['delta'].text(), maxText)
+ self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0')
+ self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0')
class TestStatsWidgetWithScatters(TestCaseQt):
+
+ SCATTER_LEGEND = 'scatter plot'
+
def setUp(self):
TestCaseQt.setUp(self)
self.scatterPlot = Plot2D()
self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60],
[2, 3, 4, 26, 69, 6],
[5, 6, 7, 10, 90, 20],
- legend='scatter plot')
+ legend=self.SCATTER_LEGEND)
self.widget = StatsWidget.StatsTable(plot=self.scatterPlot)
mystats = statshandler.StatsHandler((
@@ -526,33 +524,89 @@ class TestStatsWidgetWithScatters(TestCaseQt):
TestCaseQt.tearDown(self)
def testStats(self):
- columnsIndex = self.widget._columns_index
- itemLegend = self.widget._lgdAndKindToItems[('scatter plot', 'scatter')]['legend']
- itemMin = self.widget.item(itemLegend.row(), columnsIndex['min'])
- itemMax = self.widget.item(itemLegend.row(), columnsIndex['max'])
- itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta'])
- itemCoordsMin = self.widget.item(itemLegend.row(),
- columnsIndex['coords min'])
- itemCoordsMax = self.widget.item(itemLegend.row(),
- columnsIndex['coords max'])
- self.assertTrue(itemMin.text() == '5')
- self.assertTrue(itemMax.text() == '90')
- self.assertTrue(itemDelta.text() == '85')
- self.assertTrue(itemCoordsMin.text() == '0, 2')
- self.assertTrue(itemCoordsMax.text() == '50, 69')
+ scatter = self.scatterPlot._getItem(
+ kind='scatter', legend=self.SCATTER_LEGEND)
+ tableItems = self.widget._itemToTableItems(scatter)
+ self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND)
+ self.assertEqual(tableItems['min'].text(), '5')
+ self.assertEqual(tableItems['coords min'].text(), '0, 2')
+ self.assertEqual(tableItems['max'].text(), '90')
+ self.assertEqual(tableItems['coords max'].text(), '50, 69')
+ self.assertEqual(tableItems['delta'].text(), '85')
class TestEmptyStatsWidget(TestCaseQt):
def test(self):
widget = StatsWidget.StatsWidget()
widget.show()
+ self.qWaitForWindowExposed(widget)
+
+
+# skip unit test for pyqt4 because there is some unrealised widget without
+# apparent reason
+@unittest.skipIf(qt.qVersion().split('.')[0] == '4', reason='PyQt4 not tested')
+class TestLineWidget(TestCaseQt):
+ """Some test for the StatsLineWidget."""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+
+ mystats = statshandler.StatsHandler((
+ (stats.StatMin(), statshandler.StatFormatter()),
+ ))
+
+ self.plot = Plot1D()
+ self.plot.show()
+ x = range(20)
+ y = range(20)
+ self.plot.addCurve(x, y, legend='curve0')
+ y = range(12, 32)
+ self.plot.addCurve(x, y, legend='curve1')
+ y = range(-2, 18)
+ self.plot.addCurve(x, y, legend='curve2')
+ self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot,
+ kind='curve',
+ stats=mystats)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.widget.setPlot(None)
+ self.widget._statQlineEdit.clear()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def test(self):
+ self.widget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.plot.setActiveCurve(legend='curve0')
+ self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.000')
+ self.plot.setActiveCurve(legend='curve1')
+ self.assertTrue(self.widget._statQlineEdit['min'].text() == '12.000')
+ self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
+ self.widget.setStatsOnVisibleData(True)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget._statQlineEdit['min'].text() == '14.000')
+ self.plot.setActiveCurve(None)
+ self.assertTrue(self.plot.getActiveCurve() is None)
+ self.widget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.assertFalse(self.widget._statQlineEdit['min'].text() == '14.000')
+ self.widget.setKind('image')
+ self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.312')
def suite():
test_suite = unittest.TestSuite()
for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters,
TestStatsWidgetWithImages, TestStatsWidgetWithCurves,
- TestStatsFormatter, TestEmptyStatsWidget):
+ TestStatsFormatter, TestEmptyStatsWidget,
+ TestLineWidget):
test_suite.addTest(
unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
return test_suite
diff --git a/silx/gui/plot/test/testUtilsAxis.py b/silx/gui/plot/test/testUtilsAxis.py
index 016fafe..64373b8 100644
--- a/silx/gui/plot/test/testUtilsAxis.py
+++ b/silx/gui/plot/test/testUtilsAxis.py
@@ -26,7 +26,7 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "14/02/2018"
+__date__ = "20/11/2018"
import unittest
@@ -155,6 +155,53 @@ class TestAxisSync(TestCaseQt):
self.assertEqual(self.plot2.getYAxis().isInverted(), True)
self.assertEqual(self.plot3.getYAxis().isInverted(), True)
+ def testSyncCenter(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False, syncCenter=True)
+
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (100 - 1, 100 + 1))
+
+ def testSyncCenterAndZoom(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False, syncCenter=True, syncZoom=True)
+
+ # Supposing all the plots use the same size
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (0, 200))
+
+ def testAddAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis()])
+ sync.addAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testRemoveAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.removeAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
def suite():
test_suite = unittest.TestSuite()
diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py
index d58c041..98295ba 100644
--- a/silx/gui/plot/tools/roi.py
+++ b/silx/gui/plot/tools/roi.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,6 +31,7 @@ __date__ = "28/06/2018"
import collections
+import enum
import functools
import logging
import time
@@ -38,7 +39,6 @@ import weakref
import numpy
-from ....third_party import enum
from ....utils.weakref import WeakMethodProxy
from ... import qt, icons
from .. import PlotWidget
@@ -806,11 +806,17 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
self.itemChanged.connect(self.__itemChanged)
- @staticmethod
- def __itemChanged(item):
+ def __itemChanged(self, item):
"""Handle item updates"""
column = item.column()
- roi = item.data(qt.Qt.UserRole)
+ index = item.data(qt.Qt.UserRole)
+
+ if index is not None:
+ manager = self.getRegionOfInterestManager()
+ roi = manager.getRois()[index]
+ else:
+ roi = None
+
if column == 0:
roi.setLabel(item.text())
elif column == 1:
@@ -882,13 +888,13 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
label = roi.getLabel()
item = qt.QTableWidgetItem(label)
item.setFlags(baseFlags | qt.Qt.ItemIsEditable)
- item.setData(qt.Qt.UserRole, roi)
+ item.setData(qt.Qt.UserRole, index)
self.setItem(index, 0, item)
# Editable
item = qt.QTableWidgetItem()
item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable)
- item.setData(qt.Qt.UserRole, roi)
+ item.setData(qt.Qt.UserRole, index)
item.setCheckState(
qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
self.setItem(index, 1, item)
diff --git a/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/silx/gui/plot/tools/test/testScatterProfileToolBar.py
index b99cac7..0f4b668 100644
--- a/silx/gui/plot/tools/test/testScatterProfileToolBar.py
+++ b/silx/gui/plot/tools/test/testScatterProfileToolBar.py
@@ -97,7 +97,7 @@ class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase):
self.profile._getRoiManager().addRoi(roi)
# Wait for async interpolator init
- for _ in range(10):
+ for _ in range(20):
self.qWait(200)
if not self.profile.hasPendingOperations():
break
diff --git a/silx/gui/plot/utils/axis.py b/silx/gui/plot/utils/axis.py
index bd19996..693e8eb 100644
--- a/silx/gui/plot/utils/axis.py
+++ b/silx/gui/plot/utils/axis.py
@@ -27,13 +27,14 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "23/02/2018"
+__date__ = "20/11/2018"
import functools
import logging
from contextlib import contextmanager
import weakref
import silx.utils.weakref as silxWeakref
+from silx.gui.plot.items.axis import Axis, XAxis, YAxis
try:
from ...qt.inspect import isValid as _isQObjectValid
@@ -61,7 +62,14 @@ class SyncAxes(object):
.. versionadded:: 0.6
"""
- def __init__(self, axes, syncLimits=True, syncScale=True, syncDirection=True):
+ def __init__(self, axes,
+ syncLimits=True,
+ syncScale=True,
+ syncDirection=True,
+ syncCenter=False,
+ syncZoom=False,
+ filterHiddenPlots=False
+ ):
"""
Constructor
@@ -69,17 +77,34 @@ class SyncAxes(object):
:param bool syncLimits: Synchronize axes limits
:param bool syncScale: Synchronize axes scale
:param bool syncDirection: Synchronize axes direction
+ :param bool syncCenter: Synchronize the center of the axes in the center
+ of the plots
+ :param bool syncZoom: Synchronize the zoom of the plot
+ :param bool filterHiddenPlots: True to avoid updating hidden plots.
+ Default: False.
"""
object.__init__(self)
+
+ def implies(x, y): return bool(y ** x)
+
+ assert(implies(syncZoom, not syncLimits))
+ assert(implies(syncCenter, not syncLimits))
+ assert(implies(syncLimits, not syncCenter))
+ assert(implies(syncLimits, not syncZoom))
+
+ self.__filterHiddenPlots = filterHiddenPlots
self.__locked = False
self.__axisRefs = []
self.__syncLimits = syncLimits
self.__syncScale = syncScale
self.__syncDirection = syncDirection
+ self.__syncCenter = syncCenter
+ self.__syncZoom = syncZoom
self.__callbacks = None
+ self.__lastMainAxis = None
for axis in axes:
- self.__axisRefs.append(weakref.ref(axis))
+ self.addAxis(axis)
self.start()
@@ -90,47 +115,131 @@ class SyncAxes(object):
After that, any changes to any axes will be used to synchronize other
axes.
"""
- if self.__callbacks is not None:
+ if self.isSynchronizing():
raise RuntimeError("Axes already synchronized")
self.__callbacks = {}
axes = self.__getAxes()
- if len(axes) == 0:
- raise RuntimeError('No axis to synchronize')
# register callback for further sync
for axis in axes:
- refAxis = weakref.ref(axis)
- callbacks = []
- if self.__syncLimits:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigLimitsChanged
- sig.connect(callback)
- callbacks.append(("sigLimitsChanged", callback))
- if self.__syncScale:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigScaleChanged
- sig.connect(callback)
- callbacks.append(("sigScaleChanged", callback))
- if self.__syncDirection:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigInvertedChanged
- sig.connect(callback)
- callbacks.append(("sigInvertedChanged", callback))
-
- self.__callbacks[refAxis] = callbacks
+ self.__connectAxes(axis)
+ self.synchronize()
+
+ def isSynchronizing(self):
+ """Returns true if events are connected to the axes to synchronize them
+ all together
+
+ :rtype: bool
+ """
+ return self.__callbacks is not None
+
+ def __connectAxes(self, axis):
+ refAxis = weakref.ref(axis)
+ callbacks = []
+ if self.__syncLimits:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncCenter and self.__syncZoom:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterAndZoomChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncZoom:
+ raise NotImplementedError()
+ elif self.__syncCenter:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ if self.__syncScale:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigScaleChanged
+ sig.connect(callback)
+ callbacks.append(("sigScaleChanged", callback))
+ if self.__syncDirection:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigInvertedChanged
+ sig.connect(callback)
+ callbacks.append(("sigInvertedChanged", callback))
+
+ if self.__filterHiddenPlots:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisVisibilityChanged)
+ callback = functools.partial(callback, refAxis)
+ plot = axis._getPlot()
+ plot.sigVisibilityChanged.connect(callback)
+ callbacks.append(("sigVisibilityChanged", callback))
+
+ self.__callbacks[refAxis] = callbacks
+ def __disconnectAxes(self, axis):
+ if axis is not None and _isQObjectValid(axis):
+ ref = weakref.ref(axis)
+ callbacks = self.__callbacks.pop(ref)
+ for sigName, callback in callbacks:
+ if sigName == "sigVisibilityChanged":
+ obj = axis._getPlot()
+ else:
+ obj = axis
+ if obj is not None:
+ sig = getattr(obj, sigName)
+ sig.disconnect(callback)
+
+ def addAxis(self, axis):
+ """Add a new axes to synchronize.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to synchronize
+ """
+ self.__axisRefs.append(weakref.ref(axis))
+ if self.isSynchronizing():
+ self.__connectAxes(axis)
+ # This could be done faster as only this axis have to be fixed
+ self.synchronize()
+
+ def removeAxis(self, axis):
+ """Remove an axis from the synchronized axes.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to remove
+ """
+ ref = weakref.ref(axis)
+ self.__axisRefs.remove(ref)
+ if self.isSynchronizing():
+ self.__disconnectAxes(axis)
+
+ def synchronize(self, mainAxis=None):
+ """Synchronize programatically all the axes.
+
+ :param ~silx.gui.plot.items.Axis mainAxis:
+ The axis to take as reference (Default: the first axis).
+ """
# sync the current state
- mainAxis = axes[0]
+ axes = self.__getAxes()
+ if len(axes) == 0:
+ return
+
+ if mainAxis is None:
+ mainAxis = axes[0]
+
refMainAxis = weakref.ref(mainAxis)
if self.__syncLimits:
self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter and self.__syncZoom:
+ self.__axisCenterAndZoomChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter:
+ self.__axisCenterChanged(refMainAxis, *mainAxis.getLimits())
if self.__syncScale:
self.__axisScaleChanged(refMainAxis, mainAxis.getScale())
if self.__syncDirection:
@@ -138,14 +247,11 @@ class SyncAxes(object):
def stop(self):
"""Stop the synchronization of the axes"""
- if self.__callbacks is None:
+ if not self.isSynchronizing():
raise RuntimeError("Axes not synchronized")
- for ref, callbacks in self.__callbacks.items():
+ for ref in list(self.__callbacks.keys()):
axis = ref()
- if axis is not None and _isQObjectValid(axis):
- for sigName, callback in callbacks:
- sig = getattr(axis, sigName)
- sig.disconnect(callback)
+ self.__disconnectAxes(axis)
self.__callbacks = None
def __del__(self):
@@ -168,32 +274,130 @@ class SyncAxes(object):
yield
self.__locked = False
- def __otherAxes(self, changedAxis):
+ def __axesToUpdate(self, changedAxis):
for axis in self.__getAxes():
if axis is changedAxis:
continue
+ if self.__filterHiddenPlots:
+ plot = axis._getPlot()
+ if not plot.isVisible():
+ continue
yield axis
+ def __axisVisibilityChanged(self, changedAxis, isVisible):
+ if not isVisible:
+ return
+ if self.__locked:
+ return
+ changedAxis = changedAxis()
+ if self.__lastMainAxis is None:
+ self.__lastMainAxis = self.__axisRefs[0]
+ mainAxis = self.__lastMainAxis
+ mainAxis = mainAxis()
+ self.synchronize(mainAxis=mainAxis)
+ # force back the main axis
+ self.__lastMainAxis = weakref.ref(mainAxis)
+
+ def __getAxesCenter(self, axis, vmin, vmax):
+ """Returns the value displayed in the center of this axis range.
+
+ :rtype: float
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ center = (vmin + vmax) * 0.5
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ return center
+
+ def __getRangeInPixel(self, axis):
+ """Returns the size of the axis in pixel"""
+ bounds = axis._getPlot().getPlotBoundsInPixels()
+ # bounds: left, top, width, height
+ if isinstance(axis, XAxis):
+ return bounds[2]
+ elif isinstance(axis, YAxis):
+ return bounds[3]
+ else:
+ assert(False)
+
+ def __getLimitsFromCenter(self, axis, pos, pixelSize=None):
+ """Returns the limits to apply to this axis to move the `pos` into the
+ center of this axis.
+
+ :param Axis axis:
+ :param float pos: Position in the center of the computed limits
+ :param Union[None,float] pixelSize: Pixel size to apply to compute the
+ limits. If `None` the current pixel size is applyed.
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ if pixelSize is None:
+ # Use the current pixel size of the axis
+ limits = axis.getLimits()
+ valueRange = limits[0] - limits[1]
+ a = pos - valueRange * 0.5
+ b = pos + valueRange * 0.5
+ else:
+ pixelRange = self.__getRangeInPixel(axis)
+ a = pos - pixelRange * 0.5 * pixelSize
+ b = pos + pixelRange * 0.5 * pixelSize
+
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ if a > b:
+ return b, a
+ return a, b
+
def __axisLimitsChanged(self, changedAxis, vmin, vmax):
if self.__locked:
return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterAndZoomChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ pixelRange = self.__getRangeInPixel(changedAxis)
+ if pixelRange == 0:
+ return
+ pixelSize = (vmax - vmin) / pixelRange
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center, pixelSize)
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
changedAxis = changedAxis()
with self.__inhibitSignals():
- for axis in self.__otherAxes(changedAxis):
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center)
axis.setLimits(vmin, vmax)
def __axisScaleChanged(self, changedAxis, scale):
if self.__locked:
return
+ self.__lastMainAxis = changedAxis
changedAxis = changedAxis()
with self.__inhibitSignals():
- for axis in self.__otherAxes(changedAxis):
+ for axis in self.__axesToUpdate(changedAxis):
axis.setScale(scale)
def __axisInvertedChanged(self, changedAxis, isInverted):
if self.__locked:
return
+ self.__lastMainAxis = changedAxis
changedAxis = changedAxis()
with self.__inhibitSignals():
- for axis in self.__otherAxes(changedAxis):
+ for axis in self.__axesToUpdate(changedAxis):
axis.setInverted(isInverted)
diff --git a/silx/gui/plot3d/ParamTreeView.py b/silx/gui/plot3d/ParamTreeView.py
index ee0c876..8cf2b90 100644
--- a/silx/gui/plot3d/ParamTreeView.py
+++ b/silx/gui/plot3d/ParamTreeView.py
@@ -43,7 +43,7 @@ __date__ = "05/12/2017"
import numbers
import sys
-from silx.third_party import six
+import six
from .. import qt
from ..widgets.FloatEdit import FloatEdit as _FloatEdit
diff --git a/silx/gui/plot3d/ScalarFieldView.py b/silx/gui/plot3d/ScalarFieldView.py
index e5e680c..50cba05 100644
--- a/silx/gui/plot3d/ScalarFieldView.py
+++ b/silx/gui/plot3d/ScalarFieldView.py
@@ -886,6 +886,8 @@ class ScalarFieldView(Plot3DWindow):
self._bbox = axes.LabelledAxes()
self._bbox.children = [self._group]
+ self._outerScale = transform.Scale(1., 1., 1.)
+ self._bbox.transforms = [self._outerScale]
self.getPlot3DWidget().viewport.scene.children.append(self._bbox)
self._selectionBox = primitives.Box()
@@ -1204,6 +1206,25 @@ class ScalarFieldView(Plot3DWindow):
# Transformations
+ def setOuterScale(self, sx=1., sy=1., sz=1.):
+ """Set the scale to apply to the whole scene including the axes.
+
+ This is useful when axis lengths in data space are really different.
+
+ :param float sx: Scale factor along the X axis
+ :param float sy: Scale factor along the Y axis
+ :param float sz: Scale factor along the Z axis
+ """
+ self._outerScale.setScale(sx, sy, sz)
+ self.centerScene()
+
+ def getOuterScale(self):
+ """Returns the scales provided by :meth:`setOuterScale`.
+
+ :rtype: numpy.ndarray
+ """
+ return self._outerScale.scale
+
def setScale(self, sx=1., sy=1., sz=1.):
"""Set the scale of the 3D scalar field (i.e., size of a voxel).
diff --git a/silx/gui/plot3d/SceneWidget.py b/silx/gui/plot3d/SceneWidget.py
index 4a824d7..e60dcfc 100644
--- a/silx/gui/plot3d/SceneWidget.py
+++ b/silx/gui/plot3d/SceneWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,10 +30,11 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "24/04/2018"
-import numpy
+import enum
import weakref
-from silx.third_party import enum
+import numpy
+
from .. import qt
from ..colors import rgba
@@ -229,6 +230,9 @@ class SceneSelection(qt.QObject):
:raise ValueError: If the item is not the widget's scene
"""
previous = self.getCurrentItem()
+ if item is previous:
+ return # Fast path, nothing to do
+
if previous is not None:
previous.sigItemChanged.disconnect(self.__currentChanged)
@@ -252,15 +256,18 @@ class SceneSelection(qt.QObject):
'Not an Item3D: %s' % str(item))
current = self.getCurrentItem()
- if current is not previous:
- self.sigCurrentChanged.emit(current, previous)
- self.__updateSelectionModel()
+ self.sigCurrentChanged.emit(current, previous)
+ self.__updateSelectionModel()
def __currentChanged(self, event):
"""Handle updates of the selected item"""
if event == items.Item3DChangedType.ROOT_ITEM:
item = self.sender()
- if item.root() != self.getSceneGroup():
+
+ parent = self.parent()
+ assert isinstance(parent, SceneWidget)
+
+ if item.root() != parent.getSceneGroup():
self.setSelectedItem(None)
# Synchronization with QItemSelectionModel
@@ -488,7 +495,8 @@ class SceneWidget(Plot3DWidget):
:param int index: The index at which to place the item.
By default it is appended to the end of the list.
:return: The newly created scalar volume item
- :rtype: items.ScalarField3D
+ :rtype: ~silx.gui.plot3d.items.volume.ScalarField3D
+
"""
volume = items.ScalarField3D()
volume.setData(data, copy=copy)
@@ -508,7 +516,7 @@ class SceneWidget(Plot3DWidget):
:param int index: The index at which to place the item.
By default it is appended to the end of the list.
:return: The newly created 3D scatter item
- :rtype: items.Scatter3D
+ :rtype: ~silx.gui.plot3d.items.scatter.Scatter3D
"""
scatter3d = items.Scatter3D()
scatter3d.setData(x=x, y=y, z=z, value=value, copy=copy)
@@ -528,7 +536,7 @@ class SceneWidget(Plot3DWidget):
:param int index: The index at which to place the item.
By default it is appended to the end of the list.
:return: The newly created 2D scatter item
- :rtype: items.Scatter2D
+ :rtype: ~silx.gui.plot3d.items.scatter.Scatter2D
"""
scatter2d = items.Scatter2D()
scatter2d.setData(x=x, y=y, value=value, copy=copy)
@@ -548,7 +556,7 @@ class SceneWidget(Plot3DWidget):
:param int index: The index at which to place the item.
By default it is appended to the end of the list.
:return: The newly created image item
- :rtype: items.ImageData or items.ImageRgba
+ :rtype: ~silx.gui.plot3d.items.image.ImageData or ~silx.gui.plot3d.items.image.ImageRgba
:raise ValueError: For arrays of unsupported dimensions
"""
data = numpy.array(data, copy=False)
diff --git a/silx/gui/plot3d/_model/items.py b/silx/gui/plot3d/_model/items.py
index b09f29a..7e58d14 100644
--- a/silx/gui/plot3d/_model/items.py
+++ b/silx/gui/plot3d/_model/items.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -38,8 +38,7 @@ import logging
import weakref
import numpy
-
-from silx.third_party import six
+import six
from ...utils.image import convertArrayToQImage
from ...colors import preferredColormaps
@@ -202,7 +201,7 @@ class Settings(StaticRow):
super(Settings, self).__init__(('Settings', None), children=children)
-class Item3DRow(StaticRow):
+class Item3DRow(BaseRow):
"""Represents an :class:`Item3D` with checkable visibility
:param Item3D item: The scene item to represent.
@@ -210,9 +209,8 @@ class Item3DRow(StaticRow):
"""
def __init__(self, item, name=None):
- if name is None:
- name = item.getLabel()
- super(Item3DRow, self).__init__((name, None))
+ self.__name = None if name is None else six.text_type(name)
+ super(Item3DRow, self).__init__()
self.setFlags(
self.flags(0) | qt.Qt.ItemIsUserCheckable | qt.Qt.ItemIsSelectable,
@@ -224,7 +222,8 @@ class Item3DRow(StaticRow):
def _itemChanged(self, event):
"""Handle visibility change"""
- if event == items.ItemChangedType.VISIBLE:
+ if event in (items.ItemChangedType.VISIBLE,
+ items.Item3DChangedType.LABEL):
model = self.model()
if model is not None:
index = self.index(column=1)
@@ -235,16 +234,25 @@ class Item3DRow(StaticRow):
return self._item()
def data(self, column, role):
- if column == 0 and role == qt.Qt.CheckStateRole:
- item = self.item()
- if item is not None and item.isVisible():
- return qt.Qt.Checked
- else:
- return qt.Qt.Unchecked
- elif column == 0 and role == qt.Qt.DecorationRole:
- return icons.getQIcon('item-3dim')
- else:
- return super(Item3DRow, self).data(column, role)
+ if column == 0:
+ if role == qt.Qt.CheckStateRole:
+ item = self.item()
+ if item is not None and item.isVisible():
+ return qt.Qt.Checked
+ else:
+ return qt.Qt.Unchecked
+
+ elif role == qt.Qt.DecorationRole:
+ return icons.getQIcon('item-3dim')
+
+ elif role == qt.Qt.DisplayRole:
+ if self.__name is None:
+ item = self.item()
+ return '' if item is None else item.getLabel()
+ else:
+ return self.__name
+
+ return super(Item3DRow, self).data(column, role)
def setData(self, column, value, role):
if column == 0 and role == qt.Qt.CheckStateRole:
@@ -256,6 +264,9 @@ class Item3DRow(StaticRow):
return False
return super(Item3DRow, self).setData(column, value, role)
+ def columnCount(self):
+ return 2
+
class DataItem3DBoundingBoxRow(ProxyRow):
"""Represents :class:`DataItem3D` bounding box visibility
@@ -562,7 +573,6 @@ class _ColormapBaseProxyRow(ProxyRow):
"""Signal used internally to notify colormap (or data) update"""
def __init__(self, item, *args, **kwargs):
- self._dataRange = None
self._item = weakref.ref(item)
self._colormap = item.getColormap()
@@ -581,19 +591,11 @@ class _ColormapBaseProxyRow(ProxyRow):
:return: Colormap range (min, max)
"""
- if self._dataRange is None:
- item = self.item()
- if item is not None and self._colormap is not None:
- if hasattr(item, 'getDataRange'):
- data = item.getDataRange()
- else:
- data = item.getData(copy=False)
-
- self._dataRange = self._colormap.getColormapRange(data)
-
- else: # Fallback
- self._dataRange = 1, 100
- return self._dataRange
+ item = self.item()
+ if item is not None and self._colormap is not None:
+ return self._colormap.getColormapRange(item._getDataRange())
+ else:
+ return 1, 100 # Fallback
def _modelUpdated(self, *args, **kwargs):
"""Emit dataChanged in the model"""
@@ -624,7 +626,6 @@ class _ColormapBaseProxyRow(ProxyRow):
self._colormap = None
elif event == items.ItemChangedType.DATA:
- self._dataRange = None
self._sigColormapChanged.emit()
diff --git a/silx/gui/plot3d/items/__init__.py b/silx/gui/plot3d/items/__init__.py
index b2a9dab..58eee9c 100644
--- a/silx/gui/plot3d/items/__init__.py
+++ b/silx/gui/plot3d/items/__init__.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -38,6 +38,6 @@ from .mixins import (ColormapMixIn, InterpolationMixIn, # noqa
PlaneMixIn, SymbolMixIn) # noqa
from .clipplane import ClipPlane # noqa
from .image import ImageData, ImageRgba # noqa
-from .mesh import Mesh, Box, Cylinder, Hexagon # noqa
+from .mesh import Mesh, ColormapMesh, Box, Cylinder, Hexagon # noqa
from .scatter import Scatter2D, Scatter3D # noqa
from .volume import ScalarField3D # noqa
diff --git a/silx/gui/plot3d/items/core.py b/silx/gui/plot3d/items/core.py
index 0aefced..1745b2b 100644
--- a/silx/gui/plot3d/items/core.py
+++ b/silx/gui/plot3d/items/core.py
@@ -32,10 +32,10 @@ __license__ = "MIT"
__date__ = "15/11/2017"
from collections import defaultdict
+import enum
import numpy
-
-from silx.third_party import enum, six
+import six
from ... import qt
from ...plot.items import ItemChangedType
diff --git a/silx/gui/plot3d/items/mesh.py b/silx/gui/plot3d/items/mesh.py
index 21936ea..d3f5e38 100644
--- a/silx/gui/plot3d/items/mesh.py
+++ b/silx/gui/plot3d/items/mesh.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -35,17 +35,18 @@ __date__ = "17/07/2018"
import logging
import numpy
-from ..scene import primitives, utils
+from ..scene import primitives, utils, function
from ..scene.transform import Rotate
from .core import DataItem3D, ItemChangedType
+from .mixins import ColormapMixIn
from ._pick import PickingResult
_logger = logging.getLogger(__name__)
-class Mesh(DataItem3D):
- """Description of mesh.
+class _MeshBase(DataItem3D):
+ """Base class for :class:`Mesh' and :class:`ColormapMesh`.
:param parent: The View widget this item belongs to.
"""
@@ -54,48 +55,22 @@ class Mesh(DataItem3D):
DataItem3D.__init__(self, parent=parent)
self._mesh = None
- def setData(self,
- position,
- color,
- normal=None,
- mode='triangles',
- copy=True):
- """Set mesh geometry data.
-
- Supported drawing modes are: 'triangles', 'triangle_strip', 'fan'
+ def _setMesh(self, mesh):
+ """Set mesh primitive
- :param numpy.ndarray position:
- Position (x, y, z) of each vertex as a (N, 3) array
- :param numpy.ndarray color: Colors for each point or a single color
- :param numpy.ndarray normal: Normals for each point or None (default)
- :param str mode: The drawing mode.
- :param bool copy: True (default) to copy the data,
- False to use as is (do not modify!).
+ :param Union[None,Geometry] mesh: The scene primitive
"""
self._getScenePrimitive().children = [] # Remove any previous mesh
- if position is None or len(position) == 0:
- self._mesh = None
- else:
- self._mesh = primitives.Mesh3D(
- position, color, normal, mode=mode, copy=copy)
+ self._mesh = mesh
+ if self._mesh is not None:
self._getScenePrimitive().children.append(self._mesh)
- self.sigItemChanged.emit(ItemChangedType.DATA)
+ self._updated(ItemChangedType.DATA)
- def getData(self, copy=True):
- """Get the mesh geometry.
-
- :param bool copy:
- True (default) to get a copy,
- False to get internal representation (do not modify!).
- :return: The positions, colors, normals and mode
- :rtype: tuple of numpy.ndarray
- """
- return (self.getPositionData(copy=copy),
- self.getColorData(copy=copy),
- self.getNormalData(copy=copy),
- self.getDrawMode())
+ def _getMesh(self):
+ """Returns the underlying Mesh scene primitive"""
+ return self._mesh
def getPositionData(self, copy=True):
"""Get the mesh vertex positions.
@@ -106,38 +81,38 @@ class Mesh(DataItem3D):
:return: The (x, y, z) positions as a (N, 3) array
:rtype: numpy.ndarray
"""
- if self._mesh is None:
+ if self._getMesh() is None:
return numpy.empty((0, 3), dtype=numpy.float32)
else:
- return self._mesh.getAttribute('position', copy=copy)
+ return self._getMesh().getAttribute('position', copy=copy)
- def getColorData(self, copy=True):
- """Get the mesh vertex colors.
+ def getNormalData(self, copy=True):
+ """Get the mesh vertex normals.
:param bool copy:
True (default) to get a copy,
False to get internal representation (do not modify!).
- :return: The RGBA colors as a (N, 4) array or a single color
- :rtype: numpy.ndarray
+ :return: The normals as a (N, 3) array, a single normal or None
+ :rtype: Union[numpy.ndarray,None]
"""
- if self._mesh is None:
- return numpy.empty((0, 4), dtype=numpy.float32)
+ if self._getMesh() is None:
+ return None
else:
- return self._mesh.getAttribute('color', copy=copy)
+ return self._getMesh().getAttribute('normal', copy=copy)
- def getNormalData(self, copy=True):
- """Get the mesh vertex normals.
+ def getIndices(self, copy=True):
+ """Get the vertex indices.
:param bool copy:
True (default) to get a copy,
False to get internal representation (do not modify!).
- :return: The normals as a (N, 3) array, a single normal or None
- :rtype: numpy.ndarray or None
+ :return: The vertex indices as an array or None.
+ :rtype: Union[numpy.ndarray,None]
"""
- if self._mesh is None:
+ if self._getMesh() is None:
return None
else:
- return self._mesh.getAttribute('normal', copy=copy)
+ return self._getMesh().getIndices(copy=copy)
def getDrawMode(self):
"""Get mesh rendering mode.
@@ -145,7 +120,7 @@ class Mesh(DataItem3D):
:return: The drawing mode of this primitive
:rtype: str
"""
- return self._mesh.drawMode
+ return self._getMesh().drawMode
def _pickFull(self, context):
"""Perform precise picking in this item at given widget position.
@@ -164,28 +139,34 @@ class Mesh(DataItem3D):
return None
mode = self.getDrawMode()
- if mode == 'triangles':
- triangles = positions.reshape(-1, 3, 3)
-
- elif mode == 'triangle_strip':
- # Expand strip
- triangles = numpy.empty((len(positions) - 2, 3, 3),
- dtype=positions.dtype)
- triangles[:, 0] = positions[:-2]
- triangles[:, 1] = positions[1:-1]
- triangles[:, 2] = positions[2:]
-
- elif mode == 'fan':
- # Expand fan
- triangles = numpy.empty((len(positions) - 2, 3, 3),
- dtype=positions.dtype)
- triangles[:, 0] = positions[0]
- triangles[:, 1] = positions[1:-1]
- triangles[:, 2] = positions[2:]
+ vertexIndices = self.getIndices(copy=False)
+ if vertexIndices is not None: # Expand indices
+ positions = utils.unindexArrays(mode, vertexIndices, positions)[0]
+ triangles = positions.reshape(-1, 3, 3)
else:
- _logger.warning("Unsupported draw mode: %s" % mode)
- return None
+ if mode == 'triangles':
+ triangles = positions.reshape(-1, 3, 3)
+
+ elif mode == 'triangle_strip':
+ # Expand strip
+ triangles = numpy.empty((len(positions) - 2, 3, 3),
+ dtype=positions.dtype)
+ triangles[:, 0] = positions[:-2]
+ triangles[:, 1] = positions[1:-1]
+ triangles[:, 2] = positions[2:]
+
+ elif mode == 'fan':
+ # Expand fan
+ triangles = numpy.empty((len(positions) - 2, 3, 3),
+ dtype=positions.dtype)
+ triangles[:, 0] = positions[0]
+ triangles[:, 1] = positions[1:-1]
+ triangles[:, 2] = positions[2:]
+
+ else:
+ _logger.warning("Unsupported draw mode: %s" % mode)
+ return None
trianglesIndices, t, barycentric = utils.segmentTrianglesIntersection(
rayObject, triangles)
@@ -208,12 +189,160 @@ class Mesh(DataItem3D):
indices = trianglesIndices + closest # For corners 1 and 2
indices[closest == 0] = 0 # For first corner (common)
+ if vertexIndices is not None:
+ # Convert from indices in expanded triangles to input vertices
+ indices = vertexIndices[indices]
+
return PickingResult(self,
positions=points,
indices=indices,
fetchdata=self.getPositionData)
+class Mesh(_MeshBase):
+ """Description of mesh.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _MeshBase.__init__(self, parent=parent)
+
+ def setData(self,
+ position,
+ color,
+ normal=None,
+ mode='triangles',
+ indices=None,
+ copy=True):
+ """Set mesh geometry data.
+
+ Supported drawing modes are: 'triangles', 'triangle_strip', 'fan'
+
+ :param numpy.ndarray position:
+ Position (x, y, z) of each vertex as a (N, 3) array
+ :param numpy.ndarray color: Colors for each point or a single color
+ :param Union[numpy.ndarray,None] normal: Normals for each point or None (default)
+ :param str mode: The drawing mode.
+ :param Union[List[int],None] indices:
+ Array of vertex indices or None to use arrays directly.
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ assert mode in ('triangles', 'triangle_strip', 'fan')
+ if position is None or len(position) == 0:
+ mesh = None
+ else:
+ mesh = primitives.Mesh3D(
+ position, color, normal, mode=mode, indices=indices, copy=copy)
+ self._setMesh(mesh)
+
+ def getData(self, copy=True):
+ """Get the mesh geometry.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: The positions, colors, normals and mode
+ :rtype: tuple of numpy.ndarray
+ """
+ return (self.getPositionData(copy=copy),
+ self.getColorData(copy=copy),
+ self.getNormalData(copy=copy),
+ self.getDrawMode())
+
+ def getColorData(self, copy=True):
+ """Get the mesh vertex colors.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: The RGBA colors as a (N, 4) array or a single color
+ :rtype: numpy.ndarray
+ """
+ if self._getMesh() is None:
+ return numpy.empty((0, 4), dtype=numpy.float32)
+ else:
+ return self._getMesh().getAttribute('color', copy=copy)
+
+
+class ColormapMesh(_MeshBase, ColormapMixIn):
+ """Description of mesh which color is defined by scalar and a colormap.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _MeshBase.__init__(self, parent=parent)
+ ColormapMixIn.__init__(self, function.Colormap())
+
+ def setData(self,
+ position,
+ value,
+ normal=None,
+ mode='triangles',
+ indices=None,
+ copy=True):
+ """Set mesh geometry data.
+
+ Supported drawing modes are: 'triangles', 'triangle_strip', 'fan'
+
+ :param numpy.ndarray position:
+ Position (x, y, z) of each vertex as a (N, 3) array
+ :param numpy.ndarray value: Data value for each vertex.
+ :param Union[numpy.ndarray,None] normal: Normals for each point or None (default)
+ :param str mode: The drawing mode.
+ :param Union[List[int],None] indices:
+ Array of vertex indices or None to use arrays directly.
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ assert mode in ('triangles', 'triangle_strip', 'fan')
+ if position is None or len(position) == 0:
+ mesh = None
+ else:
+ mesh = primitives.ColormapMesh3D(
+ position=position,
+ value=numpy.array(value, copy=False).reshape(-1, 1), # Make it a 2D array
+ colormap=self._getSceneColormap(),
+ normal=normal,
+ mode=mode,
+ indices=indices,
+ copy=copy)
+ self._setMesh(mesh)
+
+ # Store data range info
+ ColormapMixIn._setRangeFromData(self, self.getValueData(copy=False))
+
+ def getData(self, copy=True):
+ """Get the mesh geometry.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: The positions, values, normals and mode
+ :rtype: tuple of numpy.ndarray
+ """
+ return (self.getPositionData(copy=copy),
+ self.getValueData(copy=copy),
+ self.getNormalData(copy=copy),
+ self.getDrawMode())
+
+ def getValueData(self, copy=True):
+ """Get the mesh vertex values.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: Array of data values
+ :rtype: numpy.ndarray
+ """
+ if self._getMesh() is None:
+ return numpy.empty((0,), dtype=numpy.float32)
+ else:
+ return self._getMesh().getAttribute('value', copy=copy)
+
+
class _CylindricalVolume(DataItem3D):
"""Class that represents a volume with a rotational symmetry along z
@@ -345,7 +474,7 @@ class _CylindricalVolume(DataItem3D):
vertices, color, normals, mode='triangles', copy=False)
self._getScenePrimitive().children.append(self._mesh)
- self.sigItemChanged.emit(ItemChangedType.DATA)
+ self._updated(ItemChangedType.DATA)
def _pickFull(self, context):
"""Perform precise picking in this item at given widget position.
diff --git a/silx/gui/plot3d/items/mixins.py b/silx/gui/plot3d/items/mixins.py
index 8e96441..40b8438 100644
--- a/silx/gui/plot3d/items/mixins.py
+++ b/silx/gui/plot3d/items/mixins.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -114,19 +114,17 @@ class ColormapMixIn(_ColormapMixIn):
self.__sceneColormap = sceneColormap
self._syncSceneColormap()
- self.sigItemChanged.connect(self.__colormapUpdated)
-
- def __colormapUpdated(self, event):
+ def _colormapChanged(self):
"""Handle colormap updates"""
- if event == ItemChangedType.COLORMAP:
- self._syncSceneColormap()
+ self._syncSceneColormap()
+ super(ColormapMixIn, self)._colormapChanged()
def _setRangeFromData(self, data=None):
"""Compute the data range the colormap should use from provided data.
:param data: Data set from which to compute the range or None
"""
- if data is None or len(data) == 0:
+ if data is None or data.size == 0:
dataRange = None
else:
dataRange = min_max(data, min_positive=True, finite=True)
@@ -144,6 +142,13 @@ class ColormapMixIn(_ColormapMixIn):
if self.getColormap().isAutoscale():
self._syncSceneColormap()
+ def _getDataRange(self):
+ """Returns the data range as used in the scene for colormap
+
+ :rtype: Union[List[float],None]
+ """
+ return self._dataRange
+
def _setSceneColormap(self, sceneColormap):
"""Set the scene colormap to sync with Colormap object.
@@ -171,8 +176,6 @@ class ColormapMixIn(_ColormapMixIn):
class SymbolMixIn(_SymbolMixIn):
"""Mix-in class for symbol and symbolSize properties for Item3D"""
- _DEFAULT_SYMBOL = 'o'
- _DEFAULT_SYMBOL_SIZE = 7.0
_SUPPORTED_SYMBOLS = collections.OrderedDict((
('o', 'Circle'),
('d', 'Diamond'),
diff --git a/silx/gui/plot3d/items/scatter.py b/silx/gui/plot3d/items/scatter.py
index a13c3db..b7bcd09 100644
--- a/silx/gui/plot3d/items/scatter.py
+++ b/silx/gui/plot3d/items/scatter.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -36,6 +36,7 @@ import logging
import sys
import numpy
+from ....utils.deprecation import deprecated
from ..scene import function, primitives, utils
from .core import DataItem3D, Item3DChangedType, ItemChangedType
@@ -43,7 +44,7 @@ from .mixins import ColormapMixIn, SymbolMixIn
from ._pick import PickingResult
-_logger = logging.getLevelName(__name__)
+_logger = logging.getLogger(__name__)
class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
@@ -94,7 +95,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
self._scatter.setAttribute('z', z, copy=copy)
self._scatter.setAttribute('value', value, copy=copy)
- ColormapMixIn._setRangeFromData(self, self.getValues(copy=False))
+ ColormapMixIn._setRangeFromData(self, self.getValueData(copy=False))
self._updated(ItemChangedType.DATA)
def getData(self, copy=True):
@@ -107,7 +108,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
return (self.getXData(copy),
self.getYData(copy),
self.getZData(copy),
- self.getValues(copy))
+ self.getValueData(copy))
def getXData(self, copy=True):
"""Returns X data coordinates.
@@ -139,7 +140,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
"""
return self._scatter.getAttribute('z', copy=copy).reshape(-1)
- def getValues(self, copy=True):
+ def getValueData(self, copy=True):
"""Returns data values.
:param bool copy: True to get a copy,
@@ -149,6 +150,