diff options
Diffstat (limited to 'silx/gui')
140 files changed, 15211 insertions, 4100 deletions
diff --git a/silx/gui/__init__.py b/silx/gui/__init__.py index 6baf238..b796e20 100644 --- a/silx/gui/__init__.py +++ b/silx/gui/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,7 +22,27 @@ # THE SOFTWARE. # # ###########################################################################*/ -"""Set of Qt widgets""" +"""This package provides a set of Qt widgets. + +It contains the following sub-packages and modules: + +- silx.gui.colors: Functions to handle colors and colormap +- silx.gui.console: IPython console widget +- silx.gui.data: + Widgets for displaying data arrays using table views and plot widgets +- silx.gui.dialog: Specific dialog widgets +- silx.gui.fit: Widgets for controlling curve fitting +- silx.gui.hdf5: Widgets for displaying content relative to HDF5 format +- silx.gui.icons: Functions to access embedded icons +- silx.gui.plot: Widgets for 1D and 2D plotting and related tools +- silx.gui.plot3d: Widgets for visualizing data in 3D based on OpenGL +- silx.gui.printer: Shared printer used by the library +- silx.gui.qt: Common wrapper over different Python Qt binding +- silx.gui.utils: Miscellaneous helpers for Qt +- silx.gui.widgets: Miscellaneous standalone widgets + +See silx documentation: http://www.silx.org/doc/silx/latest/ +""" __authors__ = ["T. Vincent"] __license__ = "MIT" diff --git a/silx/gui/_glutils/font.py b/silx/gui/_glutils/font.py index 2be2c04..b5bd6b5 100644 --- a/silx/gui/_glutils/font.py +++ b/silx/gui/_glutils/font.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -30,11 +30,10 @@ __date__ = "13/10/2016" import logging -import sys import numpy -from .. import qt -from .._utils import convertQImageToArray +from ..utils._image import convertQImageToArray +from .. import qt _logger = logging.getLogger(__name__) diff --git a/silx/gui/colors.py b/silx/gui/colors.py new file mode 100644 index 0000000..028609b --- /dev/null +++ b/silx/gui/colors.py @@ -0,0 +1,732 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides API to manage colors. +""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent", "H.Payno"] +__license__ = "MIT" +__date__ = "14/06/2018" + +from silx.gui import qt +import copy as copy_mdl +import numpy +import logging +from silx.math.combo import min_max +from silx.math.colormap import cmap as _cmap +from silx.utils.exceptions import NotEditableError + +_logger = logging.getLogger(__file__) + + +_COLORDICT = {} +"""Dictionary of common colors.""" + +_COLORDICT['b'] = _COLORDICT['blue'] = '#0000ff' +_COLORDICT['r'] = _COLORDICT['red'] = '#ff0000' +_COLORDICT['g'] = _COLORDICT['green'] = '#00ff00' +_COLORDICT['k'] = _COLORDICT['black'] = '#000000' +_COLORDICT['w'] = _COLORDICT['white'] = '#ffffff' +_COLORDICT['pink'] = '#ff66ff' +_COLORDICT['brown'] = '#a52a2a' +_COLORDICT['orange'] = '#ff9900' +_COLORDICT['violet'] = '#6600ff' +_COLORDICT['gray'] = _COLORDICT['grey'] = '#a0a0a4' +# _COLORDICT['darkGray'] = _COLORDICT['darkGrey'] = '#808080' +# _COLORDICT['lightGray'] = _COLORDICT['lightGrey'] = '#c0c0c0' +_COLORDICT['y'] = _COLORDICT['yellow'] = '#ffff00' +_COLORDICT['m'] = _COLORDICT['magenta'] = '#ff00ff' +_COLORDICT['c'] = _COLORDICT['cyan'] = '#00ffff' +_COLORDICT['darkBlue'] = '#000080' +_COLORDICT['darkRed'] = '#800000' +_COLORDICT['darkGreen'] = '#008000' +_COLORDICT['darkBrown'] = '#660000' +_COLORDICT['darkCyan'] = '#008080' +_COLORDICT['darkYellow'] = '#808000' +_COLORDICT['darkMagenta'] = '#800080' + + +# FIXME: It could be nice to expose a functional API instead of that attribute +COLORDICT = _COLORDICT + + +def rgba(color, colorDict=None): + """Convert color code '#RRGGBB' and '#RRGGBBAA' to (R, G, B, A) + + It also convert RGB(A) values from uint8 to float in [0, 1] and + accept a QColor as color argument. + + :param str color: The color to convert + :param dict colorDict: A dictionary of color name conversion to color code + :returns: RGBA colors as floats in [0., 1.] + :rtype: tuple + """ + if colorDict is None: + colorDict = _COLORDICT + + if hasattr(color, 'getRgbF'): # QColor support + color = color.getRgbF() + + values = numpy.asarray(color).ravel() + + if values.dtype.kind in 'iuf': # integer or float + # Color is an array + assert len(values) in (3, 4) + + # Convert from integers in [0, 255] to float in [0, 1] + if values.dtype.kind in 'iu': + values = values / 255. + + # Clip to [0, 1] + values[values < 0.] = 0. + values[values > 1.] = 1. + + if len(values) == 3: + return values[0], values[1], values[2], 1. + else: + return tuple(values) + + # We assume color is a string + if not color.startswith('#'): + color = colorDict[color] + + assert len(color) in (7, 9) and color[0] == '#' + r = int(color[1:3], 16) / 255. + g = int(color[3:5], 16) / 255. + b = int(color[5:7], 16) / 255. + a = int(color[7:9], 16) / 255. if len(color) == 9 else 1. + return r, g, b, a + + +_COLORMAP_CURSOR_COLORS = { + 'gray': 'pink', + 'reversed gray': 'pink', + 'temperature': 'pink', + 'red': 'green', + 'green': 'pink', + 'blue': 'yellow', + 'jet': 'pink', + 'viridis': 'pink', + 'magma': 'green', + 'inferno': 'green', + 'plasma': 'green', +} + + +def cursorColorForColormap(colormapName): + """Get a color suitable for overlay over a colormap. + + :param str colormapName: The name of the colormap. + :return: Name of the color. + :rtype: str + """ + return _COLORMAP_CURSOR_COLORS.get(colormapName, 'black') + + +DEFAULT_COLORMAPS = ( + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') +"""Tuple of supported colormap names.""" + +DEFAULT_MIN_LIN = 0 +"""Default min value if in linear normalization""" +DEFAULT_MAX_LIN = 1 +"""Default max value if in linear normalization""" +DEFAULT_MIN_LOG = 1 +"""Default min value if in log normalization""" +DEFAULT_MAX_LOG = 10 +"""Default max value if in log normalization""" + + +class Colormap(qt.QObject): + """Description of a colormap + + :param str name: Name of the colormap + :param tuple colors: optional, custom colormap. + Nx3 or Nx4 numpy array of RGB(A) colors, + either uint8 or float in [0, 1]. + If 'name' is None, then this array is used as the colormap. + :param str normalization: Normalization: 'linear' (default) or 'log' + :param float vmin: + Lower bound of the colormap or None for autoscale (default) + :param float vmax: + Upper bounds of the colormap or None for autoscale (default) + """ + + LINEAR = 'linear' + """constant for linear normalization""" + + LOGARITHM = 'log' + """constant for logarithmic normalization""" + + NORMALIZATIONS = (LINEAR, LOGARITHM) + """Tuple of managed normalizations""" + + sigChanged = qt.Signal() + """Signal emitted when the colormap has changed.""" + + def __init__(self, name='gray', colors=None, normalization=LINEAR, vmin=None, vmax=None): + qt.QObject.__init__(self) + assert normalization in Colormap.NORMALIZATIONS + assert not (name is None and colors is None) + if normalization is Colormap.LOGARITHM: + if (vmin is not None and vmin < 0) or (vmax is not None and vmax < 0): + m = "Unsuported vmin (%s) and/or vmax (%s) given for a log scale." + m += ' Autoscale will be performed.' + m = m % (vmin, vmax) + _logger.warning(m) + vmin = None + vmax = None + + self._name = str(name) if name is not None else None + self._setColors(colors) + self._normalization = str(normalization) + self._vmin = float(vmin) if vmin is not None else None + self._vmax = float(vmax) if vmax is not None else None + self._editable = True + + def isAutoscale(self): + """Return True if both min and max are in autoscale mode""" + return self._vmin is None and self._vmax is None + + def getName(self): + """Return the name of the colormap + :rtype: str + """ + return self._name + + @staticmethod + def _convertColorsFromFloatToUint8(colors): + """Convert colors from float in [0, 1] to uint8 + + :param numpy.ndarray colors: Array of float colors to convert + :return: colors as uint8 + :rtype: numpy.ndarray + """ + # Each bin is [N, N+1[ except the last one: [255, 256] + return numpy.clip( + colors.astype(numpy.float64) * 256, 0., 255.).astype(numpy.uint8) + + def _setColors(self, colors): + if colors is None: + self._colors = None + else: + colors = numpy.array(colors, copy=False) + colors.shape = -1, colors.shape[-1] + if colors.dtype.kind == 'f': + colors = self._convertColorsFromFloatToUint8(colors) + + # Makes sure it is RGBA8888 + self._colors = numpy.zeros((len(colors), 4), dtype=numpy.uint8) + self._colors[:, 3] = 255 # Alpha channel + self._colors[:, :colors.shape[1]] = colors # Copy colors + + def getNColors(self, nbColors=None): + """Returns N colors computed by sampling the colormap regularly. + + :param nbColors: + The number of colors in the returned array or None for the default value. + The default value is 256 for colormap with a name (see :meth:`setName`) and + it is the size of the LUT for colormap defined with :meth:`setColormapLUT`. + :type nbColors: int or None + :return: 2D array of uint8 of shape (nbColors, 4) + :rtype: numpy.ndarray + """ + # Handle default value for nbColors + if nbColors is None: + lut = self.getColormapLUT() + if lut is not None: # In this case uses LUT length + nbColors = len(lut) + else: # Default to 256 + nbColors = 256 + + nbColors = int(nbColors) + + colormap = self.copy() + colormap.setNormalization(Colormap.LINEAR) + colormap.setVRange(vmin=None, vmax=None) + colors = colormap.applyToData( + numpy.arange(nbColors, dtype=numpy.int)) + return colors + + def setName(self, name): + """Set the name of the colormap to use. + + :param str name: The name of the colormap. + At least the following names are supported: 'gray', + 'reversed gray', 'temperature', 'red', 'green', 'blue', 'jet', + 'viridis', 'magma', 'inferno', 'plasma'. + """ + if self.isEditable() is False: + raise NotEditableError('Colormap is not editable') + assert name in self.getSupportedColormaps() + self._name = str(name) + self._colors = None + self.sigChanged.emit() + + def getColormapLUT(self): + """Return the list of colors for the colormap or None if not set + + :return: the list of colors for the colormap or None if not set + :rtype: numpy.ndarray or None + """ + if self._colors is None: + return None + else: + return numpy.array(self._colors, copy=True) + + def setColormapLUT(self, colors): + """Set the colors of the colormap. + + :param numpy.ndarray colors: the colors of the LUT. + If float, it is converted from [0, 1] to uint8 range. + Otherwise it is casted to uint8. + + .. warning: this will set the value of name to None + """ + if self.isEditable() is False: + raise NotEditableError('Colormap is not editable') + self._setColors(colors) + if len(colors) is 0: + self._colors = None + + self._name = None + self.sigChanged.emit() + + def getNormalization(self): + """Return the normalization of the colormap ('log' or 'linear') + + :return: the normalization of the colormap + :rtype: str + """ + return self._normalization + + def setNormalization(self, norm): + """Set the norm ('log', 'linear') + + :param str norm: the norm to set + """ + if self.isEditable() is False: + raise NotEditableError('Colormap is not editable') + self._normalization = str(norm) + self.sigChanged.emit() + + def getVMin(self): + """Return the lower bound of the colormap + + :return: the lower bound of the colormap + :rtype: float or None + """ + return self._vmin + + def setVMin(self, vmin): + """Set the minimal value of the colormap + + :param float vmin: Lower bound of the colormap or None for autoscale + (default) + value) + """ + if self.isEditable() is False: + raise NotEditableError('Colormap is not editable') + if vmin is not None: + if self._vmax is not None and vmin > self._vmax: + err = "Can't set vmin because vmin >= vmax. " \ + "vmin = %s, vmax = %s" % (vmin, self._vmax) + raise ValueError(err) + + self._vmin = vmin + self.sigChanged.emit() + + def getVMax(self): + """Return the upper bounds of the colormap or None + + :return: the upper bounds of the colormap or None + :rtype: float or None + """ + return self._vmax + + def setVMax(self, vmax): + """Set the maximal value of the colormap + + :param float vmax: Upper bounds of the colormap or None for autoscale + (default) + """ + if self.isEditable() is False: + raise NotEditableError('Colormap is not editable') + if vmax is not None: + if self._vmin is not None and vmax < self._vmin: + err = "Can't set vmax because vmax <= vmin. " \ + "vmin = %s, vmax = %s" % (self._vmin, vmax) + raise ValueError(err) + + self._vmax = vmax + self.sigChanged.emit() + + def isEditable(self): + """ Return if the colormap is editable or not + + :return: editable state of the colormap + :rtype: bool + """ + return self._editable + + def setEditable(self, editable): + """ + Set the editable state of the colormap + + :param bool editable: is the colormap editable + """ + assert type(editable) is bool + self._editable = editable + self.sigChanged.emit() + + def getColormapRange(self, data=None): + """Return (vmin, vmax) + + :return: the tuple vmin, vmax fitting vmin, vmax, normalization and + data if any given + :rtype: tuple + """ + vmin = self._vmin + vmax = self._vmax + assert vmin is None or vmax is None or vmin <= vmax # TODO handle this in setters + + if self.getNormalization() == self.LOGARITHM: + # Handle negative bounds as autoscale + if vmin is not None and (vmin is not None and vmin <= 0.): + mess = 'negative vmin, moving to autoscale for lower bound' + _logger.warning(mess) + vmin = None + if vmax is not None and (vmax is not None and vmax <= 0.): + mess = 'negative vmax, moving to autoscale for upper bound' + _logger.warning(mess) + vmax = None + + if vmin is None or vmax is None: # Handle autoscale + # Get min/max from data + if data is not None: + data = numpy.array(data, copy=False) + if data.size == 0: # Fallback an array but no data + min_, max_ = self._getDefaultMin(), self._getDefaultMax() + else: + if self.getNormalization() == self.LOGARITHM: + result = min_max(data, min_positive=True, finite=True) + min_ = result.min_positive # >0 or None + max_ = result.maximum # can be <= 0 + else: + min_, max_ = min_max(data, min_positive=False, finite=True) + + # Handle fallback + if min_ is None or not numpy.isfinite(min_): + min_ = self._getDefaultMin() + if max_ is None or not numpy.isfinite(max_): + max_ = self._getDefaultMax() + else: # Fallback if no data is provided + min_, max_ = self._getDefaultMin(), self._getDefaultMax() + + if vmin is None: # Set vmin respecting provided vmax + vmin = min_ if vmax is None else min(min_, vmax) + + if vmax is None: + vmax = max(max_, vmin) # Handle max_ <= 0 for log scale + + return vmin, vmax + + def setVRange(self, vmin, vmax): + """Set the bounds of the colormap + + :param vmin: Lower bound of the colormap or None for autoscale + (default) + :param vmax: Upper bounds of the colormap or None for autoscale + (default) + """ + if self.isEditable() is False: + raise NotEditableError('Colormap is not editable') + if vmin is not None and vmax is not None: + if vmin > vmax: + err = "Can't set vmin and vmax because vmin >= vmax " \ + "vmin = %s, vmax = %s" % (vmin, vmax) + raise ValueError(err) + + if self._vmin == vmin and self._vmax == vmax: + return + + self._vmin = vmin + self._vmax = vmax + self.sigChanged.emit() + + def __getitem__(self, item): + if item == 'autoscale': + return self.isAutoscale() + elif item == 'name': + return self.getName() + elif item == 'normalization': + return self.getNormalization() + elif item == 'vmin': + return self.getVMin() + elif item == 'vmax': + return self.getVMax() + elif item == 'colors': + return self.getColormapLUT() + else: + raise KeyError(item) + + def _toDict(self): + """Return the equivalent colormap as a dictionary + (old colormap representation) + + :return: the representation of the Colormap as a dictionary + :rtype: dict + """ + return { + 'name': self._name, + 'colors': copy_mdl.copy(self._colors), + 'vmin': self._vmin, + 'vmax': self._vmax, + 'autoscale': self.isAutoscale(), + 'normalization': self._normalization + } + + def _setFromDict(self, dic): + """Set values to the colormap from a dictionary + + :param dict dic: the colormap as a dictionary + """ + if self.isEditable() is False: + raise NotEditableError('Colormap is not editable') + name = dic['name'] if 'name' in dic else None + colors = dic['colors'] if 'colors' in dic else None + vmin = dic['vmin'] if 'vmin' in dic else None + vmax = dic['vmax'] if 'vmax' in dic else None + if 'normalization' in dic: + normalization = dic['normalization'] + else: + warn = 'Normalization not given in the dictionary, ' + warn += 'set by default to ' + Colormap.LINEAR + _logger.warning(warn) + normalization = Colormap.LINEAR + + if name is None and colors is None: + err = 'The colormap should have a name defined or a tuple of colors' + raise ValueError(err) + if normalization not in Colormap.NORMALIZATIONS: + err = 'Given normalization is not recoginized (%s)' % normalization + raise ValueError(err) + + # If autoscale, then set boundaries to None + if dic.get('autoscale', False): + vmin, vmax = None, None + + self._name = name + self._colors = colors + self._vmin = vmin + self._vmax = vmax + self._autoscale = True if (vmin is None and vmax is None) else False + self._normalization = normalization + + self.sigChanged.emit() + + @staticmethod + def _fromDict(dic): + colormap = Colormap(name="") + colormap._setFromDict(dic) + return colormap + + def copy(self): + """Return a copy of the Colormap. + + :rtype: silx.gui.colors.Colormap + """ + return Colormap(name=self._name, + colors=copy_mdl.copy(self._colors), + vmin=self._vmin, + vmax=self._vmax, + normalization=self._normalization) + + def applyToData(self, data): + """Apply the colormap to the data + + :param numpy.ndarray data: The data to convert. + """ + name = self.getName() + if name is not None: # Get colormap definition from matplotlib + # FIXME: If possible remove dependency to the plot + from .plot.matplotlib import Colormap as MPLColormap + mplColormap = MPLColormap.getColormap(name) + colors = mplColormap(numpy.linspace(0, 1, 256, endpoint=True)) + colors = self._convertColorsFromFloatToUint8(colors) + + else: # Use user defined LUT + colors = self.getColormapLUT() + + vmin, vmax = self.getColormapRange(data) + normalization = self.getNormalization() + + return _cmap(data, colors, vmin, vmax, normalization) + + @staticmethod + def getSupportedColormaps(): + """Get the supported colormap names as a tuple of str. + + The list should at least contain and start by: + ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') + :rtype: tuple + """ + # FIXME: If possible remove dependency to the plot + from .plot.matplotlib import Colormap as MPLColormap + maps = MPLColormap.getSupportedColormaps() + return DEFAULT_COLORMAPS + maps + + def __str__(self): + return str(self._toDict()) + + def _getDefaultMin(self): + return DEFAULT_MIN_LIN if self._normalization == Colormap.LINEAR else DEFAULT_MIN_LOG + + def _getDefaultMax(self): + return DEFAULT_MAX_LIN if self._normalization == Colormap.LINEAR else DEFAULT_MAX_LOG + + def __eq__(self, other): + """Compare colormap values and not pointers""" + return (self.getName() == other.getName() and + self.getNormalization() == other.getNormalization() and + self.getVMin() == other.getVMin() and + self.getVMax() == other.getVMax() and + numpy.array_equal(self.getColormapLUT(), other.getColormapLUT()) + ) + + _SERIAL_VERSION = 1 + + def restoreState(self, byteArray): + """ + Read the colormap state from a QByteArray. + + :param qt.QByteArray byteArray: Stream containing the state + :return: True if the restoration sussseed + :rtype: bool + """ + if self.isEditable() is False: + raise NotEditableError('Colormap is not editable') + stream = qt.QDataStream(byteArray, qt.QIODevice.ReadOnly) + + className = stream.readQString() + if className != self.__class__.__name__: + _logger.warning("Classname mismatch. Found %s." % className) + return False + + version = stream.readUInt32() + if version != self._SERIAL_VERSION: + _logger.warning("Serial version mismatch. Found %d." % version) + return False + + name = stream.readQString() + isNull = stream.readBool() + if not isNull: + vmin = stream.readQVariant() + else: + vmin = None + isNull = stream.readBool() + if not isNull: + vmax = stream.readQVariant() + else: + vmax = None + normalization = stream.readQString() + + # emit change event only once + old = self.blockSignals(True) + try: + self.setName(name) + self.setNormalization(normalization) + self.setVRange(vmin, vmax) + finally: + self.blockSignals(old) + self.sigChanged.emit() + return True + + def saveState(self): + """ + Save state of the colomap into a QDataStream. + + :rtype: qt.QByteArray + """ + data = qt.QByteArray() + stream = qt.QDataStream(data, qt.QIODevice.WriteOnly) + + stream.writeQString(self.__class__.__name__) + stream.writeUInt32(self._SERIAL_VERSION) + stream.writeQString(self.getName()) + stream.writeBool(self.getVMin() is None) + if self.getVMin() is not None: + stream.writeQVariant(self.getVMin()) + stream.writeBool(self.getVMax() is None) + if self.getVMax() is not None: + stream.writeQVariant(self.getVMax()) + stream.writeQString(self.getNormalization()) + return data + + +_PREFERRED_COLORMAPS = None +""" +Tuple of preferred colormap names accessed with :meth:`preferredColormaps`. +""" + + +def preferredColormaps(): + """Returns the name of the preferred colormaps. + + This list is used by widgets allowing to change the colormap + like the :class:`ColormapDialog` as a subset of colormap choices. + + :rtype: tuple of str + """ + global _PREFERRED_COLORMAPS + if _PREFERRED_COLORMAPS is None: + _PREFERRED_COLORMAPS = DEFAULT_COLORMAPS + # Initialize preferred colormaps + setPreferredColormaps(('gray', 'reversed gray', + 'temperature', 'red', 'green', 'blue', 'jet', + 'viridis', 'magma', 'inferno', 'plasma', + 'hsv')) + return _PREFERRED_COLORMAPS + + +def setPreferredColormaps(colormaps): + """Set the list of preferred colormap names. + + Warning: If a colormap name is not available + it will be removed from the list. + + :param colormaps: Not empty list of colormap names + :type colormaps: iterable of str + :raise ValueError: if the list of available preferred colormaps is empty. + """ + supportedColormaps = Colormap.getSupportedColormaps() + colormaps = tuple( + cmap for cmap in colormaps if cmap in supportedColormaps) + if len(colormaps) == 0: + raise ValueError("Cannot set preferred colormaps to an empty list") + + global _PREFERRED_COLORMAPS + _PREFERRED_COLORMAPS = colormaps diff --git a/silx/gui/console.py b/silx/gui/console.py index 3c69419..b6341ef 100644 --- a/silx/gui/console.py +++ b/silx/gui/console.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,9 +34,8 @@ the widgets' methods from the console. .. note:: This module has a dependency on - `IPython <https://pypi.python.org/pypi/ipython>`_ and - `qtconsole <https://pypi.python.org/pypi/qtconsole>`_ (or *ipython.qt* for - older versions of *IPython*). An ``ImportError`` will be raised if it is + `qtconsole <https://pypi.org/project/qtconsole/>`_. + An ``ImportError`` will be raised if it is imported while the dependencies are not satisfied. Basic usage example:: @@ -76,11 +75,7 @@ from . import qt _logger = logging.getLogger(__name__) -try: - import IPython -except ImportError as e: - raise ImportError("Failed to import IPython, required by " + __name__) - + # This widget cannot be used inside an interactive IPython shell. # It would raise MultipleInstanceError("Multiple incompatible subclass # instances of InProcessInteractiveShell are being created"). @@ -92,48 +87,14 @@ else: msg = "Module " + __name__ + " cannot be used within an IPython shell" raise ImportError(msg) -# qtconsole is a separate module in recent versions of IPython/Jupyter -# http://blog.jupyter.org/2015/04/15/the-big-split/ -if IPython.__version__.startswith("2"): - qtconsole = None -else: - try: - import qtconsole - except ImportError: - qtconsole = None - -if qtconsole is not None: - try: - from qtconsole.rich_ipython_widget import RichJupyterWidget as \ - RichIPythonWidget - except ImportError: - try: - from qtconsole.rich_ipython_widget import RichIPythonWidget - except ImportError as e: - qtconsole = None - else: - from qtconsole.inprocess import QtInProcessKernelManager - else: - from qtconsole.inprocess import QtInProcessKernelManager - - -if qtconsole is None: - # Import the console machinery from ipython - - # The `has_binding` test of IPython does not find the Qt bindings - # in case silx is used in a frozen binary - import IPython.external.qt_loaders - - def has_binding(*var, **kw): - return True - - IPython.external.qt_loaders.has_binding = has_binding - - try: - from IPython.qtconsole.rich_ipython_widget import RichIPythonWidget - except ImportError: - from IPython.qt.console.rich_ipython_widget import RichIPythonWidget - from IPython.qt.inprocess import QtInProcessKernelManager + +try: + from qtconsole.rich_ipython_widget import RichJupyterWidget as \ + RichIPythonWidget +except ImportError: + from qtconsole.rich_ipython_widget import RichIPythonWidget + +from qtconsole.inprocess import QtInProcessKernelManager class IPythonWidget(RichIPythonWidget): diff --git a/silx/gui/data/DataViewer.py b/silx/gui/data/DataViewer.py index 5e0b25e..4db2863 100644 --- a/silx/gui/data/DataViewer.py +++ b/silx/gui/data/DataViewer.py @@ -37,7 +37,7 @@ from silx.utils.property import classproperty __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "26/02/2018" +__date__ = "24/04/2018" _logger = logging.getLogger(__name__) @@ -167,8 +167,10 @@ class DataViewer(qt.QFrame): self.__currentAvailableViews = [] self.__currentView = None self.__data = None + self.__info = None self.__useAxisSelection = False self.__userSelectedView = None + self.__hooks = None self.__views = [] self.__index = {} @@ -182,6 +184,15 @@ class DataViewer(qt.QFrame): self.__views = list(views) self.setDisplayMode(DataViews.EMPTY_MODE) + def setGlobalHooks(self, hooks): + """Set a data view hooks for all the views + + :param DataViewHooks context: The hooks to use + """ + self.__hooks = hooks + for v in self.__views: + v.setHooks(hooks) + def createDefaultViews(self, parent=None): """Create and returns available views which can be displayed by default by the data viewer. It is called internally by the widget. It can be @@ -250,7 +261,7 @@ class DataViewer(qt.QFrame): """ previous = self.__numpySelection.blockSignals(True) self.__numpySelection.clear() - info = DataViews.DataInfo(self.__data) + info = self._getInfo() axisNames = self.__currentView.axesNames(self.__data, info) if info.isArray and info.size != 0 and self.__data is not None and axisNames is not None: self.__useAxisSelection = True @@ -359,6 +370,8 @@ class DataViewer(qt.QFrame): :param DataView view: A dataview """ + if self.__hooks is not None: + view.setHooks(self.__hooks) self.__views.append(view) # TODO It can be skipped if the view do not support the data self.__updateAvailableViews() @@ -390,8 +403,8 @@ class DataViewer(qt.QFrame): Update available views from the current data. """ data = self.__data + info = self._getInfo() # sort available views according to priority - info = DataViews.DataInfo(data) priorities = [v.getDataPriority(data, info) for v in self.__views] views = zip(priorities, self.__views) views = filter(lambda t: t[0] > DataViews.DataView.UNSUPPORTED, views) @@ -490,6 +503,7 @@ class DataViewer(qt.QFrame): :param numpy.ndarray data: The data. """ self.__data = data + self._invalidateInfo() self.__displayedData = None self.__updateView() self.__updateNumpySelectionAxis() @@ -512,6 +526,21 @@ class DataViewer(qt.QFrame): """Returns the data""" return self.__data + def _invalidateInfo(self): + """Invalidate DataInfo cache.""" + self.__info = None + + def _getInfo(self): + """Returns the DataInfo of the current selected data. + + This value is cached. + + :rtype: DataInfo + """ + if self.__info is None: + self.__info = DataViews.DataInfo(self.__data) + return self.__info + def displayMode(self): """Returns the current display mode""" return self.__currentView.modeId() @@ -552,6 +581,8 @@ class DataViewer(qt.QFrame): isReplaced = False for idx, view in enumerate(self.__views): if view.modeId() == modeId: + if self.__hooks is not None: + newView.setHooks(self.__hooks) self.__views[idx] = newView isReplaced = True break diff --git a/silx/gui/data/DataViewerFrame.py b/silx/gui/data/DataViewerFrame.py index 89a9992..4e6d2e8 100644 --- a/silx/gui/data/DataViewerFrame.py +++ b/silx/gui/data/DataViewerFrame.py @@ -27,7 +27,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "21/09/2017" +__date__ = "24/04/2018" from silx.gui import qt from .DataViewer import DataViewer @@ -113,6 +113,13 @@ class DataViewerFrame(qt.QWidget): """Called when the displayed view changes""" self.displayedViewChanged.emit(view) + def setGlobalHooks(self, hooks): + """Set a data view hooks for all the views + + :param DataViewHooks context: The hooks to use + """ + self.__dataViewer.setGlobalHooks(hooks) + def availableViews(self): """Returns the list of registered views diff --git a/silx/gui/data/DataViews.py b/silx/gui/data/DataViews.py index ef69441..2291e87 100644 --- a/silx/gui/data/DataViews.py +++ b/silx/gui/data/DataViews.py @@ -35,13 +35,13 @@ from silx.gui import qt, icons from silx.gui.data.TextFormatter import TextFormatter from silx.io import nxdata from silx.gui.hdf5 import H5Node -from silx.io.nxdata import get_attr_as_string -from silx.gui.plot.Colormap import Colormap -from silx.gui.plot.actions.control import ColormapAction +from silx.io.nxdata import get_attr_as_unicode +from silx.gui.colors import Colormap +from silx.gui.dialog.ColormapDialog import ColormapDialog __authors__ = ["V. Valls", "P. Knobel"] __license__ = "MIT" -__date__ = "23/01/2018" +__date__ = "23/05/2018" _logger = logging.getLogger(__name__) @@ -109,6 +109,7 @@ class DataInfo(object): self.isBoolean = False self.isRecord = False self.hasNXdata = False + self.isInvalidNXdata = False self.shape = tuple() self.dim = 0 self.size = 0 @@ -118,8 +119,28 @@ class DataInfo(object): if silx.io.is_group(data): nxd = nxdata.get_default(data) + nx_class = get_attr_as_unicode(data, "NX_class") if nxd is not None: self.hasNXdata = True + # can we plot it? + is_scalar = nxd.signal_is_0d or nxd.interpretation in ["scalar", "scaler"] + if not (is_scalar or nxd.is_curve or nxd.is_x_y_value_scatter or + nxd.is_image or nxd.is_stack): + # invalid: cannot be plotted by any widget + self.isInvalidNXdata = True + elif nx_class == "NXdata": + # group claiming to be NXdata could not be parsed + self.isInvalidNXdata = True + elif nx_class == "NXentry" and "default" in data.attrs: + # entry claiming to have a default NXdata could not be parsed + self.isInvalidNXdata = True + elif nx_class == "NXroot" or silx.io.is_file(data): + # root claiming to have a default entry + if "default" in data.attrs: + def_entry = data.attrs["default"] + if def_entry in data and "default" in data[def_entry].attrs: + # and entry claims to have default NXdata + self.isInvalidNXdata = True if isinstance(data, numpy.ndarray): self.isArray = True @@ -130,7 +151,7 @@ class DataInfo(object): if silx.io.is_dataset(data): if "interpretation" in data.attrs: - self.interpretation = get_attr_as_string(data, "interpretation") + self.interpretation = get_attr_as_unicode(data, "interpretation") else: self.interpretation = None elif self.hasNXdata: @@ -166,7 +187,11 @@ class DataInfo(object): if self.shape is not None: self.dim = len(self.shape) - if hasattr(data, "size"): + if hasattr(data, "shape") and data.shape is None: + # This test is expected to avoid to fall done on the h5py issue + # https://github.com/h5py/h5py/issues/1044 + self.size = 0 + elif hasattr(data, "size"): self.size = int(data.size) else: self.size = 1 @@ -177,6 +202,18 @@ class DataInfo(object): return _normalizeData(data) +class DataViewHooks(object): + """A set of hooks defined to custom the behaviour of the data views.""" + + def getColormap(self, view): + """Returns a colormap for this view.""" + return None + + def getColormapDialog(self, view): + """Returns a color dialog for this view.""" + return None + + class DataView(object): """Holder for the data view.""" @@ -184,12 +221,6 @@ class DataView(object): """Priority returned when the requested data can't be displayed by the view.""" - _defaultColormap = None - """Store a default colormap shared with all the views""" - - _defaultColorDialog = None - """Store a default color dialog shared with all the views""" - def __init__(self, parent, modeId=None, icon=None, label=None): """Constructor @@ -204,32 +235,46 @@ class DataView(object): if icon is None: icon = qt.QIcon() self.__icon = icon + self.__hooks = None - @staticmethod - def defaultColormap(): - """Returns a shared colormap as default for all the views. + def getHooks(self): + """Returns the data viewer hooks used by this view. - :rtype: Colormap + :rtype: DataViewHooks """ - if DataView._defaultColormap is None: - DataView._defaultColormap = Colormap(name="viridis") - return DataView._defaultColormap + return self.__hooks - @staticmethod - def defaultColorDialog(): - """Returns a shared color dialog as default for all the views. + def setHooks(self, hooks): + """Set the data view hooks to use with this view. - :rtype: ColorDialog + :param DataViewHooks hooks: The data view hooks to use """ - if DataView._defaultColorDialog is None: - DataView._defaultColorDialog = ColormapAction._createDialog(qt.QApplication.instance().activeWindow()) - return DataView._defaultColorDialog + self.__hooks = hooks - @staticmethod - def _cleanUpCache(): - """Clean up the cache. Needed for tests""" - DataView._defaultColormap = None - DataView._defaultColorDialog = None + def defaultColormap(self): + """Returns a default colormap. + + :rtype: Colormap + """ + colormap = None + if self.__hooks is not None: + colormap = self.__hooks.getColormap(self) + if colormap is None: + colormap = Colormap(name="viridis") + return colormap + + def defaultColorDialog(self): + """Returns a default color dialog. + + :rtype: ColormapDialog + """ + dialog = None + if self.__hooks is not None: + dialog = self.__hooks.getColormapDialog(self) + if dialog is None: + dialog = ColormapDialog() + dialog.setModal(False) + return dialog def icon(self): """Returns the default icon""" @@ -345,8 +390,21 @@ class CompositeDataView(DataView): self.__views = OrderedDict() self.__currentView = None + def setHooks(self, hooks): + """Set the data context to use with this view. + + :param DataViewHooks hooks: The data view hooks to use + """ + super(CompositeDataView, self).setHooks(hooks) + if hooks is not None: + for v in self.__views: + v.setHooks(hooks) + def addView(self, dataView): """Add a new dataview to the available list.""" + hooks = self.getHooks() + if hooks is not None: + dataView.setHooks(hooks) self.__views[dataView] = None def availableViews(self): @@ -446,6 +504,9 @@ class CompositeDataView(DataView): break elif isinstance(view, CompositeDataView): # recurse + hooks = self.getHooks() + if hooks is not None: + newView.setHooks(hooks) if view.replaceView(modeId, newView): return True if oldView is None: @@ -1022,70 +1083,46 @@ class _InvalidNXdataView(DataView): def getDataPriority(self, data, info): data = self.normalizeData(data) - if silx.io.is_group(data): - nxd = nxdata.get_default(data) - nx_class = get_attr_as_string(data, "NX_class") - - if nxd is None: - if nx_class == "NXdata": - # invalid: could not even be parsed by NXdata - self._msg = "Group has @NX_class = NXdata, but could not be interpreted" - self._msg += " as valid NXdata." - return 100 - elif nx_class == "NXentry": - if "default" not in data.attrs: - # no link to NXdata, no problem - return DataView.UNSUPPORTED - self._msg = "NXentry group provides a @default attribute," - default_nxdata_name = data.attrs["default"] - if default_nxdata_name not in data: - self._msg += " but no corresponding NXdata group exists." - elif get_attr_as_string(data[default_nxdata_name], "NX_class") != "NXdata": - self._msg += " but the corresponding item is not a " - self._msg += "NXdata group." - else: - self._msg += " but the corresponding NXdata seems to be" - self._msg += " malformed." - return 100 - elif nx_class == "NXroot" or silx.io.is_file(data): - if "default" not in data.attrs: - # no link to NXentry, no problem - return DataView.UNSUPPORTED - default_entry_name = data.attrs["default"] - if default_entry_name not in data: - # this is a problem, but not NXdata related - return DataView.UNSUPPORTED - default_entry = data[default_entry_name] - if "default" not in default_entry.attrs: - # no NXdata specified, no problemo - return DataView.UNSUPPORTED - default_nxdata_name = default_entry.attrs["default"] - self._msg = "NXroot group provides a @default attribute " - self._msg += "pointing to a NXentry which defines its own " - self._msg += "@default attribute, " - if default_nxdata_name not in default_entry: - self._msg += " but no corresponding NXdata group exists." - elif get_attr_as_string(default_entry[default_nxdata_name], - "NX_class") != "NXdata": - self._msg += " but the corresponding item is not a " - self._msg += "NXdata group." - else: - self._msg += " but the corresponding NXdata seems to be" - self._msg += " malformed." - return 100 - else: - # Not pretending to be NXdata, no problem - return DataView.UNSUPPORTED - - is_scalar = nxd.signal_is_0d or nxd.interpretation in ["scalar", "scaler"] - if not (is_scalar or nxd.is_curve or nxd.is_x_y_value_scatter or - nxd.is_image or nxd.is_stack): - # invalid: cannot be plotted by any widget (I cannot imagine a case) - self._msg = "NXdata seems valid, but cannot be displayed " - self._msg += "by any existing plot widget." - return 100 - return DataView.UNSUPPORTED + if not info.isInvalidNXdata: + return DataView.UNSUPPORTED + + if info.hasNXdata: + self._msg = "NXdata seems valid, but cannot be displayed " + self._msg += "by any existing plot widget." + else: + nx_class = get_attr_as_unicode(data, "NX_class") + if nx_class == "NXdata": + # invalid: could not even be parsed by NXdata + self._msg = "Group has @NX_class = NXdata, but could not be interpreted" + self._msg += " as valid NXdata." + elif nx_class == "NXentry": + self._msg = "NXentry group provides a @default attribute," + default_nxdata_name = data.attrs["default"] + if default_nxdata_name not in data: + self._msg += " but no corresponding NXdata group exists." + elif get_attr_as_unicode(data[default_nxdata_name], "NX_class") != "NXdata": + self._msg += " but the corresponding item is not a " + self._msg += "NXdata group." + else: + self._msg += " but the corresponding NXdata seems to be" + self._msg += " malformed." + elif nx_class == "NXroot" or silx.io.is_file(data): + default_entry = data[data.attrs["default"]] + default_nxdata_name = default_entry.attrs["default"] + self._msg = "NXroot group provides a @default attribute " + self._msg += "pointing to a NXentry which defines its own " + self._msg += "@default attribute, " + if default_nxdata_name not in default_entry: + self._msg += " but no corresponding NXdata group exists." + elif get_attr_as_unicode(default_entry[default_nxdata_name], + "NX_class") != "NXdata": + self._msg += " but the corresponding item is not a " + self._msg += "NXdata group." + else: + self._msg += " but the corresponding NXdata seems to be" + self._msg += " malformed." + return 100 class _NXdataScalarView(DataView): @@ -1111,7 +1148,7 @@ class _NXdataScalarView(DataView): def setData(self, data): data = self.normalizeData(data) # data could be a NXdata or an NXentry - nxd = nxdata.get_default(data) + nxd = nxdata.get_default(data, validate=False) signal = nxd.signal self.getWidget().setArrayData(signal, labels=True) @@ -1119,8 +1156,8 @@ class _NXdataScalarView(DataView): def getDataPriority(self, data, info): data = self.normalizeData(data) - if info.hasNXdata: - nxd = nxdata.get_default(data) + if info.hasNXdata and not info.isInvalidNXdata: + nxd = nxdata.get_default(data, validate=False) if nxd.signal_is_0d or nxd.interpretation in ["scalar", "scaler"]: return 100 return DataView.UNSUPPORTED @@ -1151,7 +1188,7 @@ class _NXdataCurveView(DataView): def setData(self, data): data = self.normalizeData(data) - nxd = nxdata.get_default(data) + nxd = nxdata.get_default(data, validate=False) signals_names = [nxd.signal_name] + nxd.auxiliary_signals_names if nxd.axes_dataset_names[-1] is not None: x_errors = nxd.get_axis_errors(nxd.axes_dataset_names[-1]) @@ -1177,8 +1214,8 @@ class _NXdataCurveView(DataView): def getDataPriority(self, data, info): data = self.normalizeData(data) - if info.hasNXdata: - if nxdata.get_default(data).is_curve: + if info.hasNXdata and not info.isInvalidNXdata: + if nxdata.get_default(data, validate=False).is_curve: return 100 return DataView.UNSUPPORTED @@ -1204,8 +1241,13 @@ class _NXdataXYVScatterView(DataView): def setData(self, data): data = self.normalizeData(data) - nxd = nxdata.get_default(data) + nxd = nxdata.get_default(data, validate=False) + x_axis, y_axis = nxd.axes[-2:] + if x_axis is None: + x_axis = numpy.arange(nxd.signal.size) + if y_axis is None: + y_axis = numpy.arange(nxd.signal.size) x_label, y_label = nxd.axes_names[-2:] if x_label is not None: @@ -1226,8 +1268,8 @@ class _NXdataXYVScatterView(DataView): def getDataPriority(self, data, info): data = self.normalizeData(data) - if info.hasNXdata: - if nxdata.get_default(data).is_x_y_value_scatter: + if info.hasNXdata and not info.isInvalidNXdata: + if nxdata.get_default(data, validate=False).is_x_y_value_scatter: return 100 return DataView.UNSUPPORTED @@ -1256,7 +1298,7 @@ class _NXdataImageView(DataView): def setData(self, data): data = self.normalizeData(data) - nxd = nxdata.get_default(data) + nxd = nxdata.get_default(data, validate=False) isRgba = nxd.interpretation == "rgba-image" # last two axes are Y & X @@ -1274,8 +1316,8 @@ class _NXdataImageView(DataView): def getDataPriority(self, data, info): data = self.normalizeData(data) - if info.hasNXdata: - if nxdata.get_default(data).is_image: + if info.hasNXdata and not info.isInvalidNXdata: + if nxdata.get_default(data, validate=False).is_image: return 100 return DataView.UNSUPPORTED @@ -1302,7 +1344,7 @@ class _NXdataStackView(DataView): def setData(self, data): data = self.normalizeData(data) - nxd = nxdata.get_default(data) + nxd = nxdata.get_default(data, validate=False) signal_name = nxd.signal_name z_axis, y_axis, x_axis = nxd.axes[-3:] z_label, y_label, x_label = nxd.axes_names[-3:] @@ -1319,8 +1361,8 @@ class _NXdataStackView(DataView): def getDataPriority(self, data, info): data = self.normalizeData(data) - if info.hasNXdata: - if nxdata.get_default(data).is_stack: + if info.hasNXdata and not info.isInvalidNXdata: + if nxdata.get_default(data, validate=False).is_stack: return 100 return DataView.UNSUPPORTED diff --git a/silx/gui/data/Hdf5TableView.py b/silx/gui/data/Hdf5TableView.py index e4a0747..04199b2 100644 --- a/silx/gui/data/Hdf5TableView.py +++ b/silx/gui/data/Hdf5TableView.py @@ -30,8 +30,9 @@ from __future__ import division __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "10/10/2017" +__date__ = "23/05/2018" +import collections import functools import os.path import logging @@ -41,6 +42,7 @@ from .TextFormatter import TextFormatter import silx.gui.hdf5 from silx.gui.widgets import HierarchicalTableView from ..hdf5.Hdf5Formatter import Hdf5Formatter +from ..hdf5._utils import htmlFromDict try: import h5py @@ -54,7 +56,7 @@ _logger = logging.getLogger(__name__) class _CellData(object): """Store a table item """ - def __init__(self, value=None, isHeader=False, span=None): + def __init__(self, value=None, isHeader=False, span=None, tooltip=None): """ Constructor @@ -65,6 +67,7 @@ class _CellData(object): self.__value = value self.__isHeader = isHeader self.__span = span + self.__tooltip = tooltip def isHeader(self): """Returns true if the property is a sub-header title. @@ -85,6 +88,19 @@ class _CellData(object): """ return self.__span + def tooltip(self): + """Returns the tooltip of the item. + + :rtype: tuple + """ + return self.__tooltip + + def invalidateValue(self): + self.__value = None + + def invalidateToolTip(self): + self.__tooltip = None + class _TableData(object): """Modelize a table with header, row and column span. @@ -143,7 +159,7 @@ class _TableData(object): item = _CellData(value=headerLabel, isHeader=True, span=(1, self.__colCount)) self.__data.append([item]) - def addHeaderValueRow(self, headerLabel, value): + def addHeaderValueRow(self, headerLabel, value, tooltip=None): """Append the table with a row using the first column as an header and other cells as a single cell for the value. @@ -151,7 +167,7 @@ class _TableData(object): :param object value: value to store. """ header = _CellData(value=headerLabel, isHeader=True) - value = _CellData(value=value, span=(1, self.__colCount)) + value = _CellData(value=value, span=(1, self.__colCount), tooltip=tooltip) self.__data.append([header, value]) def addRow(self, *args): @@ -214,7 +230,20 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel): elif role == qt.Qt.DisplayRole: value = cell.value() if callable(value): - value = value(self.__obj) + try: + value = value(self.__obj) + except Exception: + cell.invalidateValue() + raise + return value + elif role == qt.Qt.ToolTipRole: + value = cell.tooltip() + if callable(value): + try: + value = value(self.__obj) + except Exception: + cell.invalidateToolTip() + raise return value return None @@ -260,6 +289,14 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel): """Format the HDF5 type""" return self.__hdf5Formatter.humanReadableHdf5Type(dataset) + def __attributeTooltip(self, attribute): + attributeDict = collections.OrderedDict() + if hasattr(attribute, "shape"): + attributeDict["Shape"] = self.__hdf5Formatter.humanReadableShape(attribute) + attributeDict["Data type"] = self.__hdf5Formatter.humanReadableType(attribute, full=True) + html = htmlFromDict(attributeDict, title="HDF5 Attribute") + return html + def __formatDType(self, dataset): """Format the numpy dtype""" return self.__hdf5Formatter.humanReadableType(dataset, full=True) @@ -310,7 +347,8 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel): # it's a real H5py object self.__data.addHeaderValueRow("Basename", lambda x: os.path.basename(x.name)) self.__data.addHeaderValueRow("Name", lambda x: x.name) - self.__data.addHeaderValueRow("File", lambda x: x.file.filename) + if obj.file is not None: + self.__data.addHeaderValueRow("File", lambda x: x.file.filename) if hasattr(obj, "path"): # That's a link @@ -322,8 +360,11 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel): else: if silx.io.is_file(obj): physical = lambda x: x.filename + SEPARATOR + x.name + elif obj.file is not None: + physical = lambda x: x.file.filename + SEPARATOR + x.name else: - physical = lambda x: x.file.filename + SEPARATOR + x.name + # Guess it is a virtual node + physical = "No physical location" self.__data.addHeaderValueRow("Physical", physical) if hasattr(obj, "dtype"): @@ -367,7 +408,10 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel): self.__data.addHeaderRow(headerLabel="Attributes") for key in sorted(obj.attrs.keys()): callback = lambda key, x: self.__formatter.toString(x.attrs[key]) - self.__data.addHeaderValueRow(headerLabel=key, value=functools.partial(callback, key)) + callbackTooltip = lambda key, x: self.__attributeTooltip(x.attrs[key]) + self.__data.addHeaderValueRow(headerLabel=key, + value=functools.partial(callback, key), + tooltip=functools.partial(callbackTooltip, key)) def __get_filter_info(self, dataset, filterIndex): """Get a tuple of readable info from dataset filters @@ -447,7 +491,7 @@ class Hdf5TableView(HierarchicalTableView.HierarchicalTableView): def setData(self, data): """Set the h5py-like object exposed by the model - :param h5pyObject: A h5py-like object. It can be a `h5py.Dataset`, + :param data: A h5py-like object. It can be a `h5py.Dataset`, a `h5py.File`, a `h5py.Group`. It also can be a, `silx.gui.hdf5.H5Node` which is needed to display some local path information. diff --git a/silx/gui/data/HexaTableView.py b/silx/gui/data/HexaTableView.py index 1b2a7e9..c86c0af 100644 --- a/silx/gui/data/HexaTableView.py +++ b/silx/gui/data/HexaTableView.py @@ -37,7 +37,7 @@ from silx.gui.widgets.TableWidget import CopySelectedCellsAction __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "27/09/2017" +__date__ = "23/05/2018" class _VoidConnector(object): @@ -54,7 +54,13 @@ class _VoidConnector(object): def __getBuffer(self, bufferId): if bufferId not in self.__cache: pos = bufferId << 10 - data = self.__data.tobytes()[pos:pos + 1024] + data = self.__data + if hasattr(data, "tobytes"): + data = data.tobytes()[pos:pos + 1024] + else: + # Old fashion + data = data.data[pos:pos + 1024] + self.__cache[bufferId] = data if len(self.__cache) > 32: self.__cache.popitem() diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py index ae2911d..1bf5425 100644 --- a/silx/gui/data/NXdataWidgets.py +++ b/silx/gui/data/NXdataWidgets.py @@ -26,14 +26,14 @@ """ __authors__ = ["P. Knobel"] __license__ = "MIT" -__date__ = "20/12/2017" +__date__ = "24/04/2018" import numpy from silx.gui import qt from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector -from silx.gui.plot import Plot1D, Plot2D, StackView -from silx.gui.plot.Colormap import Colormap +from silx.gui.plot import Plot1D, Plot2D, StackView, ScatterView +from silx.gui.colors import Colormap from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration @@ -211,10 +211,10 @@ class XYVScatterPlot(qt.QWidget): self.__y_axis_name = None self.__y_axis_errors = None - self._plot = Plot1D(self) - self._plot.setDefaultColormap(Colormap(name="viridis", - vmin=None, vmax=None, - normalization=Colormap.LINEAR)) + self._plot = ScatterView(self) + self._plot.setColormap(Colormap(name="viridis", + vmin=None, vmax=None, + normalization=Colormap.LINEAR)) self._slider = HorizontalSliderWithBrowser(parent=self) self._slider.setMinimum(0) @@ -235,9 +235,9 @@ class XYVScatterPlot(qt.QWidget): def getPlot(self): """Returns the plot used for the display - :rtype: Plot1D + :rtype: PlotWidget """ - return self._plot + return self._plot.getPlotWidget() def setScattersData(self, y, x, values, yerror=None, xerror=None, @@ -284,8 +284,6 @@ class XYVScatterPlot(qt.QWidget): x = self.__x_axis y = self.__y_axis - self._plot.remove(kind=("scatter", )) - idx = self._slider.value() title = "" @@ -294,16 +292,15 @@ class XYVScatterPlot(qt.QWidget): title += self.__scatter_titles[idx] # scatter dataset name self._plot.setGraphTitle(title) - self._plot.addScatter(x, y, self.__values[idx], - legend="scatter%d" % idx, - xerror=self.__x_axis_errors, - yerror=self.__y_axis_errors) + self._plot.setData(x, y, self.__values[idx], + xerror=self.__x_axis_errors, + yerror=self.__y_axis_errors) self._plot.resetZoom() self._plot.getXAxis().setLabel(self.__x_axis_name) self._plot.getYAxis().setLabel(self.__y_axis_name) def clear(self): - self._plot.clear() + self._plot.getPlotWidget().clear() class ArrayImagePlot(qt.QWidget): @@ -476,7 +473,8 @@ class ArrayImagePlot(qt.QWidget): scale = (xscale, yscale) self._plot.addImage(image, legend=legend, - origin=origin, scale=scale) + origin=origin, scale=scale, + replace=True) else: scatterx, scattery = numpy.meshgrid(x_axis, y_axis) # fixme: i don't think this can handle "irregular" RGBA images diff --git a/silx/gui/data/TextFormatter.py b/silx/gui/data/TextFormatter.py index 332625c..8440509 100644 --- a/silx/gui/data/TextFormatter.py +++ b/silx/gui/data/TextFormatter.py @@ -27,7 +27,7 @@ data module to format data as text in the same way.""" __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "13/12/2017" +__date__ = "25/06/2018" import numpy import numbers @@ -204,7 +204,7 @@ class TextFormatter(qt.QObject): def __formatBinary(self, data): if isinstance(data, numpy.void): if six.PY2: - data = [ord(d) for d in data.item()] + data = [ord(d) for d in data.data] else: data = data.item().astype(numpy.uint8) elif six.PY2: @@ -266,6 +266,8 @@ class TextFormatter(qt.QObject): elif vlen == six.binary_type: # HDF5 ASCII return self.__formatCharString(data) + elif isinstance(vlen, numpy.dtype): + return self.toString(data, vlen) return None def toString(self, data, dtype=None): @@ -291,11 +293,17 @@ class TextFormatter(qt.QObject): else: text = [self.toString(d, dtype) for d in data] return "[" + " ".join(text) + "]" + if dtype is not None and dtype.kind == 'O': + text = self.__formatH5pyObject(data, dtype) + if text is not None: + return text elif isinstance(data, numpy.void): if dtype is None: dtype = data.dtype - if data.dtype.fields is not None: - text = [self.toString(data[f], dtype) for f in dtype.fields] + if dtype.fields is not None: + text = [] + for index, field in enumerate(dtype.fields.items()): + text.append(field[0] + ":" + self.toString(data[index], field[1][0])) return "(" + " ".join(text) + ")" return self.__formatBinary(data) elif isinstance(data, (numpy.unicode_, six.text_type)): @@ -340,7 +348,7 @@ class TextFormatter(qt.QObject): elif isinstance(data, (numbers.Real, numpy.floating)): # It have to be done before complex checking return self.__floatFormat % data - elif isinstance(data, (numpy.complex_, numbers.Complex)): + elif isinstance(data, (numpy.complexfloating, numbers.Complex)): text = "" if data.real != 0: text += self.__floatFormat % data.real diff --git a/silx/gui/data/test/test_dataviewer.py b/silx/gui/data/test/test_dataviewer.py index 274df92..f3c2808 100644 --- a/silx/gui/data/test/test_dataviewer.py +++ b/silx/gui/data/test/test_dataviewer.py @@ -24,7 +24,7 @@ # ###########################################################################*/ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "22/02/2018" +__date__ = "23/04/2018" import os import tempfile @@ -208,7 +208,6 @@ class AbstractDataViewerTests(TestCaseQt): self.assertEquals(widget.displayedView().modeId(), DataViews.RAW_MODE) widget.setDisplayMode(DataViews.EMPTY_MODE) self.assertEquals(widget.displayedView().modeId(), DataViews.EMPTY_MODE) - DataView._cleanUpCache() def test_create_default_views(self): widget = self.create_widget() @@ -287,7 +286,6 @@ class TestDataView(TestCaseQt): dataViewClass = DataViews._Plot2dView widget = self.createDataViewWithData(dataViewClass, data[0]) self.qWaitForWindowExposed(widget) - DataView._cleanUpCache() def testCubeWithComplex(self): self.skipTest("OpenGL widget not yet tested") @@ -299,14 +297,12 @@ class TestDataView(TestCaseQt): dataViewClass = DataViews._Plot3dView widget = self.createDataViewWithData(dataViewClass, data) self.qWaitForWindowExposed(widget) - DataView._cleanUpCache() def testImageStackWithComplex(self): data = self.createComplexData() dataViewClass = DataViews._StackView widget = self.createDataViewWithData(dataViewClass, data) self.qWaitForWindowExposed(widget) - DataView._cleanUpCache() def suite(): diff --git a/silx/gui/dialog/AbstractDataFileDialog.py b/silx/gui/dialog/AbstractDataFileDialog.py index 1bd52bb..cb6711c 100644 --- a/silx/gui/dialog/AbstractDataFileDialog.py +++ b/silx/gui/dialog/AbstractDataFileDialog.py @@ -28,7 +28,7 @@ This module contains an :class:`AbstractDataFileDialog`. __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "12/02/2018" +__date__ = "05/03/2018" import sys @@ -494,7 +494,9 @@ class _CatchResizeEvent(qt.QObject): class AbstractDataFileDialog(qt.QDialog): """The `AbstractFileDialog` provides a generic GUI to create a custom dialog - allowing to access to file resources like HDF5 files or HDF5 datasets + allowing to access to file resources like HDF5 files or HDF5 datasets. + + .. image:: img/abstractdatafiledialog.png The dialog contains: diff --git a/silx/gui/dialog/ColormapDialog.py b/silx/gui/dialog/ColormapDialog.py new file mode 100644 index 0000000..ed10728 --- /dev/null +++ b/silx/gui/dialog/ColormapDialog.py @@ -0,0 +1,986 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A QDialog widget to set-up the colormap. + +It uses a description of colormaps as dict compatible with :class:`Plot`. + +To run the following sample code, a QApplication must be initialized. + +Create the colormap dialog and set the colormap description and data range: + +>>> from silx.gui.dialog.ColormapDialog import ColormapDialog +>>> from silx.gui.colors import Colormap + +>>> dialog = ColormapDialog() +>>> colormap = Colormap(name='red', normalization='log', +... vmin=1., vmax=2.) + +>>> dialog.setColormap(colormap) +>>> colormap.setVRange(1., 100.) # This scale the width of the plot area +>>> dialog.show() + +Get the colormap description (compatible with :class:`Plot`) from the dialog: + +>>> cmap = dialog.getColormap() +>>> cmap.getName() +'red' + +It is also possible to display an histogram of the image in the dialog. +This updates the data range with the range of the bins. + +>>> import numpy +>>> image = numpy.random.normal(size=512 * 512).reshape(512, -1) +>>> hist, bin_edges = numpy.histogram(image, bins=10) +>>> dialog.setHistogram(hist, bin_edges) + +The updates of the colormap description are also available through the signal: +:attr:`ColormapDialog.sigColormapChanged`. +""" # noqa + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"] +__license__ = "MIT" +__date__ = "23/05/2018" + + +import logging + +import numpy + +from .. import qt +from ..colors import Colormap, preferredColormaps +from ..plot import PlotWidget +from silx.gui.widgets.FloatEdit import FloatEdit +import weakref +from silx.math.combo import min_max +from silx.third_party import enum +from silx.gui import icons +from silx.math.histogram import Histogramnd + +_logger = logging.getLogger(__name__) + + +_colormapIconPreview = {} + + +class _BoundaryWidget(qt.QWidget): + """Widget to edit a boundary of the colormap (vmin, vmax)""" + sigValueChanged = qt.Signal(object) + """Signal emitted when value is changed""" + + def __init__(self, parent=None, value=0.0): + qt.QWidget.__init__(self, parent=None) + self.setLayout(qt.QHBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + self._numVal = FloatEdit(parent=self, value=value) + self.layout().addWidget(self._numVal) + self._autoCB = qt.QCheckBox('auto', parent=self) + self.layout().addWidget(self._autoCB) + self._autoCB.setChecked(False) + + self._autoCB.toggled.connect(self._autoToggled) + self.sigValueChanged = self._autoCB.toggled + self.textEdited = self._numVal.textEdited + self.editingFinished = self._numVal.editingFinished + self._dataValue = None + + def isAutoChecked(self): + return self._autoCB.isChecked() + + def getValue(self): + return None if self._autoCB.isChecked() else self._numVal.value() + + def getFiniteValue(self): + if not self._autoCB.isChecked(): + return self._numVal.value() + elif self._dataValue is None: + return self._numVal.value() + else: + return self._dataValue + + def _autoToggled(self, enabled): + self._numVal.setEnabled(not enabled) + self._updateDisplayedText() + + def _updateDisplayedText(self): + # if dataValue is finite + if self._autoCB.isChecked() and self._dataValue is not None: + old = self._numVal.blockSignals(True) + self._numVal.setValue(self._dataValue) + self._numVal.blockSignals(old) + + def setDataValue(self, dataValue): + self._dataValue = dataValue + self._updateDisplayedText() + + def setFiniteValue(self, value): + assert(value is not None) + old = self._numVal.blockSignals(True) + self._numVal.setValue(value) + self._numVal.blockSignals(old) + + def setValue(self, value, isAuto=False): + self._autoCB.setChecked(isAuto or value is None) + if value is not None: + self._numVal.setValue(value) + self._updateDisplayedText() + + +class _ColormapNameCombox(qt.QComboBox): + def __init__(self, parent=None): + qt.QComboBox.__init__(self, parent) + self.__initItems() + + ORIGINAL_NAME = qt.Qt.UserRole + 1 + + def __initItems(self): + for colormapName in preferredColormaps(): + index = self.count() + self.addItem(str.title(colormapName)) + self.setItemIcon(index, self.getIconPreview(colormapName)) + self.setItemData(index, colormapName, role=self.ORIGINAL_NAME) + + def getIconPreview(self, colormapName): + """Return an icon preview from a LUT name. + + This icons are cached into a global structure. + + :param str colormapName: str + :rtype: qt.QIcon + """ + if colormapName not in _colormapIconPreview: + icon = self.createIconPreview(colormapName) + _colormapIconPreview[colormapName] = icon + return _colormapIconPreview[colormapName] + + def createIconPreview(self, colormapName): + """Create and return an icon preview from a LUT name. + + This icons are cached into a global structure. + + :param str colormapName: Name of the LUT + :rtype: qt.QIcon + """ + colormap = Colormap(colormapName) + size = 32 + lut = colormap.getNColors(size) + if lut is None or len(lut) == 0: + return qt.QIcon() + + pixmap = qt.QPixmap(size, size) + painter = qt.QPainter(pixmap) + for i in range(size): + rgb = lut[i] + r, g, b = rgb[0], rgb[1], rgb[2] + painter.setPen(qt.QColor(r, g, b)) + painter.drawPoint(qt.QPoint(i, 0)) + + painter.drawPixmap(0, 1, size, size - 1, pixmap, 0, 0, size, 1) + painter.end() + + return qt.QIcon(pixmap) + + def getCurrentName(self): + return self.itemData(self.currentIndex(), self.ORIGINAL_NAME) + + def findColormap(self, name): + return self.findData(name, role=self.ORIGINAL_NAME) + + def setCurrentName(self, name): + index = self.findColormap(name) + if index < 0: + index = self.count() + self.addItem(str.title(name)) + self.setItemIcon(index, self.getIconPreview(name)) + self.setItemData(index, name, role=self.ORIGINAL_NAME) + self.setCurrentIndex(index) + + +@enum.unique +class _DataInPlotMode(enum.Enum): + """Enum for each mode of display of the data in the plot.""" + NONE = 'none' + RANGE = 'range' + HISTOGRAM = 'histogram' + + +class ColormapDialog(qt.QDialog): + """A QDialog widget to set the colormap. + + :param parent: See :class:`QDialog` + :param str title: The QDialog title + """ + + visibleChanged = qt.Signal(bool) + """This event is sent when the dialog visibility change""" + + def __init__(self, parent=None, title="Colormap Dialog"): + qt.QDialog.__init__(self, parent) + self.setWindowTitle(title) + + self._colormap = None + self._data = None + self._dataInPlotMode = _DataInPlotMode.RANGE + + self._ignoreColormapChange = False + """Used as a semaphore to avoid editing the colormap object when we are + only attempt to display it. + Used instead of n connect and disconnect of the sigChanged. The + disconnection to sigChanged was also limiting when this colormapdialog + is used in the colormapaction and associated to the activeImageChanged. + (because the activeImageChanged is send when the colormap changed and + the self.setcolormap is a callback) + """ + + self._histogramData = None + self._minMaxWasEdited = False + self._initialRange = None + + self._dataRange = None + """If defined 3-tuple containing information from a data: + minimum, positive minimum, maximum""" + + self._colormapStoredState = None + + # Make the GUI + vLayout = qt.QVBoxLayout(self) + + formWidget = qt.QWidget(parent=self) + vLayout.addWidget(formWidget) + formLayout = qt.QFormLayout(formWidget) + formLayout.setContentsMargins(10, 10, 10, 10) + formLayout.setSpacing(0) + + # Colormap row + self._comboBoxColormap = _ColormapNameCombox(parent=formWidget) + self._comboBoxColormap.currentIndexChanged[int].connect(self._updateName) + formLayout.addRow('Colormap:', self._comboBoxColormap) + + # Normalization row + self._normButtonLinear = qt.QRadioButton('Linear') + self._normButtonLinear.setChecked(True) + self._normButtonLog = qt.QRadioButton('Log') + self._normButtonLog.toggled.connect(self._activeLogNorm) + + normButtonGroup = qt.QButtonGroup(self) + normButtonGroup.setExclusive(True) + normButtonGroup.addButton(self._normButtonLinear) + normButtonGroup.addButton(self._normButtonLog) + self._normButtonLinear.toggled[bool].connect(self._updateLinearNorm) + + normLayout = qt.QHBoxLayout() + normLayout.setContentsMargins(0, 0, 0, 0) + normLayout.setSpacing(10) + normLayout.addWidget(self._normButtonLinear) + normLayout.addWidget(self._normButtonLog) + + formLayout.addRow('Normalization:', normLayout) + + # Min row + self._minValue = _BoundaryWidget(parent=self, value=1.0) + self._minValue.textEdited.connect(self._minMaxTextEdited) + self._minValue.editingFinished.connect(self._minEditingFinished) + self._minValue.sigValueChanged.connect(self._updateMinMax) + formLayout.addRow('\tMin:', self._minValue) + + # Max row + self._maxValue = _BoundaryWidget(parent=self, value=10.0) + self._maxValue.textEdited.connect(self._minMaxTextEdited) + self._maxValue.sigValueChanged.connect(self._updateMinMax) + self._maxValue.editingFinished.connect(self._maxEditingFinished) + formLayout.addRow('\tMax:', self._maxValue) + + # Add plot for histogram + self._plotToolbar = qt.QToolBar(self) + self._plotToolbar.setFloatable(False) + self._plotToolbar.setMovable(False) + self._plotToolbar.setIconSize(qt.QSize(8, 8)) + self._plotToolbar.setStyleSheet("QToolBar { border: 0px }") + self._plotToolbar.setOrientation(qt.Qt.Vertical) + + group = qt.QActionGroup(self._plotToolbar) + group.setExclusive(True) + + action = qt.QAction("Nothing", self) + action.setToolTip("No range nor histogram are displayed. No extra computation have to be done.") + action.setIcon(icons.getQIcon('colormap-none')) + action.setCheckable(True) + action.setData(_DataInPlotMode.NONE) + action.setChecked(action.data() == self._dataInPlotMode) + self._plotToolbar.addAction(action) + group.addAction(action) + action = qt.QAction("Data range", self) + action.setToolTip("Display the data range within the colormap range. A fast data processing have to be done.") + action.setIcon(icons.getQIcon('colormap-range')) + action.setCheckable(True) + action.setData(_DataInPlotMode.RANGE) + action.setChecked(action.data() == self._dataInPlotMode) + self._plotToolbar.addAction(action) + group.addAction(action) + action = qt.QAction("Histogram", self) + action.setToolTip("Display the data histogram within the colormap range. A slow data processing have to be done. ") + action.setIcon(icons.getQIcon('colormap-histogram')) + action.setCheckable(True) + action.setData(_DataInPlotMode.HISTOGRAM) + action.setChecked(action.data() == self._dataInPlotMode) + self._plotToolbar.addAction(action) + group.addAction(action) + group.triggered.connect(self._displayDataInPlotModeChanged) + + self._plotBox = qt.QWidget(self) + self._plotInit() + + plotBoxLayout = qt.QHBoxLayout() + plotBoxLayout.setContentsMargins(0, 0, 0, 0) + plotBoxLayout.setSpacing(2) + plotBoxLayout.addWidget(self._plotToolbar) + plotBoxLayout.addWidget(self._plot) + plotBoxLayout.setSizeConstraint(qt.QLayout.SetMinimumSize) + self._plotBox.setLayout(plotBoxLayout) + vLayout.addWidget(self._plotBox) + + # define modal buttons + types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel + self._buttonsModal = qt.QDialogButtonBox(parent=self) + self._buttonsModal.setStandardButtons(types) + self.layout().addWidget(self._buttonsModal) + self._buttonsModal.accepted.connect(self.accept) + self._buttonsModal.rejected.connect(self.reject) + + # define non modal buttons + types = qt.QDialogButtonBox.Close | qt.QDialogButtonBox.Reset + self._buttonsNonModal = qt.QDialogButtonBox(parent=self) + self._buttonsNonModal.setStandardButtons(types) + self.layout().addWidget(self._buttonsNonModal) + self._buttonsNonModal.button(qt.QDialogButtonBox.Close).clicked.connect(self.accept) + self._buttonsNonModal.button(qt.QDialogButtonBox.Reset).clicked.connect(self.resetColormap) + + # Set the colormap to default values + self.setColormap(Colormap(name='gray', normalization='linear', + vmin=None, vmax=None)) + + self.setModal(self.isModal()) + + vLayout.setSizeConstraint(qt.QLayout.SetMinimumSize) + self.setFixedSize(self.sizeHint()) + self._applyColormap() + + def showEvent(self, event): + self.visibleChanged.emit(True) + super(ColormapDialog, self).showEvent(event) + + def closeEvent(self, event): + if not self.isModal(): + self.accept() + super(ColormapDialog, self).closeEvent(event) + + def hideEvent(self, event): + self.visibleChanged.emit(False) + super(ColormapDialog, self).hideEvent(event) + + def close(self): + self.accept() + qt.QDialog.close(self) + + def setModal(self, modal): + assert type(modal) is bool + self._buttonsNonModal.setVisible(not modal) + self._buttonsModal.setVisible(modal) + qt.QDialog.setModal(self, modal) + + def exec_(self): + wasModal = self.isModal() + self.setModal(True) + result = super(ColormapDialog, self).exec_() + self.setModal(wasModal) + return result + + def _plotInit(self): + """Init the plot to display the range and the values""" + self._plot = PlotWidget() + self._plot.setDataMargins(yMinMargin=0.125, yMaxMargin=0.125) + self._plot.getXAxis().setLabel("Data Values") + self._plot.getYAxis().setLabel("") + self._plot.setInteractiveMode('select', zoomOnWheel=False) + self._plot.setActiveCurveHandling(False) + self._plot.setMinimumSize(qt.QSize(250, 200)) + self._plot.sigPlotSignal.connect(self._plotSlot) + + self._plotUpdate() + + def sizeHint(self): + return self.layout().minimumSize() + + def _plotUpdate(self, updateMarkers=True): + """Update the plot content + + :param bool updateMarkers: True to update markers, False otherwith + """ + colormap = self.getColormap() + if colormap is None: + if self._plotBox.isVisibleTo(self): + self._plotBox.setVisible(False) + self.setFixedSize(self.sizeHint()) + return + + if not self._plotBox.isVisibleTo(self): + self._plotBox.setVisible(True) + self.setFixedSize(self.sizeHint()) + + minData, maxData = self._minValue.getFiniteValue(), self._maxValue.getFiniteValue() + if minData > maxData: + # avoid a full collapse + minData, maxData = maxData, minData + minimum = minData + maximum = maxData + + if self._dataRange is not None: + minRange = self._dataRange[0] + maxRange = self._dataRange[2] + minimum = min(minimum, minRange) + maximum = max(maximum, maxRange) + + if self._histogramData is not None: + minHisto = self._histogramData[1][0] + maxHisto = self._histogramData[1][-1] + minimum = min(minimum, minHisto) + maximum = max(maximum, maxHisto) + + marge = abs(maximum - minimum) / 6.0 + if marge < 0.0001: + # Smaller that the QLineEdit precision + marge = 0.0001 + + minView, maxView = minimum - marge, maximum + marge + + if updateMarkers: + # Save the state in we are not moving the markers + self._initialRange = minView, maxView + elif self._initialRange is not None: + minView = min(minView, self._initialRange[0]) + maxView = max(maxView, self._initialRange[1]) + + x = [minView, minData, maxData, maxView] + y = [0, 0, 1, 1] + + self._plot.addCurve(x, y, + legend="ConstrainedCurve", + color='black', + symbol='o', + linestyle='-', + resetzoom=False) + + if updateMarkers: + minDraggable = (self._colormap().isEditable() and + not self._minValue.isAutoChecked()) + self._plot.addXMarker( + self._minValue.getFiniteValue(), + legend='Min', + text='Min', + draggable=minDraggable, + color='blue', + constraint=self._plotMinMarkerConstraint) + + maxDraggable = (self._colormap().isEditable() and + not self._maxValue.isAutoChecked()) + self._plot.addXMarker( + self._maxValue.getFiniteValue(), + legend='Max', + text='Max', + draggable=maxDraggable, + color='blue', + constraint=self._plotMaxMarkerConstraint) + + self._plot.resetZoom() + + def _plotMinMarkerConstraint(self, x, y): + """Constraint of the min marker""" + return min(x, self._maxValue.getFiniteValue()), y + + def _plotMaxMarkerConstraint(self, x, y): + """Constraint of the max marker""" + return max(x, self._minValue.getFiniteValue()), y + + def _plotSlot(self, event): + """Handle events from the plot""" + if event['event'] in ('markerMoving', 'markerMoved'): + value = float(str(event['xdata'])) + if event['label'] == 'Min': + self._minValue.setValue(value) + elif event['label'] == 'Max': + self._maxValue.setValue(value) + + # This will recreate the markers while interacting... + # It might break if marker interaction is changed + if event['event'] == 'markerMoved': + self._initialRange = None + self._updateMinMax() + else: + self._plotUpdate(updateMarkers=False) + + @staticmethod + def computeDataRange(data): + """Compute the data range as used by :meth:`setDataRange`. + + :param data: The data to process + :rtype: Tuple(float, float, float) + """ + if data is None or len(data) == 0: + return None, None, None + + dataRange = min_max(data, min_positive=True, finite=True) + if dataRange.minimum is None: + # Only non-finite data + dataRange = None + + if dataRange is not None: + min_positive = dataRange.min_positive + if min_positive is None: + min_positive = float('nan') + dataRange = dataRange.minimum, min_positive, dataRange.maximum + + if dataRange is None or len(dataRange) != 3: + qt.QMessageBox.warning( + None, "No Data", + "Image data does not contain any real value") + dataRange = 1., 1., 10. + + return dataRange + + @staticmethod + def computeHistogram(data): + """Compute the data histogram as used by :meth:`setHistogram`. + + :param data: The data to process + :rtype: Tuple(List(float),List(float) + """ + _data = data + if _data.ndim == 3: # RGB(A) images + _logger.info('Converting current image from RGB(A) to grayscale\ + in order to compute the intensity distribution') + _data = (_data[:, :, 0] * 0.299 + + _data[:, :, 1] * 0.587 + + _data[:, :, 2] * 0.114) + + if len(_data) == 0: + return None, None + + xmin, xmax = min_max(_data, min_positive=False, finite=True) + nbins = min(256, int(numpy.sqrt(_data.size))) + data_range = xmin, xmax + + # bad hack: get 256 bins in the case we have a B&W + if numpy.issubdtype(_data.dtype, numpy.integer): + if nbins > xmax - xmin: + nbins = xmax - xmin + + nbins = max(2, nbins) + _data = _data.ravel().astype(numpy.float32) + + histogram = Histogramnd(_data, n_bins=nbins, histo_range=data_range) + return histogram.histo, histogram.edges[0] + + def _getData(self): + if self._data is None: + return None + return self._data() + + def setData(self, data): + """Store the data as a weakref. + + According to the state of the dialog, the data will be used to display + the data range or the histogram of the data using :meth:`setDataRange` + and :meth:`setHistogram` + """ + oldData = self._getData() + if oldData is data: + return + + if data is None: + self._data = None + else: + self._data = weakref.ref(data, self._dataAboutToFinalize) + + self._updateDataInPlot() + + def _setDataInPlotMode(self, mode): + if self._dataInPlotMode == mode: + return + self._dataInPlotMode = mode + self._updateDataInPlot() + + def _displayDataInPlotModeChanged(self, action): + mode = action.data() + self._setDataInPlotMode(mode) + + def _updateDataInPlot(self): + data = self._getData() + if data is None: + self.setDataRange() + self.setHistogram() + return + + if data.size == 0: + # One or more dimensions are equal to 0 + self.setHistogram() + self.setDataRange() + return + + mode = self._dataInPlotMode + + if mode == _DataInPlotMode.NONE: + self.setHistogram() + self.setDataRange() + elif mode == _DataInPlotMode.RANGE: + result = self.computeDataRange(data) + self.setHistogram() + self.setDataRange(*result) + elif mode == _DataInPlotMode.HISTOGRAM: + # The histogram should be done in a worker thread + result = self.computeHistogram(data) + self.setHistogram(*result) + self.setDataRange() + + def _colormapAboutToFinalize(self, weakrefColormap): + """Callback when the data weakref is about to be finalized.""" + if self._colormap is weakrefColormap: + self.setColormap(None) + + def _dataAboutToFinalize(self, weakrefData): + """Callback when the data weakref is about to be finalized.""" + if self._data is weakrefData: + self.setData(None) + + def getHistogram(self): + """Returns the counts and bin edges of the displayed histogram. + + :return: (hist, bin_edges) + :rtype: 2-tuple of numpy arrays""" + if self._histogramData is None: + return None + else: + bins, counts = self._histogramData + return numpy.array(bins, copy=True), numpy.array(counts, copy=True) + + def setHistogram(self, hist=None, bin_edges=None): + """Set the histogram to display. + + This update the data range with the bounds of the bins. + + :param hist: array-like of counts or None to hide histogram + :param bin_edges: array-like of bins edges or None to hide histogram + """ + if hist is None or bin_edges is None: + self._histogramData = None + self._plot.remove(legend='Histogram', kind='histogram') + else: + hist = numpy.array(hist, copy=True) + bin_edges = numpy.array(bin_edges, copy=True) + self._histogramData = hist, bin_edges + norm_hist = hist / max(hist) + self._plot.addHistogram(norm_hist, + bin_edges, + legend="Histogram", + color='gray', + align='center', + fill=True) + self._updateMinMaxData() + + def getColormap(self): + """Return the colormap description as a :class:`.Colormap`. + + """ + if self._colormap is None: + return None + return self._colormap() + + def resetColormap(self): + """ + Reset the colormap state before modification. + + ..note :: the colormap reference state is the state when set or the + state when validated + """ + colormap = self.getColormap() + if colormap is not None and self._colormapStoredState is not None: + if self._colormap()._toDict() != self._colormapStoredState: + self._ignoreColormapChange = True + colormap._setFromDict(self._colormapStoredState) + self._ignoreColormapChange = False + self._applyColormap() + + def setDataRange(self, minimum=None, positiveMin=None, maximum=None): + """Set the range of data to use for the range of the histogram area. + + :param float minimum: The minimum of the data + :param float positiveMin: The positive minimum of the data + :param float maximum: The maximum of the data + """ + if minimum is None or positiveMin is None or maximum is None: + self._dataRange = None + self._plot.remove(legend='Range', kind='histogram') + else: + hist = numpy.array([1]) + bin_edges = numpy.array([minimum, maximum]) + self._plot.addHistogram(hist, + bin_edges, + legend="Range", + color='gray', + align='center', + fill=True) + self._dataRange = minimum, positiveMin, maximum + self._updateMinMaxData() + + def _updateMinMaxData(self): + """Update the min and max of the data according to the data range and + the histogram preset.""" + colormap = self.getColormap() + + minimum = float("+inf") + maximum = float("-inf") + + if colormap is not None and colormap.getNormalization() == colormap.LOGARITHM: + # find a range in the positive part of the data + if self._dataRange is not None: + minimum = min(minimum, self._dataRange[1]) + maximum = max(maximum, self._dataRange[2]) + if self._histogramData is not None: + positives = list(filter(lambda x: x > 0, self._histogramData[1])) + if len(positives) > 0: + minimum = min(minimum, positives[0]) + maximum = max(maximum, positives[-1]) + else: + if self._dataRange is not None: + minimum = min(minimum, self._dataRange[0]) + maximum = max(maximum, self._dataRange[2]) + if self._histogramData is not None: + minimum = min(minimum, self._histogramData[1][0]) + maximum = max(maximum, self._histogramData[1][-1]) + + if not numpy.isfinite(minimum): + minimum = None + if not numpy.isfinite(maximum): + maximum = None + + self._minValue.setDataValue(minimum) + self._maxValue.setDataValue(maximum) + self._plotUpdate() + + def accept(self): + self.storeCurrentState() + qt.QDialog.accept(self) + + def storeCurrentState(self): + """ + save the current value sof the colormap if the user want to undo is + modifications + """ + colormap = self.getColormap() + if colormap is not None: + self._colormapStoredState = colormap._toDict() + else: + self._colormapStoredState = None + + def reject(self): + self.resetColormap() + qt.QDialog.reject(self) + + def setColormap(self, colormap): + """Set the colormap description + + :param :class:`Colormap` colormap: the colormap to edit + """ + assert colormap is None or isinstance(colormap, Colormap) + if self._ignoreColormapChange is True: + return + + oldColormap = self.getColormap() + if oldColormap is colormap: + return + if oldColormap is not None: + oldColormap.sigChanged.disconnect(self._applyColormap) + + if colormap is not None: + colormap.sigChanged.connect(self._applyColormap) + colormap = weakref.ref(colormap, self._colormapAboutToFinalize) + + self._colormap = colormap + self.storeCurrentState() + self._updateResetButton() + self._applyColormap() + + def _updateResetButton(self): + resetButton = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset) + rStateEnabled = False + colormap = self.getColormap() + if colormap is not None and colormap.isEditable(): + # can reset only in the case the colormap changed + rStateEnabled = colormap._toDict() != self._colormapStoredState + resetButton.setEnabled(rStateEnabled) + + def _applyColormap(self): + self._updateResetButton() + if self._ignoreColormapChange is True: + return + + colormap = self.getColormap() + if colormap is None: + self._comboBoxColormap.setEnabled(False) + self._normButtonLinear.setEnabled(False) + self._normButtonLog.setEnabled(False) + self._minValue.setEnabled(False) + self._maxValue.setEnabled(False) + else: + self._ignoreColormapChange = True + + if colormap.getName() is not None: + name = colormap.getName() + self._comboBoxColormap.setCurrentName(name) + self._comboBoxColormap.setEnabled(self._colormap().isEditable()) + + assert colormap.getNormalization() in Colormap.NORMALIZATIONS + self._normButtonLinear.setChecked( + colormap.getNormalization() == Colormap.LINEAR) + self._normButtonLog.setChecked( + colormap.getNormalization() == Colormap.LOGARITHM) + vmin = colormap.getVMin() + vmax = colormap.getVMax() + dataRange = colormap.getColormapRange() + self._normButtonLinear.setEnabled(self._colormap().isEditable()) + self._normButtonLog.setEnabled(self._colormap().isEditable()) + self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None) + self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None) + self._minValue.setEnabled(self._colormap().isEditable()) + self._maxValue.setEnabled(self._colormap().isEditable()) + self._ignoreColormapChange = False + + self._plotUpdate() + + def _updateMinMax(self): + if self._ignoreColormapChange is True: + return + + vmin = self._minValue.getFiniteValue() + vmax = self._maxValue.getFiniteValue() + if vmax is not None and vmin is not None and vmax < vmin: + # If only one autoscale is checked constraints are too strong + # We have to edit a user value anyway it is not requested + # TODO: It would be better IMO to disable the auto checkbox before + # this case occur (valls) + cmin = self._minValue.isAutoChecked() + cmax = self._maxValue.isAutoChecked() + if cmin is False: + self._minValue.setFiniteValue(vmax) + if cmax is False: + self._maxValue.setFiniteValue(vmin) + + vmin = self._minValue.getValue() + vmax = self._maxValue.getValue() + self._ignoreColormapChange = True + colormap = self._colormap() + if colormap is not None: + colormap.setVRange(vmin, vmax) + self._ignoreColormapChange = False + self._plotUpdate() + self._updateResetButton() + + def _updateName(self): + if self._ignoreColormapChange is True: + return + + if self._colormap(): + self._ignoreColormapChange = True + self._colormap().setName( + self._comboBoxColormap.getCurrentName()) + self._ignoreColormapChange = False + + def _updateLinearNorm(self, isNormLinear): + if self._ignoreColormapChange is True: + return + + if self._colormap(): + self._ignoreColormapChange = True + norm = Colormap.LINEAR if isNormLinear else Colormap.LOGARITHM + self._colormap().setNormalization(norm) + self._ignoreColormapChange = False + + def _minMaxTextEdited(self, text): + """Handle _minValue and _maxValue textEdited signal""" + self._minMaxWasEdited = True + + def _minEditingFinished(self): + """Handle _minValue editingFinished signal + + Together with :meth:`_minMaxTextEdited`, this avoids to notify + colormap change when the min and max value where not edited. + """ + if self._minMaxWasEdited: + self._minMaxWasEdited = False + + # Fix start value + if (self._maxValue.getValue() is not None and + self._minValue.getValue() > self._maxValue.getValue()): + self._minValue.setValue(self._maxValue.getValue()) + self._updateMinMax() + + def _maxEditingFinished(self): + """Handle _maxValue editingFinished signal + + Together with :meth:`_minMaxTextEdited`, this avoids to notify + colormap change when the min and max value where not edited. + """ + if self._minMaxWasEdited: + self._minMaxWasEdited = False + + # Fix end value + if (self._minValue.getValue() is not None and + self._minValue.getValue() > self._maxValue.getValue()): + self._maxValue.setValue(self._minValue.getValue()) + self._updateMinMax() + + def keyPressEvent(self, event): + """Override key handling. + + It disables leaving the dialog when editing a text field. + """ + if event.key() == qt.Qt.Key_Enter and (self._minValue.hasFocus() or + self._maxValue.hasFocus()): + # Bypass QDialog keyPressEvent + # To avoid leaving the dialog when pressing enter on a text field + super(qt.QDialog, self).keyPressEvent(event) + else: + # Use QDialog keyPressEvent + super(ColormapDialog, self).keyPressEvent(event) + + def _activeLogNorm(self, isLog): + if self._ignoreColormapChange is True: + return + if self._colormap(): + self._ignoreColormapChange = True + norm = Colormap.LOGARITHM if isLog is True else Colormap.LINEAR + self._colormap().setNormalization(norm) + self._ignoreColormapChange = False + self._updateMinMaxData() diff --git a/silx/gui/dialog/GroupDialog.py b/silx/gui/dialog/GroupDialog.py new file mode 100644 index 0000000..71235d2 --- /dev/null +++ b/silx/gui/dialog/GroupDialog.py @@ -0,0 +1,177 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a dialog widget to select a HDF5 group in a +tree. + +.. autoclass:: GroupDialog + :show-inheritance: + :members: + + +""" +from silx.gui import qt +from silx.gui.hdf5.Hdf5TreeView import Hdf5TreeView +import silx.io +from silx.io.url import DataUrl + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "22/03/2018" + + +class GroupDialog(qt.QDialog): + """This :class:`QDialog` uses a :class:`silx.gui.hdf5.Hdf5TreeView` to + provide a HDF5 group selection dialog. + + The information identifying the selected node is provided as a + :class:`silx.io.url.DataUrl`. + + Example: + + .. code-block:: python + + dialog = GroupDialog() + dialog.addFile(filepath1) + dialog.addFile(filepath2) + + if dialog.exec_(): + print("File path: %s" % dialog.getSelectedDataUrl().file_path()) + print("HDF5 group path : %s " % dialog.getSelectedDataUrl().data_path()) + else: + print("Operation cancelled :(") + + """ + def __init__(self, parent=None): + qt.QDialog.__init__(self, parent) + self.setWindowTitle("HDF5 group selection") + + self._tree = Hdf5TreeView(self) + self._tree.setSelectionMode(qt.QAbstractItemView.SingleSelection) + self._tree.activated.connect(self._onActivation) + self._tree.selectionModel().selectionChanged.connect( + self._onSelectionChange) + + self._model = self._tree.findHdf5TreeModel() + + self._header = self._tree.header() + self._header.setSections([self._model.NAME_COLUMN, + self._model.NODE_COLUMN, + self._model.LINK_COLUMN]) + + _labelSubgroup = qt.QLabel(self) + _labelSubgroup.setText("Subgroup name (optional)") + self._lineEditSubgroup = qt.QLineEdit(self) + self._lineEditSubgroup.setToolTip( + "Specify the name of a new subgroup " + "to be created in the selected group.") + self._lineEditSubgroup.textChanged.connect( + self._onSubgroupNameChange) + + _labelSelectionTitle = qt.QLabel(self) + _labelSelectionTitle.setText("Current selection") + self._labelSelection = qt.QLabel(self) + self._labelSelection.setStyleSheet("color: gray") + self._labelSelection.setWordWrap(True) + self._labelSelection.setText("Select a group") + + buttonBox = qt.QDialogButtonBox() + self._okButton = buttonBox.addButton(qt.QDialogButtonBox.Ok) + self._okButton.setEnabled(False) + buttonBox.addButton(qt.QDialogButtonBox.Cancel) + + buttonBox.accepted.connect(self.accept) + buttonBox.rejected.connect(self.reject) + + vlayout = qt.QVBoxLayout(self) + vlayout.addWidget(self._tree) + vlayout.addWidget(_labelSubgroup) + vlayout.addWidget(self._lineEditSubgroup) + vlayout.addWidget(_labelSelectionTitle) + vlayout.addWidget(self._labelSelection) + vlayout.addWidget(buttonBox) + self.setLayout(vlayout) + + self.setMinimumWidth(400) + + self._selectedUrl = None + + def addFile(self, path): + """Add a HDF5 file to the tree. + All groups it contains will be selectable in the dialog. + + :param str path: File path + """ + self._model.insertFile(path) + + def addGroup(self, group): + """Add a HDF5 group to the tree. This group and all its subgroups + will be selectable in the dialog. + + :param h5py.Group group: HDF5 group + """ + self._model.insertH5pyObject(group) + + def _onActivation(self, idx): + # double-click or enter press + nodes = list(self._tree.selectedH5Nodes()) + node = nodes[0] + if silx.io.is_group(node.h5py_object): + self.accept() + + def _onSelectionChange(self, old, new): + self._updateUrl() + + def _onSubgroupNameChange(self, text): + self._updateUrl() + + def _updateUrl(self): + nodes = list(self._tree.selectedH5Nodes()) + subgroupName = self._lineEditSubgroup.text() + if nodes: + node = nodes[0] + if silx.io.is_group(node.h5py_object): + data_path = node.local_name + if subgroupName.lstrip("/"): + if not data_path.endswith("/"): + data_path += "/" + data_path += subgroupName.lstrip("/") + self._selectedUrl = DataUrl(file_path=node.local_filename, + data_path=data_path) + self._okButton.setEnabled(True) + self._labelSelection.setText( + self._selectedUrl.path()) + else: + self._selectedUrl = None + self._okButton.setEnabled(False) + self._labelSelection.setText("Select a group") + + def getSelectedDataUrl(self): + """Return a :class:`DataUrl` with a file path and a data path. + Return None if the dialog was cancelled. + + :return: :class:`silx.io.url.DataUrl` object pointing to the + selected group. + """ + return self._selectedUrl diff --git a/silx/gui/dialog/test/__init__.py b/silx/gui/dialog/test/__init__.py index eee8aea..f43a37a 100644 --- a/silx/gui/dialog/test/__init__.py +++ b/silx/gui/dialog/test/__init__.py @@ -26,7 +26,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "07/02/2018" +__date__ = "24/04/2018" import logging @@ -42,6 +42,8 @@ def suite(): test_suite = unittest.TestSuite() from . import test_imagefiledialog from . import test_datafiledialog + from . import test_colormapdialog test_suite.addTest(test_imagefiledialog.suite()) test_suite.addTest(test_datafiledialog.suite()) + test_suite.addTest(test_colormapdialog.suite()) return test_suite diff --git a/silx/gui/plot/test/testColormapDialog.py b/silx/gui/dialog/test/test_colormapdialog.py index 8087369..6f0ceea 100644 --- a/silx/gui/plot/test/testColormapDialog.py +++ b/silx/gui/dialog/test/test_colormapdialog.py @@ -26,7 +26,7 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "17/01/2018" +__date__ = "23/05/2018" import doctest @@ -34,9 +34,9 @@ import unittest from silx.gui.test.utils import qWaitForWindowExposedAndActivate from silx.gui import qt -from silx.gui.plot import ColormapDialog +from silx.gui.dialog import ColormapDialog from silx.gui.test.utils import TestCaseQt -from silx.gui.plot.Colormap import Colormap, preferredColormaps +from silx.gui.colors import Colormap, preferredColormaps from silx.utils.testutils import ParametricTestCase from silx.gui.plot.PlotWindow import PlotWindow @@ -119,7 +119,7 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase): self.assertTrue(self.colormap.getVMin() is None) self.assertTrue(self.colormap.getVMax() is None) self.assertTrue(self.colormap.isAutoscale() is True) - + def testGUIModalCancel(self): """Make sure the colormap is not modified if gone through reject""" assert self.colormap.isAutoscale() is False @@ -308,6 +308,19 @@ class TestColormapDialog(TestCaseQt, ParametricTestCase): colormap.setEditable(False) self.assertFalse(resetButton.isEnabled()) + def testImageData(self): + data = numpy.random.rand(5, 5) + self.colormapDiag.setData(data) + + def testEmptyData(self): + data = numpy.empty((10, 0)) + self.colormapDiag.setData(data) + + def testNoneData(self): + data = numpy.random.rand(5, 5) + self.colormapDiag.setData(data) + self.colormapDiag.setData(None) + class TestColormapAction(TestCaseQt): def setUp(self): @@ -336,16 +349,16 @@ class TestColormapAction(TestCaseQt): self.assertTrue(self.colormapDialog.getColormap() is self.defaultColormap) self.plot.addImage(data=numpy.random.rand(10, 10), legend='img1', - replace=False, origin=(0, 0), + origin=(0, 0), colormap=self.colormap1) self.plot.setActiveImage('img1') self.assertTrue(self.colormapDialog.getColormap() is self.colormap1) self.plot.addImage(data=numpy.random.rand(10, 10), legend='img2', - replace=False, origin=(0, 0), + origin=(0, 0), colormap=self.colormap2) self.plot.addImage(data=numpy.random.rand(10, 10), legend='img3', - replace=False, origin=(0, 0)) + origin=(0, 0)) self.plot.setActiveImage('img3') self.assertTrue(self.colormapDialog.getColormap() is self.defaultColormap) @@ -363,7 +376,7 @@ class TestColormapAction(TestCaseQt): self.plot.getColormapAction()._actionTriggered(checked=True) self.assertTrue(self.plot.getColormapAction().isChecked()) self.plot.addImage(data=numpy.random.rand(10, 10), legend='img1', - replace=False, origin=(0, 0), + origin=(0, 0), colormap=self.colormap1) self.colormap1.setName('red') self.plot.getColormapAction()._actionTriggered() diff --git a/silx/gui/dialog/test/test_datafiledialog.py b/silx/gui/dialog/test/test_datafiledialog.py index bdda810..38fa03b 100644 --- a/silx/gui/dialog/test/test_datafiledialog.py +++ b/silx/gui/dialog/test/test_datafiledialog.py @@ -26,7 +26,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "14/02/2018" +__date__ = "03/07/2018" import unittest @@ -79,6 +79,20 @@ def setUpModule(): f["nxdata"].attrs["NX_class"] = u"NXdata" f.close() + if h5py is not None: + directory = os.path.join(_tmpDirectory, "data") + os.mkdir(directory) + filename = os.path.join(directory, "data.h5") + f = h5py.File(filename, "w") + f["scalar"] = 10 + f["image"] = data + f["cube"] = [data, data + 1, data + 2] + f["complex_image"] = data * 1j + f["group/image"] = data + f["nxdata/foo"] = 10 + f["nxdata"].attrs["NX_class"] = u"NXdata" + f.close() + filename = _tmpDirectory + "/badformat.h5" with io.open(filename, "wb") as f: f.write(b"{\nHello Nurse!") @@ -270,7 +284,7 @@ class TestDataFileDialogInteraction(utils.TestCaseQt, _UtilsMixin): url = utils.findChildren(dialog, qt.QLineEdit, name="url")[0] action = utils.findChildren(dialog, qt.QAction, name="toParentAction")[0] toParentButton = utils.getQToolButtonFromAction(action) - filename = _tmpDirectory + "/data.h5" + filename = _tmpDirectory + "/data/data.h5" # init state path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path() @@ -286,11 +300,11 @@ class TestDataFileDialogInteraction(utils.TestCaseQt, _UtilsMixin): self.mouseClick(toParentButton, qt.Qt.LeftButton) self.qWaitForPendingActions(dialog) - self.assertSamePath(url.text(), _tmpDirectory) + self.assertSamePath(url.text(), _tmpDirectory + "/data") self.mouseClick(toParentButton, qt.Qt.LeftButton) self.qWaitForPendingActions(dialog) - self.assertSamePath(url.text(), os.path.dirname(_tmpDirectory)) + self.assertSamePath(url.text(), _tmpDirectory) def testClickOnBackToRootTool(self): if h5py is None: @@ -529,7 +543,7 @@ class TestDataFileDialogInteraction(utils.TestCaseQt, _UtilsMixin): self.qWaitForWindowExposed(dialog) dialog.selectUrl(_tmpDirectory) self.qWaitForPendingActions(dialog) - self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 3) + self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4) class TestDataFileDialog_FilterDataset(utils.TestCaseQt, _UtilsMixin): diff --git a/silx/gui/dialog/test/test_imagefiledialog.py b/silx/gui/dialog/test/test_imagefiledialog.py index 7909f10..8fef3c5 100644 --- a/silx/gui/dialog/test/test_imagefiledialog.py +++ b/silx/gui/dialog/test/test_imagefiledialog.py @@ -26,7 +26,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "12/02/2018" +__date__ = "03/07/2018" import unittest @@ -50,7 +50,7 @@ import silx.io.url from silx.gui import qt from silx.gui.test import utils from ..ImageFileDialog import ImageFileDialog -from silx.gui.plot.Colormap import Colormap +from silx.gui.colors import Colormap from silx.gui.hdf5 import Hdf5TreeModel _tmpDirectory = None @@ -88,6 +88,18 @@ def setUpModule(): f["group/image"] = data f.close() + if h5py is not None: + directory = os.path.join(_tmpDirectory, "data") + os.mkdir(directory) + filename = os.path.join(directory, "data.h5") + f = h5py.File(filename, "w") + f["scalar"] = 10 + f["image"] = data + f["cube"] = [data, data + 1, data + 2] + f["complex_image"] = data * 1j + f["group/image"] = data + f.close() + filename = _tmpDirectory + "/badformat.edf" with io.open(filename, "wb") as f: f.write(b"{\nHello Nurse!") @@ -256,27 +268,31 @@ class TestImageFileDialogInteraction(utils.TestCaseQt, _UtilsMixin): url = utils.findChildren(dialog, qt.QLineEdit, name="url")[0] action = utils.findChildren(dialog, qt.QAction, name="toParentAction")[0] toParentButton = utils.getQToolButtonFromAction(action) - filename = _tmpDirectory + "/data.h5" + filename = _tmpDirectory + "/data/data.h5" # init state path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path() dialog.selectUrl(path) self.qWaitForPendingActions(dialog) path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path() + print(url.text()) self.assertSamePath(url.text(), path) # test self.mouseClick(toParentButton, qt.Qt.LeftButton) self.qWaitForPendingActions(dialog) path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path() + print(url.text()) self.assertSamePath(url.text(), path) self.mouseClick(toParentButton, qt.Qt.LeftButton) self.qWaitForPendingActions(dialog) - self.assertSamePath(url.text(), _tmpDirectory) + print(url.text()) + self.assertSamePath(url.text(), _tmpDirectory + "/data") self.mouseClick(toParentButton, qt.Qt.LeftButton) self.qWaitForPendingActions(dialog) - self.assertSamePath(url.text(), os.path.dirname(_tmpDirectory)) + print(url.text()) + self.assertSamePath(url.text(), _tmpDirectory) def testClickOnBackToRootTool(self): if h5py is None: @@ -540,21 +556,21 @@ class TestImageFileDialogInteraction(utils.TestCaseQt, _UtilsMixin): self.qWaitForWindowExposed(dialog) dialog.selectUrl(_tmpDirectory) self.qWaitForPendingActions(dialog) - self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 5) + self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 6) codecName = fabio.edfimage.EdfImage.codec_name() index = filters.indexFromCodec(codecName) filters.setCurrentIndex(index) filters.activated[int].emit(index) self.qWait(50) - self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 3) + self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4) codecName = fabio.fit2dmaskimage.Fit2dMaskImage.codec_name() index = filters.indexFromCodec(codecName) filters.setCurrentIndex(index) filters.activated[int].emit(index) self.qWait(50) - self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 1) + self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 2) class TestImageFileDialogApi(utils.TestCaseQt, _UtilsMixin): diff --git a/silx/gui/hdf5/Hdf5Formatter.py b/silx/gui/hdf5/Hdf5Formatter.py index 0e3697f..6802142 100644 --- a/silx/gui/hdf5/Hdf5Formatter.py +++ b/silx/gui/hdf5/Hdf5Formatter.py @@ -27,7 +27,7 @@ text.""" __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "23/01/2018" +__date__ = "06/06/2018" import numpy from silx.third_party import six @@ -119,7 +119,11 @@ class Hdf5Formatter(qt.QObject): return text def humanReadableType(self, dataset, full=False): - dtype = dataset.dtype + if hasattr(dataset, "dtype"): + dtype = dataset.dtype + else: + # Fallback... + dtype = type(dataset) return self.humanReadableDType(dtype, full) def humanReadableDType(self, dtype, full=False): @@ -164,6 +168,16 @@ class Hdf5Formatter(qt.QObject): return "enum" text = str(dtype.newbyteorder('N')) + if numpy.issubdtype(dtype, numpy.floating): + if hasattr(numpy, "float128") and dtype == numpy.float128: + text = "float80" + if full: + text += " (padding 128bits)" + elif hasattr(numpy, "float96") and dtype == numpy.float96: + text = "float80" + if full: + text += " (padding 96bits)" + if full: if dtype.byteorder == "<": text = "Little-endian " + text diff --git a/silx/gui/hdf5/Hdf5TreeModel.py b/silx/gui/hdf5/Hdf5TreeModel.py index 2d62429..835708a 100644 --- a/silx/gui/hdf5/Hdf5TreeModel.py +++ b/silx/gui/hdf5/Hdf5TreeModel.py @@ -25,7 +25,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "29/11/2017" +__date__ = "11/06/2018" import os @@ -205,7 +205,23 @@ class Hdf5TreeModel(qt.QAbstractItemModel): ] """List of logical columns available""" - def __init__(self, parent=None): + sigH5pyObjectLoaded = qt.Signal(object) + """Emitted when a new root item was loaded and inserted to the model.""" + + sigH5pyObjectRemoved = qt.Signal(object) + """Emitted when a root item is removed from the model.""" + + sigH5pyObjectSynchronized = qt.Signal(object, object) + """Emitted when an item was synchronized.""" + + def __init__(self, parent=None, ownFiles=True): + """ + Constructor + + :param qt.QWidget parent: Parent widget + :param bool ownFiles: If true (default) the model will manage the files + life cycle when they was added using path (like DnD). + """ super(Hdf5TreeModel, self).__init__(parent) self.header_labels = [None] * len(self.COLUMN_IDS) @@ -221,6 +237,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel): self.__root = Hdf5Node() self.__fileDropEnabled = True self.__fileMoveEnabled = True + self.__datasetDragEnabled = False self.__animatedIcon = icons.getWaitIcon() self.__animatedIcon.iconChanged.connect(self.__updateLoadingItems) @@ -235,6 +252,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel): self.__icons.append(icons.getQIcon("item-3dim")) self.__icons.append(icons.getQIcon("item-ndim")) + self.__ownFiles = ownFiles self.__openedFiles = [] """Store the list of files opened by the model itself.""" # FIXME: It should be managed one by one by Hdf5Item itself @@ -285,16 +303,25 @@ class Hdf5TreeModel(qt.QAbstractItemModel): newItem = _unwrapNone(newItem) error = _unwrapNone(error) row = self.__root.indexOfChild(oldItem) + rootIndex = qt.QModelIndex() self.beginRemoveRows(rootIndex, row, row) self.__root.removeChildAtIndex(row) self.endRemoveRows() + if newItem is not None: rootIndex = qt.QModelIndex() - self.__openedFiles.append(newItem.obj) + if self.__ownFiles: + self.__openedFiles.append(newItem.obj) self.beginInsertRows(rootIndex, row, row) self.__root.insertChild(row, newItem) self.endInsertRows() + + if isinstance(oldItem, Hdf5LoadingItem): + self.sigH5pyObjectLoaded.emit(newItem.obj) + else: + self.sigH5pyObjectSynchronized.emit(oldItem.obj, newItem.obj) + # FIXME the error must be displayed def isFileDropEnabled(self): @@ -306,6 +333,15 @@ class Hdf5TreeModel(qt.QAbstractItemModel): fileDropEnabled = qt.Property(bool, isFileDropEnabled, setFileDropEnabled) """Property to enable/disable file dropping in the model.""" + def isDatasetDragEnabled(self): + return self.__datasetDragEnabled + + def setDatasetDragEnabled(self, enabled): + self.__datasetDragEnabled = enabled + + datasetDragEnabled = qt.Property(bool, isDatasetDragEnabled, setDatasetDragEnabled) + """Property to enable/disable drag of datasets.""" + def isFileMoveEnabled(self): return self.__fileMoveEnabled @@ -323,10 +359,12 @@ class Hdf5TreeModel(qt.QAbstractItemModel): return 0 def mimeTypes(self): + types = [] if self.__fileMoveEnabled: - return [_utils.Hdf5NodeMimeData.MIME_TYPE] - else: - return [] + types.append(_utils.Hdf5NodeMimeData.MIME_TYPE) + if self.__datasetDragEnabled: + types.append(_utils.Hdf5DatasetMimeData.MIME_TYPE) + return types def mimeData(self, indexes): """ @@ -336,7 +374,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel): :param List[qt.QModelIndex] indexes: List of indexes :rtype: qt.QMimeData """ - if not self.__fileMoveEnabled or len(indexes) == 0: + if len(indexes) == 0: return None indexes = [i for i in indexes if i.column() == 0] @@ -346,7 +384,13 @@ class Hdf5TreeModel(qt.QAbstractItemModel): raise NotImplementedError("Drag of cell is not implemented") node = self.nodeFromIndex(indexes[0]) - mimeData = _utils.Hdf5NodeMimeData(node) + + if self.__fileMoveEnabled and node.parent is self.__root: + mimeData = _utils.Hdf5NodeMimeData(node=node) + elif self.__datasetDragEnabled: + mimeData = _utils.Hdf5DatasetMimeData(node=node) + else: + mimeData = None return mimeData def flags(self, index): @@ -357,6 +401,8 @@ class Hdf5TreeModel(qt.QAbstractItemModel): if self.__fileMoveEnabled and node.parent is self.__root: # that's a root return qt.Qt.ItemIsDragEnabled | defaultFlags + elif self.__datasetDragEnabled: + return qt.Qt.ItemIsDragEnabled | defaultFlags return defaultFlags elif self.__fileDropEnabled or self.__fileMoveEnabled: return qt.Qt.ItemIsDropEnabled | defaultFlags @@ -543,8 +589,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel): return filename = node.obj.filename - self.removeIndex(index) - self.insertFileAsync(filename, index.row()) + self.insertFileAsync(filename, index.row(), synchronizingNode=node) def synchronizeH5pyObject(self, h5pyObject): """ @@ -560,8 +605,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel): if item.obj is h5pyObject: qindex = self.index(index, 0, qt.QModelIndex()) self.synchronizeIndex(qindex) - else: - index += 1 + index += 1 def removeIndex(self, index): """ @@ -576,6 +620,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel): self.beginRemoveRows(qt.QModelIndex(), index.row(), index.row()) self.__root.removeChildAtIndex(index.row()) self.endRemoveRows() + self.sigH5pyObjectRemoved.emit(node.obj) def removeH5pyObject(self, h5pyObject): """ @@ -608,14 +653,17 @@ class Hdf5TreeModel(qt.QAbstractItemModel): def hasPendingOperations(self): return len(self.__runnerSet) > 0 - def insertFileAsync(self, filename, row=-1): + def insertFileAsync(self, filename, row=-1, synchronizingNode=None): if not os.path.isfile(filename): raise IOError("Filename '%s' must be a file path" % filename) # create temporary item - text = os.path.basename(filename) - item = Hdf5LoadingItem(text=text, parent=self.__root, animatedIcon=self.__animatedIcon) - self.insertNode(row, item) + if synchronizingNode is None: + text = os.path.basename(filename) + item = Hdf5LoadingItem(text=text, parent=self.__root, animatedIcon=self.__animatedIcon) + self.insertNode(row, item) + else: + item = synchronizingNode # start loading the real one runnable = LoadingItemRunnable(filename, item) @@ -634,12 +682,20 @@ class Hdf5TreeModel(qt.QAbstractItemModel): """ try: h5file = silx_io.open(filename) - self.__openedFiles.append(h5file) + if self.__ownFiles: + self.__openedFiles.append(h5file) + self.sigH5pyObjectLoaded.emit(h5file) self.insertH5pyObject(h5file, row=row) except IOError: _logger.debug("File '%s' can't be read.", filename, exc_info=True) raise + def clear(self): + """Remove all the content of the model""" + for _ in range(self.rowCount()): + qindex = self.index(0, 0, qt.QModelIndex()) + self.removeIndex(qindex) + def appendFile(self, filename): self.insertFile(filename, -1) diff --git a/silx/gui/hdf5/Hdf5TreeView.py b/silx/gui/hdf5/Hdf5TreeView.py index 78b5c19..a86140a 100644 --- a/silx/gui/hdf5/Hdf5TreeView.py +++ b/silx/gui/hdf5/Hdf5TreeView.py @@ -25,7 +25,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "20/02/2018" +__date__ = "30/04/2018" import logging @@ -66,10 +66,8 @@ class Hdf5TreeView(qt.QTreeView): """ qt.QTreeView.__init__(self, parent) - model = Hdf5TreeModel(self) - proxy_model = NexusSortFilterProxyModel(self) - proxy_model.setSourceModel(model) - self.setModel(proxy_model) + model = self.createDefaultModel() + self.setModel(model) self.setHeader(Hdf5HeaderView(qt.Qt.Horizontal, self)) self.setSelectionBehavior(qt.QAbstractItemView.SelectRows) @@ -87,6 +85,15 @@ class Hdf5TreeView(qt.QTreeView): self.setContextMenuPolicy(qt.Qt.CustomContextMenu) self.customContextMenuRequested.connect(self._createContextMenu) + def createDefaultModel(self): + """Creates and returns the default model. + + Inherite to custom the default model""" + model = Hdf5TreeModel(self) + proxy_model = NexusSortFilterProxyModel(self) + proxy_model.setSourceModel(model) + return proxy_model + def __removeContextMenuProxies(self, ref): """Callback to remove dead proxy from the list""" self.__context_menu_callbacks.remove(ref) diff --git a/silx/gui/hdf5/NexusSortFilterProxyModel.py b/silx/gui/hdf5/NexusSortFilterProxyModel.py index 9a27968..3f2cf8d 100644 --- a/silx/gui/hdf5/NexusSortFilterProxyModel.py +++ b/silx/gui/hdf5/NexusSortFilterProxyModel.py @@ -25,7 +25,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "10/10/2017" +__date__ = "25/06/2018" import logging @@ -34,6 +34,7 @@ import numpy from .. import qt from .Hdf5TreeModel import Hdf5TreeModel import silx.io.utils +from silx.gui import icons _logger = logging.getLogger(__name__) @@ -45,6 +46,7 @@ class NexusSortFilterProxyModel(qt.QSortFilterProxyModel): def __init__(self, parent=None): qt.QSortFilterProxyModel.__init__(self, parent) self.__split = re.compile("(\\d+|\\D+)") + self.__iconCache = {} def lessThan(self, sourceLeft, sourceRight): """Returns True if the value of the item referred to by the given @@ -86,6 +88,14 @@ class NexusSortFilterProxyModel(qt.QSortFilterProxyModel): nxClass = node.obj.attrs.get("NX_class", None) return nxClass == "NXentry" + def __isNXnode(self, node): + """Returns true if the node is an NX concept""" + class_ = node.h5Class + if class_ is None or class_ != silx.io.utils.H5Type.GROUP: + return False + nxClass = node.obj.attrs.get("NX_class", None) + return nxClass is not None + def getWordsAndNumbers(self, name): """ Returns a list of words and integers composing the name. @@ -96,11 +106,14 @@ class NexusSortFilterProxyModel(qt.QSortFilterProxyModel): :param str name: A name :rtype: List """ + nonSensitive = self.sortCaseSensitivity() == qt.Qt.CaseInsensitive words = self.__split.findall(name) result = [] for i in words: if i[0].isdigit(): i = int(i) + elif nonSensitive: + i = i.lower() result.append(i) return result @@ -145,3 +158,47 @@ class NexusSortFilterProxyModel(qt.QSortFilterProxyModel): except Exception: _logger.debug("Exception occurred", exc_info=True) return None + + def __createCompoundIcon(self, backgroundIcon, foregroundIcon): + icon = qt.QIcon() + + sizes = backgroundIcon.availableSizes() + sizes = sorted(sizes, key=lambda s: s.height()) + sizes = filter(lambda s: s.height() < 100, sizes) + sizes = list(sizes) + if len(sizes) > 0: + baseSize = sizes[-1] + else: + baseSize = qt.QSize(32, 32) + + modes = [qt.QIcon.Normal, qt.QIcon.Disabled] + for mode in modes: + pixmap = qt.QPixmap(baseSize) + pixmap.fill(qt.Qt.transparent) + painter = qt.QPainter(pixmap) + painter.drawPixmap(0, 0, backgroundIcon.pixmap(baseSize, mode=mode)) + painter.drawPixmap(0, 0, foregroundIcon.pixmap(baseSize, mode=mode)) + painter.end() + icon.addPixmap(pixmap, mode=mode) + + return icon + + def __getNxIcon(self, baseIcon): + iconHash = baseIcon.cacheKey() + icon = self.__iconCache.get(iconHash, None) + if icon is None: + nxIcon = icons.getQIcon("layer-nx") + icon = self.__createCompoundIcon(baseIcon, nxIcon) + self.__iconCache[iconHash] = icon + return icon + + def data(self, index, role=qt.Qt.DisplayRole): + result = super(NexusSortFilterProxyModel, self).data(index, role) + + if index.column() == Hdf5TreeModel.NAME_COLUMN: + if role == qt.Qt.DecorationRole: + sourceIndex = self.mapToSource(index) + item = self.sourceModel().data(sourceIndex, Hdf5TreeModel.H5PY_ITEM_ROLE) + if self.__isNXnode(item): + result = self.__getNxIcon(result) + return result diff --git a/silx/gui/hdf5/_utils.py b/silx/gui/hdf5/_utils.py index ddf4db5..8385129 100644 --- a/silx/gui/hdf5/_utils.py +++ b/silx/gui/hdf5/_utils.py @@ -28,7 +28,7 @@ package `silx.gui.hdf5` package. __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "20/12/2017" +__date__ = "04/05/2018" import logging @@ -102,6 +102,26 @@ def htmlFromDict(dictionary, title=None): return result +class Hdf5DatasetMimeData(qt.QMimeData): + """Mimedata class to identify an internal drag and drop of a Hdf5Node.""" + + MIME_TYPE = "application/x-internal-h5py-dataset" + + def __init__(self, node=None, dataset=None): + qt.QMimeData.__init__(self) + self.__dataset = dataset + self.__node = node + self.setData(self.MIME_TYPE, "".encode(encoding='utf-8')) + + def node(self): + return self.__node + + def dataset(self): + if self.__node is not None: + return self.__node.obj + return self.__dataset + + class Hdf5NodeMimeData(qt.QMimeData): """Mimedata class to identify an internal drag and drop of a Hdf5Node.""" diff --git a/silx/gui/hdf5/test/test_hdf5.py b/silx/gui/hdf5/test/test_hdf5.py index 44c4456..fc27f6b 100644 --- a/silx/gui/hdf5/test/test_hdf5.py +++ b/silx/gui/hdf5/test/test_hdf5.py @@ -26,7 +26,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "20/02/2018" +__date__ = "03/05/2018" import time @@ -39,6 +39,7 @@ from contextlib import contextmanager from silx.gui import qt from silx.gui.test.utils import TestCaseQt from silx.gui import hdf5 +from silx.gui.test.utils import SignalListener from silx.io import commonh5 import weakref @@ -48,6 +49,29 @@ except ImportError: h5py = None +_tmpDirectory = None + + +def setUpModule(): + global _tmpDirectory + _tmpDirectory = tempfile.mkdtemp(prefix=__name__) + + if h5py is not None: + filename = _tmpDirectory + "/data.h5" + + # create h5 data + f = h5py.File(filename, "w") + g = f.create_group("arrays") + g.create_dataset("scalar", data=10) + f.close() + + +def tearDownModule(): + global _tmpDirectory + shutil.rmtree(_tmpDirectory) + _tmpDirectory = None + + _called = 0 @@ -71,7 +95,7 @@ class TestHdf5TreeModel(TestCaseQt): self.skipTest("h5py is not available") def waitForPendingOperations(self, model): - for i in range(10): + for _ in range(10): if not model.hasPendingOperations(): break self.qWait(10) @@ -97,53 +121,53 @@ class TestHdf5TreeModel(TestCaseQt): self.assertIsNotNone(model) def testAppendFilename(self): - with self.h5TempFile() as filename: + filename = _tmpDirectory + "/data.h5" + model = hdf5.Hdf5TreeModel() + self.assertEquals(model.rowCount(qt.QModelIndex()), 0) + model.appendFile(filename) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + # clean up + index = model.index(0, 0, qt.QModelIndex()) + h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + ref = weakref.ref(model) + model = None + self.qWaitForDestroy(ref) + + def testAppendBadFilename(self): + model = hdf5.Hdf5TreeModel() + self.assertRaises(IOError, model.appendFile, "#%$") + + def testInsertFilename(self): + filename = _tmpDirectory + "/data.h5" + try: model = hdf5.Hdf5TreeModel() self.assertEquals(model.rowCount(qt.QModelIndex()), 0) - model.appendFile(filename) + model.insertFile(filename) self.assertEquals(model.rowCount(qt.QModelIndex()), 1) # clean up index = model.index(0, 0, qt.QModelIndex()) h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + self.assertIsNotNone(h5File) + finally: ref = weakref.ref(model) model = None self.qWaitForDestroy(ref) - def testAppendBadFilename(self): - model = hdf5.Hdf5TreeModel() - self.assertRaises(IOError, model.appendFile, "#%$") - - def testInsertFilename(self): - with self.h5TempFile() as filename: - try: - model = hdf5.Hdf5TreeModel() - self.assertEquals(model.rowCount(qt.QModelIndex()), 0) - model.insertFile(filename) - self.assertEquals(model.rowCount(qt.QModelIndex()), 1) - # clean up - index = model.index(0, 0, qt.QModelIndex()) - h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) - self.assertIsNotNone(h5File) - finally: - ref = weakref.ref(model) - model = None - self.qWaitForDestroy(ref) - def testInsertFilenameAsync(self): - with self.h5TempFile() as filename: - try: - model = hdf5.Hdf5TreeModel() - self.assertEquals(model.rowCount(qt.QModelIndex()), 0) - model.insertFileAsync(filename) - index = model.index(0, 0, qt.QModelIndex()) - self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5LoadingItem.Hdf5LoadingItem) - self.waitForPendingOperations(model) - index = model.index(0, 0, qt.QModelIndex()) - self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item) - finally: - ref = weakref.ref(model) - model = None - self.qWaitForDestroy(ref) + filename = _tmpDirectory + "/data.h5" + try: + model = hdf5.Hdf5TreeModel() + self.assertEquals(model.rowCount(qt.QModelIndex()), 0) + model.insertFileAsync(filename) + index = model.index(0, 0, qt.QModelIndex()) + self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5LoadingItem.Hdf5LoadingItem) + self.waitForPendingOperations(model) + index = model.index(0, 0, qt.QModelIndex()) + self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item) + finally: + ref = weakref.ref(model) + model = None + self.qWaitForDestroy(ref) def testInsertObject(self): h5 = commonh5.File("/foo/bar/1.mock", "w") @@ -162,36 +186,37 @@ class TestHdf5TreeModel(TestCaseQt): self.assertEquals(model.rowCount(qt.QModelIndex()), 0) def testSynchronizeObject(self): - with self.h5TempFile() as filename: - h5 = h5py.File(filename) - model = hdf5.Hdf5TreeModel() - model.insertH5pyObject(h5) - self.assertEquals(model.rowCount(qt.QModelIndex()), 1) - index = model.index(0, 0, qt.QModelIndex()) - node1 = model.nodeFromIndex(index) - model.synchronizeH5pyObject(h5) - # Now h5 was loaded from it's filename - # Another ref is owned by the model - h5.close() + filename = _tmpDirectory + "/data.h5" + h5 = h5py.File(filename) + model = hdf5.Hdf5TreeModel() + model.insertH5pyObject(h5) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + index = model.index(0, 0, qt.QModelIndex()) + node1 = model.nodeFromIndex(index) + model.synchronizeH5pyObject(h5) + self.waitForPendingOperations(model) + # Now h5 was loaded from it's filename + # Another ref is owned by the model + h5.close() - index = model.index(0, 0, qt.QModelIndex()) - node2 = model.nodeFromIndex(index) - self.assertIsNot(node1, node2) - # after sync - time.sleep(0.1) - self.qapp.processEvents() - time.sleep(0.1) - index = model.index(0, 0, qt.QModelIndex()) - self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item) - # clean up - index = model.index(0, 0, qt.QModelIndex()) - h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) - self.assertIsNotNone(h5File) - h5File = None - # delete the model - ref = weakref.ref(model) - model = None - self.qWaitForDestroy(ref) + index = model.index(0, 0, qt.QModelIndex()) + node2 = model.nodeFromIndex(index) + self.assertIsNot(node1, node2) + # after sync + time.sleep(0.1) + self.qapp.processEvents() + time.sleep(0.1) + index = model.index(0, 0, qt.QModelIndex()) + self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item) + # clean up + index = model.index(0, 0, qt.QModelIndex()) + h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + self.assertIsNotNone(h5File) + h5File = None + # delete the model + ref = weakref.ref(model) + model = None + self.qWaitForDestroy(ref) def testFileMoveState(self): model = hdf5.Hdf5TreeModel() @@ -222,24 +247,24 @@ class TestHdf5TreeModel(TestCaseQt): self.assertNotEquals(model.supportedDropActions(), 0) def testDropExternalFile(self): - with self.h5TempFile() as filename: - model = hdf5.Hdf5TreeModel() - mimeData = qt.QMimeData() - mimeData.setUrls([qt.QUrl.fromLocalFile(filename)]) - model.dropMimeData(mimeData, qt.Qt.CopyAction, 0, 0, qt.QModelIndex()) - self.assertEquals(model.rowCount(qt.QModelIndex()), 1) - # after sync - self.waitForPendingOperations(model) - index = model.index(0, 0, qt.QModelIndex()) - self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item) - # clean up - index = model.index(0, 0, qt.QModelIndex()) - h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) - self.assertIsNotNone(h5File) - h5File = None - ref = weakref.ref(model) - model = None - self.qWaitForDestroy(ref) + filename = _tmpDirectory + "/data.h5" + model = hdf5.Hdf5TreeModel() + mimeData = qt.QMimeData() + mimeData.setUrls([qt.QUrl.fromLocalFile(filename)]) + model.dropMimeData(mimeData, qt.Qt.CopyAction, 0, 0, qt.QModelIndex()) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + # after sync + self.waitForPendingOperations(model) + index = model.index(0, 0, qt.QModelIndex()) + self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item) + # clean up + index = model.index(0, 0, qt.QModelIndex()) + h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + self.assertIsNotNone(h5File) + h5File = None + ref = weakref.ref(model) + model = None + self.qWaitForDestroy(ref) def getRowDataAsDict(self, model, row): displayed = {} @@ -337,6 +362,66 @@ class TestHdf5TreeModel(TestCaseQt): self.assertEquals(index, qt.QModelIndex()) +class TestHdf5TreeModelSignals(TestCaseQt): + + def setUp(self): + TestCaseQt.setUp(self) + self.model = hdf5.Hdf5TreeModel() + filename = _tmpDirectory + "/data.h5" + self.h5 = h5py.File(filename) + self.model.insertH5pyObject(self.h5) + + self.listener = SignalListener() + self.model.sigH5pyObjectLoaded.connect(self.listener.partial(signal="loaded")) + self.model.sigH5pyObjectRemoved.connect(self.listener.partial(signal="removed")) + self.model.sigH5pyObjectSynchronized.connect(self.listener.partial(signal="synchronized")) + + def tearDown(self): + self.signals = None + ref = weakref.ref(self.model) + self.model = None + self.qWaitForDestroy(ref) + self.h5.close() + self.h5 = None + TestCaseQt.tearDown(self) + + def waitForPendingOperations(self, model): + for _ in range(10): + if not model.hasPendingOperations(): + break + self.qWait(10) + else: + raise RuntimeError("Still waiting for a pending operation") + + def testInsert(self): + filename = _tmpDirectory + "/data.h5" + h5 = h5py.File(filename) + self.model.insertH5pyObject(h5) + self.assertEquals(self.listener.callCount(), 0) + + def testLoaded(self): + filename = _tmpDirectory + "/data.h5" + self.model.insertFile(filename) + self.assertEquals(self.listener.callCount(), 1) + self.assertEquals(self.listener.karguments(argumentName="signal")[0], "loaded") + self.assertIsNot(self.listener.arguments(callIndex=0)[0], self.h5) + self.assertEquals(self.listener.arguments(callIndex=0)[0].filename, filename) + + def testRemoved(self): + self.model.removeH5pyObject(self.h5) + self.assertEquals(self.listener.callCount(), 1) + self.assertEquals(self.listener.karguments(argumentName="signal")[0], "removed") + self.assertIs(self.listener.arguments(callIndex=0)[0], self.h5) + + def testSynchonized(self): + self.model.synchronizeH5pyObject(self.h5) + self.waitForPendingOperations(self.model) + self.assertEquals(self.listener.callCount(), 1) + self.assertEquals(self.listener.karguments(argumentName="signal")[0], "synchronized") + self.assertIs(self.listener.arguments(callIndex=0)[0], self.h5) + self.assertIsNot(self.listener.arguments(callIndex=0)[1], self.h5) + + class TestNexusSortFilterProxyModel(TestCaseQt): def getChildNames(self, model, index): @@ -873,6 +958,7 @@ def suite(): test_suite = unittest.TestSuite() loadTests = unittest.defaultTestLoader.loadTestsFromTestCase test_suite.addTest(loadTests(TestHdf5TreeModel)) + test_suite.addTest(loadTests(TestHdf5TreeModelSignals)) test_suite.addTest(loadTests(TestNexusSortFilterProxyModel)) test_suite.addTest(loadTests(TestHdf5TreeView)) test_suite.addTest(loadTests(TestH5Node)) diff --git a/silx/gui/icons.py b/silx/gui/icons.py index 0108b3a..bd10300 100644 --- a/silx/gui/icons.py +++ b/silx/gui/icons.py @@ -29,7 +29,7 @@ Use :func:`getQIcon` to create Qt QIcon from the name identifying an icon. __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "06/09/2017" +__date__ = "19/06/2018" import os @@ -193,10 +193,13 @@ class MultiImageAnimatedIcon(AbstractAnimatedIcon): self.__frames = [] for i in range(100): try: - pixmap = getQPixmap("%s/%02d" % (filename, i)) + filename = getQFile("%s/%02d" % (filename, i)) + except ValueError: + break + try: + icon = qt.QIcon(filename.fileName()) except ValueError: break - icon = qt.QIcon(pixmap) self.__frames.append(icon) if len(self.__frames) == 0: @@ -328,8 +331,7 @@ def getQIcon(name): """ if name not in _cached_icons: qfile = getQFile(name) - pixmap = qt.QPixmap(qfile.fileName()) - icon = qt.QIcon(pixmap) + icon = qt.QIcon(qfile.fileName()) _cached_icons[name] = icon else: icon = _cached_icons[name] @@ -392,7 +394,7 @@ def getQFile(name): for format_ in _supported_formats: format_ = str(format_) filename = silx.resources._resource_filename('%s.%s' % (name, format_), - default_directory=os.path.join('gui', 'icons')) + default_directory=os.path.join('gui', 'icons')) qfile = qt.QFile(filename) if qfile.exists(): return qfile diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py index 2db7b79..0941e82 100644 --- a/silx/gui/plot/ColorBar.py +++ b/silx/gui/plot/ColorBar.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,14 +27,16 @@ __authors__ = ["H. Payno", "T. Vincent"] __license__ = "MIT" -__date__ = "15/02/2018" +__date__ = "24/04/2018" import logging +import weakref import numpy + from ._utils import ticklayout -from .. import qt, icons -from silx.gui.plot import Colormap +from .. import qt +from silx.gui import colors _logger = logging.getLogger(__name__) @@ -70,7 +72,7 @@ class ColorBarWidget(qt.QWidget): def __init__(self, parent=None, plot=None, legend=None): self._isConnected = False - self._plot = None + self._plotRef = None self._colormap = None self._data = None @@ -96,7 +98,7 @@ class ColorBarWidget(qt.QWidget): def getPlot(self): """Returns the :class:`Plot` associated to this widget or None""" - return self._plot + return None if self._plotRef is None else self._plotRef() def setPlot(self, plot): """Associate a plot to the ColorBar @@ -105,27 +107,38 @@ class ColorBarWidget(qt.QWidget): If None will remove any connection with a previous plot. """ self._disconnectPlot() - self._plot = plot + self._plotRef = None if plot is None else weakref.ref(plot) self._connectPlot() def _disconnectPlot(self): """Disconnect from Plot signals""" - if self._plot is not None and self._isConnected: + plot = self.getPlot() + if plot is not None and self._isConnected: self._isConnected = False - self._plot.sigActiveImageChanged.disconnect( + plot.sigActiveImageChanged.disconnect( self._activeImageChanged) - self._plot.sigPlotSignal.disconnect(self._defaultColormapChanged) + plot.sigActiveScatterChanged.disconnect( + self._activeScatterChanged) + plot.sigPlotSignal.disconnect(self._defaultColormapChanged) def _connectPlot(self): """Connect to Plot signals""" - if self._plot is not None and not self._isConnected: - activeImageLegend = self._plot.getActiveImage(just_legend=True) - if activeImageLegend is None: # Show plot default colormap + plot = self.getPlot() + if plot is not None and not self._isConnected: + activeImageLegend = plot.getActiveImage(just_legend=True) + activeScatterLegend = plot._getActiveItem( + kind='scatter', just_legend=True) + if activeImageLegend is None and activeScatterLegend is None: + # Show plot default colormap self._syncWithDefaultColormap() - else: # Show active image colormap + elif activeImageLegend is not None: # Show active image colormap self._activeImageChanged(None, activeImageLegend) - self._plot.sigActiveImageChanged.connect(self._activeImageChanged) - self._plot.sigPlotSignal.connect(self._defaultColormapChanged) + elif activeScatterLegend is not None: # Show active scatter colormap + self._activeScatterChanged(None, activeScatterLegend) + + plot.sigActiveImageChanged.connect(self._activeImageChanged) + plot.sigActiveScatterChanged.connect(self._activeScatterChanged) + plot.sigPlotSignal.connect(self._defaultColormapChanged) self._isConnected = True def setVisible(self, isVisible): @@ -196,36 +209,58 @@ class ColorBarWidget(qt.QWidget): """ return self.legend.getText() - def _activeImageChanged(self, previous, legend): - """Handle plot active curve changed""" - if legend is None: # No active image, display no colormap - self.setColormap(colormap=None) - return + def _activeScatterChanged(self, previous, legend): + """Handle plot active scatter changed""" + plot = self.getPlot() - # Sync with active image - image = self._plot.getActiveImage().getData(copy=False) + # Do not handle active scatter while there is an image + if plot.getActiveImage() is not None: + return - # RGB(A) image, display default colormap - if image.ndim != 2: + if legend is None: # No active scatter, display no colormap self.setColormap(colormap=None) return - # data image, sync with image colormap - # do we need the copy here : used in the case we are changing - # vmin and vmax but should have already be done by the plot - self.setColormap(colormap=self._plot.getActiveImage().getColormap(), - data=image) + # Sync with active scatter + activeScatter = plot._getActiveItem(kind='scatter') + + self.setColormap(colormap=activeScatter.getColormap(), + data=activeScatter.getValueData(copy=False)) + + def _activeImageChanged(self, previous, legend): + """Handle plot active image changed""" + plot = self.getPlot() + + if legend is None: # No active image, try with active scatter + activeScatterLegend = plot._getActiveItem( + kind='scatter', just_legend=True) + # No more active image, use active scatter if any + self._activeScatterChanged(None, activeScatterLegend) + else: + # Sync with active image + image = plot.getActiveImage().getData(copy=False) + + # RGB(A) image, display default colormap + if image.ndim != 2: + self.setColormap(colormap=None) + return + + # data image, sync with image colormap + # do we need the copy here : used in the case we are changing + # vmin and vmax but should have already be done by the plot + self.setColormap(colormap=plot.getActiveImage().getColormap(), + data=image) def _defaultColormapChanged(self, event): """Handle plot default colormap changed""" if (event['event'] == 'defaultColormapChanged' and - self._plot.getActiveImage() is None): + self.getPlot().getActiveImage() is None): # No active image, take default colormap update into account self._syncWithDefaultColormap() def _syncWithDefaultColormap(self, data=None): """Update colorbar according to plot default colormap""" - self.setColormap(self._plot.getDefaultColormap(), data) + self.setColormap(self.getPlot().getDefaultColormap(), data) def getColorScaleBar(self): """ @@ -316,9 +351,9 @@ class ColorScaleBar(qt.QWidget): if colormap: vmin, vmax = colormap.getColormapRange(data) else: - vmin, vmax = Colormap.DEFAULT_MIN_LIN, Colormap.DEFAULT_MAX_LIN + vmin, vmax = colors.DEFAULT_MIN_LIN, colors.DEFAULT_MAX_LIN - norm = colormap.getNormalization() if colormap else Colormap.Colormap.LINEAR + norm = colormap.getNormalization() if colormap else colors.Colormap.LINEAR self.tickbar = _TickBar(vmin=vmin, vmax=vmax, norm=norm, @@ -503,7 +538,7 @@ class _ColorScale(qt.QWidget): if colormap is None: self.vmin, self.vmax = None, None else: - assert colormap.getNormalization() in Colormap.Colormap.NORMALIZATIONS + assert colormap.getNormalization() in colors.Colormap.NORMALIZATIONS self.vmin, self.vmax = self._colormap.getColormapRange(data=data) self._updateColorGradient() self.update() @@ -575,9 +610,9 @@ class _ColorScale(qt.QWidget): vmin = self.vmin vmax = self.vmax - if colormap.getNormalization() == Colormap.Colormap.LINEAR: + if colormap.getNormalization() == colors.Colormap.LINEAR: return vmin + (vmax - vmin) * value - elif colormap.getNormalization() == Colormap.Colormap.LOGARITHM: + elif colormap.getNormalization() == colors.Colormap.LOGARITHM: rpos = (numpy.log10(vmax) - numpy.log10(vmin)) * value + numpy.log10(vmin) return numpy.power(10., rpos) else: @@ -706,9 +741,9 @@ class _TickBar(qt.QWidget): # No range: no ticks self.ticks = () self.subTicks = () - elif self._norm == Colormap.Colormap.LOGARITHM: + elif self._norm == colors.Colormap.LOGARITHM: self._computeTicksLog(nticks) - elif self._norm == Colormap.Colormap.LINEAR: + elif self._norm == colors.Colormap.LINEAR: self._computeTicksLin(nticks) else: err = 'TickBar - Wrong normalization %s' % self._norm @@ -765,9 +800,9 @@ class _TickBar(qt.QWidget): def _getRelativePosition(self, val): """Return the relative position of val according to min and max value """ - if self._norm == Colormap.Colormap.LINEAR: + if self._norm == colors.Colormap.LINEAR: return 1 - (val - self._vmin) / (self._vmax - self._vmin) - elif self._norm == Colormap.Colormap.LOGARITHM: + elif self._norm == colors.Colormap.LOGARITHM: return 1 - (numpy.log10(val) - numpy.log10(self._vmin)) / (numpy.log10(self._vmax) - numpy.log(self._vmin)) else: raise ValueError('Norm is not recognized') diff --git a/silx/gui/plot/Colormap.py b/silx/gui/plot/Colormap.py index 9adf0d4..e797d89 100644 --- a/silx/gui/plot/Colormap.py +++ b/silx/gui/plot/Colormap.py @@ -22,568 +22,23 @@ # THE SOFTWARE. # # ###########################################################################*/ -"""This module provides the Colormap object +"""Deprecated module providing the Colormap object """ from __future__ import absolute_import __authors__ = ["T. Vincent", "H.Payno"] __license__ = "MIT" -__date__ = "08/01/2018" +__date__ = "24/04/2018" -from silx.gui import qt -import copy as copy_mdl -import numpy -from .matplotlib import Colormap as MPLColormap -import logging -from silx.math.combo import min_max -from silx.utils.exceptions import NotEditableError +import silx.utils.deprecation -_logger = logging.getLogger(__file__) +silx.utils.deprecation.deprecated_warning("Module", + name="silx.gui.plot.Colormap", + reason="moved", + replacement="silx.gui.colors.Colormap", + since_version="0.8.0", + only_once=True, + skip_backtrace_count=1) -DEFAULT_COLORMAPS = ( - 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') -"""Tuple of supported colormap names.""" - -DEFAULT_MIN_LIN = 0 -"""Default min value if in linear normalization""" -DEFAULT_MAX_LIN = 1 -"""Default max value if in linear normalization""" -DEFAULT_MIN_LOG = 1 -"""Default min value if in log normalization""" -DEFAULT_MAX_LOG = 10 -"""Default max value if in log normalization""" - - -class Colormap(qt.QObject): - """Description of a colormap - - :param str name: Name of the colormap - :param tuple colors: optional, custom colormap. - Nx3 or Nx4 numpy array of RGB(A) colors, - either uint8 or float in [0, 1]. - If 'name' is None, then this array is used as the colormap. - :param str normalization: Normalization: 'linear' (default) or 'log' - :param float vmin: - Lower bound of the colormap or None for autoscale (default) - :param float vmax: - Upper bounds of the colormap or None for autoscale (default) - """ - - LINEAR = 'linear' - """constant for linear normalization""" - - LOGARITHM = 'log' - """constant for logarithmic normalization""" - - NORMALIZATIONS = (LINEAR, LOGARITHM) - """Tuple of managed normalizations""" - - sigChanged = qt.Signal() - """Signal emitted when the colormap has changed.""" - - def __init__(self, name='gray', colors=None, normalization=LINEAR, vmin=None, vmax=None): - qt.QObject.__init__(self) - assert normalization in Colormap.NORMALIZATIONS - assert not (name is None and colors is None) - if normalization is Colormap.LOGARITHM: - if (vmin is not None and vmin < 0) or (vmax is not None and vmax < 0): - m = "Unsuported vmin (%s) and/or vmax (%s) given for a log scale." - m += ' Autoscale will be performed.' - m = m % (vmin, vmax) - _logger.warning(m) - vmin = None - vmax = None - - self._name = str(name) if name is not None else None - self._setColors(colors) - self._normalization = str(normalization) - self._vmin = float(vmin) if vmin is not None else None - self._vmax = float(vmax) if vmax is not None else None - self._editable = True - - def isAutoscale(self): - """Return True if both min and max are in autoscale mode""" - return self._vmin is None and self._vmax is None - - def getName(self): - """Return the name of the colormap - :rtype: str - """ - return self._name - - def _setColors(self, colors): - if colors is None: - self._colors = None - else: - self._colors = numpy.array(colors, copy=True) - - def getNColors(self, nbColors=None): - """Returns N colors computed by sampling the colormap regularly. - - :param nbColors: - The number of colors in the returned array or None for the default value. - The default value is 256 for colormap with a name (see :meth:`setName`) and - it is the size of the LUT for colormap defined with :meth:`setColormapLUT`. - :type nbColors: int or None - :return: 2D array of uint8 of shape (nbColors, 4) - :rtype: numpy.ndarray - """ - # Handle default value for nbColors - if nbColors is None: - lut = self.getColormapLUT() - if lut is not None: # In this case uses LUT length - nbColors = len(lut) - else: # Default to 256 - nbColors = 256 - - nbColors = int(nbColors) - - colormap = self.copy() - colormap.setNormalization(Colormap.LINEAR) - colormap.setVRange(vmin=None, vmax=None) - colors = colormap.applyToData( - numpy.arange(nbColors, dtype=numpy.int)) - return colors - - def setName(self, name): - """Set the name of the colormap to use. - - :param str name: The name of the colormap. - At least the following names are supported: 'gray', - 'reversed gray', 'temperature', 'red', 'green', 'blue', 'jet', - 'viridis', 'magma', 'inferno', 'plasma'. - """ - if self.isEditable() is False: - raise NotEditableError('Colormap is not editable') - assert name in self.getSupportedColormaps() - self._name = str(name) - self._colors = None - self.sigChanged.emit() - - def getColormapLUT(self): - """Return the list of colors for the colormap or None if not set - - :return: the list of colors for the colormap or None if not set - :rtype: numpy.ndarray or None - """ - if self._colors is None: - return None - else: - return numpy.array(self._colors, copy=True) - - def setColormapLUT(self, colors): - """Set the colors of the colormap. - - :param numpy.ndarray colors: the colors of the LUT - - .. warning: this will set the value of name to None - """ - if self.isEditable() is False: - raise NotEditableError('Colormap is not editable') - self._setColors(colors) - if len(colors) is 0: - self._colors = None - - self._name = None - self.sigChanged.emit() - - def getNormalization(self): - """Return the normalization of the colormap ('log' or 'linear') - - :return: the normalization of the colormap - :rtype: str - """ - return self._normalization - - def setNormalization(self, norm): - """Set the norm ('log', 'linear') - - :param str norm: the norm to set - """ - if self.isEditable() is False: - raise NotEditableError('Colormap is not editable') - self._normalization = str(norm) - self.sigChanged.emit() - - def getVMin(self): - """Return the lower bound of the colormap - - :return: the lower bound of the colormap - :rtype: float or None - """ - return self._vmin - - def setVMin(self, vmin): - """Set the minimal value of the colormap - - :param float vmin: Lower bound of the colormap or None for autoscale - (default) - value) - """ - if self.isEditable() is False: - raise NotEditableError('Colormap is not editable') - if vmin is not None: - if self._vmax is not None and vmin > self._vmax: - err = "Can't set vmin because vmin >= vmax. " \ - "vmin = %s, vmax = %s" % (vmin, self._vmax) - raise ValueError(err) - - self._vmin = vmin - self.sigChanged.emit() - - def getVMax(self): - """Return the upper bounds of the colormap or None - - :return: the upper bounds of the colormap or None - :rtype: float or None - """ - return self._vmax - - def setVMax(self, vmax): - """Set the maximal value of the colormap - - :param float vmax: Upper bounds of the colormap or None for autoscale - (default) - """ - if self.isEditable() is False: - raise NotEditableError('Colormap is not editable') - if vmax is not None: - if self._vmin is not None and vmax < self._vmin: - err = "Can't set vmax because vmax <= vmin. " \ - "vmin = %s, vmax = %s" % (self._vmin, vmax) - raise ValueError(err) - - self._vmax = vmax - self.sigChanged.emit() - - def isEditable(self): - """ Return if the colormap is editable or not - - :return: editable state of the colormap - :rtype: bool - """ - return self._editable - - def setEditable(self, editable): - """ - Set the editable state of the colormap - - :param bool editable: is the colormap editable - """ - assert type(editable) is bool - self._editable = editable - self.sigChanged.emit() - - def getColormapRange(self, data=None): - """Return (vmin, vmax) - - :return: the tuple vmin, vmax fitting vmin, vmax, normalization and - data if any given - :rtype: tuple - """ - vmin = self._vmin - vmax = self._vmax - assert vmin is None or vmax is None or vmin <= vmax # TODO handle this in setters - - if self.getNormalization() == self.LOGARITHM: - # Handle negative bounds as autoscale - if vmin is not None and (vmin is not None and vmin <= 0.): - mess = 'negative vmin, moving to autoscale for lower bound' - _logger.warning(mess) - vmin = None - if vmax is not None and (vmax is not None and vmax <= 0.): - mess = 'negative vmax, moving to autoscale for upper bound' - _logger.warning(mess) - vmax = None - - if vmin is None or vmax is None: # Handle autoscale - # Get min/max from data - if data is not None: - data = numpy.array(data, copy=False) - if data.size == 0: # Fallback an array but no data - min_, max_ = self._getDefaultMin(), self._getDefaultMax() - else: - if self.getNormalization() == self.LOGARITHM: - result = min_max(data, min_positive=True, finite=True) - min_ = result.min_positive # >0 or None - max_ = result.maximum # can be <= 0 - else: - min_, max_ = min_max(data, min_positive=False, finite=True) - - # Handle fallback - if min_ is None or not numpy.isfinite(min_): - min_ = self._getDefaultMin() - if max_ is None or not numpy.isfinite(max_): - max_ = self._getDefaultMax() - else: # Fallback if no data is provided - min_, max_ = self._getDefaultMin(), self._getDefaultMax() - - if vmin is None: # Set vmin respecting provided vmax - vmin = min_ if vmax is None else min(min_, vmax) - - if vmax is None: - vmax = max(max_, vmin) # Handle max_ <= 0 for log scale - - return vmin, vmax - - def setVRange(self, vmin, vmax): - """Set the bounds of the colormap - - :param vmin: Lower bound of the colormap or None for autoscale - (default) - :param vmax: Upper bounds of the colormap or None for autoscale - (default) - """ - if self.isEditable() is False: - raise NotEditableError('Colormap is not editable') - if vmin is not None and vmax is not None: - if vmin > vmax: - err = "Can't set vmin and vmax because vmin >= vmax " \ - "vmin = %s, vmax = %s" % (vmin, vmax) - raise ValueError(err) - - if self._vmin == vmin and self._vmax == vmax: - return - - self._vmin = vmin - self._vmax = vmax - self.sigChanged.emit() - - def __getitem__(self, item): - if item == 'autoscale': - return self.isAutoscale() - elif item == 'name': - return self.getName() - elif item == 'normalization': - return self.getNormalization() - elif item == 'vmin': - return self.getVMin() - elif item == 'vmax': - return self.getVMax() - elif item == 'colors': - return self.getColormapLUT() - else: - raise KeyError(item) - - def _toDict(self): - """Return the equivalent colormap as a dictionary - (old colormap representation) - - :return: the representation of the Colormap as a dictionary - :rtype: dict - """ - return { - 'name': self._name, - 'colors': copy_mdl.copy(self._colors), - 'vmin': self._vmin, - 'vmax': self._vmax, - 'autoscale': self.isAutoscale(), - 'normalization': self._normalization - } - - def _setFromDict(self, dic): - """Set values to the colormap from a dictionary - - :param dict dic: the colormap as a dictionary - """ - if self.isEditable() is False: - raise NotEditableError('Colormap is not editable') - name = dic['name'] if 'name' in dic else None - colors = dic['colors'] if 'colors' in dic else None - vmin = dic['vmin'] if 'vmin' in dic else None - vmax = dic['vmax'] if 'vmax' in dic else None - if 'normalization' in dic: - normalization = dic['normalization'] - else: - warn = 'Normalization not given in the dictionary, ' - warn += 'set by default to ' + Colormap.LINEAR - _logger.warning(warn) - normalization = Colormap.LINEAR - - if name is None and colors is None: - err = 'The colormap should have a name defined or a tuple of colors' - raise ValueError(err) - if normalization not in Colormap.NORMALIZATIONS: - err = 'Given normalization is not recoginized (%s)' % normalization - raise ValueError(err) - - # If autoscale, then set boundaries to None - if dic.get('autoscale', False): - vmin, vmax = None, None - - self._name = name - self._colors = colors - self._vmin = vmin - self._vmax = vmax - self._autoscale = True if (vmin is None and vmax is None) else False - self._normalization = normalization - - self.sigChanged.emit() - - @staticmethod - def _fromDict(dic): - colormap = Colormap(name="") - colormap._setFromDict(dic) - return colormap - - def copy(self): - """Return a copy of the Colormap. - - :rtype: silx.gui.plot.Colormap.Colormap - """ - return Colormap(name=self._name, - colors=copy_mdl.copy(self._colors), - vmin=self._vmin, - vmax=self._vmax, - normalization=self._normalization) - - def applyToData(self, data): - """Apply the colormap to the data - - :param numpy.ndarray data: The data to convert. - """ - rgbaImage = MPLColormap.applyColormapToData(colormap=self, data=data) - return rgbaImage - - @staticmethod - def getSupportedColormaps(): - """Get the supported colormap names as a tuple of str. - - The list should at least contain and start by: - ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') - :rtype: tuple - """ - maps = MPLColormap.getSupportedColormaps() - return DEFAULT_COLORMAPS + maps - - def __str__(self): - return str(self._toDict()) - - def _getDefaultMin(self): - return DEFAULT_MIN_LIN if self._normalization == Colormap.LINEAR else DEFAULT_MIN_LOG - - def _getDefaultMax(self): - return DEFAULT_MAX_LIN if self._normalization == Colormap.LINEAR else DEFAULT_MAX_LOG - - def __eq__(self, other): - """Compare colormap values and not pointers""" - return (self.getName() == other.getName() and - self.getNormalization() == other.getNormalization() and - self.getVMin() == other.getVMin() and - self.getVMax() == other.getVMax() and - numpy.array_equal(self.getColormapLUT(), other.getColormapLUT()) - ) - - _SERIAL_VERSION = 1 - - def restoreState(self, byteArray): - """ - Read the colormap state from a QByteArray. - - :param qt.QByteArray byteArray: Stream containing the state - :return: True if the restoration sussseed - :rtype: bool - """ - if self.isEditable() is False: - raise NotEditableError('Colormap is not editable') - stream = qt.QDataStream(byteArray, qt.QIODevice.ReadOnly) - - className = stream.readQString() - if className != self.__class__.__name__: - _logger.warning("Classname mismatch. Found %s." % className) - return False - - version = stream.readUInt32() - if version != self._SERIAL_VERSION: - _logger.warning("Serial version mismatch. Found %d." % version) - return False - - name = stream.readQString() - isNull = stream.readBool() - if not isNull: - vmin = stream.readQVariant() - else: - vmin = None - isNull = stream.readBool() - if not isNull: - vmax = stream.readQVariant() - else: - vmax = None - normalization = stream.readQString() - - # emit change event only once - old = self.blockSignals(True) - try: - self.setName(name) - self.setNormalization(normalization) - self.setVRange(vmin, vmax) - finally: - self.blockSignals(old) - self.sigChanged.emit() - return True - - def saveState(self): - """ - Save state of the colomap into a QDataStream. - - :rtype: qt.QByteArray - """ - data = qt.QByteArray() - stream = qt.QDataStream(data, qt.QIODevice.WriteOnly) - - stream.writeQString(self.__class__.__name__) - stream.writeUInt32(self._SERIAL_VERSION) - stream.writeQString(self.getName()) - stream.writeBool(self.getVMin() is None) - if self.getVMin() is not None: - stream.writeQVariant(self.getVMin()) - stream.writeBool(self.getVMax() is None) - if self.getVMax() is not None: - stream.writeQVariant(self.getVMax()) - stream.writeQString(self.getNormalization()) - return data - - -_PREFERRED_COLORMAPS = DEFAULT_COLORMAPS -""" -Tuple of preferred colormap names accessed with :meth:`preferredColormaps`. -""" - - -def preferredColormaps(): - """Returns the name of the preferred colormaps. - - This list is used by widgets allowing to change the colormap - like the :class:`ColormapDialog` as a subset of colormap choices. - - :rtype: tuple of str - """ - return _PREFERRED_COLORMAPS - - -def setPreferredColormaps(colormaps): - """Set the list of preferred colormap names. - - Warning: If a colormap name is not available - it will be removed from the list. - - :param colormaps: Not empty list of colormap names - :type colormaps: iterable of str - :raise ValueError: if the list of available preferred colormaps is empty. - """ - supportedColormaps = Colormap.getSupportedColormaps() - colormaps = tuple( - cmap for cmap in colormaps if cmap in supportedColormaps) - if len(colormaps) == 0: - raise ValueError("Cannot set preferred colormaps to an empty list") - - global _PREFERRED_COLORMAPS - _PREFERRED_COLORMAPS = colormaps - - -# Initialize preferred colormaps -setPreferredColormaps(('gray', 'reversed gray', - 'temperature', 'red', 'green', 'blue', 'jet', - 'viridis', 'magma', 'inferno', 'plasma', - 'hsv')) +from ..colors import * # noqa diff --git a/silx/gui/plot/ColormapDialog.py b/silx/gui/plot/ColormapDialog.py index 4aefab6..7c66cb8 100644 --- a/silx/gui/plot/ColormapDialog.py +++ b/silx/gui/plot/ColormapDialog.py @@ -22,960 +22,22 @@ # THE SOFTWARE. # # ###########################################################################*/ -"""A QDialog widget to set-up the colormap. +"""Deprecated module providing ColormapDialog.""" -It uses a description of colormaps as dict compatible with :class:`Plot`. +from __future__ import absolute_import -To run the following sample code, a QApplication must be initialized. - -Create the colormap dialog and set the colormap description and data range: - ->>> from silx.gui.plot.ColormapDialog import ColormapDialog ->>> from silx.gui.plot.Colormap import Colormap - ->>> dialog = ColormapDialog() ->>> colormap = Colormap(name='red', normalization='log', -... vmin=1., vmax=2.) - ->>> dialog.setColormap(colormap) ->>> colormap.setVRange(1., 100.) # This scale the width of the plot area ->>> dialog.show() - -Get the colormap description (compatible with :class:`Plot`) from the dialog: - ->>> cmap = dialog.getColormap() ->>> cmap.getName() -'red' - -It is also possible to display an histogram of the image in the dialog. -This updates the data range with the range of the bins. - ->>> import numpy ->>> image = numpy.random.normal(size=512 * 512).reshape(512, -1) ->>> hist, bin_edges = numpy.histogram(image, bins=10) ->>> dialog.setHistogram(hist, bin_edges) - -The updates of the colormap description are also available through the signal: -:attr:`ColormapDialog.sigColormapChanged`. -""" # noqa - -from __future__ import division - -__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"] +__authors__ = ["T. Vincent", "H.Payno"] __license__ = "MIT" -__date__ = "09/02/2018" - - -import logging - -import numpy - -from .. import qt -from .Colormap import Colormap, preferredColormaps -from . import PlotWidget -from silx.gui.widgets.FloatEdit import FloatEdit -import weakref -from silx.math.combo import min_max -from silx.third_party import enum -from silx.gui import icons -from silx.math.histogram import Histogramnd - -_logger = logging.getLogger(__name__) - - -_colormapIconPreview = {} - - -class _BoundaryWidget(qt.QWidget): - """Widget to edit a boundary of the colormap (vmin, vmax)""" - sigValueChanged = qt.Signal(object) - """Signal emitted when value is changed""" - - def __init__(self, parent=None, value=0.0): - qt.QWidget.__init__(self, parent=None) - self.setLayout(qt.QHBoxLayout()) - self.layout().setContentsMargins(0, 0, 0, 0) - self._numVal = FloatEdit(parent=self, value=value) - self.layout().addWidget(self._numVal) - self._autoCB = qt.QCheckBox('auto', parent=self) - self.layout().addWidget(self._autoCB) - self._autoCB.setChecked(False) - - self._autoCB.toggled.connect(self._autoToggled) - self.sigValueChanged = self._autoCB.toggled - self.textEdited = self._numVal.textEdited - self.editingFinished = self._numVal.editingFinished - self._dataValue = None - - def isAutoChecked(self): - return self._autoCB.isChecked() - - def getValue(self): - return None if self._autoCB.isChecked() else self._numVal.value() - - def getFiniteValue(self): - if not self._autoCB.isChecked(): - return self._numVal.value() - elif self._dataValue is None: - return self._numVal.value() - else: - return self._dataValue - - def _autoToggled(self, enabled): - self._numVal.setEnabled(not enabled) - self._updateDisplayedText() - - def _updateDisplayedText(self): - # if dataValue is finite - if self._autoCB.isChecked() and self._dataValue is not None: - old = self._numVal.blockSignals(True) - self._numVal.setValue(self._dataValue) - self._numVal.blockSignals(old) - - def setDataValue(self, dataValue): - self._dataValue = dataValue - self._updateDisplayedText() - - def setFiniteValue(self, value): - assert(value is not None) - old = self._numVal.blockSignals(True) - self._numVal.setValue(value) - self._numVal.blockSignals(old) - - def setValue(self, value, isAuto=False): - self._autoCB.setChecked(isAuto or value is None) - if value is not None: - self._numVal.setValue(value) - self._updateDisplayedText() - - -class _ColormapNameCombox(qt.QComboBox): - def __init__(self, parent=None): - qt.QComboBox.__init__(self, parent) - self.__initItems() - - ORIGINAL_NAME = qt.Qt.UserRole + 1 - - def __initItems(self): - for colormapName in preferredColormaps(): - index = self.count() - self.addItem(str.title(colormapName)) - self.setItemIcon(index, self.getIconPreview(colormapName)) - self.setItemData(index, colormapName, role=self.ORIGINAL_NAME) - - def getIconPreview(self, colormapName): - """Return an icon preview from a LUT name. - - This icons are cached into a global structure. - - :param str colormapName: str - :rtype: qt.QIcon - """ - if colormapName not in _colormapIconPreview: - icon = self.createIconPreview(colormapName) - _colormapIconPreview[colormapName] = icon - return _colormapIconPreview[colormapName] - - def createIconPreview(self, colormapName): - """Create and return an icon preview from a LUT name. - - This icons are cached into a global structure. - - :param str colormapName: Name of the LUT - :rtype: qt.QIcon - """ - colormap = Colormap(colormapName) - size = 32 - lut = colormap.getNColors(size) - if lut is None or len(lut) == 0: - return qt.QIcon() - - pixmap = qt.QPixmap(size, size) - painter = qt.QPainter(pixmap) - for i in range(size): - rgb = lut[i] - r, g, b = rgb[0], rgb[1], rgb[2] - painter.setPen(qt.QColor(r, g, b)) - painter.drawPoint(qt.QPoint(i, 0)) - - painter.drawPixmap(0, 1, size, size - 1, pixmap, 0, 0, size, 1) - painter.end() - - return qt.QIcon(pixmap) - - def getCurrentName(self): - return self.itemData(self.currentIndex(), self.ORIGINAL_NAME) - - def findColormap(self, name): - return self.findData(name, role=self.ORIGINAL_NAME) - - def setCurrentName(self, name): - index = self.findColormap(name) - if index < 0: - index = self.count() - self.addItem(str.title(name)) - self.setItemIcon(index, self.getIconPreview(name)) - self.setItemData(index, name, role=self.ORIGINAL_NAME) - self.setCurrentIndex(index) - - -@enum.unique -class _DataInPlotMode(enum.Enum): - """Enum for each mode of display of the data in the plot.""" - NONE = 'none' - RANGE = 'range' - HISTOGRAM = 'histogram' - - -class ColormapDialog(qt.QDialog): - """A QDialog widget to set the colormap. - - :param parent: See :class:`QDialog` - :param str title: The QDialog title - """ - - visibleChanged = qt.Signal(bool) - """This event is sent when the dialog visibility change""" - - def __init__(self, parent=None, title="Colormap Dialog"): - qt.QDialog.__init__(self, parent) - self.setWindowTitle(title) - - self._colormap = None - self._data = None - self._dataInPlotMode = _DataInPlotMode.RANGE - - self._ignoreColormapChange = False - """Used as a semaphore to avoid editing the colormap object when we are - only attempt to display it. - Used instead of n connect and disconnect of the sigChanged. The - disconnection to sigChanged was also limiting when this colormapdialog - is used in the colormapaction and associated to the activeImageChanged. - (because the activeImageChanged is send when the colormap changed and - the self.setcolormap is a callback) - """ - - self._histogramData = None - self._minMaxWasEdited = False - self._initialRange = None - - self._dataRange = None - """If defined 3-tuple containing information from a data: - minimum, positive minimum, maximum""" - - self._colormapStoredState = None - - # Make the GUI - vLayout = qt.QVBoxLayout(self) - - formWidget = qt.QWidget(parent=self) - vLayout.addWidget(formWidget) - formLayout = qt.QFormLayout(formWidget) - formLayout.setContentsMargins(10, 10, 10, 10) - formLayout.setSpacing(0) - - # Colormap row - self._comboBoxColormap = _ColormapNameCombox(parent=formWidget) - self._comboBoxColormap.currentIndexChanged[int].connect(self._updateName) - formLayout.addRow('Colormap:', self._comboBoxColormap) - - # Normalization row - self._normButtonLinear = qt.QRadioButton('Linear') - self._normButtonLinear.setChecked(True) - self._normButtonLog = qt.QRadioButton('Log') - self._normButtonLog.toggled.connect(self._activeLogNorm) - - normButtonGroup = qt.QButtonGroup(self) - normButtonGroup.setExclusive(True) - normButtonGroup.addButton(self._normButtonLinear) - normButtonGroup.addButton(self._normButtonLog) - self._normButtonLinear.toggled[bool].connect(self._updateLinearNorm) - - normLayout = qt.QHBoxLayout() - normLayout.setContentsMargins(0, 0, 0, 0) - normLayout.setSpacing(10) - normLayout.addWidget(self._normButtonLinear) - normLayout.addWidget(self._normButtonLog) - - formLayout.addRow('Normalization:', normLayout) - - # Min row - self._minValue = _BoundaryWidget(parent=self, value=1.0) - self._minValue.textEdited.connect(self._minMaxTextEdited) - self._minValue.editingFinished.connect(self._minEditingFinished) - self._minValue.sigValueChanged.connect(self._updateMinMax) - formLayout.addRow('\tMin:', self._minValue) - - # Max row - self._maxValue = _BoundaryWidget(parent=self, value=10.0) - self._maxValue.textEdited.connect(self._minMaxTextEdited) - self._maxValue.sigValueChanged.connect(self._updateMinMax) - self._maxValue.editingFinished.connect(self._maxEditingFinished) - formLayout.addRow('\tMax:', self._maxValue) - - # Add plot for histogram - self._plotToolbar = qt.QToolBar(self) - self._plotToolbar.setFloatable(False) - self._plotToolbar.setMovable(False) - self._plotToolbar.setIconSize(qt.QSize(8, 8)) - self._plotToolbar.setStyleSheet("QToolBar { border: 0px }") - self._plotToolbar.setOrientation(qt.Qt.Vertical) - - group = qt.QActionGroup(self._plotToolbar) - group.setExclusive(True) - - action = qt.QAction("Nothing", self) - action.setToolTip("No range nor histogram are displayed. No extra computation have to be done.") - action.setIcon(icons.getQIcon('colormap-none')) - action.setCheckable(True) - action.setData(_DataInPlotMode.NONE) - action.setChecked(action.data() == self._dataInPlotMode) - self._plotToolbar.addAction(action) - group.addAction(action) - action = qt.QAction("Data range", self) - action.setToolTip("Display the data range within the colormap range. A fast data processing have to be done.") - action.setIcon(icons.getQIcon('colormap-range')) - action.setCheckable(True) - action.setData(_DataInPlotMode.RANGE) - action.setChecked(action.data() == self._dataInPlotMode) - self._plotToolbar.addAction(action) - group.addAction(action) - action = qt.QAction("Histogram", self) - action.setToolTip("Display the data histogram within the colormap range. A slow data processing have to be done. ") - action.setIcon(icons.getQIcon('colormap-histogram')) - action.setCheckable(True) - action.setData(_DataInPlotMode.HISTOGRAM) - action.setChecked(action.data() == self._dataInPlotMode) - self._plotToolbar.addAction(action) - group.addAction(action) - group.triggered.connect(self._displayDataInPlotModeChanged) - - self._plotBox = qt.QWidget(self) - self._plotInit() - - plotBoxLayout = qt.QHBoxLayout() - plotBoxLayout.setContentsMargins(0, 0, 0, 0) - plotBoxLayout.setSpacing(2) - plotBoxLayout.addWidget(self._plotToolbar) - plotBoxLayout.addWidget(self._plot) - plotBoxLayout.setSizeConstraint(qt.QLayout.SetMinimumSize) - self._plotBox.setLayout(plotBoxLayout) - vLayout.addWidget(self._plotBox) - - # define modal buttons - types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel - self._buttonsModal = qt.QDialogButtonBox(parent=self) - self._buttonsModal.setStandardButtons(types) - self.layout().addWidget(self._buttonsModal) - self._buttonsModal.accepted.connect(self.accept) - self._buttonsModal.rejected.connect(self.reject) - - # define non modal buttons - types = qt.QDialogButtonBox.Close | qt.QDialogButtonBox.Reset - self._buttonsNonModal = qt.QDialogButtonBox(parent=self) - self._buttonsNonModal.setStandardButtons(types) - self.layout().addWidget(self._buttonsNonModal) - self._buttonsNonModal.button(qt.QDialogButtonBox.Close).clicked.connect(self.accept) - self._buttonsNonModal.button(qt.QDialogButtonBox.Reset).clicked.connect(self.resetColormap) - - # Set the colormap to default values - self.setColormap(Colormap(name='gray', normalization='linear', - vmin=None, vmax=None)) - - self.setModal(self.isModal()) - - vLayout.setSizeConstraint(qt.QLayout.SetMinimumSize) - self.setFixedSize(self.sizeHint()) - self._applyColormap() - - def showEvent(self, event): - self.visibleChanged.emit(True) - super(ColormapDialog, self).showEvent(event) - - def closeEvent(self, event): - if not self.isModal(): - self.accept() - super(ColormapDialog, self).closeEvent(event) - - def hideEvent(self, event): - self.visibleChanged.emit(False) - super(ColormapDialog, self).hideEvent(event) - - def close(self): - self.accept() - qt.QDialog.close(self) - - def setModal(self, modal): - assert type(modal) is bool - self._buttonsNonModal.setVisible(not modal) - self._buttonsModal.setVisible(modal) - qt.QDialog.setModal(self, modal) - - def exec_(self): - wasModal = self.isModal() - self.setModal(True) - result = super(ColormapDialog, self).exec_() - self.setModal(wasModal) - return result - - def _plotInit(self): - """Init the plot to display the range and the values""" - self._plot = PlotWidget() - self._plot.setDataMargins(yMinMargin=0.125, yMaxMargin=0.125) - self._plot.getXAxis().setLabel("Data Values") - self._plot.getYAxis().setLabel("") - self._plot.setInteractiveMode('select', zoomOnWheel=False) - self._plot.setActiveCurveHandling(False) - self._plot.setMinimumSize(qt.QSize(250, 200)) - self._plot.sigPlotSignal.connect(self._plotSlot) - - self._plotUpdate() - - def sizeHint(self): - return self.layout().minimumSize() - - def _plotUpdate(self, updateMarkers=True): - """Update the plot content - - :param bool updateMarkers: True to update markers, False otherwith - """ - colormap = self.getColormap() - if colormap is None: - if self._plotBox.isVisibleTo(self): - self._plotBox.setVisible(False) - self.setFixedSize(self.sizeHint()) - return - - if not self._plotBox.isVisibleTo(self): - self._plotBox.setVisible(True) - self.setFixedSize(self.sizeHint()) - - minData, maxData = self._minValue.getFiniteValue(), self._maxValue.getFiniteValue() - if minData > maxData: - # avoid a full collapse - minData, maxData = maxData, minData - minimum = minData - maximum = maxData - - if self._dataRange is not None: - minRange = self._dataRange[0] - maxRange = self._dataRange[2] - minimum = min(minimum, minRange) - maximum = max(maximum, maxRange) - - if self._histogramData is not None: - minHisto = self._histogramData[1][0] - maxHisto = self._histogramData[1][-1] - minimum = min(minimum, minHisto) - maximum = max(maximum, maxHisto) - - marge = abs(maximum - minimum) / 6.0 - if marge < 0.0001: - # Smaller that the QLineEdit precision - marge = 0.0001 - - minView, maxView = minimum - marge, maximum + marge - - if updateMarkers: - # Save the state in we are not moving the markers - self._initialRange = minView, maxView - elif self._initialRange is not None: - minView = min(minView, self._initialRange[0]) - maxView = max(maxView, self._initialRange[1]) - - x = [minView, minData, maxData, maxView] - y = [0, 0, 1, 1] - - self._plot.addCurve(x, y, - legend="ConstrainedCurve", - color='black', - symbol='o', - linestyle='-', - resetzoom=False) - - if updateMarkers: - minDraggable = (self._colormap().isEditable() and - not self._minValue.isAutoChecked()) - self._plot.addXMarker( - self._minValue.getFiniteValue(), - legend='Min', - text='Min', - draggable=minDraggable, - color='blue', - constraint=self._plotMinMarkerConstraint) - - maxDraggable = (self._colormap().isEditable() and - not self._maxValue.isAutoChecked()) - self._plot.addXMarker( - self._maxValue.getFiniteValue(), - legend='Max', - text='Max', - draggable=maxDraggable, - color='blue', - constraint=self._plotMaxMarkerConstraint) - - self._plot.resetZoom() - - def _plotMinMarkerConstraint(self, x, y): - """Constraint of the min marker""" - return min(x, self._maxValue.getFiniteValue()), y - - def _plotMaxMarkerConstraint(self, x, y): - """Constraint of the max marker""" - return max(x, self._minValue.getFiniteValue()), y - - def _plotSlot(self, event): - """Handle events from the plot""" - if event['event'] in ('markerMoving', 'markerMoved'): - value = float(str(event['xdata'])) - if event['label'] == 'Min': - self._minValue.setValue(value) - elif event['label'] == 'Max': - self._maxValue.setValue(value) - - # This will recreate the markers while interacting... - # It might break if marker interaction is changed - if event['event'] == 'markerMoved': - self._initialRange = None - self._updateMinMax() - else: - self._plotUpdate(updateMarkers=False) - - @staticmethod - def computeDataRange(data): - """Compute the data range as used by :meth:`setDataRange`. - - :param data: The data to process - :rtype: Tuple(float, float, float) - """ - if data is None or len(data) == 0: - return None, None, None - - dataRange = min_max(data, min_positive=True, finite=True) - if dataRange.minimum is None: - # Only non-finite data - dataRange = None - - if dataRange is not None: - min_positive = dataRange.min_positive - if min_positive is None: - min_positive = float('nan') - dataRange = dataRange.minimum, min_positive, dataRange.maximum - - if dataRange is None or len(dataRange) != 3: - qt.QMessageBox.warning( - None, "No Data", - "Image data does not contain any real value") - dataRange = 1., 1., 10. - - return dataRange - - @staticmethod - def computeHistogram(data): - """Compute the data histogram as used by :meth:`setHistogram`. - - :param data: The data to process - :rtype: Tuple(List(float),List(float) - """ - _data = data - if _data.ndim == 3: # RGB(A) images - _logger.info('Converting current image from RGB(A) to grayscale\ - in order to compute the intensity distribution') - _data = (_data[:, :, 0] * 0.299 + - _data[:, :, 1] * 0.587 + - _data[:, :, 2] * 0.114) - - if len(_data) == 0: - return None, None - - xmin, xmax = min_max(_data, min_positive=False, finite=True) - nbins = min(256, int(numpy.sqrt(_data.size))) - data_range = xmin, xmax - - # bad hack: get 256 bins in the case we have a B&W - if numpy.issubdtype(_data.dtype, numpy.integer): - if nbins > xmax - xmin: - nbins = xmax - xmin - - nbins = max(2, nbins) - _data = _data.ravel().astype(numpy.float32) - - histogram = Histogramnd(_data, n_bins=nbins, histo_range=data_range) - return histogram.histo, histogram.edges[0] - - def _getData(self): - if self._data is None: - return None - return self._data() - - def setData(self, data): - """Store the data as a weakref. - - According to the state of the dialog, the data will be used to display - the data range or the histogram of the data using :meth:`setDataRange` - and :meth:`setHistogram` - """ - oldData = self._getData() - if oldData is data: - return - - if data is None: - self.setDataRange() - self.setHistogram() - self._data = None - return - - self._data = weakref.ref(data, self._dataAboutToFinalize) - - self._updateDataInPlot() - - def _setDataInPlotMode(self, mode): - if self._dataInPlotMode == mode: - return - self._dataInPlotMode = mode - self._updateDataInPlot() - - def _displayDataInPlotModeChanged(self, action): - mode = action.data() - self._setDataInPlotMode(mode) - - def _updateDataInPlot(self): - data = self._getData() - if data is None: - return - - mode = self._dataInPlotMode - - if mode == _DataInPlotMode.NONE: - self.setHistogram() - self.setDataRange() - elif mode == _DataInPlotMode.RANGE: - result = self.computeDataRange(data) - self.setHistogram() - self.setDataRange(*result) - elif mode == _DataInPlotMode.HISTOGRAM: - # The histogram should be done in a worker thread - result = self.computeHistogram(data) - self.setHistogram(*result) - self.setDataRange() - - def _colormapAboutToFinalize(self, weakrefColormap): - """Callback when the data weakref is about to be finalized.""" - if self._colormap is weakrefColormap: - self.setColormap(None) - - def _dataAboutToFinalize(self, weakrefData): - """Callback when the data weakref is about to be finalized.""" - if self._data is weakrefData: - self.setData(None) - - def getHistogram(self): - """Returns the counts and bin edges of the displayed histogram. - - :return: (hist, bin_edges) - :rtype: 2-tuple of numpy arrays""" - if self._histogramData is None: - return None - else: - bins, counts = self._histogramData - return numpy.array(bins, copy=True), numpy.array(counts, copy=True) - - def setHistogram(self, hist=None, bin_edges=None): - """Set the histogram to display. - - This update the data range with the bounds of the bins. - - :param hist: array-like of counts or None to hide histogram - :param bin_edges: array-like of bins edges or None to hide histogram - """ - if hist is None or bin_edges is None: - self._histogramData = None - self._plot.remove(legend='Histogram', kind='histogram') - else: - hist = numpy.array(hist, copy=True) - bin_edges = numpy.array(bin_edges, copy=True) - self._histogramData = hist, bin_edges - norm_hist = hist / max(hist) - self._plot.addHistogram(norm_hist, - bin_edges, - legend="Histogram", - color='gray', - align='center', - fill=True) - self._updateMinMaxData() - - def getColormap(self): - """Return the colormap description as a :class:`.Colormap`. - - """ - if self._colormap is None: - return None - return self._colormap() - - def resetColormap(self): - """ - Reset the colormap state before modification. - - ..note :: the colormap reference state is the state when set or the - state when validated - """ - colormap = self.getColormap() - if colormap is not None and self._colormapStoredState is not None: - if self._colormap()._toDict() != self._colormapStoredState: - self._ignoreColormapChange = True - colormap._setFromDict(self._colormapStoredState) - self._ignoreColormapChange = False - self._applyColormap() - - def setDataRange(self, minimum=None, positiveMin=None, maximum=None): - """Set the range of data to use for the range of the histogram area. - - :param float minimum: The minimum of the data - :param float positiveMin: The positive minimum of the data - :param float maximum: The maximum of the data - """ - if minimum is None or positiveMin is None or maximum is None: - self._dataRange = None - self._plot.remove(legend='Range', kind='histogram') - else: - hist = numpy.array([1]) - bin_edges = numpy.array([minimum, maximum]) - self._plot.addHistogram(hist, - bin_edges, - legend="Range", - color='gray', - align='center', - fill=True) - self._dataRange = minimum, positiveMin, maximum - self._updateMinMaxData() - - def _updateMinMaxData(self): - """Update the min and max of the data according to the data range and - the histogram preset.""" - colormap = self.getColormap() - - minimum = float("+inf") - maximum = float("-inf") - - if colormap is not None and colormap.getNormalization() == colormap.LOGARITHM: - # find a range in the positive part of the data - if self._dataRange is not None: - minimum = min(minimum, self._dataRange[1]) - maximum = max(maximum, self._dataRange[2]) - if self._histogramData is not None: - positives = list(filter(lambda x: x > 0, self._histogramData[1])) - if len(positives) > 0: - minimum = min(minimum, positives[0]) - maximum = max(maximum, positives[-1]) - else: - if self._dataRange is not None: - minimum = min(minimum, self._dataRange[0]) - maximum = max(maximum, self._dataRange[2]) - if self._histogramData is not None: - minimum = min(minimum, self._histogramData[1][0]) - maximum = max(maximum, self._histogramData[1][-1]) - - if not numpy.isfinite(minimum): - minimum = None - if not numpy.isfinite(maximum): - maximum = None - - self._minValue.setDataValue(minimum) - self._maxValue.setDataValue(maximum) - self._plotUpdate() - - def accept(self): - self.storeCurrentState() - qt.QDialog.accept(self) - - def storeCurrentState(self): - """ - save the current value sof the colormap if the user want to undo is - modifications - """ - colormap = self.getColormap() - if colormap is not None: - self._colormapStoredState = colormap._toDict() - else: - self._colormapStoredState = None - - def reject(self): - self.resetColormap() - qt.QDialog.reject(self) - - def setColormap(self, colormap): - """Set the colormap description - - :param :class:`Colormap` colormap: the colormap to edit - """ - assert colormap is None or isinstance(colormap, Colormap) - if self._ignoreColormapChange is True: - return - - oldColormap = self.getColormap() - if oldColormap is colormap: - return - if oldColormap is not None: - oldColormap.sigChanged.disconnect(self._applyColormap) - - if colormap is not None: - colormap.sigChanged.connect(self._applyColormap) - colormap = weakref.ref(colormap, self._colormapAboutToFinalize) - - self._colormap = colormap - self.storeCurrentState() - self._updateResetButton() - self._applyColormap() - - def _updateResetButton(self): - resetButton = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset) - rStateEnabled = False - colormap = self.getColormap() - if colormap is not None and colormap.isEditable(): - # can reset only in the case the colormap changed - rStateEnabled = colormap._toDict() != self._colormapStoredState - resetButton.setEnabled(rStateEnabled) - - def _applyColormap(self): - self._updateResetButton() - if self._ignoreColormapChange is True: - return - - colormap = self.getColormap() - if colormap is None: - self._comboBoxColormap.setEnabled(False) - self._normButtonLinear.setEnabled(False) - self._normButtonLog.setEnabled(False) - self._minValue.setEnabled(False) - self._maxValue.setEnabled(False) - else: - self._ignoreColormapChange = True - - if colormap.getName() is not None: - name = colormap.getName() - self._comboBoxColormap.setCurrentName(name) - self._comboBoxColormap.setEnabled(self._colormap().isEditable()) - - assert colormap.getNormalization() in Colormap.NORMALIZATIONS - self._normButtonLinear.setChecked( - colormap.getNormalization() == Colormap.LINEAR) - self._normButtonLog.setChecked( - colormap.getNormalization() == Colormap.LOGARITHM) - vmin = colormap.getVMin() - vmax = colormap.getVMax() - dataRange = colormap.getColormapRange() - self._normButtonLinear.setEnabled(self._colormap().isEditable()) - self._normButtonLog.setEnabled(self._colormap().isEditable()) - self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None) - self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None) - self._minValue.setEnabled(self._colormap().isEditable()) - self._maxValue.setEnabled(self._colormap().isEditable()) - self._ignoreColormapChange = False - - self._plotUpdate() - - def _updateMinMax(self): - if self._ignoreColormapChange is True: - return - - vmin = self._minValue.getFiniteValue() - vmax = self._maxValue.getFiniteValue() - if vmax is not None and vmin is not None and vmax < vmin: - # If only one autoscale is checked constraints are too strong - # We have to edit a user value anyway it is not requested - # TODO: It would be better IMO to disable the auto checkbox before - # this case occur (valls) - cmin = self._minValue.isAutoChecked() - cmax = self._maxValue.isAutoChecked() - if cmin is False: - self._minValue.setFiniteValue(vmax) - if cmax is False: - self._maxValue.setFiniteValue(vmin) - - vmin = self._minValue.getValue() - vmax = self._maxValue.getValue() - self._ignoreColormapChange = True - colormap = self._colormap() - if colormap is not None: - colormap.setVRange(vmin, vmax) - self._ignoreColormapChange = False - self._plotUpdate() - self._updateResetButton() - - def _updateName(self): - if self._ignoreColormapChange is True: - return - - if self._colormap(): - self._ignoreColormapChange = True - self._colormap().setName( - self._comboBoxColormap.getCurrentName()) - self._ignoreColormapChange = False - - def _updateLinearNorm(self, isNormLinear): - if self._ignoreColormapChange is True: - return - - if self._colormap(): - self._ignoreColormapChange = True - norm = Colormap.LINEAR if isNormLinear else Colormap.LOGARITHM - self._colormap().setNormalization(norm) - self._ignoreColormapChange = False - - def _minMaxTextEdited(self, text): - """Handle _minValue and _maxValue textEdited signal""" - self._minMaxWasEdited = True - - def _minEditingFinished(self): - """Handle _minValue editingFinished signal - - Together with :meth:`_minMaxTextEdited`, this avoids to notify - colormap change when the min and max value where not edited. - """ - if self._minMaxWasEdited: - self._minMaxWasEdited = False - - # Fix start value - if (self._maxValue.getValue() is not None and - self._minValue.getValue() > self._maxValue.getValue()): - self._minValue.setValue(self._maxValue.getValue()) - self._updateMinMax() - - def _maxEditingFinished(self): - """Handle _maxValue editingFinished signal - - Together with :meth:`_minMaxTextEdited`, this avoids to notify - colormap change when the min and max value where not edited. - """ - if self._minMaxWasEdited: - self._minMaxWasEdited = False - - # Fix end value - if (self._minValue.getValue() is not None and - self._minValue.getValue() > self._maxValue.getValue()): - self._maxValue.setValue(self._minValue.getValue()) - self._updateMinMax() +__date__ = "24/04/2018" - def keyPressEvent(self, event): - """Override key handling. +import silx.utils.deprecation - It disables leaving the dialog when editing a text field. - """ - if event.key() == qt.Qt.Key_Enter and (self._minValue.hasFocus() or - self._maxValue.hasFocus()): - # Bypass QDialog keyPressEvent - # To avoid leaving the dialog when pressing enter on a text field - super(qt.QDialog, self).keyPressEvent(event) - else: - # Use QDialog keyPressEvent - super(ColormapDialog, self).keyPressEvent(event) +silx.utils.deprecation.deprecated_warning("Module", + name="silx.gui.plot.ColormapDialog", + reason="moved", + replacement="silx.gui.dialog.ColormapDialog", + since_version="0.8.0", + only_once=True, + skip_backtrace_count=1) - def _activeLogNorm(self, isLog): - if self._ignoreColormapChange is True: - return - if self._colormap(): - self._ignoreColormapChange = True - norm = Colormap.LOGARITHM if isLog is True else Colormap.LINEAR - self._colormap().setNormalization(norm) - self._ignoreColormapChange = False - self._updateMinMaxData() +from ..dialog.ColormapDialog import * # noqa diff --git a/silx/gui/plot/Colors.py b/silx/gui/plot/Colors.py index 2d44d4d..277e104 100644 --- a/silx/gui/plot/Colors.py +++ b/silx/gui/plot/Colors.py @@ -28,120 +28,22 @@ from __future__ import absolute_import __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "15/05/2017" +__date__ = "14/06/2018" +import silx.utils.deprecation -from silx.utils.deprecation import deprecated -import logging -import numpy +silx.utils.deprecation.deprecated_warning("Module", + name="silx.gui.plot.Colors", + reason="moved", + replacement="silx.gui.colors", + since_version="0.8.0", + only_once=True, + skip_backtrace_count=1) -from .Colormap import Colormap +from ..colors import * # noqa -_logger = logging.getLogger(__name__) - - -COLORDICT = {} -"""Dictionary of common colors.""" - -COLORDICT['b'] = COLORDICT['blue'] = '#0000ff' -COLORDICT['r'] = COLORDICT['red'] = '#ff0000' -COLORDICT['g'] = COLORDICT['green'] = '#00ff00' -COLORDICT['k'] = COLORDICT['black'] = '#000000' -COLORDICT['w'] = COLORDICT['white'] = '#ffffff' -COLORDICT['pink'] = '#ff66ff' -COLORDICT['brown'] = '#a52a2a' -COLORDICT['orange'] = '#ff9900' -COLORDICT['violet'] = '#6600ff' -COLORDICT['gray'] = COLORDICT['grey'] = '#a0a0a4' -# COLORDICT['darkGray'] = COLORDICT['darkGrey'] = '#808080' -# COLORDICT['lightGray'] = COLORDICT['lightGrey'] = '#c0c0c0' -COLORDICT['y'] = COLORDICT['yellow'] = '#ffff00' -COLORDICT['m'] = COLORDICT['magenta'] = '#ff00ff' -COLORDICT['c'] = COLORDICT['cyan'] = '#00ffff' -COLORDICT['darkBlue'] = '#000080' -COLORDICT['darkRed'] = '#800000' -COLORDICT['darkGreen'] = '#008000' -COLORDICT['darkBrown'] = '#660000' -COLORDICT['darkCyan'] = '#008080' -COLORDICT['darkYellow'] = '#808000' -COLORDICT['darkMagenta'] = '#800080' - - -def rgba(color, colorDict=None): - """Convert color code '#RRGGBB' and '#RRGGBBAA' to (R, G, B, A) - - It also convert RGB(A) values from uint8 to float in [0, 1] and - accept a QColor as color argument. - - :param str color: The color to convert - :param dict colorDict: A dictionary of color name conversion to color code - :returns: RGBA colors as floats in [0., 1.] - :rtype: tuple - """ - if colorDict is None: - colorDict = COLORDICT - - if hasattr(color, 'getRgbF'): # QColor support - color = color.getRgbF() - - values = numpy.asarray(color).ravel() - - if values.dtype.kind in 'iuf': # integer or float - # Color is an array - assert len(values) in (3, 4) - - # Convert from integers in [0, 255] to float in [0, 1] - if values.dtype.kind in 'iu': - values = values / 255. - - # Clip to [0, 1] - values[values < 0.] = 0. - values[values > 1.] = 1. - - if len(values) == 3: - return values[0], values[1], values[2], 1. - else: - return tuple(values) - - # We assume color is a string - if not color.startswith('#'): - color = colorDict[color] - - assert len(color) in (7, 9) and color[0] == '#' - r = int(color[1:3], 16) / 255. - g = int(color[3:5], 16) / 255. - b = int(color[5:7], 16) / 255. - a = int(color[7:9], 16) / 255. if len(color) == 9 else 1. - return r, g, b, a - - -_COLORMAP_CURSOR_COLORS = { - 'gray': 'pink', - 'reversed gray': 'pink', - 'temperature': 'pink', - 'red': 'green', - 'green': 'pink', - 'blue': 'yellow', - 'jet': 'pink', - 'viridis': 'pink', - 'magma': 'green', - 'inferno': 'green', - 'plasma': 'green', -} - - -def cursorColorForColormap(colormapName): - """Get a color suitable for overlay over a colormap. - - :param str colormapName: The name of the colormap. - :return: Name of the color. - :rtype: str - """ - return _COLORMAP_CURSOR_COLORS.get(colormapName, 'black') - - -@deprecated(replacement='silx.gui.plot.Colormap.applyColormap') +@silx.utils.deprecation.deprecated(replacement='silx.gui.colors.Colormap.applyColormap') def applyColormapToData(data, name='gray', normalization='linear', @@ -178,7 +80,7 @@ def applyColormapToData(data, return colormap.applyToData(data) -@deprecated(replacement='silx.gui.plot.Colormap.getSupportedColormaps') +@silx.utils.deprecation.deprecated(replacement='silx.gui.colors.Colormap.getSupportedColormaps') def getSupportedColormaps(): """Get the supported colormap names as a tuple of str. diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py index ebff175..bbcb0a5 100644 --- a/silx/gui/plot/ComplexImageView.py +++ b/silx/gui/plot/ComplexImageView.py @@ -32,7 +32,7 @@ from __future__ import absolute_import __authors__ = ["Vincent Favre-Nicolin", "T. Vincent"] __license__ = "MIT" -__date__ = "19/01/2018" +__date__ = "24/04/2018" import logging @@ -410,7 +410,7 @@ class ComplexImageView(qt.QWidget): WARNING: This colormap is not used when displaying both amplitude and phase. - :param ~silx.gui.plot.Colormap.Colormap colormap: The colormap + :param ~silx.gui.colors.Colormap colormap: The colormap :param Mode mode: If specified, set the colormap of this specific mode """ self._plotImage.setColormap(colormap, mode) @@ -419,7 +419,7 @@ class ComplexImageView(qt.QWidget): """Returns the colormap used to display the data. :param Mode mode: If specified, set the colormap of this specific mode - :rtype: ~silx.gui.plot.Colormap.Colormap + :rtype: ~silx.gui.colors.Colormap """ return self._plotImage.getColormap(mode=mode) diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py index ccb6866..81e684e 100644 --- a/silx/gui/plot/CurvesROIWidget.py +++ b/silx/gui/plot/CurvesROIWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,13 +33,11 @@ ROI are defined by : This can be used to apply or not some ROI to a curve and do some post processing. - The x coordinate of the left limit (`from` column) - The x coordinate of the right limit (`to` column) -- Raw counts: integral of the curve between the - min ROI point and the max ROI point to the y = 0 line +- Raw counts: Sum of the curve's values in the defined Region Of Intereset. .. image:: img/rawCounts.png -- Net counts: the integral of the curve between the - min ROI point and the max ROI point to [ROI min point, ROI max point] segment +- Net counts: Raw counts minus background .. image:: img/netCounts.png """ @@ -53,6 +51,7 @@ from collections import OrderedDict import logging import os import sys +import weakref import numpy @@ -93,7 +92,8 @@ class CurvesROIWidget(qt.QWidget): if name is not None: self.setWindowTitle(name) assert plot is not None - self.plot = plot + self._plotRef = weakref.ref(plot) + layout = qt.QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) @@ -162,6 +162,13 @@ class CurvesROIWidget(qt.QWidget): self._isConnected = False # True if connected to plot signals self._isInit = False + def getPlotWidget(self): + """Returns the associated PlotWidget or None + + :rtype: Union[~silx.gui.plot.PlotWidget,None] + """ + return None if self._plotRef is None else self._plotRef() + def showEvent(self, event): self._visibilityChangedHandler(visible=True) qt.QWidget.showEvent(self, event) @@ -400,14 +407,18 @@ class CurvesROIWidget(qt.QWidget): def _roiSignal(self, ddict): """Handle ROI widget signal""" _logger.debug("CurvesROIWidget._roiSignal %s", str(ddict)) + plot = self.getPlotWidget() + if plot is None: + return + if ddict['event'] == "AddROI": - xmin, xmax = self.plot.getXAxis().getLimits() + xmin, xmax = plot.getXAxis().getLimits() fromdata = xmin + 0.25 * (xmax - xmin) todata = xmin + 0.75 * (xmax - xmin) - self.plot.remove('ROI min', kind='marker') - self.plot.remove('ROI max', kind='marker') + plot.remove('ROI min', kind='marker') + plot.remove('ROI max', kind='marker') if self._middleROIMarkerFlag: - self.plot.remove('ROI middle', kind='marker') + plot.remove('ROI middle', kind='marker') roiList, roiDict = self.roiTable.getROIListAndDict() nrois = len(roiList) if nrois == 0: @@ -416,6 +427,7 @@ class CurvesROIWidget(qt.QWidget): draggable = False color = 'black' else: + # find the next index free for newroi. for i in range(nrois): i += 1 newroi = "newroi %d" % i @@ -423,29 +435,29 @@ class CurvesROIWidget(qt.QWidget): break color = 'blue' draggable = True - self.plot.addXMarker(fromdata, - legend='ROI min', - text='ROI min', - color=color, - draggable=draggable) - self.plot.addXMarker(todata, - legend='ROI max', - text='ROI max', - color=color, - draggable=draggable) + plot.addXMarker(fromdata, + legend='ROI min', + text='ROI min', + color=color, + draggable=draggable) + plot.addXMarker(todata, + legend='ROI max', + text='ROI max', + color=color, + draggable=draggable) if draggable and self._middleROIMarkerFlag: pos = 0.5 * (fromdata + todata) - self.plot.addXMarker(pos, - legend='ROI middle', - text="", - color='yellow', - draggable=draggable) + plot.addXMarker(pos, + legend='ROI middle', + text="", + color='yellow', + draggable=draggable) roiList.append(newroi) roiDict[newroi] = {} if newroi == "ICR": roiDict[newroi]['type'] = "Default" else: - roiDict[newroi]['type'] = self.plot.getXAxis().getLabel() + roiDict[newroi]['type'] = plot.getXAxis().getLabel() roiDict[newroi]['from'] = fromdata roiDict[newroi]['to'] = todata self.roiTable.fillFromROIDict(roilist=roiList, @@ -454,10 +466,10 @@ class CurvesROIWidget(qt.QWidget): self.currentROI = newroi self.calculateRois() elif ddict['event'] in ['DelROI', "ResetROI"]: - self.plot.remove('ROI min', kind='marker') - self.plot.remove('ROI max', kind='marker') + plot.remove('ROI min', kind='marker') + plot.remove('ROI max', kind='marker') if self._middleROIMarkerFlag: - self.plot.remove('ROI middle', kind='marker') + plot.remove('ROI middle', kind='marker') roiList, roiDict = self.roiTable.getROIListAndDict() roiDictKeys = list(roiDict.keys()) if len(roiDictKeys): @@ -480,37 +492,37 @@ class CurvesROIWidget(qt.QWidget): self.roilist, self.roidict = self.roiTable.getROIListAndDict() fromdata = ddict['roi']['from'] todata = ddict['roi']['to'] - self.plot.remove('ROI min', kind='marker') - self.plot.remove('ROI max', kind='marker') + plot.remove('ROI min', kind='marker') + plot.remove('ROI max', kind='marker') if ddict['key'] == 'ICR': draggable = False color = 'black' else: draggable = True color = 'blue' - self.plot.addXMarker(fromdata, - legend='ROI min', - text='ROI min', - color=color, - draggable=draggable) - self.plot.addXMarker(todata, - legend='ROI max', - text='ROI max', - color=color, - draggable=draggable) + plot.addXMarker(fromdata, + legend='ROI min', + text='ROI min', + color=color, + draggable=draggable) + plot.addXMarker(todata, + legend='ROI max', + text='ROI max', + color=color, + draggable=draggable) if draggable and self._middleROIMarkerFlag: pos = 0.5 * (fromdata + todata) - self.plot.addXMarker(pos, - legend='ROI middle', - text="", - color='yellow', - draggable=True) + plot.addXMarker(pos, + legend='ROI middle', + text="", + color='yellow', + draggable=True) self.currentROI = ddict['key'] if ddict['colheader'] in ['From', 'To']: dict0 = {} dict0['event'] = "SetActiveCurveEvent" - dict0['legend'] = self.plot.getActiveCurve(just_legend=1) - self.plot.setActiveCurve(dict0['legend']) + dict0['legend'] = plot.getActiveCurve(just_legend=1) + plot.setActiveCurve(dict0['legend']) elif ddict['colheader'] == 'Raw Counts': pass elif ddict['colheader'] == 'Net Counts': @@ -523,7 +535,8 @@ class CurvesROIWidget(qt.QWidget): def _getAllLimits(self): """Retrieve the limits based on the curves.""" - curves = self.plot.getAllCurves() + plot = self.getPlotWidget() + curves = () if plot is None else plot.getAllCurves() if not curves: return 1.0, 1.0, 100., 100. @@ -562,7 +575,12 @@ class CurvesROIWidget(qt.QWidget): if roiList is None or roiDict is None: roiList, roiDict = self.roiTable.getROIListAndDict() - activeCurve = self.plot.getActiveCurve(just_legend=False) + plot = self.getPlotWidget() + if plot is None: + activeCurve = None + else: + activeCurve = plot.getActiveCurve(just_legend=False) + if activeCurve is None: xproc = None yproc = None @@ -640,6 +658,11 @@ class CurvesROIWidget(qt.QWidget): return if self.currentROI not in roiDict: return + + plot = self.getPlotWidget() + if plot is None: + return + x = ddict['x'] if label == 'ROI min': @@ -647,36 +670,36 @@ class CurvesROIWidget(qt.QWidget): if self._middleROIMarkerFlag: pos = 0.5 * (roiDict[self.currentROI]['to'] + roiDict[self.currentROI]['from']) - self.plot.addXMarker(pos, - legend='ROI middle', - text='', - color='yellow', - draggable=True) + plot.addXMarker(pos, + legend='ROI middle', + text='', + color='yellow', + draggable=True) elif label == 'ROI max': roiDict[self.currentROI]['to'] = x if self._middleROIMarkerFlag: pos = 0.5 * (roiDict[self.currentROI]['to'] + roiDict[self.currentROI]['from']) - self.plot.addXMarker(pos, - legend='ROI middle', - text='', - color='yellow', - draggable=True) + plot.addXMarker(pos, + legend='ROI middle', + text='', + color='yellow', + draggable=True) elif label == 'ROI middle': delta = x - 0.5 * (roiDict[self.currentROI]['from'] + roiDict[self.currentROI]['to']) roiDict[self.currentROI]['from'] += delta roiDict[self.currentROI]['to'] += delta - self.plot.addXMarker(roiDict[self.currentROI]['from'], - legend='ROI min', - text='ROI min', - color='blue', - draggable=True) - self.plot.addXMarker(roiDict[self.currentROI]['to'], - legend='ROI max', - text='ROI max', - color='blue', - draggable=True) + plot.addXMarker(roiDict[self.currentROI]['from'], + legend='ROI min', + text='ROI min', + color='blue', + draggable=True) + plot.addXMarker(roiDict[self.currentROI]['to'], + legend='ROI max', + text='ROI max', + color='blue', + draggable=True) else: return self.calculateRois(roiList, roiDict) @@ -687,32 +710,39 @@ class CurvesROIWidget(qt.QWidget): It is connected to plot signals only when visible. """ + plot = self.getPlotWidget() + if visible: if not self._isInit: # Deferred ROI widget init finalization - self._isInit = True - self.sigROIWidgetSignal.connect(self._roiSignal) - # initialize with the ICR - self._roiSignal({'event': "AddROI"}) - - if not self._isConnected: - self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent) - self.plot.sigActiveCurveChanged.connect( + self._finalizeInit() + + if not self._isConnected and plot is not None: + plot.sigPlotSignal.connect(self._handleROIMarkerEvent) + plot.sigActiveCurveChanged.connect( self._activeCurveChanged) self._isConnected = True self.calculateRois() else: if self._isConnected: - self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent) - self.plot.sigActiveCurveChanged.disconnect( - self._activeCurveChanged) + if plot is not None: + plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent) + plot.sigActiveCurveChanged.disconnect( + self._activeCurveChanged) self._isConnected = False def _activeCurveChanged(self, *args): """Recompute ROIs when active curve changed.""" self.calculateRois() + def _finalizeInit(self): + self._isInit = True + self.sigROIWidgetSignal.connect(self._roiSignal) + # initialize with the ICR if no ROi existing yet + if len(self.getRois()) is 0: + self._roiSignal({'event': "AddROI"}) + class ROITable(qt.QTableWidget): """Table widget displaying ROI information. @@ -977,9 +1007,6 @@ class CurvesROIDockWidget(qt.QDockWidget): def __init__(self, parent=None, plot=None, name=None): super(CurvesROIDockWidget, self).__init__(name, parent) - assert plot is not None - self.plot = plot - self.roiWidget = CurvesROIWidget(self, name, plot=plot) """Main widget of type :class:`CurvesROIWidget`""" diff --git a/silx/gui/plot/ImageView.py b/silx/gui/plot/ImageView.py index 46e56e6..c28ffca 100644 --- a/silx/gui/plot/ImageView.py +++ b/silx/gui/plot/ImageView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# Copyright (c) 2015-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -42,18 +42,19 @@ from __future__ import division __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "17/08/2017" +__date__ = "26/04/2018" import logging import numpy +import silx from .. import qt from . import items, PlotWindow, PlotWidget, actions -from .Colormap import Colormap -from .Colors import cursorColorForColormap -from .PlotTools import LimitsToolBar +from ..colors import Colormap +from ..colors import cursorColorForColormap +from .tools import LimitsToolBar from .Profile import ProfileToolBar @@ -296,6 +297,9 @@ class ImageView(PlotWindow): if parent is None: self.setWindowTitle('ImageView') + if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward': + self.getYAxis().setInverted(True) + self._initWidgets(backend) self.profile = ProfileToolBar(plot=self) @@ -356,7 +360,7 @@ class ImageView(PlotWindow): layout.setSpacing(0) layout.setContentsMargins(0, 0, 0, 0) - centralWidget = qt.QWidget() + centralWidget = qt.QWidget(self) centralWidget.setLayout(layout) self.setCentralWidget(centralWidget) @@ -773,7 +777,7 @@ class ImageView(PlotWindow): legend=self._imageLegend, origin=origin, scale=scale, colormap=self.getColormap(), - replace=False, resetzoom=False) + resetzoom=False) self.setActiveImage(self._imageLegend) self._updateHistograms() @@ -810,17 +814,17 @@ class ImageViewMainWindow(ImageView): self.statusBar() menu = self.menuBar().addMenu('File') - menu.addAction(self.saveAction) - menu.addAction(self.printAction) + menu.addAction(self.getOutputToolBar().getSaveAction()) + menu.addAction(self.getOutputToolBar().getPrintAction()) menu.addSeparator() action = menu.addAction('Quit') action.triggered[bool].connect(qt.QApplication.instance().quit) menu = self.menuBar().addMenu('Edit') - menu.addAction(self.copyAction) + menu.addAction(self.getOutputToolBar().getCopyAction()) menu.addSeparator() - menu.addAction(self.resetZoomAction) - menu.addAction(self.colormapAction) + menu.addAction(self.getResetZoomAction()) + menu.addAction(self.getColormapAction()) menu.addAction(actions.control.KeepAspectRatioAction(self, self)) menu.addAction(actions.control.YAxisInvertedAction(self, self)) diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py index 09c5ca5..797068e 100644 --- a/silx/gui/plot/MaskToolsWidget.py +++ b/silx/gui/plot/MaskToolsWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,7 +35,7 @@ from __future__ import division __authors__ = ["T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "20/06/2017" +__date__ = "24/04/2018" import os @@ -48,7 +48,7 @@ from silx.image import shapes from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget from . import items -from .Colors import cursorColorForColormap, rgba +from ..colors import cursorColorForColormap, rgba from .. import qt from silx.third_party.EdfFile import EdfFile @@ -76,6 +76,7 @@ class ImageMask(BaseMask): :param image: :class:`silx.gui.plot.items.ImageBase` instance """ BaseMask.__init__(self, image) + self.reset(shape=(0, 0)) # Init the mask with a 2D shape def getDataValues(self): """Return image data as a 2D or 3D array (if it is a RGBA image). @@ -222,7 +223,8 @@ class MaskToolsWidget(BaseMaskToolsWidget): def setSelectionMask(self, mask, copy=True): """Set the mask to a new array. - :param numpy.ndarray mask: The array to use for the mask. + :param numpy.ndarray mask: + The array to use for the mask or None to reset the mask. :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous. Array of other types are converted. :param bool copy: True (the default) to copy the array, @@ -231,11 +233,19 @@ class MaskToolsWidget(BaseMaskToolsWidget): The mask can be cropped or padded to fit active image, the returned shape is that of the active image. """ + if mask is None: + self.resetSelectionMask() + return self._data.shape[:2] + mask = numpy.array(mask, copy=False, dtype=numpy.uint8) if len(mask.shape) != 2: _logger.error('Not an image, shape: %d', len(mask.shape)) return None + # if mask has not changed, do nothing + if numpy.array_equal(mask, self.getSelectionMask()): + return mask.shape + # ensure all mask attributes are synchronized with the active image # and connect listener activeImage = self.plot.getActiveImage() @@ -265,7 +275,7 @@ class MaskToolsWidget(BaseMaskToolsWidget): def _updatePlotMask(self): """Update mask image in plot""" mask = self.getSelectionMask(copy=False) - if len(mask): + if mask is not None: # get the mask from the plot maskItem = self.plot.getImage(self._maskName) mustBeAdded = maskItem is None @@ -303,7 +313,7 @@ class MaskToolsWidget(BaseMaskToolsWidget): if not self.browseAction.isChecked(): self.browseAction.trigger() # Disable drawing tool - if len(self.getSelectionMask(copy=False)): + if self.getSelectionMask(copy=False) is not None: self.plot.sigActiveImageChanged.connect( self._activeImageChangedAfterCare) @@ -328,6 +338,13 @@ class MaskToolsWidget(BaseMaskToolsWidget): activeImage = self.plot.getActiveImage() if activeImage is None or activeImage.getLegend() == self._maskName: # No active image or active image is the mask... + self._data = numpy.zeros((0, 0), dtype=numpy.uint8) + self._mask.setDataItem(None) + self._mask.reset() + + if self.plot.getImage(self._maskName): + self.plot.remove(self._maskName, kind='image') + self.plot.sigActiveImageChanged.disconnect( self._activeImageChangedAfterCare) else: @@ -340,7 +357,7 @@ class MaskToolsWidget(BaseMaskToolsWidget): self._scale = activeImage.getScale() self._z = activeImage.getZValue() + 1 self._data = activeImage.getData(copy=False) - if self._data.shape[:2] != self.getSelectionMask(copy=False).shape: + if self._data.shape[:2] != self._mask.getMask(copy=False).shape: # Image has not the same size, remove mask and stop listening if self.plot.getImage(self._maskName): self.plot.remove(self._maskName, kind='image') @@ -378,7 +395,7 @@ class MaskToolsWidget(BaseMaskToolsWidget): self._z = activeImage.getZValue() + 1 self._data = activeImage.getData(copy=False) self._mask.setDataItem(activeImage) - if self._data.shape[:2] != self.getSelectionMask(copy=False).shape: + if self._data.shape[:2] != self._mask.getMask(copy=False).shape: self._mask.reset(self._data.shape[:2]) self._mask.commit() else: @@ -597,7 +614,7 @@ class MaskToolsWidget(BaseMaskToolsWidget): # convert from plot to array coords col, row = (event['points'][-1] - self._origin) / self._scale col, row = int(col), int(row) - brushSize = self.pencilSpinBox.value() + brushSize = self._getPencilWidth() if self._lastPencilPos != (row, col): if self._lastPencilPos is not None: diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py index 865073b..356bda6 100644 --- a/silx/gui/plot/PlotInteraction.py +++ b/silx/gui/plot/PlotInteraction.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,7 +26,7 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "27/06/2017" +__date__ = "24/04/2018" import math @@ -34,7 +34,8 @@ import numpy import time import weakref -from . import Colors +from .. import colors +from .. import qt from . import items from .Interaction import (ClickOrDrag, LEFT_BTN, RIGHT_BTN, State, StateMachine) @@ -115,11 +116,52 @@ class _ZoomOnWheel(ClickOrDrag, _PlotInteraction): Base class for :class:`Pan` and :class:`Zoom` """ + + _DOUBLE_CLICK_TIMEOUT = 0.4 + class ZoomIdle(ClickOrDrag.Idle): def onWheel(self, x, y, angle): scaleF = 1.1 if angle > 0 else 1. / 1.1 applyZoomToPlot(self.machine.plot, scaleF, (x, y)) + def click(self, x, y, btn): + """Handle clicks by sending events + + :param int x: Mouse X position in pixels + :param int y: Mouse Y position in pixels + :param btn: Clicked mouse button + """ + if btn == LEFT_BTN: + lastClickTime, lastClickPos = self._lastClick + + # Signal mouse double clicked event first + if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT: + # Use position of first click + eventDict = prepareMouseSignal('mouseDoubleClicked', 'left', + *lastClickPos) + self.plot.notify(**eventDict) + + self._lastClick = 0., None + else: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareMouseSignal('mouseClicked', 'left', + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**eventDict) + + self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y) + + elif btn == RIGHT_BTN: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareMouseSignal('mouseClicked', 'right', + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**eventDict) + def __init__(self, plot): """Init. @@ -135,6 +177,8 @@ class _ZoomOnWheel(ClickOrDrag, _PlotInteraction): } StateMachine.__init__(self, states, 'idle') + self._lastClick = 0., None + # Pan ######################################################################### @@ -229,11 +273,9 @@ class Zoom(_ZoomOnWheel): Zoom-in on selected area, zoom-out on right click, and zoom on mouse wheel. """ - _DOUBLE_CLICK_TIMEOUT = 0.4 def __init__(self, plot, color): self.color = color - self._lastClick = 0., None super(Zoom, self).__init__(plot) self.plot.getLimitsHistory().clear() @@ -263,38 +305,6 @@ class Zoom(_ZoomOnWheel): return areaX0, areaY0, areaX1, areaY1 - def click(self, x, y, btn): - if btn == LEFT_BTN: - lastClickTime, lastClickPos = self._lastClick - - # Signal mouse double clicked event first - if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT: - # Use position of first click - eventDict = prepareMouseSignal('mouseDoubleClicked', 'left', - *lastClickPos) - self.plot.notify(**eventDict) - - self._lastClick = 0., None - else: - # Signal mouse clicked event - dataPos = self.plot.pixelToData(x, y) - assert dataPos is not None - eventDict = prepareMouseSignal('mouseClicked', 'left', - dataPos[0], dataPos[1], - x, y) - self.plot.notify(**eventDict) - - self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y) - - elif btn == RIGHT_BTN: - # Signal mouse clicked event - dataPos = self.plot.pixelToData(x, y) - assert dataPos is not None - eventDict = prepareMouseSignal('mouseClicked', 'right', - dataPos[0], dataPos[1], - x, y) - self.plot.notify(**eventDict) - def beginDrag(self, x, y): dataPos = self.plot.pixelToData(x, y) assert dataPos is not None @@ -424,7 +434,7 @@ class SelectPolygon(Select): """Update drawing first point, using self._firstPos""" x, y = self.machine.plot.dataToPixel(*self._firstPos, check=False) - offset = self.machine.DRAG_THRESHOLD_DIST + offset = self.machine.getDragThreshold() points = [(x - offset, y - offset), (x - offset, y + offset), (x + offset, y + offset), @@ -458,10 +468,10 @@ class SelectPolygon(Select): check=False) dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y) + threshold = self.machine.getDragThreshold() + # Only allow to close polygon after first point - if (len(self.points) > 2 and - dx < self.machine.DRAG_THRESHOLD_DIST and - dy < self.machine.DRAG_THRESHOLD_DIST): + if len(self.points) > 2 and dx <= threshold and dy <= threshold: self.machine.resetSelectionArea() self.points[-1] = self.points[0] @@ -489,8 +499,7 @@ class SelectPolygon(Select): previousPos = self.machine.plot.dataToPixel(*self.points[-2], check=False) dx, dy = abs(previousPos[0] - x), abs(previousPos[1] - y) - if(dx >= self.machine.DRAG_THRESHOLD_DIST or - dy >= self.machine.DRAG_THRESHOLD_DIST): + if dx >= threshold or dy >= threshold: self.points.append(dataPos) else: self.points[-1] = dataPos @@ -502,8 +511,9 @@ class SelectPolygon(Select): firstPos = self.machine.plot.dataToPixel(*self._firstPos, check=False) dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y) - if (dx < self.machine.DRAG_THRESHOLD_DIST and - dy < self.machine.DRAG_THRESHOLD_DIST): + threshold = self.machine.getDragThreshold() + + if dx <= threshold and dy <= threshold: x, y = firstPos # Snap to first point dataPos = self.machine.plot.pixelToData(x, y) @@ -523,6 +533,17 @@ class SelectPolygon(Select): if isinstance(self.state, self.states['select']): self.resetSelectionArea() + def getDragThreshold(self): + """Return dragging ratio with device to pixel ratio applied. + + :rtype: float + """ + ratio = 1. + if qt.BINDING in ('PyQt5', 'PySide2'): + ratio = self.plot.window().windowHandle().devicePixelRatio() + return self.DRAG_THRESHOLD_DIST * ratio + + class Select2Points(Select): """Base class for drawing selection based on 2 input points.""" @@ -1204,6 +1225,48 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction): self.plot.setGraphCursorShape() +class ItemsInteractionForCombo(ItemsInteraction): + """Interaction with items to combine through :class:`FocusManager`. + """ + + class Idle(ItemsInteraction.Idle): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + def test(item): + return (item.isSelectable() or + (isinstance(item, items.DraggableMixIn) and + item.isDraggable())) + + picked = self.machine.plot._pickMarker(x, y, test) + if picked is not None: + itemInteraction = True + + else: + picked = self.machine.plot._pickImageOrCurve(x, y, test) + itemInteraction = picked is not None + + if itemInteraction: # Request focus and handle interaction + self.goto('clickOrDrag', x, y) + return True + else: # Do not request focus + return False + + elif btn == RIGHT_BTN: + self.goto('rightClick', x, y) + return True + + def __init__(self, plot): + _PlotInteraction.__init__(self, plot) + + states = { + 'idle': ItemsInteractionForCombo.Idle, + 'rightClick': ClickOrDrag.RightClick, + 'clickOrDrag': ClickOrDrag.ClickOrDrag, + 'drag': ClickOrDrag.Drag + } + StateMachine.__init__(self, states, 'idle') + + # FocusManager ################################################################ class FocusManager(StateMachine): @@ -1344,6 +1407,74 @@ class ZoomAndSelect(ItemsInteraction): return super(ZoomAndSelect, self).endDrag(startPos, endPos) +class PanAndSelect(ItemsInteraction): + """Combine Pan and ItemInteraction state machine. + + :param plot: The Plot to which this interaction is attached + """ + + def __init__(self, plot): + super(PanAndSelect, self).__init__(plot) + self._pan = Pan(plot) + self._doPan = False + + def click(self, x, y, btn): + """Handle mouse click + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :param btn: Pressed button id + :return: True if click is catched by an item, False otherwise + """ + eventDict = self._handleClick(x, y, btn) + + if eventDict is not None: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + clickedEventDict = prepareMouseSignal('mouseClicked', btn, + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**clickedEventDict) + + self.plot.notify(**eventDict) + + else: + self._pan.click(x, y, btn) + + def beginDrag(self, x, y): + """Handle start drag and switching between zoom and item drag. + + :param x: X position in pixels + :param y: Y position in pixels + """ + self._doPan = not super(PanAndSelect, self).beginDrag(x, y) + if self._doPan: + self._pan.beginDrag(x, y) + + def drag(self, x, y): + """Handle drag, eventually forwarding to zoom. + + :param x: X position in pixels + :param y: Y position in pixels + """ + if self._doPan: + return self._pan.drag(x, y) + else: + return super(PanAndSelect, self).drag(x, y) + + def endDrag(self, startPos, endPos): + """Handle end of drag, eventually forwarding to zoom. + + :param startPos: (x, y) position at the beginning of the drag + :param endPos: (x, y) position at the end of the drag + """ + if self._doPan: + return self._pan.endDrag(startPos, endPos) + else: + return super(PanAndSelect, self).endDrag(startPos, endPos) + + # Interaction mode control #################################################### class PlotInteraction(object): @@ -1384,12 +1515,21 @@ class PlotInteraction(object): if isinstance(self._eventHandler, ZoomAndSelect): return {'mode': 'zoom', 'color': self._eventHandler.color} + elif isinstance(self._eventHandler, FocusManager): + drawHandler = self._eventHandler.eventHandlers[1] + if not isinstance(drawHandler, Select): + raise RuntimeError('Unknown interactive mode') + + result = drawHandler.parameters.copy() + result['mode'] = 'draw' + return result + elif isinstance(self._eventHandler, Select): result = self._eventHandler.parameters.copy() result['mode'] = 'draw' return result - elif isinstance(self._eventHandler, Pan): + elif isinstance(self._eventHandler, PanAndSelect): return {'mode': 'pan'} else: @@ -1400,7 +1540,7 @@ class PlotInteraction(object): """Switch the interactive mode. :param str mode: The name of the interactive mode. - In 'draw', 'pan', 'select', 'zoom'. + In 'draw', 'pan', 'select', 'select-draw', 'zoom'. :param color: Only for 'draw' and 'zoom' modes. Color to use for drawing selection area. Default black. If None, selection area is not drawn. @@ -1413,15 +1553,15 @@ class PlotInteraction(object): :param str label: Only for 'draw' mode. :param float width: Width of the pencil. Only for draw pencil mode. """ - assert mode in ('draw', 'pan', 'select', 'zoom') + assert mode in ('draw', 'pan', 'select', 'select-draw', 'zoom') plot = self._plot() assert plot is not None if color not in (None, 'video inverted'): - color = Colors.rgba(color) + color = colors.rgba(color) - if mode == 'draw': + if mode in ('draw', 'select-draw'): assert shape in self._DRAW_MODES eventHandlerClass = self._DRAW_MODES[shape] parameters = { @@ -1430,14 +1570,21 @@ class PlotInteraction(object): 'color': color, 'width': width, } + eventHandler = eventHandlerClass(plot, parameters) self._eventHandler.cancel() - self._eventHandler = eventHandlerClass(plot, parameters) + + if mode == 'draw': + self._eventHandler = eventHandler + + else: # mode == 'select-draw' + self._eventHandler = FocusManager( + (ItemsInteractionForCombo(plot), eventHandler)) elif mode == 'pan': # Ignores color, shape and label self._eventHandler.cancel() - self._eventHandler = Pan(plot) + self._eventHandler = PanAndSelect(plot) elif mode == 'zoom': # Ignores shape and label diff --git a/silx/gui/plot/PlotToolButtons.py b/silx/gui/plot/PlotToolButtons.py index fc5fcf4..e354877 100644 --- a/silx/gui/plot/PlotToolButtons.py +++ b/silx/gui/plot/PlotToolButtons.py @@ -30,6 +30,7 @@ The following QToolButton are available: - :class:`.AspectToolButton` - :class:`.YAxisOriginToolButton` - :class:`.ProfileToolButton` +- :class:`.SymbolToolButton` """ @@ -38,10 +39,15 @@ __license__ = "MIT" __date__ = "27/06/2017" +import functools import logging +import weakref + from .. import icons from .. import qt +from .items import SymbolMixIn + _logger = logging.getLogger(__name__) @@ -52,7 +58,7 @@ class PlotToolButton(qt.QToolButton): def __init__(self, parent=None, plot=None): super(PlotToolButton, self).__init__(parent) - self._plot = None + self._plotRef = None if plot is not None: self.setPlot(plot) @@ -60,7 +66,7 @@ class PlotToolButton(qt.QToolButton): """ Returns the plot connected to the widget. """ - return self._plot + return None if self._plotRef is None else self._plotRef() def setPlot(self, plot): """ @@ -68,13 +74,18 @@ class PlotToolButton(qt.QToolButton): :param plot: :class:`.PlotWidget` instance on which to operate. """ - if self._plot is plot: + previousPlot = self.plot() + + if previousPlot is plot: return - if self._plot is not None: - self._disconnectPlot(self._plot) - self._plot = plot - if self._plot is not None: - self._connectPlot(self._plot) + if previousPlot is not None: + self._disconnectPlot(previousPlot) + + if plot is None: + self._plotRef = None + else: + self._plotRef = weakref.ref(plot) + self._connectPlot(plot) def _connectPlot(self, plot): """ @@ -282,3 +293,71 @@ class ProfileToolButton(PlotToolButton): def computeProfileIn2D(self): self._profileDimensionChanged(2) + + +class SymbolToolButton(PlotToolButton): + """A tool button with a drop-down menu to control symbol size and marker. + + :param parent: See QWidget + :param plot: The `~silx.gui.plot.PlotWidget` to control + """ + + def __init__(self, parent=None, plot=None): + super(SymbolToolButton, self).__init__(parent=parent, plot=plot) + + self.setToolTip('Set symbol size and marker') + self.setIcon(icons.getQIcon('plot-symbols')) + + menu = qt.QMenu(self) + + # Size slider + + slider = qt.QSlider(qt.Qt.Horizontal) + slider.setRange(1, 20) + slider.setValue(SymbolMixIn._DEFAULT_SYMBOL_SIZE) + slider.setTracking(False) + slider.valueChanged.connect(self._sizeChanged) + widgetAction = qt.QWidgetAction(menu) + widgetAction.setDefaultWidget(slider) + menu.addAction(widgetAction) + + menu.addSeparator() + + # Marker actions + + for marker, name in zip(SymbolMixIn.getSupportedSymbols(), + SymbolMixIn.getSupportedSymbolNames()): + action = qt.QAction(name, menu) + action.setCheckable(False) + action.triggered.connect( + functools.partial(self._markerChanged, marker)) + menu.addAction(action) + + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + def _sizeChanged(self, value): + """Manage slider value changed + + :param int value: Marker size + """ + plot = self.plot() + if plot is None: + return + + for item in plot._getItems(withhidden=True): + if isinstance(item, SymbolMixIn): + item.setSymbolSize(value) + + def _markerChanged(self, marker): + """Manage change of marker. + + :param str marker: Letter describing the marker + """ + plot = self.plot() + if plot is None: + return + + for item in plot._getItems(withhidden=True): + if isinstance(item, SymbolMixIn): + item.setSymbol(marker) diff --git a/silx/gui/plot/PlotTools.py b/silx/gui/plot/PlotTools.py index 7fadfd2..5929473 100644 --- a/silx/gui/plot/PlotTools.py +++ b/silx/gui/plot/PlotTools.py @@ -25,288 +25,19 @@ """Set of widgets to associate with a :class:'PlotWidget'. """ -from __future__ import division +from __future__ import absolute_import -__authors__ = ["V.A. Sole", "T. Vincent"] +__authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "16/10/2017" +__date__ = "01/03/2018" -import logging -import numbers -import traceback -import weakref +from ...utils.deprecation import deprecated_warning -import numpy +deprecated_warning(type_='module', + name=__file__, + reason='Plot tools refactoring', + replacement='silx.gui.plot.tools', + since_version='0.8') -from .. import qt -from silx.gui.widgets.FloatEdit import FloatEdit - -_logger = logging.getLogger(__name__) - - -# PositionInfo ################################################################ - -class PositionInfo(qt.QWidget): - """QWidget displaying coords converted from data coords of the mouse. - - Provide this widget with a list of couple: - - - A name to display before the data - - A function that takes (x, y) as arguments and returns something that - gets converted to a string. - If the result is a float it is converted with '%.7g' format. - - To run the following sample code, a QApplication must be initialized. - First, create a PlotWindow and add a QToolBar where to place the - PositionInfo widget. - - >>> from silx.gui.plot import PlotWindow - >>> from silx.gui import qt - - >>> plot = PlotWindow() # Create a PlotWindow to add the widget to - >>> toolBar = qt.QToolBar() # Create a toolbar to place the widget in - >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) # Add it to plot - - Then, create the PositionInfo widget and add it to the toolbar. - The PositionInfo widget is created with a list of converters, here - to display polar coordinates of the mouse position. - - >>> import numpy - >>> from silx.gui.plot.PlotTools import PositionInfo - - >>> position = PositionInfo(plot=plot, converters=[ - ... ('Radius', lambda x, y: numpy.sqrt(x*x + y*y)), - ... ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))]) - >>> toolBar.addWidget(position) # Add the widget to the toolbar - <...> - >>> plot.show() # To display the PlotWindow with the position widget - - :param plot: The PlotWidget this widget is displaying data coords from. - :param converters: - List of 2-tuple: name to display and conversion function from (x, y) - in data coords to displayed value. - If None, the default, it displays X and Y. - :param parent: Parent widget - """ - - def __init__(self, parent=None, plot=None, converters=None): - assert plot is not None - self._plotRef = weakref.ref(plot) - - super(PositionInfo, self).__init__(parent) - - if converters is None: - converters = (('X', lambda x, y: x), ('Y', lambda x, y: y)) - - self.autoSnapToActiveCurve = False - """Toggle snapping use position to active curve. - - - True to snap used coordinates to the active curve if the active curve - is displayed with symbols and mouse is close enough. - If the mouse is not close to a point of the curve, values are - displayed in red. - - False (the default) to always use mouse coordinates. - - """ - - self._fields = [] # To store (QLineEdit, name, function (x, y)->v) - - # Create a new layout with new widgets - layout = qt.QHBoxLayout() - layout.setContentsMargins(0, 0, 0, 0) - # layout.setSpacing(0) - - # Create all QLabel and store them with the corresponding converter - for name, func in converters: - layout.addWidget(qt.QLabel('<b>' + name + ':</b>')) - - contentWidget = qt.QLabel() - contentWidget.setText('------') - contentWidget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse) - contentWidget.setFixedWidth( - contentWidget.fontMetrics().width('##############')) - layout.addWidget(contentWidget) - self._fields.append((contentWidget, name, func)) - - layout.addStretch(1) - self.setLayout(layout) - - # Connect to Plot events - plot.sigPlotSignal.connect(self._plotEvent) - - @property - def plot(self): - """The :class:`.PlotWindow` this widget is attached to.""" - return self._plotRef() - - def getConverters(self): - """Return the list of converters as 2-tuple (name, function).""" - return [(name, func) for _label, name, func in self._fields] - - def _plotEvent(self, event): - """Handle events from the Plot. - - :param dict event: Plot event - """ - if event['event'] == 'mouseMoved': - x, y = event['x'], event['y'] - xPixel, yPixel = event['xpixel'], event['ypixel'] - self._updateStatusBar(x, y, xPixel, yPixel) - - def _updateStatusBar(self, x, y, xPixel, yPixel): - """Update information from the status bar using the definitions. - - :param float x: Position-x in data - :param float y: Position-y in data - :param float xPixel: Position-x in pixels - :param float yPixel: Position-y in pixels - """ - styleSheet = "color: rgb(0, 0, 0);" # Default style - - if self.autoSnapToActiveCurve and self.plot.getGraphCursor(): - # Check if near active curve with symbols. - - styleSheet = "color: rgb(255, 0, 0);" # Style far from curve - - activeCurve = self.plot.getActiveCurve() - if activeCurve: - xData = activeCurve.getXData(copy=False) - yData = activeCurve.getYData(copy=False) - if activeCurve.getSymbol(): # Only handled if symbols on curve - closestIndex = numpy.argmin( - pow(xData - x, 2) + pow(yData - y, 2)) - - xClosest = xData[closestIndex] - yClosest = yData[closestIndex] - - closestInPixels = self.plot.dataToPixel( - xClosest, yClosest, axis=activeCurve.getYAxis()) - if closestInPixels is not None: - if (abs(closestInPixels[0] - xPixel) < 5 and - abs(closestInPixels[1] - yPixel) < 5): - # Update label style sheet - styleSheet = "color: rgb(0, 0, 0);" - - # if close enough, wrap to data point coords - x, y = xClosest, yClosest - - for label, name, func in self._fields: - label.setStyleSheet(styleSheet) - - try: - value = func(x, y) - text = self.valueToString(value) - label.setText(text) - except: - label.setText('Error') - _logger.error( - "Error while converting coordinates (%f, %f)" - "with converter '%s'" % (x, y, name)) - _logger.error(traceback.format_exc()) - - def valueToString(self, value): - if isinstance(value, (tuple, list)): - value = [self.valueToString(v) for v in value] - return ", ".join(value) - elif isinstance(value, numbers.Real): - # Use this for floats and int - return '%.7g' % value - else: - # Fallback for other types - return str(value) - -# LimitsToolBar ############################################################## - -class LimitsToolBar(qt.QToolBar): - """QToolBar displaying and controlling the limits of a :class:`PlotWidget`. - - To run the following sample code, a QApplication must be initialized. - First, create a PlotWindow: - - >>> from silx.gui.plot import PlotWindow - >>> plot = PlotWindow() # Create a PlotWindow to add the toolbar to - - Then, create the LimitsToolBar and add it to the PlotWindow. - - >>> from silx.gui import qt - >>> from silx.gui.plot.PlotTools import LimitsToolBar - - >>> toolbar = LimitsToolBar(plot=plot) # Create the toolbar - >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolbar) # Add it to the plot - >>> plot.show() # To display the PlotWindow with the limits toolbar - - :param parent: See :class:`QToolBar`. - :param plot: :class:`PlotWidget` instance on which to operate. - :param str title: See :class:`QToolBar`. - """ - - def __init__(self, parent=None, plot=None, title='Limits'): - super(LimitsToolBar, self).__init__(title, parent) - assert plot is not None - self._plot = plot - self._plot.sigPlotSignal.connect(self._plotWidgetSlot) - - self._initWidgets() - - @property - def plot(self): - """The :class:`PlotWidget` the toolbar is attached to.""" - return self._plot - - def _initWidgets(self): - """Create and init Toolbar widgets.""" - xMin, xMax = self.plot.getXAxis().getLimits() - yMin, yMax = self.plot.getYAxis().getLimits() - - self.addWidget(qt.QLabel('Limits: ')) - self.addWidget(qt.QLabel(' X: ')) - self._xMinFloatEdit = FloatEdit(self, xMin) - self._xMinFloatEdit.editingFinished[()].connect( - self._xFloatEditChanged) - self.addWidget(self._xMinFloatEdit) - - self._xMaxFloatEdit = FloatEdit(self, xMax) - self._xMaxFloatEdit.editingFinished[()].connect( - self._xFloatEditChanged) - self.addWidget(self._xMaxFloatEdit) - - self.addWidget(qt.QLabel(' Y: ')) - self._yMinFloatEdit = FloatEdit(self, yMin) - self._yMinFloatEdit.editingFinished[()].connect( - self._yFloatEditChanged) - self.addWidget(self._yMinFloatEdit) - - self._yMaxFloatEdit = FloatEdit(self, yMax) - self._yMaxFloatEdit.editingFinished[()].connect( - self._yFloatEditChanged) - self.addWidget(self._yMaxFloatEdit) - - def _plotWidgetSlot(self, event): - """Listen to :class:`PlotWidget` events.""" - if event['event'] not in ('limitsChanged',): - return - - xMin, xMax = self.plot.getXAxis().getLimits() - yMin, yMax = self.plot.getYAxis().getLimits() - - self._xMinFloatEdit.setValue(xMin) - self._xMaxFloatEdit.setValue(xMax) - self._yMinFloatEdit.setValue(yMin) - self._yMaxFloatEdit.setValue(yMax) - - def _xFloatEditChanged(self): - """Handle X limits changed from the GUI.""" - xMin, xMax = self._xMinFloatEdit.value(), self._xMaxFloatEdit.value() - if xMax < xMin: - xMin, xMax = xMax, xMin - - self.plot.getXAxis().setLimits(xMin, xMax) - - def _yFloatEditChanged(self): - """Handle Y limits changed from the GUI.""" - yMin, yMax = self._yMinFloatEdit.value(), self._yMaxFloatEdit.value() - if yMax < yMin: - yMin, yMax = yMax, yMin - - self.plot.getYAxis().setLimits(yMin, yMax) +from .tools import PositionInfo, LimitsToolBar # noqa diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py index 3641b8c..2f7132c 100644 --- a/silx/gui/plot/PlotWidget.py +++ b/silx/gui/plot/PlotWidget.py @@ -31,37 +31,43 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "18/10/2017" +__date__ = "14/06/2018" from collections import OrderedDict, namedtuple from contextlib import contextmanager +import datetime as dt import itertools import logging import numpy +import silx +from silx.utils.weakref import WeakMethodProxy +from silx.utils import deprecation +from silx.utils.property import classproperty from silx.utils.deprecation import deprecated # Import matplotlib backend here to init matplotlib our way from .backends.BackendMatplotlib import BackendMatplotlibQt -from .Colormap import Colormap -from . import Colors +from ..colors import Colormap +from .. import colors from . import PlotInteraction from . import PlotEvents from .LimitsHistory import LimitsHistory from . import _utils from . import items +from .items.axis import TickMode from .. import qt from ._utils.panzoom import ViewConstraints - +from ...gui.plot._utils.dtime_ticklayout import timestamp _logger = logging.getLogger(__name__) -_COLORDICT = Colors.COLORDICT +_COLORDICT = colors.COLORDICT _COLORLIST = [_COLORDICT['black'], _COLORDICT['blue'], _COLORDICT['red'], @@ -110,8 +116,12 @@ class PlotWidget(qt.QMainWindow): :type backend: str or :class:`BackendBase.BackendBase` """ - DEFAULT_BACKEND = 'matplotlib' - """Class attribute setting the default backend for all instances.""" + # TODO: Can be removed for silx 0.10 + @classproperty + @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) + def DEFAULT_BACKEND(self): + """Class attribute setting the default backend for all instances.""" + return silx.config.DEFAULT_PLOT_BACKEND colorList = _COLORLIST colorDict = _COLORDICT @@ -209,7 +219,7 @@ class PlotWidget(qt.QMainWindow): self.setWindowTitle('PlotWidget') if backend is None: - backend = self.DEFAULT_BACKEND + backend = silx.config.DEFAULT_PLOT_BACKEND if hasattr(backend, "__call__"): self._backend = backend(self, parent) @@ -296,7 +306,9 @@ class PlotWidget(qt.QMainWindow): self.setGraphYLimits(0., 100., axis='right') self.setGraphYLimits(0., 100., axis='left') + # TODO: Can be removed for silx 0.10 @staticmethod + @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) def setDefaultBackend(backend): """Set system wide default plot backend. @@ -306,7 +318,7 @@ class PlotWidget(qt.QMainWindow): 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none' or a :class:`BackendBase.BackendBase` class """ - PlotWidget.DEFAULT_BACKEND = backend + silx.config.DEFAULT_PLOT_BACKEND = backend def _getDirtyPlot(self): """Return the plot dirty flag. @@ -525,7 +537,9 @@ class PlotWidget(qt.QMainWindow): :param numpy.ndarray x: The data corresponding to the x coordinates. If you attempt to plot an histogram you can set edges values in x. - In this case len(x) = len(y) + 1 + In this case len(x) = len(y) + 1. + If x contains datetime objects the XAxis tickMode is set to + TickMode.TIME_SERIES. :param numpy.ndarray y: The data corresponding to the y coordinates :param str legend: The legend to be associated to the curve (or None) :param info: User-defined information associated to the curve @@ -533,7 +547,7 @@ class PlotWidget(qt.QMainWindow): curves :param color: color(s) to be used :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or - one of the predefined color names defined in Colors.py + one of the predefined color names defined in colors.py :param str symbol: Symbol to be drawn at each (x, y) position:: - 'o' circle @@ -686,6 +700,13 @@ class PlotWidget(qt.QMainWindow): if yerror is None: yerror = curve.getYErrorData(copy=False) + # Convert x to timestamps so that the internal representation + # remains floating points. The user is expected to set the axis' + # tickMode to TickMode.TIME_SERIES and, if necessary, set the axis + # to the correct time zone. + if len(x) > 0 and isinstance(x[0], dt.datetime): + x = [timestamp(d) for d in x] + curve.setData(x, y, xerror, yerror, copy=copy) if replace: # Then remove all other curves @@ -739,7 +760,7 @@ class PlotWidget(qt.QMainWindow): The legend to be associated to the histogram (or None) :param color: color to be used :type color: str ("#RRGGBB") or RGB unsigned byte array or - one of the predefined color names defined in Colors.py + one of the predefined color names defined in colors.py :param bool fill: True to fill the curve, False otherwise (default). :param str align: In case histogram values and edges have the same length N, @@ -785,7 +806,7 @@ class PlotWidget(qt.QMainWindow): return legend def addImage(self, data, legend=None, info=None, - replace=True, replot=None, + replace=False, replot=None, xScale=None, yScale=None, z=None, selectable=None, draggable=None, colormap=None, pixmap=None, @@ -811,7 +832,8 @@ class PlotWidget(qt.QMainWindow): Note: boolean values are converted to int8. :param str legend: The legend to be associated to the image (or None) :param info: User-defined information associated to the image - :param bool replace: True (default) to delete already existing images + :param bool replace: + True to delete already existing images (Default: False). :param int z: Layer on which to draw the image (default: 0) This allows to control the overlay. :param bool selectable: Indicate if the image can be selected. @@ -821,7 +843,7 @@ class PlotWidget(qt.QMainWindow): :param colormap: Description of the :class:`.Colormap` to use (or None). This is ignored if data is a RGB(A) image. - :type colormap: Union[silx.gui.plot.Colormap.Colormap, dict] + :type colormap: Union[silx.gui.colors.Colormap, dict] :param pixmap: Pixmap representation of the data (if any) :type pixmap: (nrows, ncolumns, RGBA) ubyte array or None (default) :param str xlabel: X axis label to show when this curve is active, @@ -964,7 +986,7 @@ class PlotWidget(qt.QMainWindow): :param numpy.ndarray y: The data corresponding to the y coordinates :param numpy.ndarray value: The data value associated with each point :param str legend: The legend to be associated to the scatter (or None) - :param silx.gui.plot.Colormap.Colormap colormap: + :param silx.gui.colors.Colormap colormap: The :class:`.Colormap`. to be used for the scatter (or None) :param info: User-defined information associated to the curve :param str symbol: Symbol to be drawn at each (x, y) position:: @@ -1477,7 +1499,7 @@ class PlotWidget(qt.QMainWindow): :param bool flag: Toggle the display of a crosshair cursor. The crosshair cursor is hidden by default. :param color: The color to use for the crosshair. - :type color: A string (either a predefined color name in Colors.py + :type color: A string (either a predefined color name in colors.py or "#RRGGBB")) or a 4 columns unsigned byte array (Default: black). :param int linewidth: The width of the lines of the crosshair @@ -2264,13 +2286,13 @@ class PlotWidget(qt.QMainWindow): It only affects future calls to :meth:`addImage` without the colormap parameter. - :param silx.gui.plot.Colormap.Colormap colormap: + :param silx.gui.colors.Colormap colormap: The description of the default colormap, or None to set the :class:`.Colormap` to a linear autoscale gray colormap. """ if colormap is None: - colormap = Colormap(name='gray', + colormap = Colormap(name=silx.config.DEFAULT_COLORMAP_NAME, normalization='linear', vmin=None, vmax=None) @@ -2370,10 +2392,10 @@ class PlotWidget(qt.QMainWindow): to handle the graph events If None (default), use a default listener. """ - # TODO allow multiple listeners, keep a weakref on it + # TODO allow multiple listeners # allow register listener by event type if callbackFunction is None: - callbackFunction = self.graphCallback + callbackFunction = WeakMethodProxy(self.graphCallback) self._callback = callbackFunction def graphCallback(self, ddict=None): @@ -2392,6 +2414,8 @@ class PlotWidget(qt.QMainWindow): if ddict['button'] == "left": self.setActiveCurve(ddict['label']) qt.QToolTip.showText(self.cursor().pos(), ddict['label']) + elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left': + self.setActiveCurve(None) def saveGraph(self, filename, fileFormat=None, dpi=None, **kw): """Save a snapshot of the plot. @@ -2519,9 +2543,8 @@ class PlotWidget(qt.QMainWindow): # Compute bbox wth figure aspect ratio plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] - plotRatio = plotHeight / plotWidth - - if plotRatio > 0.: + if plotWidth > 0 and plotHeight > 0: + plotRatio = plotHeight / plotWidth dataRatio = (ymax - ymin) / (xmax - xmin) if dataRatio < plotRatio: # Increase y range @@ -2741,6 +2764,39 @@ class PlotWidget(qt.QMainWindow): return None + def _pick(self, x, y): + """Pick items in the plot at given position. + + :param float x: X position in pixels + :param float y: Y position in pixels + :return: Iterable of (plot item, indices) at picked position. + Items are ordered from back to front. + """ + items = [] + + # Convert backend result to plot items + for itemInfo in self._backend.pickItems( + x, y, kinds=('marker', 'curve', 'image')): + kind, legend = itemInfo['kind'], itemInfo['legend'] + + if kind in ('marker', 'image'): + item = self._getItem(kind=kind, legend=legend) + indices = None # TODO compute indices for images + + else: # backend kind == 'curve' + for kind in ('curve', 'histogram', 'scatter'): + item = self._getItem(kind=kind, legend=legend) + if item is not None: + indices = itemInfo['indices'] + break + else: + _logger.error( + 'Cannot find corresponding picked item') + continue + items.append((item, indices)) + + return tuple(items) + # User event handling # def _isPositionInPlotArea(self, x, y): @@ -2846,7 +2902,7 @@ class PlotWidget(qt.QMainWindow): """Switch the interactive mode. :param str mode: The name of the interactive mode. - In 'draw', 'pan', 'select', 'zoom'. + In 'draw', 'pan', 'select', 'select-draw', 'zoom'. :param color: Only for 'draw' and 'zoom' modes. Color to use for drawing selection area. Default black. :type color: Color description: The name as a str or @@ -2959,7 +3015,7 @@ class PlotWidget(qt.QMainWindow): :param str label: Associated text for identifying draw signals :param color: The color to use to draw the selection area :type color: string ("#RRGGBB") or 4 column unsigned byte array or - one of the predefined color names defined in Colors.py + one of the predefined color names defined in colors.py """ _logger.warning( 'setDrawModeEnabled deprecated, use setInteractiveMode instead') @@ -3011,7 +3067,7 @@ class PlotWidget(qt.QMainWindow): (Default: 'black') :param color: The color to use to draw the selection area :type color: string ("#RRGGBB") or 4 column unsigned byte array or - one of the predefined color names defined in Colors.py + one of the predefined color names defined in colors.py """ _logger.warning( 'setZoomModeEnabled deprecated, use setInteractiveMode instead') diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py index 5c7e661..459ffdc 100644 --- a/silx/gui/plot/PlotWindow.py +++ b/silx/gui/plot/PlotWindow.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,11 +29,14 @@ The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`. __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "15/02/2018" +__date__ = "05/06/2018" import collections import logging +import weakref +import silx +from silx.utils.weakref import WeakMethodProxy from silx.utils.deprecation import deprecated from . import PlotWidget @@ -44,11 +47,12 @@ from .actions import fit as actions_fit from .actions import control as actions_control from .actions import histogram as actions_histogram from . import PlotToolButtons -from .PlotTools import PositionInfo +from . import tools from .Profile import ProfileToolBar from .LegendSelector import LegendsDockWidget from .CurvesROIWidget import CurvesROIDockWidget from .MaskToolsWidget import MaskToolsDockWidget +from .StatsWidget import BasicStatsWidget from .ColorBar import ColorBarWidget try: from ..console import IPythonDockWidget @@ -90,7 +94,7 @@ class PlotWindow(PlotWidget): (Default: False). It also supports a list of (name, funct(x, y)->value) to customize the displayed values. - See :class:`silx.gui.plot.PlotTools.PositionInfo`. + See :class:`~silx.gui.plot.tools.PositionInfo`. :param bool roi: Toggle visibilty of ROI action. :param bool mask: Toggle visibilty of mask action. :param bool fit: Toggle visibilty of fit action. @@ -114,6 +118,7 @@ class PlotWindow(PlotWidget): self._curvesROIDockWidget = None self._maskToolsDockWidget = None self._consoleDockWidget = None + self._statsWidget = None # Create color bar, hidden by default for backward compatibility self._colorbar = ColorBarWidget(parent=self, plot=self) @@ -122,11 +127,6 @@ class PlotWindow(PlotWidget): self.group = qt.QActionGroup(self) self.group.setExclusive(False) - self.zoomModeAction = self.group.addAction( - actions.mode.ZoomModeAction(self)) - self.panModeAction = self.group.addAction( - actions.mode.PanModeAction(self)) - self.resetZoomAction = self.group.addAction( actions.control.ResetZoomAction(self)) self.resetZoomAction.setVisible(resetzoom) @@ -205,28 +205,13 @@ class PlotWindow(PlotWidget): actions_medfilt.MedianFilter1DAction(self)) self._medianFilter1DAction.setVisible(False) - self._separator = qt.QAction('separator', self) - self._separator.setSeparator(True) - self.group.addAction(self._separator) - - self.copyAction = self.group.addAction(actions.io.CopyAction(self)) - self.copyAction.setVisible(copy) - self.addAction(self.copyAction) - - self.saveAction = self.group.addAction(actions.io.SaveAction(self)) - self.saveAction.setVisible(save) - self.addAction(self.saveAction) - - self.printAction = self.group.addAction(actions.io.PrintAction(self)) - self.printAction.setVisible(print_) - self.addAction(self.printAction) - self.fitAction = self.group.addAction(actions_fit.FitAction(self)) self.fitAction.setVisible(fit) self.addAction(self.fitAction) # lazy loaded actions needed by the controlButton menu self._consoleAction = None + self._statsAction = None self._panWithArrowKeysAction = None self._crosshairAction = None @@ -244,10 +229,12 @@ class PlotWindow(PlotWidget): gridLayout.addWidget(self._colorbar, 0, 1) gridLayout.setRowStretch(0, 1) gridLayout.setColumnStretch(0, 1) - centralWidget = qt.QWidget() + centralWidget = qt.QWidget(self) centralWidget.setLayout(gridLayout) self.setCentralWidget(centralWidget) + self._positionWidget = None + if control or position: hbox = qt.QHBoxLayout() hbox.setContentsMargins(0, 0, 0, 0) @@ -270,22 +257,69 @@ class PlotWindow(PlotWidget): converters = position else: converters = None - self.positionWidget = PositionInfo( + self._positionWidget = tools.PositionInfo( plot=self, converters=converters) - self.positionWidget.autoSnapToActiveCurve = True + # Set a snapping mode that is consistent with legacy one + self._positionWidget.setSnappingMode( + tools.PositionInfo.SNAPPING_CROSSHAIR | + tools.PositionInfo.SNAPPING_ACTIVE_ONLY | + tools.PositionInfo.SNAPPING_SYMBOLS_ONLY | + tools.PositionInfo.SNAPPING_CURVE | + tools.PositionInfo.SNAPPING_SCATTER) - hbox.addWidget(self.positionWidget) + hbox.addWidget(self._positionWidget) hbox.addStretch(1) - bottomBar = qt.QWidget() + bottomBar = qt.QWidget(centralWidget) bottomBar.setLayout(hbox) gridLayout.addWidget(bottomBar, 1, 0, 1, -1) # Creating the toolbar also create actions for toolbuttons + self._interactiveModeToolBar = tools.InteractiveModeToolBar( + parent=self, plot=self) + self.addToolBar(self._interactiveModeToolBar) + self._toolbar = self._createToolBar(title='Plot', parent=None) self.addToolBar(self._toolbar) + self._outputToolBar = tools.OutputToolBar(parent=self, plot=self) + self._outputToolBar.getCopyAction().setVisible(copy) + self._outputToolBar.getSaveAction().setVisible(save) + self._outputToolBar.getPrintAction().setVisible(print_) + self.addToolBar(self._outputToolBar) + + # Activate shortcuts in PlotWindow widget: + for toolbar in (self._interactiveModeToolBar, self._outputToolBar): + for action in toolbar.actions(): + self.addAction(action) + + def getInteractiveModeToolBar(self): + """Returns QToolBar controlling interactive mode. + + :rtype: QToolBar + """ + return self._interactiveModeToolBar + + def getOutputToolBar(self): + """Returns QToolBar containing save, copy and print actions + + :rtype: QToolBar + """ + return self._outputToolBar + + @property + @deprecated(replacement="getPositionInfoWidget()", since_version="0.8.0") + def positionWidget(self): + return self.getPositionInfoWidget() + + def getPositionInfoWidget(self): + """Returns the widget displaying current cursor position information + + :rtype: ~silx.gui.plot.tools.PositionInfo + """ + return self._positionWidget + def getSelectionMask(self): """Return the current mask handled by :attr:`maskToolsDockWidget`. @@ -313,7 +347,7 @@ class PlotWindow(PlotWidget): show it or hide it.""" # create widget if needed (first call) if self._consoleDockWidget is None: - available_vars = {"plt": self} + available_vars = {"plt": weakref.proxy(self)} banner = "The variable 'plt' is available. Use the 'whos' " banner += "and 'help(plt)' commands for more information.\n\n" self._consoleDockWidget = IPythonDockWidget( @@ -327,6 +361,9 @@ class PlotWindow(PlotWidget): self._consoleDockWidget.setVisible(isChecked) + def _toggleStatsVisibility(self, isChecked=False): + self.getStatsWidget().parent().setVisible(isChecked) + def _createToolBar(self, title, parent): """Create a QToolBar from the QAction of the PlotWindow. @@ -355,8 +392,6 @@ class PlotWindow(PlotWidget): self.yAxisInvertedAction = toolbar.addWidget(obj) else: raise RuntimeError() - if obj is self.panModeAction: - toolbar.addSeparator() return toolbar def toolBar(self): @@ -381,6 +416,7 @@ class PlotWindow(PlotWidget): controlMenu.clear() controlMenu.addAction(self.getLegendsDockWidget().toggleViewAction()) controlMenu.addAction(self.getRoiAction()) + controlMenu.addAction(self.getStatsAction()) controlMenu.addAction(self.getMaskAction()) controlMenu.addAction(self.getConsoleAction()) @@ -474,8 +510,35 @@ class PlotWindow(PlotWidget): self.addTabbedDockWidget(self._maskToolsDockWidget) return self._maskToolsDockWidget + def getStatsWidget(self): + """Returns a BasicStatsWidget connected to this plot + + :rtype: BasicStatsWidget + """ + if self._statsWidget is None: + dockWidget = qt.QDockWidget(parent=self) + dockWidget.setWindowTitle("Curves stats") + dockWidget.layout().setContentsMargins(0, 0, 0, 0) + self._statsWidget = BasicStatsWidget(parent=self, plot=self) + dockWidget.setWidget(self._statsWidget) + dockWidget.hide() + self.addTabbedDockWidget(dockWidget) + return self._statsWidget + # getters for actions @property + @deprecated(replacement="getInteractiveModeToolBar().getZoomModeAction()", + since_version="0.8.0") + def zoomModeAction(self): + return self.getInteractiveModeToolBar().getZoomModeAction() + + @property + @deprecated(replacement="getInteractiveModeToolBar().getPanModeAction()", + since_version="0.8.0") + def panModeAction(self): + return self.getInteractiveModeToolBar().getPanModeAction() + + @property @deprecated(replacement="getConsoleAction()", since_version="0.4.0") def consoleAction(self): return self.getConsoleAction() @@ -545,6 +608,14 @@ class PlotWindow(PlotWidget): def roiAction(self): return self.getRoiAction() + def getStatsAction(self): + if self._statsAction is None: + self._statsAction = qt.QAction('Curves stats', self) + self._statsAction.setCheckable(True) + self._statsAction.setChecked(self.getStatsWidget().parent().isVisible()) + self._statsAction.toggled.connect(self._toggleStatsVisibility) + return self._statsAction + def getRoiAction(self): """QAction toggling curve ROI dock widget @@ -667,21 +738,21 @@ class PlotWindow(PlotWidget): :rtype: actions.PlotAction """ - return self.copyAction + return self.getOutputToolBar().getCopyAction() def getSaveAction(self): """Action to save plot :rtype: actions.PlotAction """ - return self.saveAction + return self.getOutputToolBar().getSaveAction() def getPrintAction(self): """Action to print plot :rtype: actions.PlotAction """ - return self.printAction + return self.getOutputToolBar().getPrintAction() def getFitAction(self): """Action to fit selected curve @@ -757,7 +828,7 @@ class Plot2D(PlotWindow): posInfo = [ ('X', lambda x, y: x), ('Y', lambda x, y: y), - ('Data', self._getImageValue)] + ('Data', WeakMethodProxy(self._getImageValue))] super(Plot2D, self).__init__(parent=parent, backend=backend, resetzoom=True, autoScale=False, @@ -772,6 +843,9 @@ class Plot2D(PlotWindow): self.getXAxis().setLabel('Columns') self.getYAxis().setLabel('Rows') + if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward': + self.getYAxis().setInverted(True) + self.profile = ProfileToolBar(plot=self) self.addToolBar(self.profile) @@ -780,10 +854,41 @@ class Plot2D(PlotWindow): # Put colorbar action after colormap action actions = self.toolBar().actions() - for index, action in enumerate(actions): + for action in actions: if action is self.getColormapAction(): break + self.sigActiveImageChanged.connect(self.__activeImageChanged) + + def __activeImageChanged(self, previous, legend): + """Handle change of active image + + :param Union[str,None] previous: Legend of previous active image + :param Union[str,None] legend: Legend of current active image + """ + if previous is not None: + item = self.getImage(previous) + if item is not None: + item.sigItemChanged.disconnect(self.__imageChanged) + + if legend is not None: + item = self.getImage(legend) + item.sigItemChanged.connect(self.__imageChanged) + + positionInfo = self.getPositionInfoWidget() + if positionInfo is not None: + positionInfo.updateInfo() + + def __imageChanged(self, event): + """Handle update of active image item + + :param event: Type of changed event + """ + if event == items.ItemChangedType.DATA: + positionInfo = self.getPositionInfoWidget() + if positionInfo is not None: + positionInfo.updateInfo() + def _getImageValue(self, x, y): """Get status bar value of top most image at position (x, y) diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py index f61412d..5a733fe 100644 --- a/silx/gui/plot/Profile.py +++ b/silx/gui/plot/Profile.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,7 @@ and stacks of images""" __authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno"] __license__ = "MIT" -__date__ = "17/08/2017" +__date__ = "24/04/2018" import weakref @@ -40,7 +40,7 @@ from silx.image.bilinear import BilinearImage from .. import icons from .. import qt from . import items -from .Colors import cursorColorForColormap +from ..colors import cursorColorForColormap from . import actions from .PlotToolButtons import ProfileToolButton from .ProfileMainWindow import ProfileMainWindow @@ -637,6 +637,12 @@ class ProfileToolBar(qt.QToolBar): colormap=colormap) else: coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + # Scale horizontal and vertical profile coordinates + if self._roiInfo[2] == 'X': + coords = coords * scale[0] + origin[0] + elif self._roiInfo[2] == 'Y': + coords = coords * scale[1] + origin[1] + self.getProfilePlot().addCurve(coords, profile[0], legend=profileName, diff --git a/silx/gui/plot/ProfileMainWindow.py b/silx/gui/plot/ProfileMainWindow.py index 835de2c..3738511 100644 --- a/silx/gui/plot/ProfileMainWindow.py +++ b/silx/gui/plot/ProfileMainWindow.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -73,6 +73,8 @@ class ProfileMainWindow(qt.QMainWindow): self._plot2D.setParent(None) # necessary to avoid widget destruction if self._plot1D is None: self._plot1D = Plot1D() + self._plot1D.setGraphYLabel('Profile') + self._plot1D.setGraphXLabel('') self.setCentralWidget(self._plot1D) elif self._profileType == "2D": if self._plot1D is not None: diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py index a9c1073..2a10f6d 100644 --- a/silx/gui/plot/ScatterMaskToolsWidget.py +++ b/silx/gui/plot/ScatterMaskToolsWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,7 +35,7 @@ from __future__ import division __authors__ = ["P. Knobel"] __license__ = "MIT" -__date__ = "07/04/2017" +__date__ = "24/04/2018" import math @@ -45,10 +45,11 @@ import numpy import sys from .. import qt +from ...math.combo import min_max from ...image import shapes from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget -from .Colors import cursorColorForColormap, rgba +from ..colors import cursorColorForColormap, rgba _logger = logging.getLogger(__name__) @@ -186,13 +187,18 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): self._z = 2 # Mask layer in plot self._data_scatter = None """plot Scatter item for data""" + + self._data_extent = None + """Maximum extent of the data i.e., max(xMax-xMin, yMax-yMin)""" + self._mask_scatter = None """plot Scatter item for representing the mask""" def setSelectionMask(self, mask, copy=True): """Set the mask to a new array. - :param numpy.ndarray mask: The array to use for the mask. + :param numpy.ndarray mask: + The array to use for the mask or None to reset the mask. :type mask: numpy.ndarray of uint8, C-contiguous. Array of other types are converted. :param bool copy: True (the default) to copy the array, @@ -201,6 +207,10 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): The mask can be cropped or padded to fit active scatter, the returned shape is that of the scatter data. """ + if mask is None: + self.resetSelectionMask() + return self._data_scatter.getXData(copy=False).shape + mask = numpy.array(mask, copy=False, dtype=numpy.uint8) if self._data_scatter.getXData(copy=False).shape == (0,) \ @@ -216,7 +226,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): def _updatePlotMask(self): """Update mask image in plot""" mask = self.getSelectionMask(copy=False) - if len(mask): + if mask is not None: self.plot.addScatter(self._data_scatter.getXData(), self._data_scatter.getYData(), mask, @@ -226,8 +236,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): self._mask_scatter = self.plot._getItem(kind="scatter", legend=self._maskName) self._mask_scatter.setSymbolSize( - self._data_scatter.getSymbolSize() * 4.0 - ) + self._data_scatter.getSymbolSize() + 2.0) elif self.plot._getItem(kind="scatter", legend=self._maskName) is not None: self.plot.remove(self._maskName, kind='scatter') @@ -248,7 +257,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): if not self.browseAction.isChecked(): self.browseAction.trigger() # Disable drawing tool - if len(self.getSelectionMask(copy=False)): + if self.getSelectionMask(copy=False) is not None: self.plot.sigActiveScatterChanged.connect( self._activeScatterChangedAfterCare) @@ -265,6 +274,9 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): # No active scatter or active scatter is the mask... self.plot.sigActiveScatterChanged.disconnect( self._activeScatterChangedAfterCare) + self._data_extent = None + self._data_scatter = None + else: colormap = activeScatter.getColormap() self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name'])) @@ -274,13 +286,22 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): self._z = activeScatter.getZValue() + 1 self._data_scatter = activeScatter - if self._data_scatter.getXData(copy=False).shape != self.getSelectionMask(copy=False).shape: + + # Adjust brush size to data range + xMin, xMax = min_max(self._data_scatter.getXData(copy=False)) + yMin, yMax = min_max(self._data_scatter.getYData(copy=False)) + self._data_extent = max(xMax - xMin, yMax - yMin) + + if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape: # scatter has not the same size, remove mask and stop listening if self.plot._getItem(kind="scatter", legend=self._maskName): self.plot.remove(self._maskName, kind='scatter') self.plot.sigActiveScatterChanged.disconnect( self._activeScatterChangedAfterCare) + self._data_extent = None + self._data_scatter = None + else: # Refresh in case z changed self._mask.setDataItem(self._data_scatter) @@ -295,6 +316,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): self.setEnabled(False) self._data_scatter = None + self._data_extent = None self._mask.reset() self._mask.commit() @@ -309,8 +331,19 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): self._z = activeScatter.getZValue() + 1 self._data_scatter = activeScatter + + # Adjust brush size to data range + xData = self._data_scatter.getXData(copy=False) + yData = self._data_scatter.getYData(copy=False) + if xData.size > 0 and yData.size > 0: + xMin, xMax = min_max(xData) + yMin, yMax = min_max(yData) + self._data_extent = max(xMax - xMin, yMax - yMin) + else: + self._data_extent = None + self._mask.setDataItem(self._data_scatter) - if self._data_scatter.getXData(copy=False).shape != self.getSelectionMask(copy=False).shape: + if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape: self._mask.reset(self._data_scatter.getXData(copy=False).shape) self._mask.commit() else: @@ -439,6 +472,16 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): shape=self._data_scatter.getXData(copy=False).shape) self._mask.commit() + def _getPencilWidth(self): + """Returns the width of the pencil to use in data coordinates` + + :rtype: float + """ + width = super(ScatterMaskToolsWidget, self)._getPencilWidth() + if self._data_extent is not None: + width *= 0.01 * self._data_extent + return width + def _plotDrawEvent(self, event): """Handle draw events from the plot""" if (self._drawingMode is None or @@ -467,7 +510,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): event['event'] == 'drawingFinished'): doMask = self._isMasking() vertices = event['points'] - vertices = vertices.astype(numpy.int)[:, (1, 0)] # (y, x) + vertices = vertices[:, (1, 0)] # (y, x) self._mask.updatePolygon(level, vertices, doMask) self._mask.commit() @@ -475,7 +518,8 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): doMask = self._isMasking() # convert from plot to array coords x, y = event['points'][-1] - brushSize = self.pencilSpinBox.value() + + brushSize = self._getPencilWidth() if self._lastPencilPos != (y, x): if self._lastPencilPos is not None: diff --git a/silx/gui/plot/ScatterView.py b/silx/gui/plot/ScatterView.py new file mode 100644 index 0000000..f830cb3 --- /dev/null +++ b/silx/gui/plot/ScatterView.py @@ -0,0 +1,353 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A widget dedicated to display scatter plots + +It is based on a :class:`~silx.gui.plot.PlotWidget` with additional tools +for scatter plots. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "14/06/2018" + + +import logging +import weakref + +import numpy + +from . import items +from . import PlotWidget +from . import tools +from .tools.profile import ScatterProfileToolBar +from .ColorBar import ColorBarWidget +from .ScatterMaskToolsWidget import ScatterMaskToolsWidget + +from ..widgets.BoxLayoutDockWidget import BoxLayoutDockWidget +from .. import qt, icons + + +_logger = logging.getLogger(__name__) + + +class ScatterView(qt.QMainWindow): + """Main window with a PlotWidget and tools specific for scatter plots. + + :param parent: The parent of this widget + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`~silx.gui.plot.PlotWidget` for the list of supported backend. + :type backend: Union[str,~silx.gui.plot.backends.BackendBase.BackendBase] + """ + + _SCATTER_LEGEND = ' ' + """Legend used for the scatter item""" + + def __init__(self, parent=None, backend=None): + super(ScatterView, self).__init__(parent=parent) + if parent is not None: + # behave as a widget + self.setWindowFlags(qt.Qt.Widget) + else: + self.setWindowTitle('ScatterView') + + # Create plot widget + plot = PlotWidget(parent=self, backend=backend) + self._plot = weakref.ref(plot) + + # Add an empty scatter + plot.addScatter(x=(), y=(), value=(), legend=self._SCATTER_LEGEND) + + # Create colorbar widget with white background + self._colorbar = ColorBarWidget(parent=self, plot=plot) + self._colorbar.setAutoFillBackground(True) + palette = self._colorbar.palette() + palette.setColor(qt.QPalette.Background, qt.Qt.white) + palette.setColor(qt.QPalette.Window, qt.Qt.white) + self._colorbar.setPalette(palette) + + # Create PositionInfo widget + self.__lastPickingPos = None + self.__pickingCache = None + self._positionInfo = tools.PositionInfo( + plot=plot, + converters=(('X', lambda x, y: x), + ('Y', lambda x, y: y), + ('Data', lambda x, y: self._getScatterValue(x, y)), + ('Index', lambda x, y: self._getScatterIndex(x, y)))) + + # Combine plot, position info and colorbar into central widget + gridLayout = qt.QGridLayout() + gridLayout.setSpacing(0) + gridLayout.setContentsMargins(0, 0, 0, 0) + gridLayout.addWidget(plot, 0, 0) + gridLayout.addWidget(self._colorbar, 0, 1) + gridLayout.addWidget(self._positionInfo, 1, 0, 1, -1) + gridLayout.setRowStretch(0, 1) + gridLayout.setColumnStretch(0, 1) + centralWidget = qt.QWidget(self) + centralWidget.setLayout(gridLayout) + self.setCentralWidget(centralWidget) + + # Create mask tool dock widget + self._maskToolsWidget = ScatterMaskToolsWidget(parent=self, plot=plot) + self._maskDock = BoxLayoutDockWidget() + self._maskDock.setWindowTitle('Scatter Mask') + self._maskDock.setWidget(self._maskToolsWidget) + self._maskDock.setVisible(False) + self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._maskDock) + + self._maskAction = self._maskDock.toggleViewAction() + self._maskAction.setIcon(icons.getQIcon('image-mask')) + self._maskAction.setToolTip("Display/hide mask tools") + + # Create toolbars + self._interactiveModeToolBar = tools.InteractiveModeToolBar( + parent=self, plot=plot) + + self._scatterToolBar = tools.ScatterToolBar( + parent=self, plot=plot) + self._scatterToolBar.addAction(self._maskAction) + + self._profileToolBar = ScatterProfileToolBar(parent=self, plot=plot) + + self._outputToolBar = tools.OutputToolBar(parent=self, plot=plot) + + # Activate shortcuts in PlotWindow widget: + for toolbar in (self._interactiveModeToolBar, + self._scatterToolBar, + self._profileToolBar, + self._outputToolBar): + self.addToolBar(toolbar) + for action in toolbar.actions(): + self.addAction(action) + + def _pickScatterData(self, x, y): + """Get data and index and value of top most scatter plot at position (x, y) + + :param float x: X position in plot coordinates + :param float y: Y position in plot coordinates + :return: The data index and value at that point or None + """ + pickingPos = x, y + if self.__lastPickingPos != pickingPos: + self.__pickingCache = None + self.__lastPickingPos = pickingPos + + plot = self.getPlotWidget() + if plot is not None: + pixelPos = plot.dataToPixel(x, y) + if pixelPos is not None: + # Start from top-most item + for item, indices in reversed(plot._pick(*pixelPos)): + if isinstance(item, items.Scatter): + # Get last index + # with matplotlib it should be the top-most point + dataIndex = indices[-1] + self.__pickingCache = ( + dataIndex, + item.getValueData(copy=False)[dataIndex]) + break + + return self.__pickingCache + + def _getScatterValue(self, x, y): + """Get data value of top most scatter plot at position (x, y) + + :param float x: X position in plot coordinates + :param float y: Y position in plot coordinates + :return: The data value at that point or '-' + """ + picking = self._pickScatterData(x, y) + return '-' if picking is None else picking[1] + + def _getScatterIndex(self, x, y): + """Get data index of top most scatter plot at position (x, y) + + :param float x: X position in plot coordinates + :param float y: Y position in plot coordinates + :return: The data index at that point or '-' + """ + picking = self._pickScatterData(x, y) + return '-' if picking is None else picking[0] + + _PICK_OFFSET = 3 # Offset in pixel used for picking + + def _mouseInPlotArea(self, x, y): + """Clip mouse coordinates to plot area coordinates + + :param float x: X position in pixels + :param float y: Y position in pixels + :return: (x, y) in data coordinates + """ + plot = self.getPlotWidget() + left, top, width, height = plot.getPlotBoundsInPixels() + xPlot = numpy.clip(x, left, left + width - 1) + yPlot = numpy.clip(y, top, top + height - 1) + return xPlot, yPlot + + def getPlotWidget(self): + """Returns the :class:`~silx.gui.plot.PlotWidget` this window is based on. + + :rtype: ~silx.gui.plot.PlotWidget + """ + return self._plot() + + def getPositionInfoWidget(self): + """Returns the widget display mouse coordinates information. + + :rtype: ~silx.gui.plot.tools.PositionInfo + """ + return self._positionInfo + + def getMaskToolsWidget(self): + """Returns the widget controlling mask drawing + + :rtype: ~silx.gui.plot.ScatterMaskToolsWidget + """ + return self._maskToolsWidget + + def getInteractiveModeToolBar(self): + """Returns QToolBar controlling interactive mode. + + :rtype: ~silx.gui.plot.tools.InteractiveModeToolBar + """ + return self._interactiveModeToolBar + + def getScatterToolBar(self): + """Returns QToolBar providing scatter plot tools. + + :rtype: ~silx.gui.plot.tools.ScatterToolBar + """ + return self._scatterToolBar + + def getScatterProfileToolBar(self): + """Returns QToolBar providing scatter profile tools. + + :rtype: ~silx.gui.plot.tools.profile.ScatterProfileToolBar + """ + return self._profileToolBar + + def getOutputToolBar(self): + """Returns QToolBar containing save, copy and print actions + + :rtype: ~silx.gui.plot.tools.OutputToolBar + """ + return self._outputToolBar + + def setColormap(self, colormap=None): + """Set the colormap for the displayed scatter and the + default plot colormap. + + :param ~silx.gui.colors.Colormap colormap: + The description of the colormap. + """ + self.getScatterItem().setColormap(colormap) + # Resilient to call to PlotWidget API (e.g., clear) + self.getPlotWidget().setDefaultColormap(colormap) + + def getColormap(self): + """Return the :class:`.Colormap` in use. + + :return: Colormap currently in use + :rtype: ~silx.gui.colors.Colormap + """ + self.getScatterItem().getColormap() + + # Control displayed scatter plot + + def setData(self, x, y, value, xerror=None, yerror=None, copy=True): + """Set the data of the scatter plot. + + To reset the scatter plot, set x, y and value to None. + + :param Union[numpy.ndarray,None] x: X coordinates. + :param Union[numpy.ndarray,None] y: Y coordinates. + :param Union[numpy.ndarray,None] value: + The data corresponding to the value of the data points. + :param xerror: Values with the uncertainties on the x values. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :type xerror: A float, or a numpy.ndarray of float32. + + :param yerror: Values with the uncertainties on the y values + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + x = () if x is None else x + y = () if y is None else y + value = () if value is None else value + + self.getScatterItem().setData( + x=x, y=y, value=value, xerror=xerror, yerror=yerror, copy=copy) + + def getData(self, *args, **kwargs): + return self.getScatterItem().getData(*args, **kwargs) + + getData.__doc__ = items.Scatter.getData.__doc__ + + def getScatterItem(self): + """Returns the plot item displaying the scatter data. + + This allows to set the style of the displayed scatter. + + :rtype: ~silx.gui.plot.items.Scatter + """ + plot = self.getPlotWidget() + scatter = plot._getItem(kind='scatter', legend=self._SCATTER_LEGEND) + if scatter is None: # Resilient to call to PlotWidget API (e.g., clear) + plot.addScatter(x=(), y=(), value=(), legend=self._SCATTER_LEGEND) + scatter = plot._getItem( + kind='scatter', legend=self._SCATTER_LEGEND) + return scatter + + # Convenient proxies + + def getXAxis(self, *args, **kwargs): + return self.getPlotWidget().getXAxis(*args, **kwargs) + + getXAxis.__doc__ = PlotWidget.getXAxis.__doc__ + + def getYAxis(self, *args, **kwargs): + return self.getPlotWidget().getYAxis(*args, **kwargs) + + getYAxis.__doc__ = PlotWidget.getYAxis.__doc__ + + def setGraphTitle(self, *args, **kwargs): + return self.getPlotWidget().setGraphTitle(*args, **kwargs) + + setGraphTitle.__doc__ = PlotWidget.setGraphTitle.__doc__ + + def getGraphTitle(self, *args, **kwargs): + return self.getPlotWidget().getGraphTitle(*args, **kwargs) + + getGraphTitle.__doc__ = PlotWidget.getGraphTitle.__doc__ + + def resetZoom(self, *args, **kwargs): + return self.getPlotWidget().resetZoom(*args, **kwargs) + + resetZoom.__doc__ = PlotWidget.resetZoom.__doc__ diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py index 1fb188c..d1e8e3c 100644 --- a/silx/gui/plot/StackView.py +++ b/silx/gui/plot/StackView.py @@ -69,16 +69,18 @@ Example:: __authors__ = ["P. Knobel", "H. Payno"] __license__ = "MIT" -__date__ = "15/02/2018" +__date__ = "26/04/2018" import numpy +import logging +import silx from silx.gui import qt from .. import icons from . import items, PlotWindow, actions -from .Colormap import Colormap -from .Colors import cursorColorForColormap -from .PlotTools import LimitsToolBar +from ..colors import Colormap +from ..colors import cursorColorForColormap +from .tools import LimitsToolBar from .Profile import Profile3DToolBar from ..widgets.FrameBrowser import HorizontalSliderWithBrowser @@ -96,6 +98,8 @@ except ImportError: else: from silx.io.utils import is_dataset +_logger = logging.getLogger(__name__) + class StackView(qt.QMainWindow): """Stack view widget, to display and browse through stack of @@ -156,6 +160,12 @@ class StackView(qt.QMainWindow): integer. """ + sigFrameChanged = qt.Signal(int) + """Signal emitter when the frame number has changed. + + This signal provides the current frame number. + """ + def __init__(self, parent=None, resetzoom=True, backend=None, autoScale=False, logScale=False, grid=False, colormap=True, aspectRatio=True, yinverted=True, @@ -206,6 +216,9 @@ class StackView(qt.QMainWindow): self.sigActiveImageChanged = self._plot.sigActiveImageChanged self.sigPlotSignal = self._plot.sigPlotSignal + if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward': + self._plot.getYAxis().setInverted(True) + self._addColorBarAction() self._plot.profile = Profile3DToolBar(parent=self._plot, @@ -221,6 +234,7 @@ class StackView(qt.QMainWindow): self._browser_label = qt.QLabel("Image index (Dim0):") self._browser = HorizontalSliderWithBrowser(central_widget) + self._browser.setRange(0, 0) self._browser.valueChanged[int].connect(self.__updateFrameNumber) self._browser.setEnabled(False) @@ -313,7 +327,7 @@ class StackView(qt.QMainWindow): assert self._stack is not None assert 0 <= self._perspective < 3 - # ensure we have the stack encapsulated in an array like object + # ensure we have the stack encapsulated in an array-like object # having a transpose() method if isinstance(self._stack, numpy.ndarray): self.__transposed_view = self._stack @@ -324,7 +338,7 @@ class StackView(qt.QMainWindow): elif isinstance(self._stack, ListOfImages): self.__transposed_view = ListOfImages(self._stack) - # transpose the array like object if necessary + # transpose the array-like object if necessary if self._perspective == 1: self.__transposed_view = self.__transposed_view.transpose((1, 0, 2)) elif self._perspective == 2: @@ -338,13 +352,16 @@ class StackView(qt.QMainWindow): :param index: index of the frame to be displayed """ - assert self.__transposed_view is not None + if self.__transposed_view is None: + # no data set + return self._plot.addImage(self.__transposed_view[index, :, :], origin=self._getImageOrigin(), scale=self._getImageScale(), legend=self.__imageLegend, - resetzoom=False, replace=False) + resetzoom=False) self._updateTitle() + self.sigFrameChanged.emit(index) def _set3DScaleAndOrigin(self, calibrations): """Set scale and origin for all 3 axes, to be used when plotting @@ -358,7 +375,7 @@ class StackView(qt.QMainWindow): calibration.NoCalibration()) else: self.calibrations3D = [] - for calib in calibrations: + for i, calib in enumerate(calibrations): if hasattr(calib, "__len__") and len(calib) == 2: calib = calibration.LinearCalibration(calib[0], calib[1]) elif calib is None: @@ -367,9 +384,19 @@ class StackView(qt.QMainWindow): raise TypeError("calibration must be a 2-tuple, None or" + " an instance of an AbstractCalibration " + "subclass") + elif not calib.is_affine(): + _logger.warning( + "Calibration for dimension %d is not linear, " + "it will be ignored for scaling the graph axes.", + i) self.calibrations3D.append(calib) def _getXYZCalibs(self): + """Return calibrations sorted in the XYZ graph order. + + If the X or Y calibration is not linear, it will be replaced + with a :class:`calibration.NoCalibration` object + and as a result the corresponding axis will not be scaled.""" xy_dims = [0, 1, 2] xy_dims.remove(self._perspective) @@ -377,6 +404,12 @@ class StackView(qt.QMainWindow): ycalib = self.calibrations3D[min(xy_dims)] zcalib = self.calibrations3D[self._perspective] + # filter out non-linear calibration for graph axes + if not xcalib.is_affine(): + xcalib = calibration.NoCalibration() + if not ycalib.is_affine(): + ycalib = calibration.NoCalibration() + return xcalib, ycalib, zcalib def _getImageScale(self): @@ -469,6 +502,7 @@ class StackView(qt.QMainWindow): colormap=self.getColormap(), origin=self._getImageOrigin(), scale=self._getImageScale(), + replace=True, resetzoom=False) self._plot.setActiveImage(self.__imageLegend) self._plot.setGraphTitle("Image z=%g" % self._getImageZ(0)) @@ -586,6 +620,14 @@ class StackView(qt.QMainWindow): """ self._browser.setValue(number) + def getFrameNumber(self): + """Set the frame selection to a specific value + + :return: Index of currently displayed frame + :rtype: int + """ + return self._browser.value() + def setFirstStackDimension(self, first_stack_dimension): """When viewing the last 3 dimensions of an n-D array (n>3), you can use this method to change the text in the combobox. @@ -641,6 +683,8 @@ class StackView(qt.QMainWindow): self.__transposed_view = None self._perspective = 0 self._browser.setEnabled(False) + # reset browser range + self._browser.setRange(0, 0) self._plot.clear() def setLabels(self, labels=None): @@ -1101,17 +1145,17 @@ class StackViewMainWindow(StackView): self.statusBar() menu = self.menuBar().addMenu('File') - menu.addAction(self._plot.saveAction) - menu.addAction(self._plot.printAction) + menu.addAction(self._plot.getOutputToolBar().getSaveAction()) + menu.addAction(self._plot.getOutputToolBar().getPrintAction()) menu.addSeparator() action = menu.addAction('Quit') action.triggered[bool].connect(qt.QApplication.instance().quit) menu = self.menuBar().addMenu('Edit') - menu.addAction(self._plot.copyAction) + menu.addAction(self._plot.getOutputToolBar().getCopyAction()) menu.addSeparator() - menu.addAction(self._plot.resetZoomAction) - menu.addAction(self._plot.colormapAction) + menu.addAction(self._plot.getResetZoomAction()) + menu.addAction(self._plot.getColormapAction()) menu.addAction(self.getColorBarAction()) menu.addAction(actions.control.KeepAspectRatioAction(self._plot, self)) diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py new file mode 100644 index 0000000..a36dd9f --- /dev/null +++ b/silx/gui/plot/StatsWidget.py @@ -0,0 +1,572 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +Module containing widgets displaying stats from items of a plot. +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "12/06/2018" + + +import functools +import logging +import numpy +from collections import OrderedDict + +import silx.utils.weakref +from silx.gui import qt +from silx.gui import icons +from silx.gui.plot.items.curve import Curve as CurveItem +from silx.gui.plot.items.histogram import Histogram as HistogramItem +from silx.gui.plot.items.image import ImageBase as ImageItem +from silx.gui.plot.items.scatter import Scatter as ScatterItem +from silx.gui.plot import stats as statsmdl +from silx.gui.widgets.TableWidget import TableWidget +from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter + +logger = logging.getLogger(__name__) + + +class StatsWidget(qt.QWidget): + """ + Widget displaying a set of :class:`Stat` to be displayed on a + :class:`StatsTable` and to be apply on items contained in the :class:`Plot` + Also contains options to: + + * compute statistics on all the data or on visible data only + * show statistics of all items or only the active one + + :param parent: Qt parent + :param plot: the plot containing items on which we want statistics. + """ + + NUMBER_FORMAT = '{0:.3f}' + + class OptionsWidget(qt.QToolBar): + + def __init__(self, parent=None): + qt.QToolBar.__init__(self, parent) + self.setIconSize(qt.QSize(16, 16)) + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-active-items")) + action.setText("Active items only") + action.setToolTip("Display stats for active items only.") + action.setCheckable(True) + action.setChecked(True) + self.__displayActiveItems = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-whole-items")) + action.setText("All items") + action.setToolTip("Display stats for all available items.") + action.setCheckable(True) + self.__displayWholeItems = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-visible-data")) + action.setText("Use the visible data range") + action.setToolTip("Use the visible data range.<br/>" + "If activated the data is filtered to only use" + "visible data of the plot." + "The filtering is a data sub-sampling." + "No interpolation is made to fit data to" + "boundaries.") + action.setCheckable(True) + self.__useVisibleData = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-whole-data")) + action.setText("Use the full data range") + action.setToolTip("Use the full data range.") + action.setCheckable(True) + action.setChecked(True) + self.__useWholeData = action + + self.addAction(self.__displayWholeItems) + self.addAction(self.__displayActiveItems) + self.addSeparator() + self.addAction(self.__useVisibleData) + self.addAction(self.__useWholeData) + + self.itemSelection = qt.QActionGroup(self) + self.itemSelection.setExclusive(True) + self.itemSelection.addAction(self.__displayActiveItems) + self.itemSelection.addAction(self.__displayWholeItems) + + self.dataRangeSelection = qt.QActionGroup(self) + self.dataRangeSelection.setExclusive(True) + self.dataRangeSelection.addAction(self.__useWholeData) + self.dataRangeSelection.addAction(self.__useVisibleData) + + def isActiveItemMode(self): + return self.itemSelection.checkedAction() is self.__displayActiveItems + + def isVisibleDataRangeMode(self): + return self.dataRangeSelection.checkedAction() is self.__useVisibleData + + def __init__(self, parent=None, plot=None, stats=None): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + self._options = self.OptionsWidget(parent=self) + self.layout().addWidget(self._options) + self._statsTable = StatsTable(parent=self, plot=plot) + self.setStats = self._statsTable.setStats + self.setStats(stats) + + self.layout().addWidget(self._statsTable) + self.setPlot = self._statsTable.setPlot + + self._options.itemSelection.triggered.connect( + self._optSelectionChanged) + self._options.dataRangeSelection.triggered.connect( + self._optDataRangeChanged) + self._optSelectionChanged() + self._optDataRangeChanged() + + self.setDisplayOnlyActiveItem = self._statsTable.setDisplayOnlyActiveItem + self.setStatsOnVisibleData = self._statsTable.setStatsOnVisibleData + + def _optSelectionChanged(self, action=None): + self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode()) + + def _optDataRangeChanged(self, action=None): + self._statsTable.setStatsOnVisibleData(self._options.isVisibleDataRangeMode()) + + +class BasicStatsWidget(StatsWidget): + """ + Widget defining a simple set of :class:`Stat` to be displayed on a + :class:`StatsWidget`. + + :param parent: Qt parent + :param plot: the plot containing items on which we want statistics. + """ + + STATS = StatsHandler(( + (statsmdl.StatMin(), StatFormatter()), + statsmdl.StatCoordMin(), + (statsmdl.StatMax(), StatFormatter()), + statsmdl.StatCoordMax(), + (('std', numpy.std), StatFormatter()), + (('mean', numpy.mean), StatFormatter()), + statsmdl.StatCOM() + )) + + def __init__(self, parent=None, plot=None): + StatsWidget.__init__(self, parent=parent, plot=plot, stats=self.STATS) + + +class StatsTable(TableWidget): + """ + TableWidget displaying for each curves contained by the Plot some + information: + + * legend + * minimal value + * maximal value + * standard deviation (std) + + :param parent: The widget's parent. + :param plot: :class:`.PlotWidget` instance on which to operate + """ + + COMPATIBLE_KINDS = { + 'curve': CurveItem, + 'image': ImageItem, + 'scatter': ScatterItem, + 'histogram': HistogramItem + } + + COMPATIBLE_ITEMS = tuple(COMPATIBLE_KINDS.values()) + + def __init__(self, parent=None, plot=None): + TableWidget.__init__(self, parent) + """Next freeID for the curve""" + self.plot = None + self._displayOnlyActItem = False + self._statsOnVisibleData = False + self._lgdAndKindToItems = {} + """Associate to a tuple(legend, kind) the items legend""" + self.callbackImage = None + self.callbackScatter = None + self.callbackCurve = None + """Associate the curve legend to his first item""" + self._statsHandler = None + self._legendsSet = [] + """list of legends actually displayed""" + self._resetColumns() + + self.setColumnCount(len(self._columns)) + self.setSelectionBehavior(qt.QAbstractItemView.SelectRows) + self.setPlot(plot) + self.setSortingEnabled(True) + + def _resetColumns(self): + self._columns_index = OrderedDict([('legend', 0), ('kind', 1)]) + self._columns = self._columns_index.keys() + self.setColumnCount(len(self._columns)) + + def setStats(self, statsHandler): + """ + + :param statsHandler: Set the statistics to be displayed and how to + format them using + :rtype: :class:`StatsHandler` + """ + _statsHandler = statsHandler + if statsHandler is None: + _statsHandler = StatsHandler(statFormatters=()) + if isinstance(_statsHandler, (list, tuple)): + _statsHandler = StatsHandler(_statsHandler) + assert isinstance(_statsHandler, StatsHandler) + self._resetColumns() + self.clear() + + for statName, stat in list(_statsHandler.stats.items()): + assert isinstance(stat, statsmdl.StatBase) + self._columns_index[statName] = len(self._columns_index) + self._statsHandler = _statsHandler + self._columns = self._columns_index.keys() + self.setColumnCount(len(self._columns)) + + self._updateItemObserve() + self._updateAllStats() + + def getStatsHandler(self): + return self._statsHandler + + def _updateAllStats(self): + for (legend, kind) in self._lgdAndKindToItems: + self._updateStats(legend, kind) + + @staticmethod + def _getKind(myItem): + if isinstance(myItem, CurveItem): + return 'curve' + elif isinstance(myItem, ImageItem): + return 'image' + elif isinstance(myItem, ScatterItem): + return 'scatter' + elif isinstance(myItem, HistogramItem): + return 'histogram' + else: + return None + + def setPlot(self, plot): + """ + Define the plot to interact with + + :param plot: the plot containing the items on which statistics are + applied + :rtype: :class:`.PlotWidget` + """ + if self.plot: + self._dealWithPlotConnection(create=False) + self.plot = plot + self.clear() + if self.plot: + self._dealWithPlotConnection(create=True) + self._updateItemObserve() + + def _updateItemObserve(self): + if self.plot: + self.clear() + if self._displayOnlyActItem is True: + activeCurve = self.plot.getActiveCurve(just_legend=False) + activeScatter = self.plot._getActiveItem(kind='scatter', + just_legend=False) + activeImage = self.plot.getActiveImage(just_legend=False) + if activeCurve: + self._addItem(activeCurve) + if activeImage: + self._addItem(activeImage) + if activeScatter: + self._addItem(activeScatter) + else: + [self._addItem(curve) for curve in self.plot.getAllCurves()] + [self._addItem(image) for image in self.plot.getAllImages()] + scatters = self.plot._getItems(kind='scatter', + just_legend=False, + withhidden=True) + [self._addItem(scatter) for scatter in scatters] + histograms = self.plot._getItems(kind='histogram', + just_legend=False, + withhidden=True) + [self._addItem(histogram) for histogram in histograms] + + def _dealWithPlotConnection(self, create=True): + """ + Manage connection to plot signals + + Note: connection on Item are managed by the _removeItem function + """ + if self.plot is None: + return + if self._displayOnlyActItem: + if create is True: + if self.callbackImage is None: + self.callbackImage = functools.partial(self._activeItemChanged, 'image') + self.callbackScatter = functools.partial(self._activeItemChanged, 'scatter') + self.callbackCurve = functools.partial(self._activeItemChanged, 'curve') + self.plot.sigActiveImageChanged.connect(self.callbackImage) + self.plot.sigActiveScatterChanged.connect(self.callbackScatter) + self.plot.sigActiveCurveChanged.connect(self.callbackCurve) + else: + if self.callbackImage is not None: + self.plot.sigActiveImageChanged.disconnect(self.callbackImage) + self.plot.sigActiveScatterChanged.disconnect(self.callbackScatter) + self.plot.sigActiveCurveChanged.disconnect(self.callbackCurve) + self.callbackImage = None + self.callbackScatter = None + self.callbackCurve = None + else: + if create is True: + self.plot.sigContentChanged.connect(self._plotContentChanged) + else: + self.plot.sigContentChanged.disconnect(self._plotContentChanged) + if create is True: + self.plot.sigPlotSignal.connect(self._zoomPlotChanged) + else: + self.plot.sigPlotSignal.disconnect(self._zoomPlotChanged) + + def clear(self): + """ + Clear all existing items + """ + lgdsAndKinds = list(self._lgdAndKindToItems.keys()) + for lgdAndKind in lgdsAndKinds: + self._removeItem(legend=lgdAndKind[0], kind=lgdAndKind[1]) + self._lgdAndKindToItems = {} + qt.QTableWidget.clear(self) + self.setRowCount(0) + + # It have to called befor3e accessing to the header items + self.setHorizontalHeaderLabels(self._columns) + + if self._statsHandler is not None: + for columnId, name in enumerate(self._columns): + item = self.horizontalHeaderItem(columnId) + if name in self._statsHandler.stats: + stat = self._statsHandler.stats[name] + text = stat.name[0].upper() + stat.name[1:] + if stat.description is not None: + tooltip = stat.description + else: + tooltip = "" + else: + text = name[0].upper() + name[1:] + tooltip = "" + item.setToolTip(tooltip) + item.setText(text) + + if hasattr(self.horizontalHeader(), 'setSectionResizeMode'): # Qt5 + self.horizontalHeader().setSectionResizeMode(qt.QHeaderView.ResizeToContents) + else: # Qt4 + self.horizontalHeader().setResizeMode(qt.QHeaderView.ResizeToContents) + self.setColumnHidden(self._columns_index['kind'], True) + + def _addItem(self, item): + assert isinstance(item, self.COMPATIBLE_ITEMS) + if (item.getLegend(), self._getKind(item)) in self._lgdAndKindToItems: + self._updateStats(item.getLegend(), self._getKind(item)) + return + + self.setRowCount(self.rowCount() + 1) + indexTable = self.rowCount() - 1 + kind = self._getKind(item) + + self._lgdAndKindToItems[(item.getLegend(), kind)] = {} + + # the get item will manage the item creation of not existing + _createItem = self._getItem + for itemName in self._columns: + _createItem(name=itemName, legend=item.getLegend(), kind=kind, + indexTable=indexTable) + + self._updateStats(legend=item.getLegend(), kind=kind) + + callback = functools.partial( + silx.utils.weakref.WeakMethodProxy(self._updateStats), + item.getLegend(), kind) + item.sigItemChanged.connect(callback) + self.setColumnHidden(self._columns_index['kind'], + item.getLegend() not in self._legendsSet) + self._legendsSet.append(item.getLegend()) + + def _getItem(self, name, legend, kind, indexTable): + if (legend, kind) not in self._lgdAndKindToItems: + self._lgdAndKindToItems[(legend, kind)] = {} + if not (name in self._lgdAndKindToItems[(legend, kind)] and + self._lgdAndKindToItems[(legend, kind)]): + if name in ('legend', 'kind'): + _item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type) + if name == 'legend': + _item.setText(legend) + else: + assert name == 'kind' + _item.setText(kind) + else: + if self._statsHandler.formatters[name]: + _item = self._statsHandler.formatters[name].tabWidgetItemClass() + else: + _item = qt.QTableWidgetItem() + tooltip = self._statsHandler.stats[name].getToolTip(kind=kind) + if tooltip is not None: + _item.setToolTip(tooltip) + + _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable) + self.setItem(indexTable, self._columns_index[name], _item) + self._lgdAndKindToItems[(legend, kind)][name] = _item + + return self._lgdAndKindToItems[(legend, kind)][name] + + def _removeItem(self, legend, kind): + if (legend, kind) not in self._lgdAndKindToItems or not self.plot: + return + + self.firstItem = self._lgdAndKindToItems[(legend, kind)]['legend'] + del self._lgdAndKindToItems[(legend, kind)] + self.removeRow(self.firstItem.row()) + self._legendsSet.remove(legend) + self.setColumnHidden(self._columns_index['kind'], + legend not in self._legendsSet) + + def _updateCurrentStats(self): + for lgdAndKind in self._lgdAndKindToItems: + self._updateStats(lgdAndKind[0], lgdAndKind[1]) + + def _updateStats(self, legend, kind, event=None): + if self._statsHandler is None: + return + + assert kind in ('curve', 'image', 'scatter', 'histogram') + if kind == 'curve': + item = self.plot.getCurve(legend) + elif kind == 'image': + item = self.plot.getImage(legend) + elif kind == 'scatter': + item = self.plot.getScatter(legend) + elif kind == 'histogram': + item = self.plot.getHistogram(legend) + else: + raise ValueError('kind not managed') + + if not item or (item.getLegend(), kind) not in self._lgdAndKindToItems: + return + + assert isinstance(item, self.COMPATIBLE_ITEMS) + + statsValDict = self._statsHandler.calculate(item, self.plot, + self._statsOnVisibleData) + + lgdItem = self._lgdAndKindToItems[(item.getLegend(), kind)]['legend'] + assert lgdItem + rowStat = lgdItem.row() + + for statName, statVal in list(statsValDict.items()): + assert statName in self._lgdAndKindToItems[(item.getLegend(), kind)] + tableItem = self._getItem(name=statName, legend=item.getLegend(), + kind=kind, indexTable=rowStat) + tableItem.setText(str(statVal)) + + def currentChanged(self, current, previous): + if current.row() >= 0: + legendItem = self.item(current.row(), self._columns_index['legend']) + assert legendItem + kindItem = self.item(current.row(), self._columns_index['kind']) + kind = kindItem.text() + if kind == 'curve': + self.plot.setActiveCurve(legendItem.text()) + elif kind == 'image': + self.plot.setActiveImage(legendItem.text()) + elif kind == 'scatter': + self.plot._setActiveItem('scatter', legendItem.text()) + elif kind == 'histogram': + # active histogram not managed by the plot actually + pass + else: + raise ValueError('kind not managed') + qt.QTableWidget.currentChanged(self, current, previous) + + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + """ + + :param bool displayOnlyActItem: True if we want to only show active + item + """ + if self._displayOnlyActItem == displayOnlyActItem: + return + self._displayOnlyActItem = displayOnlyActItem + self._dealWithPlotConnection(create=False) + self._updateItemObserve() + self._dealWithPlotConnection(create=True) + + def setStatsOnVisibleData(self, b): + """ + .. warning:: When visible data is activated we will process to a simple + filtering of visible data by the user. The filtering is a + simple data sub-sampling. No interpolation is made to fit + data to boundaries. + + :param bool b: True if we want to apply statistics only on visible data + """ + if self._statsOnVisibleData != b: + self._statsOnVisibleData = b + self._updateCurrentStats() + + def _activeItemChanged(self, kind): + """Callback used when plotting only the active item""" + assert kind in ('curve', 'image', 'scatter', 'histogram') + self._updateItemObserve() + + def _plotContentChanged(self, action, kind, legend): + """Callback used when plotting all the plot items""" + if kind not in ('curve', 'image', 'scatter', 'histogram'): + return + if kind == 'curve': + item = self.plot.getCurve(legend) + elif kind == 'image': + item = self.plot.getImage(legend) + elif kind == 'scatter': + item = self.plot.getScatter(legend) + elif kind == 'histogram': + item = self.plot.getHistogram(legend) + else: + raise ValueError('kind not managed') + + if action == 'add': + if item is None: + raise ValueError('Item from legend "%s" do not exists' % legend) + self._addItem(item) + elif action == 'remove': + self._removeItem(legend, kind) + + def _zoomPlotChanged(self, event): + if self._statsOnVisibleData is True: + if 'event' in event and event['event'] == 'limitsChanged': + self._updateCurrentStats() diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py index 35a48ae..da0dbf5 100644 --- a/silx/gui/plot/_BaseMaskToolsWidget.py +++ b/silx/gui/plot/_BaseMaskToolsWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,16 +29,17 @@ from __future__ import division __authors__ = ["T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "02/10/2017" +__date__ = "24/04/2018" import os +import weakref import numpy from silx.gui import qt, icons from silx.gui.widgets.FloatEdit import FloatEdit -from silx.gui.plot.Colormap import Colormap -from silx.gui.plot.Colors import rgba +from silx.gui.colors import Colormap +from silx.gui.colors import rgba from .actions.mode import PanModeAction @@ -372,7 +373,7 @@ class BaseMaskToolsWidget(qt.QWidget): # as parent have to be the first argument of the widget to fit # QtDesigner need but here plot can't be None by default. assert plot is not None - self._plot = plot + self._plotRef = weakref.ref(plot) self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask self._colormap = Colormap(name="", @@ -409,12 +410,21 @@ class BaseMaskToolsWidget(qt.QWidget): :param bool copy: True (default) to get a copy of the mask. If False, the returned array MUST not be modified. - :return: The array of the mask with dimension of the 'active' plot item. - If there is no active image or scatter, an empty array is - returned. - :rtype: numpy.ndarray of uint8 + :return: The mask (as an array of uint8) with dimension of + the 'active' plot item. + If there is no active image or scatter, it returns None. + :rtype: Union[numpy.ndarray,None] """ - return self._mask.getMask(copy=copy) + mask = self._mask.getMask(copy=copy) + return None if mask.size == 0 else mask + + def setSelectionMask(self, mask): + """Set the mask: Must be implemented in subclass""" + raise NotImplementedError() + + def resetSelectionMask(self): + """Reset the mask: Must be implemented in subclass""" + raise NotImplementedError() def multipleMasks(self): """Return the current mode of multiple masks support. @@ -453,7 +463,11 @@ class BaseMaskToolsWidget(qt.QWidget): @property def plot(self): """The :class:`.PlotWindow` this widget is attached to.""" - return self._plot + plot = self._plotRef() + if plot is None: + raise RuntimeError( + 'Mask widget attached to a PlotWidget that no longer exists') + return plot def setDirection(self, direction=qt.QBoxLayout.LeftToRight): """Set the direction of the layout of the widget @@ -604,8 +618,8 @@ class BaseMaskToolsWidget(qt.QWidget): self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S)) self.polygonAction.setToolTip( 'Polygon selection tool: (Un)Mask a polygonal region <b>S</b><br>' - 'Left-click to place polygon corners<br>' - 'Right-click to place the last corner') + 'Left-click to place new polygon corners<br>' + 'Left-click on first corner to close the polygon') self.polygonAction.setCheckable(True) self.polygonAction.triggered.connect(self._activePolygonMode) self.addAction(self.polygonAction) @@ -962,13 +976,20 @@ class BaseMaskToolsWidget(qt.QWidget): self.plot.setInteractiveMode('draw', shape='polygon', source=self, color=color) self._updateDrawingModeWidgets() + def _getPencilWidth(self): + """Returns the width of the pencil to use in data coordinates` + + :rtype: float + """ + return self.pencilSpinBox.value() + def _activePencilMode(self): """Handle pencil action mode triggering""" self._releaseDrawingMode() self._drawingMode = 'pencil' self.plot.sigPlotSignal.connect(self._plotDrawEvent) color = self.getCurrentMaskColor() - width = self.pencilSpinBox.value() + width = self._getPencilWidth() self.plot.setInteractiveMode( 'draw', shape='pencil', source=self, color=color, width=width) self._updateDrawingModeWidgets() diff --git a/silx/gui/plot/__init__.py b/silx/gui/plot/__init__.py index b03392d..3a141b3 100644 --- a/silx/gui/plot/__init__.py +++ b/silx/gui/plot/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -37,6 +37,7 @@ List of Qt widgets: - :mod:`.PlotWindow`: A :mod:`.PlotWidget` with a configurable set of tools. - :class:`.Plot1D`: A widget with tools for curves. - :class:`.Plot2D`: A widget with tools for images. +- :class:`.ScatterView`: A widget with tools for scatter plot. - :class:`.ImageView`: A widget with tools for images and a side histogram. - :class:`.StackView`: A widget with tools for a stack of images. @@ -61,8 +62,10 @@ __date__ = "03/05/2017" from .PlotWidget import PlotWidget # noqa from .PlotWindow import PlotWindow, Plot1D, Plot2D # noqa +from .items.axis import TickMode from .ImageView import ImageView # noqa from .StackView import StackView # noqa +from .ScatterView import ScatterView # noqa __all__ = ['ImageView', 'PlotWidget', 'PlotWindow', 'Plot1D', 'Plot2D', - 'StackView'] + 'StackView', 'ScatterView', 'TickMode'] diff --git a/silx/gui/plot/_utils/dtime_ticklayout.py b/silx/gui/plot/_utils/dtime_ticklayout.py new file mode 100644 index 0000000..95fc235 --- /dev/null +++ b/silx/gui/plot/_utils/dtime_ticklayout.py @@ -0,0 +1,438 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module implements date-time labels layout on graph axes.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["P. Kenter"] +__license__ = "MIT" +__date__ = "04/04/2018" + + +import datetime as dt +import logging +import math +import time + +import dateutil.tz + +from dateutil.relativedelta import relativedelta + +from silx.third_party import enum +from .ticklayout import niceNumGeneric + +_logger = logging.getLogger(__name__) + + +MICROSECONDS_PER_SECOND = 1000000 +SECONDS_PER_MINUTE = 60 +SECONDS_PER_HOUR = 60 * SECONDS_PER_MINUTE +SECONDS_PER_DAY = 24 * SECONDS_PER_HOUR +SECONDS_PER_YEAR = 365.25 * SECONDS_PER_DAY +SECONDS_PER_MONTH_AVERAGE = SECONDS_PER_YEAR / 12 # Seconds per average month + + +# No dt.timezone in Python 2.7 so we use dateutil.tz.tzutc +_EPOCH = dt.datetime(1970, 1, 1, tzinfo=dateutil.tz.tzutc()) + +def timestamp(dtObj): + """ Returns POSIX timestamp of a datetime objects. + + If the dtObj object has a timestamp() method (python 3.3), this is + used. Otherwise (e.g. python 2.7) it is calculated here. + + The POSIX timestamp is a floating point value of the number of seconds + since the start of an epoch (typically 1970-01-01). For details see: + https://docs.python.org/3/library/datetime.html#datetime.datetime.timestamp + + :param datetime.datetime dtObj: date-time representation. + :return: POSIX timestamp + :rtype: float + """ + if hasattr(dtObj, "timestamp"): + return dtObj.timestamp() + else: + # Back ported from Python 3.5 + if dtObj.tzinfo is None: + return time.mktime((dtObj.year, dtObj.month, dtObj.day, + dtObj.hour, dtObj.minute, dtObj.second, + -1, -1, -1)) + dtObj.microsecond / 1e6 + else: + return (dtObj - _EPOCH).total_seconds() + + +@enum.unique +class DtUnit(enum.Enum): + YEARS = 0 + MONTHS = 1 + DAYS = 2 + HOURS = 3 + MINUTES = 4 + SECONDS = 5 + MICRO_SECONDS = 6 # a fraction of a second + + +def getDateElement(dateTime, unit): + """ Picks the date element with the unit from the dateTime + + E.g. getDateElement(datetime(1970, 5, 6), DtUnit.Day) will return 6 + + :param datetime dateTime: date/time to pick from + :param DtUnit unit: The unit describing the date element. + """ + if unit == DtUnit.YEARS: + return dateTime.year + elif unit == DtUnit.MONTHS: + return dateTime.month + elif unit == DtUnit.DAYS: + return dateTime.day + elif unit == DtUnit.HOURS: + return dateTime.hour + elif unit == DtUnit.MINUTES: + return dateTime.minute + elif unit == DtUnit.SECONDS: + return dateTime.second + elif unit == DtUnit.MICRO_SECONDS: + return dateTime.microsecond + else: + raise ValueError("Unexpected DtUnit: {}".format(unit)) + + +def setDateElement(dateTime, value, unit): + """ Returns a copy of dateTime with the tickStep unit set to value + + :param datetime.datetime: date time object + :param int value: value to set + :param DtUnit unit: unit + :return: datetime.datetime + """ + intValue = int(value) + _logger.debug("setDateElement({}, {} (int={}), {})" + .format(dateTime, value, intValue, unit)) + + year = dateTime.year + month = dateTime.month + day = dateTime.day + hour = dateTime.hour + minute = dateTime.minute + second = dateTime.second + microsecond = dateTime.microsecond + + if unit == DtUnit.YEARS: + year = intValue + elif unit == DtUnit.MONTHS: + month = intValue + elif unit == DtUnit.DAYS: + day = intValue + elif unit == DtUnit.HOURS: + hour = intValue + elif unit == DtUnit.MINUTES: + minute = intValue + elif unit == DtUnit.SECONDS: + second = intValue + elif unit == DtUnit.MICRO_SECONDS: + microsecond = intValue + else: + raise ValueError("Unexpected DtUnit: {}".format(unit)) + + _logger.debug("creating date time {}" + .format((year, month, day, hour, minute, second, microsecond))) + + return dt.datetime(year, month, day, hour, minute, second, microsecond, + tzinfo=dateTime.tzinfo) + + + +def roundToElement(dateTime, unit): + """ Returns a copy of dateTime with the + + :param datetime.datetime: date time object + :param DtUnit unit: unit + :return: datetime.datetime + """ + year = dateTime.year + month = dateTime.month + day = dateTime.day + hour = dateTime.hour + minute = dateTime.minute + second = dateTime.second + microsecond = dateTime.microsecond + + if unit.value < DtUnit.YEARS.value: + pass # Never round years + if unit.value < DtUnit.MONTHS.value: + month = 1 + if unit.value < DtUnit.DAYS.value: + day = 1 + if unit.value < DtUnit.HOURS.value: + hour = 0 + if unit.value < DtUnit.MINUTES.value: + minute = 0 + if unit.value < DtUnit.SECONDS.value: + second = 0 + if unit.value < DtUnit.MICRO_SECONDS.value: + microsecond = 0 + + result = dt.datetime(year, month, day, hour, minute, second, microsecond, + tzinfo=dateTime.tzinfo) + + return result + + +def addValueToDate(dateTime, value, unit): + """ Adds a value with unit to a dateTime. + + Uses dateutil.relativedelta.relativedelta from the standard library to do + the actual math. This function doesn't allow for fractional month or years, + so month and year are truncated to integers before adding. + + :param datetime dateTime: date time + :param float value: value to be added + :param DtUnit unit: of the value + :return: + """ + #logger.debug("addValueToDate({}, {}, {})".format(dateTime, value, unit)) + + if unit == DtUnit.YEARS: + intValue = int(value) # floats not implemented in relativeDelta(years) + return dateTime + relativedelta(years=intValue) + elif unit == DtUnit.MONTHS: + intValue = int(value) # floats not implemented in relativeDelta(mohths) + return dateTime + relativedelta(months=intValue) + elif unit == DtUnit.DAYS: + return dateTime + relativedelta(days=value) + elif unit == DtUnit.HOURS: + return dateTime + relativedelta(hours=value) + elif unit == DtUnit.MINUTES: + return dateTime + relativedelta(minutes=value) + elif unit == DtUnit.SECONDS: + return dateTime + relativedelta(seconds=value) + elif unit == DtUnit.MICRO_SECONDS: + return dateTime + relativedelta(microseconds=value) + else: + raise ValueError("Unexpected DtUnit: {}".format(unit)) + + +def bestUnit(durationInSeconds): + """ Gets the best tick spacing given a duration in seconds. + + :param durationInSeconds: time span duration in seconds + :return: DtUnit enumeration. + """ + + # Based on; https://stackoverflow.com/a/2144398/ + # If the duration is longer than two years the tick spacing will be in + # years. Else, if the duration is longer than two months, the spacing will + # be in months, Etcetera. + # + # This factor differs per unit. As a baseline it is 2, but for instance, + # for Months this needs to be higher (3>), This because it is impossible to + # have partial months so the tick spacing is always at least 1 month. A + # duration of two months would result in two ticks, which is too few. + # months would then results + + if durationInSeconds > SECONDS_PER_YEAR * 3: + return (durationInSeconds / SECONDS_PER_YEAR, DtUnit.YEARS) + elif durationInSeconds > SECONDS_PER_MONTH_AVERAGE * 3: + return (durationInSeconds / SECONDS_PER_MONTH_AVERAGE, DtUnit.MONTHS) + elif durationInSeconds > SECONDS_PER_DAY * 2: + return (durationInSeconds / SECONDS_PER_DAY, DtUnit.DAYS) + elif durationInSeconds > SECONDS_PER_HOUR * 2: + return (durationInSeconds / SECONDS_PER_HOUR, DtUnit.HOURS) + elif durationInSeconds > SECONDS_PER_MINUTE * 2: + return (durationInSeconds / SECONDS_PER_MINUTE, DtUnit.MINUTES) + elif durationInSeconds > 1 * 2: + return (durationInSeconds, DtUnit.SECONDS) + else: + return (durationInSeconds * MICROSECONDS_PER_SECOND, + DtUnit.MICRO_SECONDS) + + +NICE_DATE_VALUES = { + DtUnit.YEARS: [1, 2, 5, 10], + DtUnit.MONTHS: [1, 2, 3, 4, 6, 12], + DtUnit.DAYS: [1, 2, 3, 7, 14, 28], + DtUnit.HOURS: [1, 2, 3, 4, 6, 12], + DtUnit.MINUTES: [1, 2, 3, 5, 10, 15, 30], + DtUnit.SECONDS: [1, 2, 3, 5, 10, 15, 30], + DtUnit.MICRO_SECONDS : [1.0, 2.0, 5.0, 10.0], # floats for microsec +} + + +def bestFormatString(spacing, unit): + """ Finds the best format string given the spacing and DtUnit. + + If the spacing is a fractional number < 1 the format string will take this + into account + + :param spacing: spacing between ticks + :param DtUnit unit: + :return: Format string for use in strftime + :rtype: str + """ + isSmall = spacing < 1 + + if unit == DtUnit.YEARS: + return "%Y-m" if isSmall else "%Y" + elif unit == DtUnit.MONTHS: + return "%Y-%m-%d" if isSmall else "%Y-%m" + elif unit == DtUnit.DAYS: + return "%H:%M" if isSmall else "%Y-%m-%d" + elif unit == DtUnit.HOURS: + return "%H:%M" if isSmall else "%H:%M" + elif unit == DtUnit.MINUTES: + return "%H:%M:%S" if isSmall else "%H:%M" + elif unit == DtUnit.SECONDS: + return "%S.%f" if isSmall else "%H:%M:%S" + elif unit == DtUnit.MICRO_SECONDS: + return "%S.%f" + else: + raise ValueError("Unexpected DtUnit: {}".format(unit)) + + +def niceDateTimeElement(value, unit, isRound=False): + """ Uses the Nice Numbers algorithm to determine a nice value. + + The fractions are optimized for the unit of the date element. + """ + + niceValues = NICE_DATE_VALUES[unit] + elemValue = niceNumGeneric(value, niceValues, isRound=isRound) + + if unit == DtUnit.YEARS or unit == DtUnit.MONTHS: + elemValue = max(1, int(elemValue)) + + return elemValue + + +def findStartDate(dMin, dMax, nTicks): + """ Rounds a date down to the nearest nice number of ticks + """ + assert dMax > dMin, \ + "dMin ({}) should come before dMax ({})".format(dMin, dMax) + + delta = dMax - dMin + lengthSec = delta.total_seconds() + _logger.debug("findStartDate: {}, {} (duration = {} sec, {} days)" + .format(dMin, dMax, lengthSec, lengthSec / SECONDS_PER_DAY)) + + length, unit = bestUnit(delta.total_seconds()) + niceLength = niceDateTimeElement(length, unit) + + _logger.debug("Length: {:8.3f} {} (nice = {})" + .format(length, unit.name, niceLength)) + + niceSpacing = niceDateTimeElement(niceLength / nTicks, unit, isRound=True) + + _logger.debug("Spacing: {:8.3f} {} (nice = {})" + .format(niceLength / nTicks, unit.name, niceSpacing)) + + dVal = getDateElement(dMin, unit) + + if unit == DtUnit.MONTHS: # TODO: better rounding? + niceVal = math.floor((dVal-1) / niceSpacing) * niceSpacing + 1 + elif unit == DtUnit.DAYS: + niceVal = math.floor((dVal-1) / niceSpacing) * niceSpacing + 1 + else: + niceVal = math.floor(dVal / niceSpacing) * niceSpacing + + _logger.debug("StartValue: dVal = {}, niceVal: {} ({})" + .format(dVal, niceVal, unit.name)) + + startDate = roundToElement(dMin, unit) + startDate = setDateElement(startDate, niceVal, unit) + + return startDate, niceSpacing, unit + + +def dateRange(dMin, dMax, step, unit, includeFirstBeyond = False): + """ Generates a range of dates + + :param datetime dMin: start date + :param datetime dMax: end date + :param int step: the step size + :param DtUnit unit: the unit of the step size + :param bool includeFirstBeyond: if True the first date later than dMax will + be included in the range. If False (the default), the last generated + datetime will always be smaller than dMax. + :return: + """ + if (unit == DtUnit.YEARS or unit == DtUnit.MONTHS or + unit == DtUnit.MICRO_SECONDS): + + # Month and years will be converted to integers + assert int(step) > 0, "Integer value or tickstep is 0" + else: + assert step > 0, "tickstep is 0" + + dateTime = dMin + while dateTime < dMax: + yield dateTime + dateTime = addValueToDate(dateTime, step, unit) + + if includeFirstBeyond: + yield dateTime + + + +def calcTicks(dMin, dMax, nTicks): + """Returns tick positions. + + :param datetime.datetime dMin: The min value on the axis + :param datetime.datetime dMax: The max value on the axis + :param int nTicks: The target number of ticks. The actual number of found + ticks may differ. + :returns: (list of datetimes, DtUnit) tuple + """ + _logger.debug("Calc calcTicks({}, {}, nTicks={})" + .format(dMin, dMax, nTicks)) + + startDate, niceSpacing, unit = findStartDate(dMin, dMax, nTicks) + + result = [] + for d in dateRange(startDate, dMax, niceSpacing, unit, + includeFirstBeyond=True): + result.append(d) + + assert result[0] <= dMin, \ + "First nice date ({}) should be <= dMin {}".format(result[0], dMin) + + assert result[-1] >= dMax, \ + "Last nice date ({}) should be >= dMax {}".format(result[-1], dMax) + + return result, niceSpacing, unit + + +def calcTicksAdaptive(dMin, dMax, axisLength, tickDensity): + """ Calls calcTicks with a variable number of ticks, depending on axisLength + """ + # At least 2 ticks + nticks = max(2, int(round(tickDensity * axisLength))) + return calcTicks(dMin, dMax, nticks) + + + + + diff --git a/silx/gui/plot/_utils/test/__init__.py b/silx/gui/plot/_utils/test/__init__.py index 4a443ac..624dbcb 100644 --- a/silx/gui/plot/_utils/test/__init__.py +++ b/silx/gui/plot/_utils/test/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,10 +32,12 @@ __date__ = "18/10/2016" import unittest +from .test_dtime_ticklayout import suite as test_dtime_ticklayout_suite from .test_ticklayout import suite as test_ticklayout_suite def suite(): testsuite = unittest.TestSuite() + testsuite.addTest(test_dtime_ticklayout_suite()) testsuite.addTest(test_ticklayout_suite()) return testsuite diff --git a/silx/gui/plot/_utils/test/testColormap.py b/silx/gui/plot/_utils/test/testColormap.py new file mode 100644 index 0000000..d77fa65 --- /dev/null +++ b/silx/gui/plot/_utils/test/testColormap.py @@ -0,0 +1,648 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +import logging +import time +import unittest + +import numpy +from PyMca5 import spslut + +from silx.image.colormap import dataToRGBAColormap + +_logger = logging.getLogger(__name__) + +# TODOs: +# what to do with max < min: as SPS LUT or also invert outside boundaries? +# test usedMin and usedMax +# benchmark + + +# common ###################################################################### + +class _TestColormap(unittest.TestCase): + # Array data types to test + FLOATING_DTYPES = numpy.float16, numpy.float32, numpy.float64 + SIGNED_DTYPES = FLOATING_DTYPES + (numpy.int8, numpy.int16, + numpy.int32, numpy.int64) + UNSIGNED_DTYPES = numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64 + DTYPES = SIGNED_DTYPES + UNSIGNED_DTYPES + + # Array sizes to test + SIZES = 2, 10, 256, 1024 # , 2048, 4096 + + # Colormaps definitions + _LUT_RED_256 = numpy.zeros((256, 4), dtype=numpy.uint8) + _LUT_RED_256[:, 0] = numpy.arange(256, dtype=numpy.uint8) + _LUT_RED_256[:, 3] = 255 + + _LUT_RGB_3 = numpy.array(((255, 0, 0, 255), + (0, 255, 0, 255), + (0, 0, 255, 255)), dtype=numpy.uint8) + + _LUT_RGB_768 = numpy.zeros((768, 4), dtype=numpy.uint8) + _LUT_RGB_768[0:256, 0] = numpy.arange(256, dtype=numpy.uint8) + _LUT_RGB_768[256:512, 1] = numpy.arange(256, dtype=numpy.uint8) + _LUT_RGB_768[512:768, 1] = numpy.arange(256, dtype=numpy.uint8) + _LUT_RGB_768[:, 3] = 255 + + COLORMAPS = { + 'red 256': _LUT_RED_256, + 'rgb 3': _LUT_RGB_3, + 'rgb 768': _LUT_RGB_768, + } + + @staticmethod + def _log(*args): + """Logging used by test for debugging.""" + _logger.debug(str(args)) + + @staticmethod + def buildControlPixmap(data, colormap, start=None, end=None, + isLog10=False): + """Generate a pixmap used to test C pixmap.""" + if isLog10: # Convert to log + if start is None: + posValue = data[numpy.nonzero(data > 0)] + if posValue.size != 0: + start = numpy.nanmin(posValue) + else: + start = 0. + + if end is None: + end = numpy.nanmax(data) + + start = 0. if start <= 0. else numpy.log10(start, + dtype=numpy.float64) + end = 0. if end <= 0. else numpy.log10(end, + dtype=numpy.float64) + + data = numpy.log10(data, dtype=numpy.float64) + else: + if start is None: + start = numpy.nanmin(data) + if end is None: + end = numpy.nanmax(data) + + start, end = float(start), float(end) + min_, max_ = min(start, end), max(start, end) + + if start == end: + indices = numpy.asarray((len(colormap) - 1) * (data >= max_), + dtype=numpy.int) + else: + clipData = numpy.clip(data, min_, max_) # Clip first avoid overflow + scale = len(colormap) / (end - start) + normData = scale * (numpy.asarray(clipData, numpy.float64) - start) + + # Clip again to makes sure <= len(colormap) - 1 + indices = numpy.asarray(numpy.clip(normData, + 0, len(colormap) - 1), + dtype=numpy.uint32) + + pixmap = numpy.take(colormap, indices, axis=0) + pixmap.shape = data.shape + (4,) + return numpy.ascontiguousarray(pixmap) + + @staticmethod + def buildSPSLUTRedPixmap(data, start=None, end=None, isLog10=False): + """Generate a pixmap with SPS LUT. + Only supports red colormap with 256 colors. + """ + colormap = spslut.RED + mapping = spslut.LOG if isLog10 else spslut.LINEAR + + if start is None and end is None: + autoScale = 1 + start, end = 0, 1 + else: + autoScale = 0 + if start is None: + start = data.min() + if end is None: + end = data.max() + + pixmap, size, minMax = spslut.transform(data, + (1, 0), + (mapping, 3.0), + 'RGBX', + colormap, + autoScale, + (start, end), + (0, 255), + 1) + pixmap.shape = data.shape[0], data.shape[1], 4 + + return pixmap + + def _testColormap(self, data, colormap, start, end, control=None, + isLog10=False, nanColor=None): + """Test pixmap built with C code against SPS LUT if possible, + else against Python control code.""" + startTime = time.time() + pixmap = dataToRGBAColormap(data, + colormap, + start, + end, + isLog10, + nanColor) + duration = time.time() - startTime + + # Compare with result + controlType = 'array' + if control is None: + startTime = time.time() + + # Compare with SPS LUT if possible + if (colormap.shape == self.COLORMAPS['red 256'].shape and + numpy.all(numpy.equal(colormap, self.COLORMAPS['red 256'])) and + data.size % 2 == 0 and + data.dtype in (numpy.float32, numpy.float64)): + # Only works with red colormap and even size + # as it needs 2D data + if len(data.shape) == 1: + data.shape = data.size // 2, -1 + pixmap.shape = data.shape + (4,) + control = self.buildSPSLUTRedPixmap(data, start, end, isLog10) + controlType = 'SPS LUT' + + # Compare with python test implementation + else: + control = self.buildControlPixmap(data, colormap, start, end, + isLog10) + controlType = 'Python control code' + + controlDuration = time.time() - startTime + if duration >= controlDuration: + self._log('duration', duration, 'control', controlDuration) + # Allows duration to be 20% over SPS LUT duration + # self.assertTrue(duration < 1.2 * controlDuration) + + difference = numpy.fabs(numpy.asarray(pixmap, dtype=numpy.float64) - + numpy.asarray(control, dtype=numpy.float64)) + if numpy.any(difference != 0.0): + self._log('control', controlType) + self._log('data', data) + self._log('pixmap', pixmap) + self._log('control', control) + self._log('errors', numpy.ravel(difference)) + self._log('errors', difference[difference != 0]) + self._log('in pixmap', pixmap[difference != 0]) + self._log('in control', control[difference != 0]) + self._log('Max error', difference.max()) + + # Allows a difference of 1 per channel + self.assertTrue(numpy.all(difference <= 1.0)) + + return duration + + +# TestColormap ################################################################ + +class TestColormap(_TestColormap): + """Test common limit case for colormap in C with both linear and log mode. + + Test with different: data types, sizes, colormaps (with different sizes), + mapping range. + """ + + def testNoData(self): + """Test pixmap generation with empty data.""" + self._log("TestColormap.testNoData") + cmapName = 'red 256' + colormap = self.COLORMAPS[cmapName] + + for dtype in self.DTYPES: + for isLog10 in (False, True): + data = numpy.array((), dtype=dtype) + result = numpy.array((), dtype=numpy.uint8) + result.shape = 0, 4 + duration = self._testColormap(data, colormap, + None, None, result, isLog10) + self._log('No data', 'red 256', dtype, len(data), (None, None), + 'isLog10:', isLog10, duration) + + def testNaN(self): + """Test pixmap generation with NaN values and no NaN color.""" + self._log("TestColormap.testNaN") + cmapName = 'red 256' + colormap = self.COLORMAPS[cmapName] + + for dtype in self.FLOATING_DTYPES: + for isLog10 in (False, True): + # All NaNs + data = numpy.array((float('nan'),) * 4, dtype=dtype) + result = numpy.array(((0, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, colormap, + None, None, result, isLog10) + self._log('All NaNs', 'red 256', dtype, len(data), + (None, None), 'isLog10:', isLog10, duration) + + # Some NaNs + data = numpy.array((1., float('nan'), 0., float('nan')), + dtype=dtype) + result = numpy.array(((255, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, colormap, + None, None, result, isLog10) + self._log('Some NaNs', 'red 256', dtype, len(data), + (None, None), 'isLog10:', isLog10, duration) + + def testNaNWithColor(self): + """Test pixmap generation with NaN values with a NaN color.""" + self._log("TestColormap.testNaNWithColor") + cmapName = 'red 256' + colormap = self.COLORMAPS[cmapName] + + for dtype in self.FLOATING_DTYPES: + for isLog10 in (False, True): + # All NaNs + data = numpy.array((float('nan'),) * 4, dtype=dtype) + result = numpy.array(((128, 128, 128, 255), + (128, 128, 128, 255), + (128, 128, 128, 255), + (128, 128, 128, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, colormap, + None, None, result, isLog10, + nanColor=(128, 128, 128, 255)) + self._log('All NaNs', 'red 256', dtype, len(data), + (None, None), 'isLog10:', isLog10, duration) + + # Some NaNs + data = numpy.array((1., float('nan'), 0., float('nan')), + dtype=dtype) + result = numpy.array(((255, 0, 0, 255), + (128, 128, 128, 255), + (0, 0, 0, 255), + (128, 128, 128, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, colormap, + None, None, result, isLog10, + nanColor=(128, 128, 128, 255)) + self._log('Some NaNs', 'red 256', dtype, len(data), + (None, None), 'isLog10:', isLog10, duration) + + +# TestLinearColormap ########################################################## + +class TestLinearColormap(_TestColormap): + """Test fill pixmap with colormap in C with linear mode. + + Test with different: data types, sizes, colormaps (with different sizes), + mapping range. + """ + + # Colormap ranges to map + RANGES = (None, None), (1, 10) + + def test1DData(self): + """Test pixmap generation for 1D data of different size and types.""" + self._log("TestLinearColormap.test1DData") + for cmapName, colormap in self.COLORMAPS.items(): + for size in self.SIZES: + for dtype in self.DTYPES: + for start, end in self.RANGES: + # Increasing values + data = numpy.arange(size, dtype=dtype) + duration = self._testColormap(data, colormap, + start, end) + + self._log('1D', cmapName, dtype, size, (start, end), + duration) + + # Reverse order + data = data[::-1] + duration = self._testColormap(data, colormap, + start, end) + + self._log('1D', cmapName, dtype, size, (start, end), + duration) + + def test2DData(self): + """Test pixmap generation for 2D data of different size and types.""" + self._log("TestLinearColormap.test2DData") + for cmapName, colormap in self.COLORMAPS.items(): + for size in self.SIZES: + for dtype in self.DTYPES: + for start, end in self.RANGES: + # Increasing values + data = numpy.arange(size * size, dtype=dtype) + data = numpy.nan_to_num(data) + data.shape = size, size + duration = self._testColormap(data, colormap, + start, end) + + self._log('2D', cmapName, dtype, size, (start, end), + duration) + + # Reverse order + data = data[::-1, ::-1] + duration = self._testColormap(data, colormap, + start, end) + + self._log('2D', cmapName, dtype, size, (start, end), + duration) + + def testInf(self): + """Test pixmap generation with Inf values.""" + self._log("TestLinearColormap.testInf") + + for dtype in self.FLOATING_DTYPES: + # All positive Inf + data = numpy.array((float('inf'),) * 4, dtype=dtype) + result = numpy.array(((255, 0, 0, 255), + (255, 0, 0, 255), + (255, 0, 0, 255), + (255, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, self.COLORMAPS['red 256'], + None, None, result) + self._log('All +Inf', 'red 256', dtype, len(data), (None, None), + duration) + + # All negative Inf + data = numpy.array((float('-inf'),) * 4, dtype=dtype) + result = numpy.array(((255, 0, 0, 255), + (255, 0, 0, 255), + (255, 0, 0, 255), + (255, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, self.COLORMAPS['red 256'], + None, None, result) + self._log('All -Inf', 'red 256', dtype, len(data), (None, None), + duration) + + # All +/-Inf + data = numpy.array((float('inf'), float('-inf'), + float('-inf'), float('inf')), dtype=dtype) + result = numpy.array(((255, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255), + (255, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, self.COLORMAPS['red 256'], + None, None, result) + self._log('All +/-Inf', 'red 256', dtype, len(data), (None, None), + duration) + + # Some +/-Inf + data = numpy.array((float('inf'), 0., float('-inf'), -10.), + dtype=dtype) + result = numpy.array(((255, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, self.COLORMAPS['red 256'], + None, None, + result) # Seg Fault with SPS + self._log('Some +/-Inf', 'red 256', dtype, len(data), (None, None), + duration) + + @unittest.skip("Not for reproductible tests") + def test1DDataRandom(self): + """Test pixmap generation for 1D data of different size and types.""" + self._log("TestLinearColormap.test1DDataRandom") + for cmapName, colormap in self.COLORMAPS.items(): + for size in self.SIZES: + for dtype in self.DTYPES: + for start, end in self.RANGES: + try: + dtypeMax = numpy.iinfo(dtype).max + except ValueError: + dtypeMax = numpy.finfo(dtype).max + data = numpy.asarray(numpy.random.rand(size) * dtypeMax, + dtype=dtype) + duration = self._testColormap(data, colormap, + start, end) + + self._log('1D Random', cmapName, dtype, size, + (start, end), duration) + + +# TestLog10Colormap ########################################################### + +class TestLog10Colormap(_TestColormap): + """Test fill pixmap with colormap in C with log mode. + + Test with different: data types, sizes, colormaps (with different sizes), + mapping range. + """ + # Colormap ranges to map + RANGES = (None, None), (1, 10) # , (10, 1) + + def test1DDataAllPositive(self): + """Test pixmap generation for all positive 1D data.""" + self._log("TestLog10Colormap.test1DDataAllPositive") + for cmapName, colormap in self.COLORMAPS.items(): + for size in self.SIZES: + for dtype in self.DTYPES: + for start, end in self.RANGES: + # Increasing values + data = numpy.arange(size, dtype=dtype) + 1 + duration = self._testColormap(data, colormap, + start, end, + isLog10=True) + + self._log('1D', cmapName, dtype, size, (start, end), + duration) + + # Reverse order + data = data[::-1] + duration = self._testColormap(data, colormap, + start, end, + isLog10=True) + + self._log('1D', cmapName, dtype, size, (start, end), + duration) + + def test2DDataAllPositive(self): + """Test pixmap generation for all positive 2D data.""" + self._log("TestLog10Colormap.test2DDataAllPositive") + for cmapName, colormap in self.COLORMAPS.items(): + for size in self.SIZES: + for dtype in self.DTYPES: + for start, end in self.RANGES: + # Increasing values + data = numpy.arange(size * size, dtype=dtype) + 1 + data = numpy.nan_to_num(data) + data.shape = size, size + duration = self._testColormap(data, colormap, + start, end, + isLog10=True) + + self._log('2D', cmapName, dtype, size, (start, end), + duration) + + # Reverse order + data = data[::-1, ::-1] + duration = self._testColormap(data, colormap, + start, end, + isLog10=True) + + self._log('2D', cmapName, dtype, size, (start, end), + duration) + + def testAllNegative(self): + """Test pixmap generation for all negative 1D data.""" + self._log("TestLog10Colormap.testAllNegative") + for cmapName, colormap in self.COLORMAPS.items(): + for size in self.SIZES: + for dtype in self.SIGNED_DTYPES: + for start, end in self.RANGES: + # Increasing values + data = numpy.arange(-size, 0, dtype=dtype) + duration = self._testColormap(data, colormap, + start, end, + isLog10=True) + + self._log('1D', cmapName, dtype, size, (start, end), + duration) + + # Reverse order + data = data[::-1] + duration = self._testColormap(data, colormap, + start, end, + isLog10=True) + + self._log('1D', cmapName, dtype, size, (start, end), + duration) + + def testCrossingZero(self): + """Test pixmap generation for 1D data with negative and zero.""" + self._log("TestLog10Colormap.testCrossingZero") + for cmapName, colormap in self.COLORMAPS.items(): + for size in self.SIZES: + for dtype in self.SIGNED_DTYPES: + for start, end in self.RANGES: + # Increasing values + data = numpy.arange(-size/2, size/2 + 1, dtype=dtype) + duration = self._testColormap(data, colormap, + start, end, + isLog10=True) + + self._log('1D', cmapName, dtype, size, (start, end), + duration) + + # Reverse order + data = data[::-1] + duration = self._testColormap(data, colormap, + start, end, + isLog10=True) + + self._log('1D', cmapName, dtype, size, (start, end), + duration) + + @unittest.skip("Not for reproductible tests") + def test1DDataRandom(self): + """Test pixmap generation for 1D data of different size and types.""" + self._log("TestLog10Colormap.test1DDataRandom") + for cmapName, colormap in self.COLORMAPS.items(): + for size in self.SIZES: + for dtype in self.DTYPES: + for start, end in self.RANGES: + try: + dtypeMax = numpy.iinfo(dtype).max + dtypeMin = numpy.iinfo(dtype).min + except ValueError: + dtypeMax = numpy.finfo(dtype).max + dtypeMin = numpy.finfo(dtype).min + if dtypeMin < 0: + data = numpy.asarray(-dtypeMax/2. + + numpy.random.rand(size) * dtypeMax, + dtype=dtype) + else: + data = numpy.asarray(numpy.random.rand(size) * dtypeMax, + dtype=dtype) + + duration = self._testColormap(data, colormap, + start, end, + isLog10=True) + + self._log('1D Random', cmapName, dtype, size, + (start, end), duration) + + def testInf(self): + """Test pixmap generation with Inf values.""" + self._log("TestLog10Colormap.testInf") + + for dtype in self.FLOATING_DTYPES: + # All positive Inf + data = numpy.array((float('inf'),) * 4, dtype=dtype) + result = numpy.array(((255, 0, 0, 255), + (255, 0, 0, 255), + (255, 0, 0, 255), + (255, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, self.COLORMAPS['red 256'], + None, None, result, isLog10=True) + self._log('All +Inf', 'red 256', dtype, len(data), (None, None), + duration) + + # All negative Inf + data = numpy.array((float('-inf'),) * 4, dtype=dtype) + result = numpy.array(((0, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, self.COLORMAPS['red 256'], + None, None, result, isLog10=True) + self._log('All -Inf', 'red 256', dtype, len(data), (None, None), + duration) + + # All +/-Inf + data = numpy.array((float('inf'), float('-inf'), + float('-inf'), float('inf')), dtype=dtype) + result = numpy.array(((255, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255), + (255, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, self.COLORMAPS['red 256'], + None, None, result, isLog10=True) + self._log('All +/-Inf', 'red 256', dtype, len(data), (None, None), + duration) + + # Some +/-Inf + data = numpy.array((float('inf'), 0., float('-inf'), -10.), + dtype=dtype) + result = numpy.array(((255, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255), + (0, 0, 0, 255)), dtype=numpy.uint8) + duration = self._testColormap(data, self.COLORMAPS['red 256'], + None, None, result, isLog10=True) + self._log('Some +/-Inf', 'red 256', dtype, len(data), (None, None), + duration) + + +def suite(): + testSuite = unittest.TestSuite() + for testClass in (TestColormap, TestLinearColormap): # , TestLog10Colormap): + testSuite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(testClass)) + return testSuite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/_utils/test/test_dtime_ticklayout.py b/silx/gui/plot/_utils/test/test_dtime_ticklayout.py new file mode 100644 index 0000000..2b87148 --- /dev/null +++ b/silx/gui/plot/_utils/test/test_dtime_ticklayout.py @@ -0,0 +1,93 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["P. Kenter"] +__license__ = "MIT" +__date__ = "06/04/2018" + + +import datetime as dt +import unittest + + +from silx.gui.plot._utils.dtime_ticklayout import ( + calcTicks, DtUnit, SECONDS_PER_YEAR) + + +class DtTestTickLayout(unittest.TestCase): + """Test ticks layout algorithms""" + + def testSmallMonthlySpacing(self): + """ Tests a range that did result in a spacing of less than 1 month. + It is impossible to add fractional month so the unit must be in days + """ + from dateutil import parser + d1 = parser.parse("2017-01-03 13:15:06.000044") + d2 = parser.parse("2017-03-08 09:16:16.307584") + _ticks, _units, spacing = calcTicks(d1, d2, nTicks=4) + + self.assertEqual(spacing, DtUnit.DAYS) + + + def testNoCrash(self): + """ Creates many combinations of and number-of-ticks and end-dates; + tests that it doesn't give an exception and returns a reasonable number + of ticks. + """ + d1 = dt.datetime(2017, 1, 3, 13, 15, 6, 44) + + value = 100e-6 # Start at 100 micro sec range. + + while value <= 200 * SECONDS_PER_YEAR: + + d2 = d1 + dt.timedelta(microseconds=value*1e6) # end date range + + for numTicks in range(2, 12): + ticks, _, _ = calcTicks(d1, d2, numTicks) + + margin = 2.5 + self.assertTrue( + numTicks/margin <= len(ticks) <= numTicks*margin, + "Condition {} <= {} <= {} failed for # ticks={} and d2={}:" + .format(numTicks/margin, len(ticks), numTicks * margin, + numTicks, d2)) + + value = value * 1.5 # let date period grow exponentially + + + + + +def suite(): + testsuite = unittest.TestSuite() + testsuite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(DtTestTickLayout)) + return testsuite + + +if __name__ == '__main__': + unittest.main() diff --git a/silx/gui/plot/_utils/ticklayout.py b/silx/gui/plot/_utils/ticklayout.py index 6e9f654..c9fd3e6 100644 --- a/silx/gui/plot/_utils/ticklayout.py +++ b/silx/gui/plot/_utils/ticklayout.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -51,28 +51,65 @@ def numberOfDigits(tickSpacing): # Nice Numbers ################################################################ -def _niceNum(value, isRound=False): - expvalue = math.floor(math.log10(value)) - frac = value/pow(10., expvalue) - if isRound: - if frac < 1.5: - nicefrac = 1. - elif frac < 3.: - nicefrac = 2. - elif frac < 7.: - nicefrac = 5. - else: - nicefrac = 10. +# This is the original niceNum implementation. For the date time ticks a more +# generic implementation was needed. +# +# def _niceNum(value, isRound=False): +# expvalue = math.floor(math.log10(value)) +# frac = value/pow(10., expvalue) +# if isRound: +# if frac < 1.5: +# nicefrac = 1. +# elif frac < 3.: # In niceNumGeneric this is (2+5)/2 = 3.5 +# nicefrac = 2. +# elif frac < 7.: +# nicefrac = 5. # In niceNumGeneric this is (5+10)/2 = 7.5 +# else: +# nicefrac = 10. +# else: +# if frac <= 1.: +# nicefrac = 1. +# elif frac <= 2.: +# nicefrac = 2. +# elif frac <= 5.: +# nicefrac = 5. +# else: +# nicefrac = 10. +# return nicefrac * pow(10., expvalue) + + +def niceNumGeneric(value, niceFractions=None, isRound=False): + """ A more generic implementation of the _niceNum function + + Allows the user to specify the fractions instead of using a hardcoded + list of [1, 2, 5, 10.0]. + """ + if value == 0: + return value + + if niceFractions is None: # Use default values + niceFractions = 1., 2., 5., 10. + roundFractions = (1.5, 3., 7., 10.) if isRound else niceFractions + else: - if frac <= 1.: - nicefrac = 1. - elif frac <= 2.: - nicefrac = 2. - elif frac <= 5.: - nicefrac = 5. - else: - nicefrac = 10. - return nicefrac * pow(10., expvalue) + roundFractions = list(niceFractions) + if isRound: + # Take the average with the next element. The last remains the same. + for i in range(len(roundFractions) - 1): + roundFractions[i] = (niceFractions[i] + niceFractions[i+1]) / 2 + + highest = niceFractions[-1] + value = float(value) + + expvalue = math.floor(math.log(value, highest)) + frac = value / pow(highest, expvalue) + + for niceFrac, roundFrac in zip(niceFractions, roundFractions): + if frac <= roundFrac: + return niceFrac * pow(highest, expvalue) + + # should not come here + assert False, "should not come here" def niceNumbers(vMin, vMax, nTicks=5): @@ -89,8 +126,8 @@ def niceNumbers(vMin, vMax, nTicks=5): number of fractional digit to show :rtype: tuple """ - vrange = _niceNum(vMax - vMin, False) - spacing = _niceNum(vrange / nTicks, True) + vrange = niceNumGeneric(vMax - vMin, isRound=False) + spacing = niceNumGeneric(vrange / nTicks, isRound=True) graphmin = math.floor(vMin / spacing) * spacing graphmax = math.ceil(vMax / spacing) * spacing nfrac = numberOfDigits(spacing) diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py index ac6dc2f..6e08f21 100644 --- a/silx/gui/plot/actions/control.py +++ b/silx/gui/plot/actions/control.py @@ -50,12 +50,11 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "15/02/2018" +__date__ = "24/04/2018" from . import PlotAction import logging from silx.gui.plot import items -from silx.gui.plot.ColormapDialog import ColormapDialog from silx.gui.plot._utils import applyZoomToPlot as _applyZoomToPlot from silx.gui import qt from silx.gui import icons @@ -328,6 +327,7 @@ class ColormapAction(PlotAction): triggered=self._actionTriggered, checkable=True, parent=parent) self.plot.sigActiveImageChanged.connect(self._updateColormap) + self.plot.sigActiveScatterChanged.connect(self._updateColormap) def setColorDialog(self, colorDialog): """Set a specific color dialog instead of using the default dialog.""" @@ -344,6 +344,7 @@ class ColormapAction(PlotAction): :parent QWidget parent: Parent of the new colormap :rtype: ColormapDialog """ + from silx.gui.dialog.ColormapDialog import ColormapDialog dialog = ColormapDialog(parent=parent) dialog.setModal(False) return dialog @@ -393,10 +394,19 @@ class ColormapAction(PlotAction): else: # No active image or active image is RGBA, - # set dialog from default info - colormap = self.plot.getDefaultColormap() - # Reset histogram and range if any - self._dialog.setData(None) + # Check for active scatter plot + scatter = self.plot._getActiveItem(kind='scatter') + if scatter is not None: + colormap = scatter.getColormap() + data = scatter.getValueData(copy=False) + self._dialog.setData(data) + + else: + # No active data image nor scatter, + # set dialog from default info + colormap = self.plot.getDefaultColormap() + # Reset histogram and range if any + self._dialog.setData(None) self._dialog.setColormap(colormap) @@ -408,7 +418,7 @@ class ColorBarAction(PlotAction): :param parent: See :class:`QAction` """ def __init__(self, plot, parent=None): - self._dialog = None # To store an instance of ColormapDialog + self._dialog = None # To store an instance of ColorBar super(ColorBarAction, self).__init__( plot, icon='colorbar', text='Colorbar', tooltip="Show/Hide the colorbar", diff --git a/silx/gui/plot/actions/histogram.py b/silx/gui/plot/actions/histogram.py index 40ef873..d6e3269 100644 --- a/silx/gui/plot/actions/histogram.py +++ b/silx/gui/plot/actions/histogram.py @@ -34,7 +34,7 @@ The following QAction are available: from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] -__date__ = "27/06/2017" +__date__ = "30/04/2018" __license__ = "MIT" from . import PlotAction @@ -129,7 +129,7 @@ class PixelIntensitiesHistoAction(PlotAction): edges=edges, legend='pixel intensity', fill=True, - color='red') + color='#66aad7') plot.resetZoom() def eventFilter(self, qobject, event): diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py index d6d5909..ac06942 100644 --- a/silx/gui/plot/actions/io.py +++ b/silx/gui/plot/actions/io.py @@ -44,13 +44,16 @@ from silx.io.utils import save1D, savespec from silx.io.nxdata import save_NXdata import logging import sys +import os.path from collections import OrderedDict import traceback import numpy -from silx.gui import qt +from silx.utils.deprecation import deprecated +from silx.gui import qt, printer +from silx.gui.dialog.GroupDialog import GroupDialog from silx.third_party.EdfFile import EdfFile from silx.third_party.TiffIO import TiffIO -from silx.gui._utils import convertArrayToQImage +from ...utils._image import convertArrayToQImage if sys.version_info[0] == 3: from io import BytesIO else: @@ -60,10 +63,26 @@ else: _logger = logging.getLogger(__name__) -_NEXUS_HDF5_EXT = [".nx5", ".nxs", ".hdf", ".hdf5", ".cxi", ".h5"] +_NEXUS_HDF5_EXT = [".h5", ".nx5", ".nxs", ".hdf", ".hdf5", ".cxi"] _NEXUS_HDF5_EXT_STR = ' '.join(['*' + ext for ext in _NEXUS_HDF5_EXT]) +def selectOutputGroup(h5filename): + """Open a dialog to prompt the user to select a group in + which to output data. + + :param str h5filename: name of an existing HDF5 file + :rtype: str + :return: Name of output group, or None if the dialog was cancelled + """ + dialog = GroupDialog() + dialog.addFile(h5filename) + dialog.setWindowTitle("Select an output group") + if not dialog.exec_(): + return None + return dialog.getSelectedDataUrl().data_path() + + class SaveAction(PlotAction): """QAction for saving Plot content. @@ -72,12 +91,11 @@ class SaveAction(PlotAction): :param plot: :class:`.PlotWidget` instance on which to operate. :param parent: See :class:`QAction`. """ - # TODO find a way to make the filter list selectable and extensible SNAPSHOT_FILTER_SVG = 'Plot Snapshot as SVG (*.svg)' SNAPSHOT_FILTER_PNG = 'Plot Snapshot as PNG (*.png)' - SNAPSHOT_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG) + DEFAULT_ALL_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG) # Dict of curve filters with CSV-like format # Using ordered dict to guarantee filters order @@ -101,10 +119,10 @@ class SaveAction(PlotAction): CURVE_FILTER_NXDATA = 'Curve as NXdata (%s)' % _NEXUS_HDF5_EXT_STR - CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [CURVE_FILTER_NPY, - CURVE_FILTER_NXDATA] + DEFAULT_CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [ + CURVE_FILTER_NPY, CURVE_FILTER_NXDATA] - ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)", ) + DEFAULT_ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)",) IMAGE_FILTER_EDF = 'Image data as EDF (*.edf)' IMAGE_FILTER_TIFF = 'Image data as TIFF (*.tif)' @@ -114,23 +132,53 @@ class SaveAction(PlotAction): IMAGE_FILTER_CSV_SEMICOLON = 'Image data as ;-separated CSV (*.csv)' IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)' IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)' - IMAGE_FILTER_RGB_TIFF = 'Image as TIFF (*.tif)' IMAGE_FILTER_NXDATA = 'Image as NXdata (%s)' % _NEXUS_HDF5_EXT_STR - IMAGE_FILTERS = (IMAGE_FILTER_EDF, - IMAGE_FILTER_TIFF, - IMAGE_FILTER_NUMPY, - IMAGE_FILTER_ASCII, - IMAGE_FILTER_CSV_COMMA, - IMAGE_FILTER_CSV_SEMICOLON, - IMAGE_FILTER_CSV_TAB, - IMAGE_FILTER_RGB_PNG, - IMAGE_FILTER_RGB_TIFF, - IMAGE_FILTER_NXDATA) + DEFAULT_IMAGE_FILTERS = (IMAGE_FILTER_EDF, + IMAGE_FILTER_TIFF, + IMAGE_FILTER_NUMPY, + IMAGE_FILTER_ASCII, + IMAGE_FILTER_CSV_COMMA, + IMAGE_FILTER_CSV_SEMICOLON, + IMAGE_FILTER_CSV_TAB, + IMAGE_FILTER_RGB_PNG, + IMAGE_FILTER_NXDATA) SCATTER_FILTER_NXDATA = 'Scatter as NXdata (%s)' % _NEXUS_HDF5_EXT_STR - SCATTER_FILTERS = (SCATTER_FILTER_NXDATA, ) + DEFAULT_SCATTER_FILTERS = (SCATTER_FILTER_NXDATA,) + + # filters for which we don't want an "overwrite existing file" warning + DEFAULT_APPEND_FILTERS = (CURVE_FILTER_NXDATA, IMAGE_FILTER_NXDATA, + SCATTER_FILTER_NXDATA) def __init__(self, plot, parent=None): + self._filters = { + 'all': OrderedDict(), + 'curve': OrderedDict(), + 'curves': OrderedDict(), + 'image': OrderedDict(), + 'scatter': OrderedDict()} + + # Initialize filters + for nameFilter in self.DEFAULT_ALL_FILTERS: + self.setFileFilter( + dataKind='all', nameFilter=nameFilter, func=self._saveSnapshot) + + for nameFilter in self.DEFAULT_CURVE_FILTERS: + self.setFileFilter( + dataKind='curve', nameFilter=nameFilter, func=self._saveCurve) + + for nameFilter in self.DEFAULT_ALL_CURVES_FILTERS: + self.setFileFilter( + dataKind='curves', nameFilter=nameFilter, func=self._saveCurves) + + for nameFilter in self.DEFAULT_IMAGE_FILTERS: + self.setFileFilter( + dataKind='image', nameFilter=nameFilter, func=self._saveImage) + + for nameFilter in self.DEFAULT_SCATTER_FILTERS: + self.setFileFilter( + dataKind='scatter', nameFilter=nameFilter, func=self._saveScatter) + super(SaveAction, self).__init__( plot, icon='document-save', text='Save as...', tooltip='Save curve/image/plot snapshot dialog', @@ -148,7 +196,7 @@ class SaveAction(PlotAction): msg.setDetailedText(traceback.format_exc()) msg.exec_() - def _saveSnapshot(self, filename, nameFilter): + def _saveSnapshot(self, plot, filename, nameFilter): """Save a snapshot of the :class:`PlotWindow` widget. :param str filename: The name of the file to write @@ -165,10 +213,51 @@ class SaveAction(PlotAction): 'Saving plot snapshot failed: format not supported') return False - self.plot.saveGraph(filename, fileFormat=fileFormat) + plot.saveGraph(filename, fileFormat=fileFormat) return True - def _saveCurve(self, filename, nameFilter): + def _getAxesLabels(self, item): + # If curve has no associated label, get the default from the plot + xlabel = item.getXLabel() or self.plot.getXAxis().getLabel() + ylabel = item.getYLabel() or self.plot.getYAxis().getLabel() + return xlabel, ylabel + + def _selectWriteableOutputGroup(self, filename): + if os.path.exists(filename) and os.path.isfile(filename) \ + and os.access(filename, os.W_OK): + entryPath = selectOutputGroup(filename) + if entryPath is None: + _logger.info("Save operation cancelled") + return None + return entryPath + elif not os.path.exists(filename): + # create new entry in new file + return "/entry" + else: + self._errorMessage('Save failed (file access issue)\n') + return None + + def _saveCurveAsNXdata(self, curve, filename): + entryPath = self._selectWriteableOutputGroup(filename) + if entryPath is None: + return False + + xlabel, ylabel = self._getAxesLabels(curve) + + return save_NXdata( + filename, + nxentry_name=entryPath, + signal=curve.getYData(copy=False), + axes=[curve.getXData(copy=False)], + signal_name="y", + axes_names=["x"], + signal_long_name=ylabel, + axes_long_names=[xlabel], + signal_errors=curve.getYErrorData(copy=False), + axes_errors=[curve.getXErrorData(copy=True)], + title=self.plot.getGraphTitle()) + + def _saveCurve(self, plot, filename, nameFilter): """Save a curve from the plot. :param str filename: The name of the file to write @@ -176,15 +265,15 @@ class SaveAction(PlotAction): :return: False if format is not supported or save failed, True otherwise. """ - if nameFilter not in self.CURVE_FILTERS: + if nameFilter not in self.DEFAULT_CURVE_FILTERS: return False # Check if a curve is to be saved - curve = self.plot.getActiveCurve() + curve = plot.getActiveCurve() # before calling _saveCurve, if there is no selected curve, we # make sure there is only one curve on the graph if curve is None: - curves = self.plot.getAllCurves() + curves = plot.getAllCurves() if not curves: self._errorMessage("No curve to be saved") return False @@ -199,26 +288,10 @@ class SaveAction(PlotAction): # .npy or nxdata fmt, csvdelim, autoheader = ("", "", False) - # If curve has no associated label, get the default from the plot - xlabel = curve.getXLabel() - if xlabel is None: - xlabel = self.plot.getXAxis().getLabel() - ylabel = curve.getYLabel() - if ylabel is None: - ylabel = self.plot.getYAxis().getLabel() + xlabel, ylabel = self._getAxesLabels(curve) if nameFilter == self.CURVE_FILTER_NXDATA: - return save_NXdata( - filename, - signal=curve.getYData(copy=False), - axes=[curve.getXData(copy=False)], - signal_name="y", - axes_names=["x"], - signal_long_name=ylabel, - axes_long_names=[xlabel], - signal_errors=curve.getYErrorData(copy=False), - axes_errors=[curve.getXErrorData(copy=True)], - title=self.plot.getGraphTitle()) + return self._saveCurveAsNXdata(curve, filename) try: save1D(filename, @@ -233,7 +306,7 @@ class SaveAction(PlotAction): return True - def _saveCurves(self, filename, nameFilter): + def _saveCurves(self, plot, filename, nameFilter): """Save all curves from the plot. :param str filename: The name of the file to write @@ -241,10 +314,10 @@ class SaveAction(PlotAction): :return: False if format is not supported or save failed, True otherwise. """ - if nameFilter not in self.ALL_CURVES_FILTERS: + if nameFilter not in self.DEFAULT_ALL_CURVES_FILTERS: return False - curves = self.plot.getAllCurves() + curves = plot.getAllCurves() if not curves: self._errorMessage("No curves to be saved") return False @@ -252,8 +325,8 @@ class SaveAction(PlotAction): curve = curves[0] scanno = 1 try: - xlabel = curve.getXLabel() or self.plot.getGraphXLabel() - ylabel = curve.getYLabel() or self.plot.getGraphYLabel(curve.getYAxis()) + xlabel = curve.getXLabel() or plot.getGraphXLabel() + ylabel = curve.getYLabel() or plot.getGraphYLabel(curve.getYAxis()) specfile = savespec(filename, curve.getXData(copy=False), curve.getYData(copy=False), @@ -269,8 +342,8 @@ class SaveAction(PlotAction): for curve in curves[1:]: try: scanno += 1 - xlabel = curve.getXLabel() or self.plot.getGraphXLabel() - ylabel = curve.getYLabel() or self.plot.getGraphYLabel(curve.getYAxis()) + xlabel = curve.getXLabel() or plot.getGraphXLabel() + ylabel = curve.getYLabel() or plot.getGraphYLabel(curve.getYAxis()) specfile = savespec(specfile, curve.getXData(copy=False), curve.getYData(copy=False), @@ -286,7 +359,7 @@ class SaveAction(PlotAction): return True - def _saveImage(self, filename, nameFilter): + def _saveImage(self, plot, filename, nameFilter): """Save an image from the plot. :param str filename: The name of the file to write @@ -294,13 +367,13 @@ class SaveAction(PlotAction): :return: False if format is not supported or save failed, True otherwise. """ - if nameFilter not in self.IMAGE_FILTERS: + if nameFilter not in self.DEFAULT_IMAGE_FILTERS: return False - image = self.plot.getActiveImage() + image = plot.getActiveImage() if image is None: qt.QMessageBox.warning( - self.plot, "No Data", "No image to be saved") + plot, "No Data", "No image to be saved") return False data = image.getData(copy=False) @@ -325,21 +398,24 @@ class SaveAction(PlotAction): return True elif nameFilter == self.IMAGE_FILTER_NXDATA: + entryPath = self._selectWriteableOutputGroup(filename) + if entryPath is None: + return False xorigin, yorigin = image.getOrigin() xscale, yscale = image.getScale() xaxis = xorigin + xscale * numpy.arange(data.shape[1]) yaxis = yorigin + yscale * numpy.arange(data.shape[0]) - xlabel = image.getXLabel() or self.plot.getGraphXLabel() - ylabel = image.getYLabel() or self.plot.getGraphYLabel() + xlabel, ylabel = self._getAxesLabels(image) interpretation = "image" if len(data.shape) == 2 else "rgba-image" return save_NXdata(filename, + nxentry_name=entryPath, signal=data, axes=[yaxis, xaxis], signal_name="image", axes_names=["y", "x"], axes_long_names=[ylabel, xlabel], - title=self.plot.getGraphTitle(), + title=plot.getGraphTitle(), interpretation=interpretation) elif nameFilter in (self.IMAGE_FILTER_ASCII, @@ -368,19 +444,13 @@ class SaveAction(PlotAction): return False return True - elif nameFilter in (self.IMAGE_FILTER_RGB_PNG, - self.IMAGE_FILTER_RGB_TIFF): + elif nameFilter == self.IMAGE_FILTER_RGB_PNG: # Get displayed image rgbaImage = image.getRgbaImageData(copy=False) # Convert RGB QImage qimage = convertArrayToQImage(rgbaImage[:, :, :3]) - if nameFilter == self.IMAGE_FILTER_RGB_PNG: - fileFormat = 'PNG' - else: - fileFormat = 'TIFF' - - if qimage.save(filename, fileFormat): + if qimage.save(filename, 'PNG'): return True else: _logger.error('Failed to save image as %s', filename) @@ -391,7 +461,7 @@ class SaveAction(PlotAction): return False - def _saveScatter(self, filename, nameFilter): + def _saveScatter(self, plot, filename, nameFilter): """Save an image from the plot. :param str filename: The name of the file to write @@ -399,12 +469,15 @@ class SaveAction(PlotAction): :return: False if format is not supported or save failed, True otherwise. """ - if nameFilter not in self.SCATTER_FILTERS: + if nameFilter not in self.DEFAULT_SCATTER_FILTERS: return False if nameFilter == self.SCATTER_FILTER_NXDATA: - scatter = self.plot.getScatter() - # TODO: we could get all scatters on this plot and concatenate their (x, y, values) + entryPath = self._selectWriteableOutputGroup(filename) + if entryPath is None: + return False + scatter = plot.getScatter() + x = scatter.getXData(copy=False) y = scatter.getYData(copy=False) z = scatter.getValueData(copy=False) @@ -417,51 +490,92 @@ class SaveAction(PlotAction): if isinstance(yerror, float): yerror = yerror * numpy.ones(x.shape, dtype=numpy.float32) - xlabel = self.plot.getGraphXLabel() - ylabel = self.plot.getGraphYLabel() + xlabel = plot.getGraphXLabel() + ylabel = plot.getGraphYLabel() return save_NXdata( filename, + nxentry_name=entryPath, signal=z, axes=[x, y], signal_name="values", axes_names=["x", "y"], axes_long_names=[xlabel, ylabel], axes_errors=[xerror, yerror], - title=self.plot.getGraphTitle()) + title=plot.getGraphTitle()) + + def setFileFilter(self, dataKind, nameFilter, func): + """Set a name filter to add/replace a file format support + + :param str dataKind: + The kind of data for which the provided filter is valid. + One of: 'all', 'curve', 'curves', 'image', 'scatter' + :param str nameFilter: The name filter in the QFileDialog. + See :meth:`QFileDialog.setNameFilters`. + :param callable func: The function to call to perform saving. + Expected signature is: + bool func(PlotWidget plot, str filename, str nameFilter) + """ + assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter') + + self._filters[dataKind][nameFilter] = func + + def getFileFilters(self, dataKind): + """Returns the nameFilter and associated function for a kind of data. + + :param str dataKind: + The kind of data for which the provided filter is valid. + On of: 'all', 'curve', 'curves', 'image', 'scatter' + :return: {nameFilter: function} associations. + :rtype: collections.OrderedDict + """ + assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter') + + return self._filters[dataKind].copy() def _actionTriggered(self, checked=False): """Handle save action.""" # Set-up filters - filters = [] + filters = OrderedDict() # Add image filters if there is an active image if self.plot.getActiveImage() is not None: - filters.extend(self.IMAGE_FILTERS) + filters.update(self._filters['image'].items()) # Add curve filters if there is a curve to save if (self.plot.getActiveCurve() is not None or len(self.plot.getAllCurves()) == 1): - filters.extend(self.CURVE_FILTERS) + filters.update(self._filters['curve'].items()) if len(self.plot.getAllCurves()) > 1: - filters.extend(self.ALL_CURVES_FILTERS) + filters.update(self._filters['curves'].items()) # Add scatter filters if there is a scatter # todo: CSV if self.plot.getScatter() is not None: - filters.extend(self.SCATTER_FILTERS) + filters.update(self._filters['scatter'].items()) - filters.extend(self.SNAPSHOT_FILTERS) + filters.update(self._filters['all'].items()) # Create and run File dialog dialog = qt.QFileDialog(self.plot) + dialog.setOption(dialog.DontUseNativeDialog) dialog.setWindowTitle("Output File Selection") dialog.setModal(1) - dialog.setNameFilters(filters) + dialog.setNameFilters(list(filters.keys())) dialog.setFileMode(dialog.AnyFile) dialog.setAcceptMode(dialog.AcceptSave) + def onFilterSelection(filt_): + # disable overwrite confirmation for NXdata types, + # because we append the data to existing files + if filt_ in self.DEFAULT_APPEND_FILTERS: + dialog.setOption(dialog.DontConfirmOverwrite) + else: + dialog.setOption(dialog.DontConfirmOverwrite, False) + + dialog.filterSelected.connect(onFilterSelection) + if not dialog.exec_(): return False @@ -469,34 +583,25 @@ class SaveAction(PlotAction): filename = dialog.selectedFiles()[0] dialog.close() - # Forces the filename extension to match the chosen filter - if "NXdata" in nameFilter: - has_allowed_ext = False - for ext in _NEXUS_HDF5_EXT: + if '(' in nameFilter and ')' == nameFilter.strip()[-1]: + # Check for correct file extension + # Extract file extensions as .something + extensions = [ext[ext.find('.'):] for ext in + nameFilter[nameFilter.find('(')+1:-1].split()] + for ext in extensions: if (len(filename) > len(ext) and filename[-len(ext):].lower() == ext.lower()): - has_allowed_ext = True - if not has_allowed_ext: - filename += ".h5" - else: - default_extension = nameFilter.split()[-1][2:-1] - if (len(filename) <= len(default_extension) or - filename[-len(default_extension):].lower() != default_extension.lower()): - filename += default_extension + break + else: # filename has no extension supported in nameFilter, add one + if len(extensions) >= 1: + filename += extensions[0] # Handle save - if nameFilter in self.SNAPSHOT_FILTERS: - return self._saveSnapshot(filename, nameFilter) - elif nameFilter in self.CURVE_FILTERS: - return self._saveCurve(filename, nameFilter) - elif nameFilter in self.ALL_CURVES_FILTERS: - return self._saveCurves(filename, nameFilter) - elif nameFilter in self.IMAGE_FILTERS: - return self._saveImage(filename, nameFilter) - elif nameFilter in self.SCATTER_FILTERS: - return self._saveScatter(filename, nameFilter) + func = filters.get(nameFilter, None) + if func is not None: + return func(self.plot, filename, nameFilter) else: - _logger.warning('Unsupported file filter: %s', nameFilter) + _logger.error('Unsupported file filter: %s', nameFilter) return False @@ -526,9 +631,6 @@ class PrintAction(PlotAction): :param parent: See :class:`QAction`. """ - # Share QPrinter instance to propose latest used as default - _printer = None - def __init__(self, plot, parent=None): super(PrintAction, self).__init__( plot, icon='document-print', text='Print...', @@ -538,15 +640,17 @@ class PrintAction(PlotAction): self.setShortcut(qt.QKeySequence.Print) self.setShortcutContext(qt.Qt.WidgetShortcut) - @property - def printer(self): - """The QPrinter instance used by the actions. + def getPrinter(self): + """The QPrinter instance used by the PrintAction. - This is shared accross all instances of PrintAct + :rtype: QPrinter """ - if self._printer is None: - PrintAction._printer = qt.QPrinter() - return self._printer + return printer.getDefaultPrinter() + + @property + @deprecated(replacement="getPrinter()", since_version="0.8.0") + def printer(self): + return self.getPrinter() def printPlotAsWidget(self): """Open the print dialog and print the plot. @@ -555,7 +659,7 @@ class PrintAction(PlotAction): :return: True if successful """ - dialog = qt.QPrintDialog(self.printer, self.plot) + dialog = qt.QPrintDialog(self.getPrinter(), self.plot) dialog.setWindowTitle('Print Plot') if not dialog.exec_(): return False @@ -564,10 +668,10 @@ class PrintAction(PlotAction): widget = self.plot.centralWidget() painter = qt.QPainter() - if not painter.begin(self.printer): + if not painter.begin(self.getPrinter()): return False - pageRect = self.printer.pageRect() + pageRect = self.getPrinter().pageRect() xScale = pageRect.width() / widget.width() yScale = pageRect.height() / widget.height() scale = min(xScale, yScale) @@ -588,7 +692,7 @@ class PrintAction(PlotAction): :return: True if successful """ # Init printer and start printer dialog - dialog = qt.QPrintDialog(self.printer, self.plot) + dialog = qt.QPrintDialog(self.getPrinter(), self.plot) dialog.setWindowTitle('Print Plot') if not dialog.exec_(): return False @@ -599,13 +703,13 @@ class PrintAction(PlotAction): pixmap = qt.QPixmap() pixmap.loadFromData(pngData, 'png') - xScale = self.printer.pageRect().width() / pixmap.width() - yScale = self.printer.pageRect().height() / pixmap.height() + xScale = self.getPrinter().pageRect().width() / pixmap.width() + yScale = self.getPrinter().pageRect().height() / pixmap.height() scale = min(xScale, yScale) # Draw pixmap with painter painter = qt.QPainter() - if not painter.begin(self.printer): + if not painter.begin(self.getPrinter()): return False painter.drawPixmap(0, 0, diff --git a/silx/gui/plot/actions/mode.py b/silx/gui/plot/actions/mode.py index 026a94d..ee05256 100644 --- a/silx/gui/plot/actions/mode.py +++ b/silx/gui/plot/actions/mode.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -69,7 +69,9 @@ class ZoomModeAction(PlotAction): self.blockSignals(old) def _actionTriggered(self, checked=False): - self.plot.setInteractiveMode('zoom', source=self) + plot = self.plot + if plot is not None: + plot.setInteractiveMode('zoom', source=self) class PanModeAction(PlotAction): @@ -97,4 +99,6 @@ class PanModeAction(PlotAction): self.blockSignals(old) def _actionTriggered(self, checked=False): - self.plot.setInteractiveMode('pan', source=self) + plot = self.plot + if plot is not None: + plot.setInteractiveMode('pan', source=self) diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py index 45bf785..8352ea0 100644 --- a/silx/gui/plot/backends/BackendBase.py +++ b/silx/gui/plot/backends/BackendBase.py @@ -31,8 +31,7 @@ This API is a simplified version of PyMca PlotBackend API. __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "16/08/2017" - +__date__ = "24/04/2018" import weakref from ... import qt @@ -59,6 +58,7 @@ class BackendBase(object): self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)} self.__yAxisInverted = False self.__keepDataAspectRatio = False + self._xAxisTimeZone = None self._axesDisplayed = True # Store a weakref to get access to the plot state. self._setPlot(plot) @@ -109,7 +109,7 @@ class BackendBase(object): :param str legend: The legend to be associated to the curve :param color: color(s) to be used :type color: string ("#RRGGBB") or (npoints, 4) unsigned byte array or - one of the predefined color names defined in Colors.py + one of the predefined color names defined in colors.py :param str symbol: Symbol to be drawn at each (x, y) position:: - ' ' or '' no symbol @@ -252,7 +252,7 @@ class BackendBase(object): :param bool flag: Toggle the display of a crosshair cursor. :param color: The color to use for the crosshair. - :type color: A string (either a predefined color name in Colors.py + :type color: A string (either a predefined color name in colors.py or "#RRGGBB")) or a 4 columns unsigned byte array. :param int linewidth: The width of the lines of the crosshair. :param linestyle: Type of line:: @@ -406,6 +406,39 @@ class BackendBase(object): # Graph axes + + def getXAxisTimeZone(self): + """Returns tzinfo that is used if the X-Axis plots date-times. + + None means the datetimes are interpreted as local time. + + :rtype: datetime.tzinfo of None. + """ + return self._xAxisTimeZone + + def setXAxisTimeZone(self, tz): + """Sets tzinfo that is used if the X-Axis plots date-times. + + Use None to let the datetimes be interpreted as local time. + + :rtype: datetime.tzinfo of None. + """ + self._xAxisTimeZone = tz + + def isXAxisTimeSeries(self): + """Return True if the X-axis scale shows datetime objects. + + :rtype: bool + """ + raise NotImplementedError() + + def setXAxisTimeSeries(self, isTimeSeries): + """Set whether the X-axis is a time series + + :param bool flag: True to switch to time series, False for regular axis. + """ + raise NotImplementedError() + def setXAxisLogarithmic(self, flag): """Set the X axis scale between linear and log. @@ -503,4 +536,4 @@ class BackendBase(object): are displayed and not the other. This only check status set to axes from the public API """ - return self._axesDisplayed
\ No newline at end of file + return self._axesDisplayed diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py index f9a1fe5..49c4540 100644 --- a/silx/gui/plot/backends/BackendMatplotlib.py +++ b/silx/gui/plot/backends/BackendMatplotlib.py @@ -32,9 +32,11 @@ __date__ = "18/10/2017" import logging - +import datetime as dt import numpy +from pkg_resources import parse_version as _parse_version + _logger = logging.getLogger(__name__) @@ -42,7 +44,6 @@ _logger = logging.getLogger(__name__) from ... import qt # First of all init matplotlib and set its backend -from ..matplotlib import Colormap as MPLColormap from ..matplotlib import FigureCanvasQTAgg import matplotlib from matplotlib.container import Container @@ -52,10 +53,103 @@ from matplotlib.image import AxesImage from matplotlib.backend_bases import MouseEvent from matplotlib.lines import Line2D from matplotlib.collections import PathCollection, LineCollection +from matplotlib.ticker import Formatter, ScalarFormatter, Locator + + from ..matplotlib.ModestImage import ModestImage from . import BackendBase from .._utils import FLOAT32_MINPOS +from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp + + + +class NiceDateLocator(Locator): + """ + Matplotlib Locator that uses Nice Numbers algorithm (adapted to dates) + to find the tick locations. This results in the same number behaviour + as when using the silx Open GL backend. + + Expects the data to be posix timestampes (i.e. seconds since 1970) + """ + def __init__(self, numTicks=5, tz=None): + """ + :param numTicks: target number of ticks + :param datetime.tzinfo tz: optional time zone. None is local time. + """ + super(NiceDateLocator, self).__init__() + self.numTicks = numTicks + + self._spacing = None + self._unit = None + self.tz = tz + + @property + def spacing(self): + """ The current spacing. Will be updated when new tick value are made""" + return self._spacing + + @property + def unit(self): + """ The current DtUnit. Will be updated when new tick value are made""" + return self._unit + + def __call__(self): + """Return the locations of the ticks""" + vmin, vmax = self.axis.get_view_interval() + return self.tick_values(vmin, vmax) + + def tick_values(self, vmin, vmax): + """ Calculates tick values + """ + if vmax < vmin: + vmin, vmax = vmax, vmin + + # vmin and vmax should be timestamps (i.e. seconds since 1 Jan 1970) + dtMin = dt.datetime.fromtimestamp(vmin, tz=self.tz) + dtMax = dt.datetime.fromtimestamp(vmax, tz=self.tz) + dtTicks, self._spacing, self._unit = \ + calcTicks(dtMin, dtMax, self.numTicks) + + # Convert datetime back to time stamps. + ticks = [timestamp(dtTick) for dtTick in dtTicks] + return ticks + + + +class NiceAutoDateFormatter(Formatter): + """ + Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the + best possible formats given the locators current spacing an date unit. + """ + + def __init__(self, locator, tz=None): + """ + :param niceDateLocator: a NiceDateLocator object + :param datetime.tzinfo tz: optional time zone. None is local time. + """ + super(NiceAutoDateFormatter, self).__init__() + self.locator = locator + self.tz = tz + + @property + def formatString(self): + if self.locator.spacing is None or self.locator.unit is None: + # Locator has no spacing or units yet. Return elaborate fmtString + return "Y-%m-%d %H:%M:%S" + else: + return bestFormatString(self.locator.spacing, self.locator.unit) + + + def __call__(self, x, pos=None): + """Return the format for tick val *x* at position *pos* + Expects x to be a POSIX timestamp (seconds since 1 Jan 1970) + """ + dateTime = dt.datetime.fromtimestamp(x, tz=self.tz) + tickStr = dateTime.strftime(self.formatString) + return tickStr + + class _MarkerContainer(Container): @@ -130,6 +224,7 @@ class BackendMatplotlib(BackendBase.BackendBase): # when getting the limits at the expense of a replot self._dirtyLimits = True self._axesDisplayed = True + self._matplotlibVersion = _parse_version(matplotlib.__version__) self.fig = Figure() self.fig.set_facecolor("w") @@ -153,7 +248,7 @@ class BackendMatplotlib(BackendBase.BackendBase): self.ax2.set_autoscaley_on(True) self.ax.set_zorder(1) # this works but the figure color is left - if matplotlib.__version__[0] < '2': + if self._matplotlibVersion < _parse_version('2'): self.ax.set_axis_bgcolor('none') else: self.ax.set_facecolor('none') @@ -165,9 +260,9 @@ class BackendMatplotlib(BackendBase.BackendBase): self._colormaps = {} self._graphCursor = tuple() - self.matplotlibVersion = matplotlib.__version__ self._enableAxis('right', False) + self._isXAxisTimeSeries = False # Add methods @@ -235,7 +330,7 @@ class BackendMatplotlib(BackendBase.BackendBase): color=actualColor, marker=symbol, picker=picker, - s=symbolsize) + s=symbolsize**2) artists.append(scatter) if fill: @@ -286,7 +381,7 @@ class BackendMatplotlib(BackendBase.BackendBase): # No transparent colormap with matplotlib < 1.2.0 # Add support for transparent colormap for uint8 data with # colormap with 256 colors, linear norm, [0, 255] range - if matplotlib.__version__ < '1.2.0': + if self._matplotlibVersion < _parse_version('1.2.0'): if (len(data.shape) == 2 and colormap.getName() is None and colormap.getColormapLUT() is not None): colors = colormap.getColormapLUT() @@ -313,29 +408,14 @@ class BackendMatplotlib(BackendBase.BackendBase): else: imageClass = AxesImage - # the normalization can be a source of time waste - # Two possibilities, we receive data or a ready to show image - if len(data.shape) == 3: # RGBA image - image = imageClass(self.ax, - label="__IMAGE__" + legend, - interpolation='nearest', - picker=picker, - zorder=z, - origin='lower') + # All image are shown as RGBA image + image = imageClass(self.ax, + label="__IMAGE__" + legend, + interpolation='nearest', + picker=picker, + zorder=z, + origin='lower') - else: - # Convert colormap argument to matplotlib colormap - scalarMappable = MPLColormap.getScalarMappable(colormap, data) - - # try as data - image = imageClass(self.ax, - label="__IMAGE__" + legend, - interpolation='nearest', - cmap=scalarMappable.cmap, - picker=picker, - zorder=z, - norm=scalarMappable.norm, - origin='lower') if alpha < 1: image.set_alpha(alpha) @@ -359,14 +439,17 @@ class BackendMatplotlib(BackendBase.BackendBase): ystep = 1 if scale[1] >= 0. else -1 data = data[::ystep, ::xstep] - if matplotlib.__version__ < "2.1": + if self._matplotlibVersion < _parse_version('2.1'): # matplotlib 1.4.2 do not support float128 dtype = data.dtype if dtype.kind == "f" and dtype.itemsize >= 16: _logger.warning("Your matplotlib version do not support " - "float128. Data converted to floa64.") + "float128. Data converted to float64.") data = data.astype(numpy.float64) + if data.ndim == 2: # Data image, convert to RGBA image + data = colormap.applyToData(data) + image.set_data(data) self.ax.add_artist(image) @@ -671,11 +754,39 @@ class BackendMatplotlib(BackendBase.BackendBase): # Graph axes + def setXAxisTimeZone(self, tz): + super(BackendMatplotlib, self).setXAxisTimeZone(tz) + + # Make new formatter and locator with the time zone. + self.setXAxisTimeSeries(self.isXAxisTimeSeries()) + + def isXAxisTimeSeries(self): + return self._isXAxisTimeSeries + + def setXAxisTimeSeries(self, isTimeSeries): + self._isXAxisTimeSeries = isTimeSeries + if self._isXAxisTimeSeries: + # We can't use a matplotlib.dates.DateFormatter because it expects + # the data to be in datetimes. Silx works internally with + # timestamps (floats). + locator = NiceDateLocator(tz=self.getXAxisTimeZone()) + self.ax.xaxis.set_major_locator(locator) + self.ax.xaxis.set_major_formatter( + NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone())) + else: + try: + scalarFormatter = ScalarFormatter(useOffset=False) + except: + _logger.warning('Cannot disabled axes offsets in %s ' % + matplotlib.__version__) + scalarFormatter = ScalarFormatter() + self.ax.xaxis.set_major_formatter(scalarFormatter) + def setXAxisLogarithmic(self, flag): # Workaround for matplotlib 2.1.0 when one tries to set an axis # to log scale with both limits <= 0 # In this case a draw with positive limits is needed first - if flag and matplotlib.__version__ >= '2.1.0': + if flag and self._matplotlibVersion >= _parse_version('2.1.0'): xlim = self.ax.get_xlim() if xlim[0] <= 0 and xlim[1] <= 0: self.ax.set_xlim(1, 10) @@ -685,15 +796,17 @@ class BackendMatplotlib(BackendBase.BackendBase): self.ax.set_xscale('log' if flag else 'linear') def setYAxisLogarithmic(self, flag): - # Workaround for matplotlib 2.1.0 when one tries to set an axis - # to log scale with both limits <= 0 - # In this case a draw with positive limits is needed first - if flag and matplotlib.__version__ >= '2.1.0': + # Workaround for matplotlib 2.0 issue with negative bounds + # before switching to log scale + if flag and self._matplotlibVersion >= _parse_version('2.0.0'): redraw = False - for axis in (self.ax, self.ax2): + for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)): ylim = axis.get_ylim() - if ylim[0] <= 0 and ylim[1] <= 0: - axis.set_ylim(1, 10) + if ylim[0] <= 0 or ylim[1] <= 0: + dataRange = self._plot.getDataRange()[dataRangeIndex] + if dataRange is None: + dataRange = 1, 100 # Fallback + axis.set_ylim(*dataRange) redraw = True if redraw: self.draw() @@ -722,16 +835,31 @@ class BackendMatplotlib(BackendBase.BackendBase): # Data <-> Pixel coordinates conversion + def _mplQtYAxisCoordConversion(self, y): + """Qt origin (top) to/from matplotlib origin (bottom) conversion. + + :rtype: float + """ + height = self.fig.get_window_extent().height + return height - y + def dataToPixel(self, x, y, axis): ax = self.ax2 if axis == "right" else self.ax pixels = ax.transData.transform_point((x, y)) xPixel, yPixel = pixels.T + + # Convert from matplotlib origin (bottom) to Qt origin (top) + yPixel = self._mplQtYAxisCoordConversion(yPixel) + return xPixel, yPixel def pixelToData(self, x, y, axis, check): ax = self.ax2 if axis == "right" else self.ax + # Convert from Qt origin (top) to matplotlib origin (bottom) + y = self._mplQtYAxisCoordConversion(y) + inv = ax.transData.inverted() x, y = inv.transform_point((x, y)) @@ -745,12 +873,12 @@ class BackendMatplotlib(BackendBase.BackendBase): return x, y def getPlotBoundsInPixels(self): - bbox = self.ax.get_window_extent().transformed( - self.fig.dpi_scale_trans.inverted()) - dpi = self.fig.dpi + bbox = self.ax.get_window_extent() # Warning this is not returning int... - return (bbox.bounds[0] * dpi, bbox.bounds[1] * dpi, - bbox.bounds[2] * dpi, bbox.bounds[3] * dpi) + return (bbox.xmin, + self._mplQtYAxisCoordConversion(bbox.ymax), + bbox.width, + bbox.height) def setAxesDisplayed(self, displayed): """Display or not the axes. @@ -822,7 +950,8 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): def _onMousePress(self, event): self._plot.onMousePress( - event.x, event.y, self._MPL_TO_PLOT_BUTTONS[event.button]) + event.x, self._mplQtYAxisCoordConversion(event.y), + self._MPL_TO_PLOT_BUTTONS[event.button]) def _onMouseMove(self, event): if self._graphCursor: @@ -839,14 +968,17 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): self._plot._setDirtyPlot(overlayOnly=True) # onMouseMove must trigger replot if dirty flag is raised - self._plot.onMouseMove(event.x, event.y) + self._plot.onMouseMove( + event.x, self._mplQtYAxisCoordConversion(event.y)) def _onMouseRelease(self, event): self._plot.onMouseRelease( - event.x, event.y, self._MPL_TO_PLOT_BUTTONS[event.button]) + event.x, self._mplQtYAxisCoordConversion(event.y), + self._MPL_TO_PLOT_BUTTONS[event.button]) def _onMouseWheel(self, event): - self._plot.onMouseWheel(event.x, event.y, event.step) + self._plot.onMouseWheel( + event.x, self._mplQtYAxisCoordConversion(event.y), event.step) def leaveEvent(self, event): """QWidget event handler""" @@ -880,7 +1012,8 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): self._picked = [] # Weird way to do an explicit picking: Simulate a button press event - mouseEvent = MouseEvent('button_press_event', self, x, y) + mouseEvent = MouseEvent('button_press_event', + self, x, self._mplQtYAxisCoordConversion(y)) cid = self.mpl_connect('pick_event', self._onPick) self.fig.pick(mouseEvent) self.mpl_disconnect(cid) @@ -924,7 +1057,7 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): """ # Starting with mpl 2.1.0, toggling autoscale raises a ValueError # in some situations. See #1081, #1136, #1163, - if matplotlib.__version__ >= "2.0.0": + if self._matplotlibVersion >= _parse_version("2.0.0"): try: FigureCanvasQTAgg.draw(self) except ValueError as err: @@ -956,7 +1089,6 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): if yRightLimits != self.ax2.get_ybound(): self._plot.getYAxis(axis='right')._emitLimitsChanged() - self._drawOverlays() def replot(self): @@ -975,6 +1107,12 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): elif dirtyFlag: # Need full redraw self.draw() + # Workaround issue of rendering overlays with some matplotlib versions + if (_parse_version('1.5') <= self._matplotlibVersion < _parse_version('2.1') and + not hasattr(self, '_firstReplot')): + self._firstReplot = False + if self._overlays or self._graphCursor: + qt.QTimer.singleShot(0, self.draw) # Request async draw # cursor diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py index 3c18f4f..0001bb9 100644 --- a/silx/gui/plot/backends/BackendOpenGL.py +++ b/silx/gui/plot/backends/BackendOpenGL.py @@ -28,7 +28,7 @@ from __future__ import division __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "16/08/2017" +__date__ = "24/04/2018" from collections import OrderedDict, namedtuple from ctypes import c_void_p @@ -38,8 +38,7 @@ import numpy from .._utils import FLOAT32_MINPOS from . import BackendBase -from .. import Colors -from ..Colormap import Colormap +from ... import colors from ... import qt from ..._glutils import gl @@ -355,7 +354,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._markers = OrderedDict() self._items = OrderedDict() self._plotContent = PlotDataContent() # For images and curves - self._selectionAreas = OrderedDict() self._glGarbageCollector = [] self._plotFrame = GLPlotFrame2D( @@ -399,7 +397,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): previousMousePosInPixels = self._mousePosInPixels self._mousePosInPixels = (xPixel, yPixel) if isCursorInPlot else None if (self._crosshairCursor is not None and - previousMousePosInPixels != self._crosshairCursor): + previousMousePosInPixels != self._mousePosInPixels): # Avoid replot when cursor remains outside plot area self._plot._setDirtyPlot(overlayOnly=True) @@ -431,14 +429,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # OpenGLWidget API - @staticmethod - def _setBlendFuncGL(): - # gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA) - gl.glBlendFuncSeparate(gl.GL_SRC_ALPHA, - gl.GL_ONE_MINUS_SRC_ALPHA, - gl.GL_ONE, - gl.GL_ONE) - def initializeGL(self): gl.testGL() @@ -446,7 +436,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glClearStencil(0) gl.glEnable(gl.GL_BLEND) - self._setBlendFuncGL() + # gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA) + gl.glBlendFuncSeparate(gl.GL_SRC_ALPHA, + gl.GL_ONE_MINUS_SRC_ALPHA, + gl.GL_ONE, + gl.GL_ONE) # For lines gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST) @@ -500,7 +494,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glUniform1i(self._progTex.uniforms['tex'], texUnit) gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE, - mat4Identity()) + mat4Identity().astype(numpy.float32)) stride = self._plotVertices.shape[-1] * self._plotVertices.itemsize gl.glEnableVertexAttribArray(self._progTex.attributes['position']) @@ -649,24 +643,20 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] - isXLog = self._plotFrame.xAxis.isLog - isYLog = self._plotFrame.yAxis.isLog - # Render in plot area gl.glScissor(self._plotFrame.margins.left, self._plotFrame.margins.bottom, plotWidth, plotHeight) gl.glEnable(gl.GL_SCISSOR_TEST) - gl.glViewport(self._plotFrame.margins.left, - self._plotFrame.margins.bottom, - plotWidth, plotHeight) + gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) # Prepare vertical and horizontal markers rendering self._progBase.use() - gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, - self._plotFrame.transformedDataProjMat) - gl.glUniform2i(self._progBase.uniforms['isLog'], isXLog, isYLog) + gl.glUniformMatrix4fv( + self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, + self.matScreenProj.astype(numpy.float32)) + gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) gl.glUniform1i(self._progBase.uniforms['hatchStep'], 0) gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) posAttrib = self._progBase.attributes['position'] @@ -677,10 +667,12 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): for marker in self._markers.values(): xCoord, yCoord = marker['x'], marker['y'] - if ((isXLog and xCoord is not None and - xCoord < FLOAT32_MINPOS) or - (isYLog and yCoord is not None and - yCoord < FLOAT32_MINPOS)): + if ((self._plotFrame.xAxis.isLog and + xCoord is not None and + xCoord <= 0) or + (self._plotFrame.yAxis.isLog and + yCoord is not None and + yCoord <= 0)): # Do not render markers with negative coords on log axis continue @@ -706,9 +698,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): align=RIGHT, valign=BOTTOM) labels.append(label) - xMin, xMax = self._plotFrame.dataRanges.x - vertices = numpy.array(((xMin, yCoord), - (xMax, yCoord)), + width = self._plotFrame.size[0] + vertices = numpy.array(((0, pixelPos[1]), + (width, pixelPos[1])), dtype=numpy.float32) else: # yCoord is None: vertical line in data space @@ -721,13 +713,12 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): align=LEFT, valign=TOP) labels.append(label) - yMin, yMax = self._plotFrame.dataRanges.y - vertices = numpy.array(((xCoord, yMin), - (xCoord, yMax)), + height = self._plotFrame.size[1] + vertices = numpy.array(((pixelPos[0], 0), + (pixelPos[0], height)), dtype=numpy.float32) self._progBase.use() - gl.glUniform4f(self._progBase.uniforms['color'], *marker['color']) @@ -759,13 +750,12 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # For now simple implementation: using a curve for each marker # Should pack all markers to a single set of points markerCurve = GLPlotCurve2D( - numpy.array((xCoord,), dtype=numpy.float32), - numpy.array((yCoord,), dtype=numpy.float32), + numpy.array((pixelPos[0],), dtype=numpy.float64), + numpy.array((pixelPos[1],), dtype=numpy.float64), marker=marker['symbol'], markerColor=marker['color'], markerSize=11) - markerCurve.render(self._plotFrame.transformedDataProjMat, - isXLog, isYLog) + markerCurve.render(self.matScreenProj, False, False) markerCurve.discard() gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) @@ -777,8 +767,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glDisable(gl.GL_SCISSOR_TEST) def _renderOverlayGL(self): - # Render selection area and crosshair cursor - if self._selectionAreas or self._crosshairCursor is not None: + # Render crosshair cursor + if self._crosshairCursor is not None: plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] # Scissor to plot area @@ -788,41 +778,21 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glEnable(gl.GL_SCISSOR_TEST) self._progBase.use() - gl.glUniform2i(self._progBase.uniforms['isLog'], - self._plotFrame.xAxis.isLog, - self._plotFrame.yAxis.isLog) + gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) posAttrib = self._progBase.attributes['position'] matrixUnif = self._progBase.uniforms['matrix'] colorUnif = self._progBase.uniforms['color'] hatchStepUnif = self._progBase.uniforms['hatchStep'] - # Render selection area in plot area - if self._selectionAreas: - gl.glViewport(self._plotFrame.margins.left, - self._plotFrame.margins.bottom, - plotWidth, plotHeight) - - gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE, - self._plotFrame.transformedDataProjMat) - - for shape in self._selectionAreas.values(): - if shape.isVideoInverted: - gl.glBlendFunc(gl.GL_ONE_MINUS_DST_COLOR, gl.GL_ZERO) - - shape.render(posAttrib, colorUnif, hatchStepUnif) - - if shape.isVideoInverted: - self._setBlendFuncGL() - - # Render crosshair cursor is screen frame but with scissor + # Render crosshair cursor in screen frame but with scissor if (self._crosshairCursor is not None and self._mousePosInPixels is not None): gl.glViewport( 0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE, - self.matScreenProj) + self.matScreenProj.astype(numpy.float32)) color, lineWidth = self._crosshairCursor gl.glUniform4f(colorUnif, *color) @@ -881,31 +851,30 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): isXLog, isYLog) # Render Items + gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) + self._progBase.use() gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, - self._plotFrame.transformedDataProjMat) - gl.glUniform2i(self._progBase.uniforms['isLog'], - self._plotFrame.xAxis.isLog, - self._plotFrame.yAxis.isLog) + self.matScreenProj.astype(numpy.float32)) + gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) for item in self._items.values(): - shape2D = item.get('_shape2D') - if shape2D is None: - closed = item['shape'] != 'polylines' - shape2D = Shape2D(tuple(zip(item['x'], item['y'])), - fill=item['fill'], - fillColor=item['color'], - stroke=True, - strokeColor=item['color'], - strokeClosed=closed) - item['_shape2D'] = shape2D - - if ((isXLog and shape2D.xMin < FLOAT32_MINPOS) or - (isYLog and shape2D.yMin < FLOAT32_MINPOS)): + if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or + (isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)): # Ignore items <= 0. on log axes continue + closed = item['shape'] != 'polylines' + points = [self.dataToPixel(x, y, axis='left', check=False) + for (x, y) in zip(item['x'], item['y'])] + shape2D = Shape2D(points, + fill=item['fill'], + fillColor=item['color'], + stroke=True, + strokeColor=item['color'], + strokeClosed=closed) + posAttrib = self._progBase.attributes['position'] colorUnif = self._progBase.uniforms['color'] hatchStepUnif = self._progBase.uniforms['hatchStep'] @@ -944,6 +913,21 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # Add methods + @staticmethod + def _castArrayTo(v): + """Returns best floating type to cast the array to. + + :param numpy.ndarray v: Array to cast + :rtype: numpy.dtype + :raise ValueError: If dtype is not supported + """ + if numpy.issubdtype(v.dtype, numpy.floating): + return numpy.float32 if v.itemsize <= 4 else numpy.float64 + elif numpy.issubdtype(v.dtype, numpy.integer): + return numpy.float32 if v.itemsize <= 2 else numpy.float64 + else: + raise ValueError('Unsupported data type') + def addCurve(self, x, y, legend, color, symbol, linewidth, linestyle, yaxis, @@ -954,8 +938,21 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): assert parameter is not None assert yaxis in ('left', 'right') - x = numpy.array(x, dtype=numpy.float32, copy=False, order='C') - y = numpy.array(y, dtype=numpy.float32, copy=False, order='C') + # Convert input data + x = numpy.array(x, copy=False) + y = numpy.array(y, copy=False) + + # Check if float32 is enough + if (self._castArrayTo(x) is numpy.float32 and + self._castArrayTo(y) is numpy.float32): + dtype = numpy.float32 + else: + dtype = numpy.float64 + + x = numpy.array(x, dtype=dtype, copy=False, order='C') + y = numpy.array(y, dtype=dtype, copy=False, order='C') + + # Convert errors to float32 if xerror is not None: xerror = numpy.array( xerror, dtype=numpy.float32, copy=False, order='C') @@ -963,6 +960,47 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): yerror = numpy.array( yerror, dtype=numpy.float32, copy=False, order='C') + # Handle axes log scale: convert data + + if self._plotFrame.xAxis.isLog: + logX = numpy.log10(x) + + if xerror is not None: + # Transform xerror so that + # log10(x) +/- xerror' = log10(x +/- xerror) + if hasattr(xerror, 'shape') and len(xerror.shape) == 2: + xErrorMinus, xErrorPlus = xerror[0], xerror[1] + else: + xErrorMinus, xErrorPlus = xerror, xerror + xErrorMinus = logX - numpy.log10(x - xErrorMinus) + xErrorPlus = numpy.log10(x + xErrorPlus) - logX + xerror = numpy.array((xErrorMinus, xErrorPlus), + dtype=numpy.float32) + + x = logX + + isYLog = (yaxis == 'left' and self._plotFrame.yAxis.isLog) or ( + yaxis == 'right' and self._plotFrame.y2Axis.isLog) + + if isYLog: + logY = numpy.log10(y) + + if yerror is not None: + # Transform yerror so that + # log10(y) +/- yerror' = log10(y +/- yerror) + if hasattr(yerror, 'shape') and len(yerror.shape) == 2: + yErrorMinus, yErrorPlus = yerror[0], yerror[1] + else: + yErrorMinus, yErrorPlus = yerror, yerror + yErrorMinus = logY - numpy.log10(y - yErrorMinus) + yErrorPlus = numpy.log10(y + yErrorPlus) - logY + yerror = numpy.array((yErrorMinus, yErrorPlus), + dtype=numpy.float32) + + y = logY + + # TODO check if need more filtering of error (e.g., clip to positive) + # TODO check and improve this if (len(color) == 4 and type(color[3]) in [type(1), numpy.uint8, numpy.int8]): @@ -973,7 +1011,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): color = None else: colorArray = None - color = Colors.rgba(color) + color = colors.rgba(color) if alpha < 1.: # Apply image transparency if colorArray is not None and colorArray.shape[1] == 4: @@ -995,7 +1033,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): marker=symbol, markerColor=color, markerSize=symbolsize, - fillColor=color if fill else None) + fillColor=color if fill else None, + isYLog=isYLog) curve.info = { 'legend': legend, 'zOrder': z, @@ -1054,7 +1093,13 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): elif len(data.shape) == 3: # For RGB, RGBA data assert data.shape[2] in (3, 4) - assert data.dtype in (numpy.float32, numpy.uint8) + + if numpy.issubdtype(data.dtype, numpy.floating): + data = numpy.array(data, dtype=numpy.float32, copy=False) + elif numpy.issubdtype(data.dtype, numpy.integer): + data = numpy.array(data, dtype=numpy.uint8, copy=False) + else: + raise ValueError('Unsupported data type') image = GLPlotRGBAImage(data, origin, scale, alpha) @@ -1106,7 +1151,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._items[legend] = { 'shape': shape, - 'color': Colors.rgba(color), + 'color': colors.rgba(color), 'fill': 'hatch' if fill else None, 'x': x, 'y': y @@ -1133,19 +1178,12 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if isConstraint: x, y = constraint(x, y) - if x is not None and self._plotFrame.xAxis.isLog and x <= 0.: - raise RuntimeError( - 'Cannot add marker with X <= 0 with X axis log scale') - if y is not None and self._plotFrame.yAxis.isLog and y <= 0.: - raise RuntimeError( - 'Cannot add marker with Y <= 0 with Y axis log scale') - self._markers[legend] = { 'x': x, 'y': y, 'legend': legend, 'text': text, - 'color': Colors.rgba(color), + 'color': colors.rgba(color), 'behaviors': behaviors, 'constraint': constraint if isConstraint else None, 'symbol': symbol, @@ -1204,7 +1242,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): "BackendOpenGL.setGraphCursor linestyle parameter ignored") if flag: - color = Colors.rgba(color) + color = colors.rgba(color) crosshairCursor = color, linewidth else: crosshairCursor = None @@ -1304,6 +1342,16 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): else: yPickMin, yPickMax = yPick1, yPick0 + # Apply log scale if axis is log + if self._plotFrame.xAxis.isLog: + xPickMin = numpy.log10(xPickMin) + xPickMax = numpy.log10(xPickMax) + + if (yAxis == 'left' and self._plotFrame.yAxis.isLog) or ( + yAxis == 'right' and self._plotFrame.y2Axis.isLog): + yPickMin = numpy.log10(yPickMin) + yPickMax = numpy.log10(yPickMax) + pickedIndices = item.pick(xPickMin, yPickMin, xPickMax, yPickMax) if pickedIndices: @@ -1548,6 +1596,18 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # Graph axes + def getXAxisTimeZone(self): + return self._plotFrame.xAxis.timeZone + + def setXAxisTimeZone(self, tz): + self._plotFrame.xAxis.timeZone = tz + + def isXAxisTimeSeries(self): + return self._plotFrame.xAxis.isTimeSeries + + def setXAxisTimeSeries(self, isTimeSeries): + self._plotFrame.xAxis.isTimeSeries = isTimeSeries + def setXAxisLogarithmic(self, flag): if flag != self._plotFrame.xAxis.isLog: if flag and self._keepDataAspectRatio: @@ -1657,4 +1717,4 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): def setAxesDisplayed(self, displayed): BackendBase.BackendBase.setAxesDisplayed(self, displayed) - self._plotFrame.displayed = displayed
\ No newline at end of file + self._plotFrame.displayed = displayed diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py index 124a3da..12b6bbe 100644 --- a/silx/gui/plot/backends/glutils/GLPlotCurve.py +++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,8 @@ This module provides classes to render 2D lines and scatter plots """ +from __future__ import division + __authors__ = ["T. Vincent"] __license__ = "MIT" __date__ = "03/04/2017" @@ -33,73 +35,73 @@ __date__ = "03/04/2017" import math import logging +import warnings import numpy from silx.math.combo import min_max from ...._glutils import gl -from ...._glutils import numpyToGLType, Program, vertexBuffer -from ..._utils import FLOAT32_MINPOS -from .GLSupport import buildFillMaskIndices +from ...._glutils import Program, vertexBuffer +from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate _logger = logging.getLogger(__name__) _MPL_NONES = None, 'None', '', ' ' +"""Possible values for None""" -# fill ######################################################################## +def _notNaNSlices(array, length=1): + """Returns slices of none NaN values in the array. -class _Fill2D(object): - _LINEAR, _LOG10_X, _LOG10_Y, _LOG10_X_Y = 0, 1, 2, 3 + :param numpy.ndarray array: 1D array from which to get slices + :param int length: Slices shorter than length gets discarded + :return: Array of (start, end) slice indices + :rtype: numpy.ndarray + """ + isnan = numpy.isnan(numpy.array(array, copy=False).reshape(-1)) + notnan = numpy.logical_not(isnan) + start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1 + if notnan[0]: + start = numpy.append(0, start) + end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1 + if notnan[-1]: + end = numpy.append(end, len(array)) + slices = numpy.transpose((start, end)) + if length > 1: + # discard slices with less than length values + slices = slices[numpy.diff(slices, axis=1).ravel() >= length] + return slices - _SHADERS = { - 'vertexTransforms': { - _LINEAR: """ - vec4 transformXY(float x, float y) { - return vec4(x, y, 0.0, 1.0); - } - """, - _LOG10_X: """ - const float oneOverLog10 = 0.43429448190325176; - vec4 transformXY(float x, float y) { - return vec4(oneOverLog10 * log(x), y, 0.0, 1.0); - } - """, - _LOG10_Y: """ - const float oneOverLog10 = 0.43429448190325176; +# fill ######################################################################## - vec4 transformXY(float x, float y) { - return vec4(x, oneOverLog10 * log(y), 0.0, 1.0); - } - """, - _LOG10_X_Y: """ - const float oneOverLog10 = 0.43429448190325176; +class _Fill2D(object): + """Object rendering curve filling as polygons + + :param numpy.ndarray xData: X coordinates of points + :param numpy.ndarray yData: Y coordinates of points + :param float baseline: Y value of the 'bottom' of the fill. + 0 for linear Y scale, -38 for log Y scale + :param List[float] color: RGBA color as 4 float in [0, 1] + :param List[float] offset: Translation of coordinates (ox, oy) + """ - vec4 transformXY(float x, float y) { - return vec4(oneOverLog10 * log(x), - oneOverLog10 * log(y), - 0.0, 1.0); - } - """ - }, - 'vertex': """ + _PROGRAM = Program( + vertexShader=""" #version 120 uniform mat4 matrix; attribute float xPos; attribute float yPos; - %s - void main(void) { - gl_Position = matrix * transformXY(xPos, yPos); + gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0); } """, - 'fragment': """ + fragmentShader=""" #version 120 uniform vec4 color; @@ -107,72 +109,95 @@ class _Fill2D(object): void main(void) { gl_FragColor = color; } - """ - } - - _programs = { - _LINEAR: Program( - _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LINEAR], - _SHADERS['fragment'], attrib0='xPos'), - _LOG10_X: Program( - _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LOG10_X], - _SHADERS['fragment'], attrib0='xPos'), - _LOG10_Y: Program( - _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LOG10_Y], - _SHADERS['fragment'], attrib0='xPos'), - _LOG10_X_Y: Program( - _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LOG10_X_Y], - _SHADERS['fragment'], attrib0='xPos'), - } - - def __init__(self, xFillVboData=None, yFillVboData=None, - xMin=None, yMin=None, xMax=None, yMax=None, - color=(0., 0., 0., 1.)): - self.xFillVboData = xFillVboData - self.yFillVboData = yFillVboData - self.xMin, self.yMin = xMin, yMin - self.xMax, self.yMax = xMax, yMax + """, + attrib0='xPos') + + def __init__(self, xData=None, yData=None, + baseline=0, + color=(0., 0., 0., 1.), + offset=(0., 0.)): + self.xData = xData + self.yData = yData + self._xFillVboData = None + self._yFillVboData = None self.color = color + self.offset = offset - self._bboxVertices = None - self._indices = None - self._indicesType = None + # Offset baseline + self.baseline = baseline - self.offset[1] def prepare(self): - if self._indices is None: - self._indices = buildFillMaskIndices(self.xFillVboData.size) - self._indicesType = numpyToGLType(self._indices.dtype) - - if self._bboxVertices is None: - yMin, yMax = min(self.yMin, 1e-32), max(self.yMax, 1e-32) - self._bboxVertices = numpy.array(((self.xMin, self.xMin, - self.xMax, self.xMax), - (yMin, yMax, yMin, yMax)), - dtype=numpy.float32) - - def render(self, matrix, isXLog, isYLog): + """Rendering preparation: build indices and bounding box vertices""" + if (self._xFillVboData is None and + self.xData is not None and self.yData is not None): + + # Get slices of not NaN values longer than 1 element + isnan = numpy.logical_or(numpy.isnan(self.xData), + numpy.isnan(self.yData)) + notnan = numpy.logical_not(isnan) + start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1 + if notnan[0]: + start = numpy.append(0, start) + end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1 + if notnan[-1]: + end = numpy.append(end, len(isnan)) + slices = numpy.transpose((start, end)) + # discard slices with less than length values + slices = slices[numpy.diff(slices, axis=1).reshape(-1) >= 2] + + # Number of points: slice + 2 * leading and trailing points + # Twice leading and trailing points to produce degenerated triangles + nbPoints = numpy.sum(numpy.diff(slices, axis=1)) + 4 * len(slices) + points = numpy.empty((nbPoints, 2), dtype=numpy.float32) + + offset = 0 + for start, end in slices: + # Duplicate first point for connecting degenerated triangle + points[offset:offset+2] = self.xData[start], self.baseline + + # 2nd point of the polygon is last point + points[offset+2] = self.xData[end-1], self.baseline + + # Add all points from the data + indices = start + buildFillMaskIndices(end - start) + + points[offset+3:offset+3+len(indices), 0] = self.xData[indices] + points[offset+3:offset+3+len(indices), 1] = self.yData[indices] + + # Duplicate last point for connecting degenerated triangle + points[offset+3+len(indices)] = points[offset+3+len(indices)-1] + + offset += len(indices) + 4 + + self._xFillVboData, self._yFillVboData = vertexBuffer(points.T) + + def render(self, matrix): + """Perform rendering + + :param numpy.ndarray matrix: 4x4 transform matrix to use + """ self.prepare() - if isXLog: - transform = self._LOG10_X_Y if isYLog else self._LOG10_X - else: - transform = self._LOG10_Y if isYLog else self._LINEAR + if self._xFillVboData is None: + return # Nothing to display - prog = self._programs[transform] - prog.use() + self._PROGRAM.use() - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) + gl.glUniformMatrix4fv( + self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE, + numpy.dot(matrix, + mat4Translate(*self.offset)).astype(numpy.float32)) - gl.glUniform4f(prog.uniforms['color'], *self.color) + gl.glUniform4f(self._PROGRAM.uniforms['color'], *self.color) - xPosAttrib = prog.attributes['xPos'] - yPosAttrib = prog.attributes['yPos'] + xPosAttrib = self._PROGRAM.attributes['xPos'] + yPosAttrib = self._PROGRAM.attributes['yPos'] gl.glEnableVertexAttribArray(xPosAttrib) - self.xFillVboData.setVertexAttrib(xPosAttrib) + self._xFillVboData.setVertexAttrib(xPosAttrib) gl.glEnableVertexAttribArray(yPosAttrib) - self.yFillVboData.setVertexAttrib(yPosAttrib) + self._yFillVboData.setVertexAttrib(yPosAttrib) # Prepare fill mask gl.glEnable(gl.GL_STENCIL_TEST) @@ -182,8 +207,7 @@ class _Fill2D(object): gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE) gl.glDepthMask(gl.GL_FALSE) - gl.glDrawElements(gl.GL_TRIANGLE_STRIP, self._indices.size, - self._indicesType, self._indices) + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, self._xFillVboData.size) gl.glStencilFunc(gl.GL_EQUAL, 1, 1) # Reset stencil while drawing @@ -191,14 +215,30 @@ class _Fill2D(object): gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE) gl.glDepthMask(gl.GL_TRUE) - gl.glVertexAttribPointer(xPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0, - self._bboxVertices[0]) - gl.glVertexAttribPointer(yPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0, - self._bboxVertices[1]) - gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, self._bboxVertices[0].size) + # Draw directly in NDC + gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE, + mat4Identity().astype(numpy.float32)) + + # NDC vertices + gl.glVertexAttribPointer( + xPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0, + numpy.array((-1., -1., 1., 1.), dtype=numpy.float32)) + gl.glVertexAttribPointer( + yPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0, + numpy.array((-1., 1., -1., 1.), dtype=numpy.float32)) + + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4) gl.glDisable(gl.GL_STENCIL_TEST) + def discard(self): + """Release VBOs""" + if self._xFillVboData is not None: + self._xFillVboData.vbo.discard() + + self._xFillVboData = None + self._yFillVboData = None + # line ######################################################################## @@ -206,44 +246,25 @@ SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':' class _Lines2D(object): + """Object rendering curve as a polyline + + :param xVboData: X coordinates VBO + :param yVboData: Y coordinates VBO + :param colorVboData: VBO of colors + :param distVboData: VBO of distance along the polyline + :param str style: Line style in: '-', '--', '-.', ':' + :param List[float] color: RGBA color as 4 float in [0, 1] + :param float width: Line width + :param float dashPeriod: Period of dashes + :param drawMode: OpenGL drawing mode + :param List[float] offset: Translation of coordinates (ox, oy) + """ + STYLES = SOLID, DASHED, DASHDOT, DOTTED """Supported line styles""" - _LINEAR, _LOG10_X, _LOG10_Y, _LOG10_X_Y = 0, 1, 2, 3 - - _SHADERS = { - 'vertexTransforms': { - _LINEAR: """ - vec4 transformXY(float x, float y) { - return vec4(x, y, 0.0, 1.0); - } - """, - _LOG10_X: """ - const float oneOverLog10 = 0.43429448190325176; - - vec4 transformXY(float x, float y) { - return vec4(oneOverLog10 * log(x), y, 0.0, 1.0); - } - """, - _LOG10_Y: """ - const float oneOverLog10 = 0.43429448190325176; - - vec4 transformXY(float x, float y) { - return vec4(x, oneOverLog10 * log(y), 0.0, 1.0); - } - """, - _LOG10_X_Y: """ - const float oneOverLog10 = 0.43429448190325176; - - vec4 transformXY(float x, float y) { - return vec4(oneOverLog10 * log(x), - oneOverLog10 * log(y), - 0.0, 1.0); - } - """ - }, - 'solid': { - 'vertex': """ + _SOLID_PROGRAM = Program( + vertexShader=""" #version 120 uniform mat4 matrix; @@ -253,14 +274,12 @@ class _Lines2D(object): varying vec4 vColor; - %s - void main(void) { - gl_Position = matrix * transformXY(xPos, yPos); + gl_Position = matrix * vec4(xPos, yPos, 0., 1.) ; vColor = color; } """, - 'fragment': """ + fragmentShader=""" #version 120 varying vec4 vColor; @@ -268,15 +287,14 @@ class _Lines2D(object): void main(void) { gl_FragColor = vColor; } - """ - }, - + """, + attrib0='xPos') - # Limitation: Dash using an estimate of distance in screen coord - # to avoid computing distance when viewport is resized - # results in inequal dashes when viewport aspect ratio is far from 1 - 'dashed': { - 'vertex': """ + # Limitation: Dash using an estimate of distance in screen coord + # to avoid computing distance when viewport is resized + # results in inequal dashes when viewport aspect ratio is far from 1 + _DASH_PROGRAM = Program( + vertexShader=""" #version 120 uniform mat4 matrix; @@ -289,10 +307,8 @@ class _Lines2D(object): varying float vDist; varying vec4 vColor; - %s - void main(void) { - gl_Position = matrix * transformXY(xPos, yPos); + gl_Position = matrix * vec4(xPos, yPos, 0., 1.); //Estimate distance in pixels vec2 probe = vec2(matrix * vec4(1., 1., 0., 0.)) * halfViewportSize; @@ -301,7 +317,7 @@ class _Lines2D(object): vColor = color; } """, - 'fragment': """ + fragmentShader=""" #version 120 /* Dashes: [0, x], [y, z] @@ -318,16 +334,14 @@ class _Lines2D(object): } gl_FragColor = vColor; } - """ - } - } - - _programs = {} + """, + attrib0='xPos') def __init__(self, xVboData=None, yVboData=None, colorVboData=None, distVboData=None, style=SOLID, color=(0., 0., 0., 1.), - width=1, dashPeriod=20, drawMode=None): + width=1, dashPeriod=20, drawMode=None, + offset=(0., 0.)): self.xVboData = xVboData self.yVboData = yVboData self.distVboData = distVboData @@ -335,136 +349,83 @@ class _Lines2D(object): self.useColorVboData = colorVboData is not None self.color = color - self._width = 1 self.width = width self._style = None self.style = style self.dashPeriod = dashPeriod + self.offset = offset self._drawMode = drawMode if drawMode is not None else gl.GL_LINE_STRIP @property def style(self): + """Line style (Union[str,None])""" return self._style @style.setter def style(self, style): if style in _MPL_NONES: self._style = None - self.render = self._renderNone else: assert style in self.STYLES self._style = style - if style == SOLID: - self.render = self._renderSolid - else: # DASHED, DASHDOT, DOTTED - self.render = self._renderDash - - @property - def width(self): - return self._width - - @width.setter - def width(self, width): - # try: - # widthRange = self._widthRange - # except AttributeError: - # widthRange = gl.glGetFloatv(gl.GL_ALIASED_LINE_WIDTH_RANGE) - # # Shared among contexts, this should be enough.. - # _Lines2D._widthRange = widthRange - # assert width >= widthRange[0] and width <= widthRange[1] - self._width = width - - @classmethod - def _getProgram(cls, transform, style): - try: - prgm = cls._programs[(transform, style)] - except KeyError: - sources = cls._SHADERS[style] - vertexShdr = sources['vertex'] % \ - cls._SHADERS['vertexTransforms'][transform] - prgm = Program(vertexShdr, sources['fragment'], attrib0='xPos') - cls._programs[(transform, style)] = prgm - return prgm @classmethod def init(cls): + """OpenGL context initialization""" gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST) - def _renderNone(self, matrix, isXLog, isYLog): - pass - - render = _renderNone # Overridden in style setter - - def _renderSolid(self, matrix, isXLog, isYLog): - if isXLog: - transform = self._LOG10_X_Y if isYLog else self._LOG10_X - else: - transform = self._LOG10_Y if isYLog else self._LINEAR - - prog = self._getProgram(transform, 'solid') - prog.use() - - gl.glEnable(gl.GL_LINE_SMOOTH) - - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) - - colorAttrib = prog.attributes['color'] - if self.useColorVboData and self.colorVboData is not None: - gl.glEnableVertexAttribArray(colorAttrib) - self.colorVboData.setVertexAttrib(colorAttrib) - else: - gl.glDisableVertexAttribArray(colorAttrib) - gl.glVertexAttrib4f(colorAttrib, *self.color) - - xPosAttrib = prog.attributes['xPos'] - gl.glEnableVertexAttribArray(xPosAttrib) - self.xVboData.setVertexAttrib(xPosAttrib) - - yPosAttrib = prog.attributes['yPos'] - gl.glEnableVertexAttribArray(yPosAttrib) - self.yVboData.setVertexAttrib(yPosAttrib) + def render(self, matrix): + """Perform rendering - gl.glLineWidth(self.width) - gl.glDrawArrays(self._drawMode, 0, self.xVboData.size) + :param numpy.ndarray matrix: 4x4 transform matrix to use + """ + style = self.style + if style is None: + return - gl.glDisable(gl.GL_LINE_SMOOTH) + elif style == SOLID: + program = self._SOLID_PROGRAM + program.use() + + else: # DASHED, DASHDOT, DOTTED + program = self._DASH_PROGRAM + program.use() + + x, y, viewWidth, viewHeight = gl.glGetFloatv(gl.GL_VIEWPORT) + gl.glUniform2f(program.uniforms['halfViewportSize'], + 0.5 * viewWidth, 0.5 * viewHeight) + + if self.style == DOTTED: + dash = (0.1 * self.dashPeriod, + 0.6 * self.dashPeriod, + 0.7 * self.dashPeriod, + self.dashPeriod) + elif self.style == DASHDOT: + dash = (0.3 * self.dashPeriod, + 0.5 * self.dashPeriod, + 0.6 * self.dashPeriod, + self.dashPeriod) + else: + dash = (0.5 * self.dashPeriod, + self.dashPeriod, + self.dashPeriod, + self.dashPeriod) - def _renderDash(self, matrix, isXLog, isYLog): - if isXLog: - transform = self._LOG10_X_Y if isYLog else self._LOG10_X - else: - transform = self._LOG10_Y if isYLog else self._LINEAR + gl.glUniform4f(program.uniforms['dash'], *dash) - prog = self._getProgram(transform, 'dashed') - prog.use() + distAttrib = program.attributes['distance'] + gl.glEnableVertexAttribArray(distAttrib) + self.distVboData.setVertexAttrib(distAttrib) gl.glEnable(gl.GL_LINE_SMOOTH) - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) - x, y, viewWidth, viewHeight = gl.glGetFloatv(gl.GL_VIEWPORT) - gl.glUniform2f(prog.uniforms['halfViewportSize'], - 0.5 * viewWidth, 0.5 * viewHeight) - - if self.style == DOTTED: - dash = (0.1 * self.dashPeriod, - 0.6 * self.dashPeriod, - 0.7 * self.dashPeriod, - self.dashPeriod) - elif self.style == DASHDOT: - dash = (0.3 * self.dashPeriod, - 0.5 * self.dashPeriod, - 0.6 * self.dashPeriod, - self.dashPeriod) - else: - dash = (0.5 * self.dashPeriod, - self.dashPeriod, - self.dashPeriod, - self.dashPeriod) + matrix = numpy.dot(matrix, + mat4Translate(*self.offset)).astype(numpy.float32) + gl.glUniformMatrix4fv(program.uniforms['matrix'], + 1, gl.GL_TRUE, matrix) - gl.glUniform4f(prog.uniforms['dash'], *dash) - - colorAttrib = prog.attributes['color'] + colorAttrib = program.attributes['color'] if self.useColorVboData and self.colorVboData is not None: gl.glEnableVertexAttribArray(colorAttrib) self.colorVboData.setVertexAttrib(colorAttrib) @@ -472,15 +433,11 @@ class _Lines2D(object): gl.glDisableVertexAttribArray(colorAttrib) gl.glVertexAttrib4f(colorAttrib, *self.color) - distAttrib = prog.attributes['distance'] - gl.glEnableVertexAttribArray(distAttrib) - self.distVboData.setVertexAttrib(distAttrib) - - xPosAttrib = prog.attributes['xPos'] + xPosAttrib = program.attributes['xPos'] gl.glEnableVertexAttribArray(xPosAttrib) self.xVboData.setVertexAttrib(xPosAttrib) - yPosAttrib = prog.attributes['yPos'] + yPosAttrib = program.attributes['yPos'] gl.glEnableVertexAttribArray(yPosAttrib) self.yVboData.setVertexAttrib(yPosAttrib) @@ -491,6 +448,12 @@ class _Lines2D(object): def _distancesFromArrays(xData, yData): + """Returns distances between each points + + :param numpy.ndarray xData: X coordinate of points + :param numpy.ndarray yData: Y coordinate of points + :rtype: numpy.ndarray + """ deltas = numpy.dstack(( numpy.ediff1d(xData, to_begin=numpy.float32(0.)), numpy.ediff1d(yData, to_begin=numpy.float32(0.))))[0] @@ -506,43 +469,22 @@ H_LINE, V_LINE = '_', '|' class _Points2D(object): + """Object rendering curve markers + + :param xVboData: X coordinates VBO + :param yVboData: Y coordinates VBO + :param colorVboData: VBO of colors + :param str marker: Kind of symbol to use, see :attr:`MARKERS`. + :param List[float] color: RGBA color as 4 float in [0, 1] + :param float size: Marker size + :param List[float] offset: Translation of coordinates (ox, oy) + """ + MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK, H_LINE, V_LINE) + """List of supported markers""" - _LINEAR, _LOG10_X, _LOG10_Y, _LOG10_X_Y = 0, 1, 2, 3 - - _SHADERS = { - 'vertexTransforms': { - _LINEAR: """ - vec4 transformXY(float x, float y) { - return vec4(x, y, 0.0, 1.0); - } - """, - _LOG10_X: """ - const float oneOverLog10 = 0.43429448190325176; - - vec4 transformXY(float x, float y) { - return vec4(oneOverLog10 * log(x), y, 0.0, 1.0); - } - """, - _LOG10_Y: """ - const float oneOverLog10 = 0.43429448190325176; - - vec4 transformXY(float x, float y) { - return vec4(x, oneOverLog10 * log(y), 0.0, 1.0); - } - """, - _LOG10_X_Y: """ - const float oneOverLog10 = 0.43429448190325176; - - vec4 transformXY(float x, float y) { - return vec4(oneOverLog10 * log(x), - oneOverLog10 * log(y), - 0.0, 1.0); - } - """ - }, - 'vertex': """ + _VERTEX_SHADER = """ #version 120 uniform mat4 matrix; @@ -554,16 +496,14 @@ class _Points2D(object): varying vec4 vColor; - %s - void main(void) { - gl_Position = matrix * transformXY(xPos, yPos); + gl_Position = matrix * vec4(xPos, yPos, 0., 1.); vColor = color; gl_PointSize = size; } - """, + """ - 'fragmentSymbols': { + _FRAGMENT_SHADER_SYMBOLS = { DIAMOND: """ float alphaSymbol(vec2 coord, float size) { vec2 centerCoord = abs(coord - vec2(0.5, 0.5)); @@ -640,9 +580,9 @@ class _Points2D(object): } } """ - }, + } - 'fragment': """ + _FRAGMENT_SHADER_TEMPLATE = """ #version 120 uniform float size; @@ -660,17 +600,17 @@ class _Points2D(object): } } """ - } - _programs = {} + _PROGRAMS = {} def __init__(self, xVboData=None, yVboData=None, colorVboData=None, - marker=SQUARE, color=(0., 0., 0., 1.), size=7): + marker=SQUARE, color=(0., 0., 0., 1.), size=7, + offset=(0., 0.)): self.color = color self._marker = None self.marker = marker - self._size = 1 self.size = size + self.offset = offset self.xVboData = xVboData self.yVboData = yVboData @@ -679,54 +619,37 @@ class _Points2D(object): @property def marker(self): + """Symbol used to display markers (str)""" return self._marker @marker.setter def marker(self, marker): if marker in _MPL_NONES: self._marker = None - self.render = self._renderNone else: assert marker in self.MARKERS self._marker = marker - self.render = self._renderMarkers - - @property - def size(self): - return self._size - - @size.setter - def size(self, size): - # try: - # sizeRange = self._sizeRange - # except AttributeError: - # sizeRange = gl.glGetFloatv(gl.GL_POINT_SIZE_RANGE) - # # Shared among contexts, this should be enough.. - # _Points2D._sizeRange = sizeRange - # assert size >= sizeRange[0] and size <= sizeRange[1] - self._size = size @classmethod - def _getProgram(cls, transform, marker): + def _getProgram(cls, marker): """On-demand shader program creation.""" if marker == PIXEL: marker = SQUARE elif marker == POINT: marker = CIRCLE - try: - prgm = cls._programs[(transform, marker)] - except KeyError: - vertShdr = cls._SHADERS['vertex'] % \ - cls._SHADERS['vertexTransforms'][transform] - fragShdr = cls._SHADERS['fragment'] % \ - cls._SHADERS['fragmentSymbols'][marker] - prgm = Program(vertShdr, fragShdr, attrib0='xPos') - - cls._programs[(transform, marker)] = prgm - return prgm + + if marker not in cls._PROGRAMS: + cls._PROGRAMS[marker] = Program( + vertexShader=cls._VERTEX_SHADER, + fragmentShader=(cls._FRAGMENT_SHADER_TEMPLATE % + cls._FRAGMENT_SHADER_SYMBOLS[marker]), + attrib0='xPos') + + return cls._PROGRAMS[marker] @classmethod def init(cls): + """OpenGL context initialization""" version = gl.glGetString(gl.GL_VERSION) majorVersion = int(version[0]) assert majorVersion >= 2 @@ -735,30 +658,31 @@ class _Points2D(object): if majorVersion >= 3: # OpenGL 3 gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) - def _renderNone(self, matrix, isXLog, isYLog): - pass + def render(self, matrix): + """Perform rendering - render = _renderNone + :param numpy.ndarray matrix: 4x4 transform matrix to use + """ + if self.marker is None: + return - def _renderMarkers(self, matrix, isXLog, isYLog): - if isXLog: - transform = self._LOG10_X_Y if isYLog else self._LOG10_X - else: - transform = self._LOG10_Y if isYLog else self._LINEAR + program = self._getProgram(self.marker) + program.use() + + matrix = numpy.dot(matrix, + mat4Translate(*self.offset)).astype(numpy.float32) + gl.glUniformMatrix4fv(program.uniforms['matrix'], 1, gl.GL_TRUE, matrix) - prog = self._getProgram(transform, self.marker) - prog.use() - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) if self.marker == PIXEL: size = 1 elif self.marker == POINT: size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point else: size = self.size - gl.glUniform1f(prog.uniforms['size'], size) + gl.glUniform1f(program.uniforms['size'], size) # gl.glPointSize(self.size) - cAttrib = prog.attributes['color'] + cAttrib = program.attributes['color'] if self.useColorVboData and self.colorVboData is not None: gl.glEnableVertexAttribArray(cAttrib) self.colorVboData.setVertexAttrib(cAttrib) @@ -766,11 +690,11 @@ class _Points2D(object): gl.glDisableVertexAttribArray(cAttrib) gl.glVertexAttrib4f(cAttrib, *self.color) - xAttrib = prog.attributes['xPos'] + xAttrib = program.attributes['xPos'] gl.glEnableVertexAttribArray(xAttrib) self.xVboData.setVertexAttrib(xAttrib) - yAttrib = prog.attributes['yPos'] + yAttrib = program.attributes['yPos'] gl.glEnableVertexAttribArray(yAttrib) self.yVboData.setVertexAttrib(yAttrib) @@ -786,40 +710,35 @@ class _ErrorBars(object): This is using its own VBO as opposed to fill/points/lines. There is no picking on error bars. - As is, there is no way to update data and errors, but it handles - log scales by removing data <= 0 and clipping error bars to positive - range. It uses 2 vertices per error bars and uses :class:`_Lines2D` to render error bars and :class:`_Points2D` to render the ends. + + :param numpy.ndarray xData: X coordinates of the data. + :param numpy.ndarray yData: Y coordinates of the data. + :param xError: The absolute error on the X axis. + :type xError: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for negative errors, + row 1 for positive errors. + :param yError: The absolute error on the Y axis. + :type yError: A float, or a numpy.ndarray of float32. See xError. + :param float xMin: The min X value already computed by GLPlotCurve2D. + :param float yMin: The min Y value already computed by GLPlotCurve2D. + :param List[float] color: RGBA color as 4 float in [0, 1] + :param List[float] offset: Translation of coordinates (ox, oy) """ def __init__(self, xData, yData, xError, yError, xMin, yMin, - color=(0., 0., 0., 1.)): - """Initialization. - - :param numpy.ndarray xData: X coordinates of the data. - :param numpy.ndarray yData: Y coordinates of the data. - :param xError: The absolute error on the X axis. - :type xError: A float, or a numpy.ndarray of float32. - If it is an array, it can either be a 1D array of - same length as the data or a 2D array with 2 rows - of same length as the data: row 0 for negative errors, - row 1 for positive errors. - :param yError: The absolute error on the Y axis. - :type yError: A float, or a numpy.ndarray of float32. See xError. - :param float xMin: The min X value already computed by GLPlotCurve2D. - :param float yMin: The min Y value already computed by GLPlotCurve2D. - :param color: The color to use for both lines and ending points. - :type color: tuple of 4 floats - """ + color=(0., 0., 0., 1.), + offset=(0., 0.)): self._attribs = None - self._isXLog, self._isYLog = False, False self._xMin, self._yMin = xMin, yMin + self.offset = offset if xError is not None or yError is not None: - assert len(xData) == len(yData) self._xData = numpy.array( xData, order='C', dtype=numpy.float32, copy=False) self._yData = numpy.array( @@ -834,61 +753,19 @@ class _ErrorBars(object): self._xData, self._yData = None, None self._xError, self._yError = None, None - self._lines = _Lines2D(None, None, color=color, drawMode=gl.GL_LINES) - self._xErrPoints = _Points2D(None, None, color=color, marker=V_LINE) - self._yErrPoints = _Points2D(None, None, color=color, marker=H_LINE) + self._lines = _Lines2D( + None, None, color=color, drawMode=gl.GL_LINES, offset=offset) + self._xErrPoints = _Points2D( + None, None, color=color, marker=V_LINE, offset=offset) + self._yErrPoints = _Points2D( + None, None, color=color, marker=H_LINE, offset=offset) - def _positiveValueFilter(self, onlyXPos, onlyYPos): - """Filter data (x, y) and errors (xError, yError) to remove - negative and null data values on required axis (onlyXPos, onlyYPos). + def _buildVertices(self): + """Generates error bars vertices""" + nbLinesPerDataPts = (0 if self._xError is None else 2) + \ + (0 if self._yError is None else 2) - Returned arrays might be NOT contiguous. - - :return: Filtered xData, yData, xError and yError arrays. - """ - if ((not onlyXPos or self._xMin > 0.) and - (not onlyYPos or self._yMin > 0.)): - # No need to filter, all values are > 0 on log axes - return self._xData, self._yData, self._xError, self._yError - - _logger.warning( - 'Removing values <= 0 of curve with error bars on a log axis.') - - x, y = self._xData, self._yData - xError, yError = self._xError, self._yError - - # First remove negative data - if onlyXPos and onlyYPos: - mask = (x > 0.) & (y > 0.) - elif onlyXPos: - mask = x > 0. - else: # onlyYPos - mask = y > 0. - x, y = x[mask], y[mask] - - # Remove corresponding values from error arrays - if xError is not None and xError.size != 1: - if len(xError.shape) == 1: - xError = xError[mask] - else: # 2 rows - xError = xError[:, mask] - if yError is not None and yError.size != 1: - if len(yError.shape) == 1: - yError = yError[mask] - else: # 2 rows - yError = yError[:, mask] - - return x, y, xError, yError - - def _buildVertices(self, isXLog, isYLog): - """Generates error bars vertices according to log scales.""" - xData, yData, xError, yError = self._positiveValueFilter( - isXLog, isYLog) - - nbLinesPerDataPts = 1 if xError is not None else 0 - nbLinesPerDataPts += 1 if yError is not None else 0 - - nbDataPts = len(xData) + nbDataPts = len(self._xData) # interleave coord+error, coord-error. # xError vertices first if any, then yError vertices if any. @@ -897,64 +774,61 @@ class _ErrorBars(object): yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2, dtype=numpy.float32) - if xError is not None: # errors on the X axis - if len(xError.shape) == 2: - xErrorMinus, xErrorPlus = xError[0], xError[1] + if self._xError is not None: # errors on the X axis + if len(self._xError.shape) == 2: + xErrorMinus, xErrorPlus = self._xError[0], self._xError[1] else: # numpy arrays of len 1 or len(xData) - xErrorMinus, xErrorPlus = xError, xError + xErrorMinus, xErrorPlus = self._xError, self._xError # Interleave vertices for xError - endXError = 2 * nbDataPts - xCoords[0:endXError-1:2] = xData + xErrorPlus + endXError = 4 * nbDataPts + xCoords[0:endXError-3:4] = self._xData + xErrorPlus + xCoords[1:endXError-2:4] = self._xData + xCoords[2:endXError-1:4] = self._xData + xCoords[3:endXError:4] = self._xData - xErrorMinus - minValues = xData - xErrorMinus - if isXLog: - # Clip min bounds to positive value - minValues[minValues <= 0] = FLOAT32_MINPOS - xCoords[1:endXError:2] = minValues + yCoords[0:endXError-3:4] = self._yData + yCoords[1:endXError-2:4] = self._yData + yCoords[2:endXError-1:4] = self._yData + yCoords[3:endXError:4] = self._yData - yCoords[0:endXError-1:2] = yData - yCoords[1:endXError:2] = yData else: endXError = 0 - if yError is not None: # errors on the Y axis - if len(yError.shape) == 2: - yErrorMinus, yErrorPlus = yError[0], yError[1] + if self._yError is not None: # errors on the Y axis + if len(self._yError.shape) == 2: + yErrorMinus, yErrorPlus = self._yError[0], self._yError[1] else: # numpy arrays of len 1 or len(yData) - yErrorMinus, yErrorPlus = yError, yError + yErrorMinus, yErrorPlus = self._yError, self._yError # Interleave vertices for yError - xCoords[endXError::2] = xData - xCoords[endXError+1::2] = xData - yCoords[endXError::2] = yData + yErrorPlus - minValues = yData - yErrorMinus - if isYLog: - # Clip min bounds to positive value - minValues[minValues <= 0] = FLOAT32_MINPOS - yCoords[endXError+1::2] = minValues + xCoords[endXError::4] = self._xData + xCoords[endXError+1::4] = self._xData + xCoords[endXError+2::4] = self._xData + xCoords[endXError+3::4] = self._xData + + yCoords[endXError::4] = self._yData + yErrorPlus + yCoords[endXError+1::4] = self._yData + yCoords[endXError+2::4] = self._yData + yCoords[endXError+3::4] = self._yData - yErrorMinus return xCoords, yCoords - def prepare(self, isXLog, isYLog): + def prepare(self): + """Rendering preparation: build indices and bounding box vertices""" if self._xData is None: return - if self._isXLog != isXLog or self._isYLog != isYLog: - # Log state has changed - self._isXLog, self._isYLog = isXLog, isYLog - - self.discard() # discard existing VBOs - if self._attribs is None: - xCoords, yCoords = self._buildVertices(isXLog, isYLog) + xCoords, yCoords = self._buildVertices() xAttrib, yAttrib = vertexBuffer((xCoords, yCoords)) self._attribs = xAttrib, yAttrib - self._lines.xVboData, self._lines.yVboData = xAttrib, yAttrib + self._lines.xVboData = xAttrib + self._lines.yVboData = yAttrib # Set xError points using the same VBO as lines self._xErrPoints.xVboData = xAttrib.copy() @@ -972,13 +846,20 @@ class _ErrorBars(object): self._yErrPoints.yVboData.offset += (yAttrib.itemsize * yAttrib.size // 2) - def render(self, matrix, isXLog, isYLog): + def render(self, matrix): + """Perform rendering + + :param numpy.ndarray matrix: 4x4 transform matrix to use + """ + self.prepare() + if self._attribs is not None: - self._lines.render(matrix, isXLog, isYLog) - self._xErrPoints.render(matrix, isXLog, isYLog) - self._yErrPoints.render(matrix, isXLog, isYLog) + self._lines.render(matrix) + self._xErrPoints.render(matrix) + self._yErrPoints.render(matrix) def discard(self): + """Release VBOs""" if self._attribs is not None: self._lines.xVboData, self._lines.yVboData = None, None self._xErrPoints.xVboData, self._xErrPoints.yVboData = None, None @@ -1014,71 +895,80 @@ def _proxyProperty(*componentsAttributes): class GLPlotCurve2D(object): def __init__(self, xData, yData, colorData=None, xError=None, yError=None, - lineStyle=None, lineColor=None, - lineWidth=None, lineDashPeriod=None, - marker=None, markerColor=None, markerSize=None, - fillColor=None): - self._isXLog = False - self._isYLog = False - self.xData, self.yData, self.colorData = xData, yData, colorData - - if fillColor is not None: - self.fill = _Fill2D(color=fillColor) - else: - self.fill = None + lineStyle=SOLID, + lineColor=(0., 0., 0., 1.), + lineWidth=1, + lineDashPeriod=20, + marker=SQUARE, + markerColor=(0., 0., 0., 1.), + markerSize=7, + fillColor=None, + isYLog=False): + + self.colorData = colorData # Compute x bounds if xError is None: - result = min_max(xData, min_positive=True) - self.xMin = result.minimum - self.xMinPos = result.min_positive - self.xMax = result.maximum + self.xMin, self.xMax = min_max(xData, min_positive=False) else: # Takes the error into account if hasattr(xError, 'shape') and len(xError.shape) == 2: - xErrorPlus, xErrorMinus = xError[0], xError[1] + xErrorMinus, xErrorPlus = xError[0], xError[1] else: - xErrorPlus, xErrorMinus = xError, xError - result = min_max(xData - xErrorMinus, min_positive=True) - self.xMin = result.minimum - self.xMinPos = result.min_positive - self.xMax = (xData + xErrorPlus).max() + xErrorMinus, xErrorPlus = xError, xError + self.xMin = numpy.nanmin(xData - xErrorMinus) + self.xMax = numpy.nanmax(xData + xErrorPlus) # Compute y bounds if yError is None: - result = min_max(yData, min_positive=True) - self.yMin = result.minimum - self.yMinPos = result.min_positive - self.yMax = result.maximum + self.yMin, self.yMax = min_max(yData, min_positive=False) else: # Takes the error into account if hasattr(yError, 'shape') and len(yError.shape) == 2: - yErrorPlus, yErrorMinus = yError[0], yError[1] + yErrorMinus, yErrorPlus = yError[0], yError[1] else: - yErrorPlus, yErrorMinus = yError, yError - result = min_max(yData - yErrorMinus, min_positive=True) - self.yMin = result.minimum - self.yMinPos = result.min_positive - self.yMax = (yData + yErrorPlus).max() - - self._errorBars = _ErrorBars(xData, yData, xError, yError, - self.xMin, self.yMin) - - kwargs = {'style': lineStyle} - if lineColor is not None: - kwargs['color'] = lineColor - if lineWidth is not None: - kwargs['width'] = lineWidth - if lineDashPeriod is not None: - kwargs['dashPeriod'] = lineDashPeriod - self.lines = _Lines2D(**kwargs) - - kwargs = {'marker': marker} - if markerColor is not None: - kwargs['color'] = markerColor - if markerSize is not None: - kwargs['size'] = markerSize - self.points = _Points2D(**kwargs) + yErrorMinus, yErrorPlus = yError, yError + self.yMin = numpy.nanmin(yData - yErrorMinus) + self.yMax = numpy.nanmax(yData + yErrorPlus) + + # Handle data offset + if xData.itemsize > 4 or yData.itemsize > 4: # Use normalization + # offset data, do not offset error as it is relative + self.offset = self.xMin, self.yMin + self.xData = (xData - self.offset[0]).astype(numpy.float32) + self.yData = (yData - self.offset[1]).astype(numpy.float32) + + else: # float32 + self.offset = 0., 0. + self.xData = xData + self.yData = yData + + if fillColor is not None: + # Use different baseline depending of Y log scale + self.fill = _Fill2D(self.xData, self.yData, + baseline=-38 if isYLog else 0, + color=fillColor, + offset=self.offset) + else: + self.fill = None + + self._errorBars = _ErrorBars(self.xData, self.yData, + xError, yError, + self.xMin, self.yMin, + offset=self.offset) + + self.lines = _Lines2D() + self.lines.style = lineStyle + self.lines.color = lineColor + self.lines.width = lineWidth + self.lines.dashPeriod = lineDashPeriod + self.lines.offset = self.offset + + self.points = _Points2D() + self.points.marker = marker + self.points.color = markerColor + self.points.size = markerSize + self.points.offset = self.offset xVboData = _proxyProperty(('lines', 'xVboData'), ('points', 'xVboData')) @@ -1108,123 +998,53 @@ class GLPlotCurve2D(object): @classmethod def init(cls): + """OpenGL context initialization""" _Lines2D.init() _Points2D.init() - @staticmethod - def _logFilterData(x, y, color=None, xLog=False, yLog=False): - # Copied from Plot.py - if xLog and yLog: - idx = numpy.nonzero((x > 0) & (y > 0))[0] - x = numpy.take(x, idx) - y = numpy.take(y, idx) - elif yLog: - idx = numpy.nonzero(y > 0)[0] - x = numpy.take(x, idx) - y = numpy.take(y, idx) - elif xLog: - idx = numpy.nonzero(x > 0)[0] - x = numpy.take(x, idx) - y = numpy.take(y, idx) - else: - idx = None - - if idx is not None and isinstance(color, numpy.ndarray): - colors = numpy.zeros((x.size, 4), color.dtype) - colors[:, 0] = color[idx, 0] - colors[:, 1] = color[idx, 1] - colors[:, 2] = color[idx, 2] - colors[:, 3] = color[idx, 3] - else: - colors = color - return x, y, colors - - def prepare(self, isXLog, isYLog): - # init only supports updating isXLog, isYLog - xData, yData, colorData = self.xData, self.yData, self.colorData - - if self._isXLog != isXLog or self._isYLog != isYLog: - # Log state has changed - self._isXLog, self._isYLog = isXLog, isYLog - - # Check if data <= 0. with log scale - if (isXLog and self.xMin <= 0.) or (isYLog and self.yMin <= 0.): - # Filtering data is needed - xData, yData, colorData = self._logFilterData( - self.xData, self.yData, self.colorData, - self._isXLog, self._isYLog) - - self.discard() # discard existing VBOs - + def prepare(self): + """Rendering preparation: build indices and bounding box vertices""" if self.xVboData is None: xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None if self.lineStyle in (DASHED, DASHDOT, DOTTED): - dists = _distancesFromArrays(xData, yData) + dists = _distancesFromArrays(self.xData, self.yData) if self.colorData is None: xAttrib, yAttrib, dAttrib = vertexBuffer( - (xData, yData, dists), - prefix=(1, 1, 0), suffix=(1, 1, 0)) + (self.xData, self.yData, dists)) else: xAttrib, yAttrib, cAttrib, dAttrib = vertexBuffer( - (xData, yData, colorData, dists), - prefix=(1, 1, 0, 0), suffix=(1, 1, 0, 0)) + (self.xData, self.yData, self.colorData, dists)) elif self.colorData is None: - xAttrib, yAttrib = vertexBuffer( - (xData, yData), prefix=(1, 1), suffix=(1, 1)) + xAttrib, yAttrib = vertexBuffer((self.xData, self.yData)) else: xAttrib, yAttrib, cAttrib = vertexBuffer( - (xData, yData, colorData), prefix=(1, 1, 0)) - - # Shrink VBO - self.xVboData = xAttrib.copy() - self.xVboData.size -= 2 - self.xVboData.offset += xAttrib.itemsize + (self.xData, self.yData, self.colorData)) - self.yVboData = yAttrib.copy() - self.yVboData.size -= 2 - self.yVboData.offset += yAttrib.itemsize + self.xVboData = xAttrib + self.yVboData = yAttrib + self.distVboData = dAttrib - if cAttrib is not None and colorData.dtype.kind == 'u': + if cAttrib is not None and self.colorData.dtype.kind == 'u': cAttrib.normalization = True # Normalize uint to [0, 1] self.colorVboData = cAttrib self.useColorVboData = cAttrib is not None - self.distVboData = dAttrib - - if self.fill is not None: - xData = xData.reshape(xData.size, 1) - zero = numpy.array((1e-32,), dtype=self.yData.dtype) - - # Add one point before data: (x0, 0.) - xAttrib.vbo.update(xData[0], xAttrib.offset, - xData[0].itemsize) - yAttrib.vbo.update(zero, yAttrib.offset, zero.itemsize) - - # Add one point after data: (xN, 0.) - xAttrib.vbo.update(xData[-1], - xAttrib.offset + - (xAttrib.size - 1) * xAttrib.itemsize, - xData[-1].itemsize) - yAttrib.vbo.update(zero, - yAttrib.offset + - (yAttrib.size - 1) * yAttrib.itemsize, - zero.itemsize) - - self.fill.xFillVboData = xAttrib - self.fill.yFillVboData = yAttrib - self.fill.xMin, self.fill.yMin = self.xMin, self.yMin - self.fill.xMax, self.fill.yMax = self.xMax, self.yMax - - self._errorBars.prepare(isXLog, isYLog) def render(self, matrix, isXLog, isYLog): - self.prepare(isXLog, isYLog) + """Perform rendering + + :param numpy.ndarray matrix: 4x4 transform matrix to use + :param bool isXLog: + :param bool isYLog: + """ + self.prepare() if self.fill is not None: - self.fill.render(matrix, isXLog, isYLog) - self._errorBars.render(matrix, isXLog, isYLog) - self.lines.render(matrix, isXLog, isYLog) - self.points.render(matrix, isXLog, isYLog) + self.fill.render(matrix) + self._errorBars.render(matrix) + self.lines.render(matrix) + self.points.render(matrix) def discard(self): + """Release VBOs""" if self.xVboData is not None: self.xVboData.vbo.discard() @@ -1234,6 +1054,8 @@ class GLPlotCurve2D(object): self.distVboData = None self._errorBars.discard() + if self.fill is not None: + self.fill.discard() def pick(self, xPickMin, yPickMin, xPickMax, yPickMax): """Perform picking on the curve according to its rendering. @@ -1251,19 +1073,29 @@ class GLPlotCurve2D(object): if (self.marker is None and self.lineStyle is None) or \ self.xMin > xPickMax or xPickMin > self.xMax or \ self.yMin > yPickMax or yPickMin > self.yMax: - # Note: With log scale the bounding box is too large if - # some data <= 0. return None - elif self.lineStyle is not None: + # offset picking bounds + xPickMin = xPickMin - self.offset[0] + xPickMax = xPickMax - self.offset[0] + yPickMin = yPickMin - self.offset[1] + yPickMax = yPickMax - self.offset[1] + + if self.lineStyle is not None: # Using Cohen-Sutherland algorithm for line clipping - codes = ((self.yData > yPickMax) << 3) | \ + with warnings.catch_warnings(): # Ignore NaN comparison warnings + warnings.simplefilter('ignore', category=RuntimeWarning) + codes = ((self.yData > yPickMax) << 3) | \ ((self.yData < yPickMin) << 2) | \ ((self.xData > xPickMax) << 1) | \ (self.xData < xPickMin) + notNaN = numpy.logical_not(numpy.logical_or( + numpy.isnan(self.xData), numpy.isnan(self.yData))) + # Add all points that are inside the picking area - indices = numpy.nonzero(codes == 0)[0].tolist() + indices = numpy.nonzero( + numpy.logical_and(codes == 0, notNaN))[0].tolist() # Segment that might cross the area with no end point inside it segToTestIdx = numpy.nonzero((codes[:-1] != 0) & @@ -1309,9 +1141,11 @@ class GLPlotCurve2D(object): indices.sort() else: - indices = numpy.nonzero((self.xData >= xPickMin) & - (self.xData <= xPickMax) & - (self.yData >= yPickMin) & - (self.yData <= yPickMax))[0].tolist() + with warnings.catch_warnings(): # Ignore NaN comparison warnings + warnings.simplefilter('ignore', category=RuntimeWarning) + indices = numpy.nonzero((self.xData >= xPickMin) & + (self.xData <= xPickMax) & + (self.yData >= yPickMin) & + (self.yData <= yPickMax))[0].tolist() return indices diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py index eb101c4..4ad1547 100644 --- a/silx/gui/plot/backends/glutils/GLPlotFrame.py +++ b/silx/gui/plot/backends/glutils/GLPlotFrame.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,7 @@ __date__ = "03/04/2017" # keep aspect ratio managed here? # smarter dirty flag handling? +import datetime as dt import math import weakref import logging @@ -47,7 +48,8 @@ from ..._utils import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX from .GLSupport import mat4Ortho from .GLText import Text2D, CENTER, BOTTOM, TOP, LEFT, RIGHT, ROTATE_270 from ..._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10 - +from ..._utils.dtime_ticklayout import calcTicksAdaptive, bestFormatString +from ..._utils.dtime_ticklayout import timestamp _logger = logging.getLogger(__name__) @@ -68,6 +70,8 @@ class PlotAxis(object): self._plot = weakref.ref(plot) + self._isDateTime = False + self._timeZone = None self._isLog = False self._dataRange = 1., 100. self._displayCoords = (0., 0.), (1., 0.) @@ -110,6 +114,29 @@ class PlotAxis(object): self._dirtyTicks() @property + def timeZone(self): + """Returnss datetime.tzinfo that is used if this axis plots date times.""" + return self._timeZone + + @timeZone.setter + def timeZone(self, tz): + """Sets dateetime.tzinfo that is used if this axis plots date times.""" + self._timeZone = tz + self._dirtyTicks() + + @property + def isTimeSeries(self): + """Whether the axis is showing floats as datetime objects""" + return self._isDateTime + + @isTimeSeries.setter + def isTimeSeries(self, isTimeSeries): + isTimeSeries = bool(isTimeSeries) + if isTimeSeries != self._isDateTime: + self._isDateTime = isTimeSeries + self._dirtyTicks() + + @property def displayCoords(self): """The coordinates of the start and end points of the axis in display space (i.e., in pixels) as a tuple of 2 tuples of @@ -235,6 +262,10 @@ class PlotAxis(object): (x0, y0), (x1, y1) = self.displayCoords if self.isLog: + + if self.isTimeSeries: + _logger.warning("Time series not implemented for log-scale") + logMin, logMax = math.log10(dataMin), math.log10(dataMax) tickMin, tickMax, step, _ = niceNumbersForLog10(logMin, logMax) @@ -269,19 +300,41 @@ class PlotAxis(object): # Density of 1.3 label per 92 pixels # i.e., 1.3 label per inch on a 92 dpi screen - tickMin, tickMax, step, nbFrac = niceNumbersAdaptative( - dataMin, dataMax, nbPixels, 1.3 / 92) - - for dataPos in self._frange(tickMin, tickMax, step): - if dataMin <= dataPos <= dataMax: - xPixel = x0 + (dataPos - dataMin) * xScale - yPixel = y0 + (dataPos - dataMin) * yScale - - if nbFrac == 0: - text = '%g' % dataPos - else: - text = ('%.' + str(nbFrac) + 'f') % dataPos - yield ((xPixel, yPixel), dataPos, text) + tickDensity = 1.3 / 92 + + if not self.isTimeSeries: + tickMin, tickMax, step, nbFrac = niceNumbersAdaptative( + dataMin, dataMax, nbPixels, tickDensity) + + for dataPos in self._frange(tickMin, tickMax, step): + if dataMin <= dataPos <= dataMax: + xPixel = x0 + (dataPos - dataMin) * xScale + yPixel = y0 + (dataPos - dataMin) * yScale + + if nbFrac == 0: + text = '%g' % dataPos + else: + text = ('%.' + str(nbFrac) + 'f') % dataPos + yield ((xPixel, yPixel), dataPos, text) + else: + # Time series + dtMin = dt.datetime.fromtimestamp(dataMin, tz=self.timeZone) + dtMax = dt.datetime.fromtimestamp(dataMax, tz=self.timeZone) + + tickDateTimes, spacing, unit = calcTicksAdaptive( + dtMin, dtMax, nbPixels, tickDensity) + + for tickDateTime in tickDateTimes: + if dtMin <= tickDateTime <= dtMax: + + dataPos = timestamp(tickDateTime) + xPixel = x0 + (dataPos - dataMin) * xScale + yPixel = y0 + (dataPos - dataMin) * yScale + + fmtStr = bestFormatString(spacing, unit) + text = tickDateTime.strftime(fmtStr) + + yield ((xPixel, yPixel), dataPos, text) # GLPlotFrame ################################################################# @@ -501,7 +554,8 @@ class GLPlotFrame(object): gl.glLineWidth(self._LINE_WIDTH) - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + matProj.astype(numpy.float32)) gl.glUniform4f(prog.uniforms['color'], 0., 0., 0., 1.) gl.glUniform1f(prog.uniforms['tickFactor'], 0.) @@ -534,7 +588,8 @@ class GLPlotFrame(object): prog.use() gl.glLineWidth(self._LINE_WIDTH) - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + matProj.astype(numpy.float32)) gl.glUniform4f(prog.uniforms['color'], 0.7, 0.7, 0.7, 1.) gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen @@ -810,11 +865,11 @@ class GLPlotFrame2D(GLPlotFrame): # Non-orthogonal axes if self.baseVectors != self.DEFAULT_BASE_VECTORS: (xx, xy), (yx, yy) = self.baseVectors - mat = mat * numpy.matrix(( + mat = numpy.dot(mat, numpy.array(( (xx, yx, 0., 0.), (xy, yy, 0., 0.), (0., 0., 1., 0.), - (0., 0., 0., 1.)), dtype=numpy.float32) + (0., 0., 0., 1.)), dtype=numpy.float64)) self._transformedDataProjMat = mat @@ -839,11 +894,11 @@ class GLPlotFrame2D(GLPlotFrame): # Non-orthogonal axes if self.baseVectors != self.DEFAULT_BASE_VECTORS: (xx, xy), (yx, yy) = self.baseVectors - mat = mat * numpy.matrix(( + mat = numpy.dot(mat, numpy.matrix(( (xx, yx, 0., 0.), (xy, yy, 0., 0.), (0., 0., 1., 0.), - (0., 0., 0., 1.)), dtype=numpy.float32) + (0., 0., 0., 1.)), dtype=numpy.float64)) self._transformedDataY2ProjMat = mat diff --git a/silx/gui/plot/backends/glutils/GLPlotImage.py b/silx/gui/plot/backends/glutils/GLPlotImage.py index df5b289..6f3c487 100644 --- a/silx/gui/plot/backends/glutils/GLPlotImage.py +++ b/silx/gui/plot/backends/glutils/GLPlotImage.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -350,8 +350,11 @@ class GLPlotColormap(_GLPlotData2D): gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT) - mat = matrix * mat4Translate(*self.origin) * mat4Scale(*self.scale) - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, mat) + mat = numpy.dot(numpy.dot(matrix, + mat4Translate(*self.origin)), + mat4Scale(*self.scale)) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + mat.astype(numpy.float32)) gl.glUniform1f(prog.uniforms['alpha'], self.alpha) @@ -377,9 +380,11 @@ class GLPlotColormap(_GLPlotData2D): gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT) - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) - mat = mat4Translate(ox, oy) * mat4Scale(*self.scale) - gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, mat) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + matrix.astype(numpy.float32)) + mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale)) + gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, + mat.astype(numpy.float32)) gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog) @@ -598,8 +603,10 @@ class GLPlotRGBAImage(_GLPlotData2D): gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT) - mat = matrix * mat4Translate(*self.origin) * mat4Scale(*self.scale) - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, mat) + mat = numpy.dot(numpy.dot(matrix, mat4Translate(*self.origin)), + mat4Scale(*self.scale)) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + mat.astype(numpy.float32)) gl.glUniform1f(prog.uniforms['alpha'], self.alpha) @@ -617,9 +624,11 @@ class GLPlotRGBAImage(_GLPlotData2D): gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT) - gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) - mat = mat4Translate(ox, oy) * mat4Scale(*self.scale) - gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, mat) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + matrix.astype(numpy.float32)) + mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale)) + gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, + mat.astype(numpy.float32)) gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog) diff --git a/silx/gui/plot/backends/glutils/GLSupport.py b/silx/gui/plot/backends/glutils/GLSupport.py index 3f473be..18c5eb7 100644 --- a/silx/gui/plot/backends/glutils/GLSupport.py +++ b/silx/gui/plot/backends/glutils/GLSupport.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -36,11 +36,20 @@ import numpy from ...._glutils import gl -def buildFillMaskIndices(nIndices): - if nIndices <= numpy.iinfo(numpy.uint16).max + 1: - dtype = numpy.uint16 - else: - dtype = numpy.uint32 +def buildFillMaskIndices(nIndices, dtype=None): + """Returns triangle strip indices for rendering a filled polygon mask + + :param int nIndices: Number of points + :param Union[numpy.dtype,None] dtype: + If specified the dtype of the returned indices array + :return: 1D array of indices constructing a triangle strip + :rtype: numpy.ndarray + """ + if dtype is None: + if nIndices <= numpy.iinfo(numpy.uint16).max + 1: + dtype = numpy.uint16 + else: + dtype = numpy.uint32 lastIndex = nIndices - 1 splitIndex = lastIndex // 2 + 1 @@ -158,35 +167,35 @@ class Shape2D(object): def mat4Ortho(left, right, bottom, top, near, far): """Orthographic projection matrix (row-major)""" - return numpy.matrix(( + return numpy.array(( (2./(right - left), 0., 0., -(right+left)/float(right-left)), (0., 2./(top - bottom), 0., -(top+bottom)/float(top-bottom)), (0., 0., -2./(far-near), -(far+near)/float(far-near)), - (0., 0., 0., 1.)), dtype=numpy.float32) + (0., 0., 0., 1.)), dtype=numpy.float64) def mat4Translate(x=0., y=0., z=0.): """Translation matrix (row-major)""" - return numpy.matrix(( + return numpy.array(( (1., 0., 0., x), (0., 1., 0., y), (0., 0., 1., z), - (0., 0., 0., 1.)), dtype=numpy.float32) + (0., 0., 0., 1.)), dtype=numpy.float64) def mat4Scale(sx=1., sy=1., sz=1.): """Scale matrix (row-major)""" - return numpy.matrix(( + return numpy.array(( (sx, 0., 0., 0.), (0., sy, 0., 0.), (0., 0., sz, 0.), - (0., 0., 0., 1.)), dtype=numpy.float32) + (0., 0., 0., 1.)), dtype=numpy.float64) def mat4Identity(): """Identity matrix""" - return numpy.matrix(( + return numpy.array(( (1., 0., 0., 0.), (0., 1., 0., 0.), (0., 0., 1., 0.), - (0., 0., 0., 1.)), dtype=numpy.float32) + (0., 0., 0., 1.)), dtype=numpy.float64) diff --git a/silx/gui/plot/backends/glutils/GLText.py b/silx/gui/plot/backends/glutils/GLText.py index cef0c5a..1540e26 100644 --- a/silx/gui/plot/backends/glutils/GLText.py +++ b/silx/gui/plot/backends/glutils/GLText.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -195,8 +195,9 @@ class Text2D(object): gl.glUniform1i(prog.uniforms['texText'], texUnit) + mat = numpy.dot(matrix, mat4Translate(int(self.x), int(self.y))) gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, - matrix * mat4Translate(int(self.x), int(self.y))) + mat.astype(numpy.float32)) gl.glUniform4f(prog.uniforms['color'], *self.color) if self.bgColor is not None: diff --git a/silx/gui/plot/items/axis.py b/silx/gui/plot/items/axis.py index d7e6eff..3d9fe14 100644 --- a/silx/gui/plot/items/axis.py +++ b/silx/gui/plot/items/axis.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,12 +29,24 @@ __authors__ = ["V. Valls"] __license__ = "MIT" __date__ = "06/12/2017" +import datetime as dt import logging + +import dateutil.tz + from ... import qt +from silx.third_party import enum + _logger = logging.getLogger(__name__) +class TickMode(enum.Enum): + """Determines if ticks are regular number or datetimes.""" + DEFAULT = 0 # Ticks are regular numbers + TIME_SERIES = 1 # Ticks are datetime objects + + class Axis(qt.QObject): """This class describes and controls a plot axis. @@ -82,7 +94,23 @@ class Axis(qt.QObject): # Store currently displayed labels # Current label can differ from input one with active curve handling self._currentLabel = '' - self._plot = plot + + def _getPlot(self): + """Returns the PlotWidget this Axis belongs to. + + :rtype: PlotWidget + """ + plot = self.parent() + if plot is None: + raise RuntimeError("Axis no longer attached to a PlotWidget") + return plot + + def _getBackend(self): + """Returns the backend + + :rtype: BackendBase + """ + return self._getPlot()._backend def getLimits(self): """Get the limits of this axis. @@ -102,7 +130,7 @@ class Axis(qt.QObject): return self._internalSetLimits(vmin, vmax) - self._plot._setDirtyPlot() + self._getPlot()._setDirtyPlot() self._emitLimitsChanged() @@ -110,7 +138,7 @@ class Axis(qt.QObject): """Emit axis sigLimitsChanged and PlotWidget limitsChanged event""" vmin, vmax = self.getLimits() self.sigLimitsChanged.emit(vmin, vmax) - self._plot._notifyLimitsChanged(emitSignal=False) + self._getPlot()._notifyLimitsChanged(emitSignal=False) def _checkLimits(self, vmin, vmax): """Makes sure axis range is not empty @@ -172,7 +200,7 @@ class Axis(qt.QObject): """ self._defaultLabel = label self._setCurrentLabel(label) - self._plot._setDirtyPlot() + self._getPlot()._setDirtyPlot() def _setCurrentLabel(self, label): """Define the label currently displayed. @@ -207,6 +235,14 @@ class Axis(qt.QObject): # For the backward compatibility signal emitLog = self._scale == self.LOGARITHMIC or scale == self.LOGARITHMIC + self._scale = scale + + # TODO hackish way of forcing update of curves and images + plot = self._getPlot() + for item in plot._getItems(withhidden=True): + item._updated() + plot._invalidateDataRange() + if scale == self.LOGARITHMIC: self._internalSetLogarithmic(True) elif scale == self.LINEAR: @@ -214,13 +250,7 @@ class Axis(qt.QObject): else: raise ValueError("Scale %s unsupported" % scale) - self._scale = scale - - # TODO hackish way of forcing update of curves and images - for item in self._plot._getItems(withhidden=True): - item._updated() - self._plot._invalidateDataRange() - self._plot._forceResetZoom() + plot._forceResetZoom() self.sigScaleChanged.emit(self._scale) if emitLog: @@ -241,6 +271,40 @@ class Axis(qt.QObject): flag = bool(flag) self.setScale(self.LOGARITHMIC if flag else self.LINEAR) + def getTimeZone(self): + """Sets tzinfo that is used if this axis plots date times. + + None means the datetimes are interpreted as local time. + + :rtype: datetime.tzinfo of None. + """ + raise NotImplementedError() + + def setTimeZone(self, tz): + """Sets tzinfo that is used if this axis' tickMode is TIME_SERIES + + The tz must be a descendant of the datetime.tzinfo class, "UTC" or None. + Use None to let the datetimes be interpreted as local time. + Use the string "UTC" to let the date datetimes be in UTC time. + + :param tz: datetime.tzinfo, "UTC" or None. + """ + raise NotImplementedError() + + def getTickMode(self): + """Determines if axis ticks are number or datetimes. + + :rtype: TickMode enum. + """ + raise NotImplementedError() + + def setTickMode(self, tickMode): + """Determines if axis ticks are number or datetimes. + + :param TickMode tickMode: tick mode enum. + """ + raise NotImplementedError() + def isAutoScale(self): """Return True if axis is automatically adjusting its limits. @@ -271,7 +335,7 @@ class Axis(qt.QObject): """ updated = self._setLimitsConstraints(minPos, maxPos) if updated: - plot = self._plot + plot = self._getPlot() xMin, xMax = plot.getXAxis().getLimits() yMin, yMax = plot.getYAxis().getLimits() y2Min, y2Max = plot.getYAxis('right').getLimits() @@ -294,7 +358,7 @@ class Axis(qt.QObject): """ updated = self._setRangeConstraints(minRange, maxRange) if updated: - plot = self._plot + plot = self._getPlot() xMin, xMax = plot.getXAxis().getLimits() yMin, yMax = plot.getYAxis().getLimits() y2Min, y2Max = plot.getYAxis('right').getLimits() @@ -308,25 +372,51 @@ class XAxis(Axis): # TODO With some changes on the backend, it will be able to remove all this # specialised implementations (prefixel by '_internal') + def getTimeZone(self): + return self._getBackend().getXAxisTimeZone() + + def setTimeZone(self, tz): + if isinstance(tz, str) and tz.upper() == "UTC": + tz = dateutil.tz.tzutc() + elif not(tz is None or isinstance(tz, dt.tzinfo)): + raise TypeError("tz must be a dt.tzinfo object, None or 'UTC'.") + + self._getBackend().setXAxisTimeZone(tz) + self._getPlot()._setDirtyPlot() + + def getTickMode(self): + if self._getBackend().isXAxisTimeSeries(): + return TickMode.TIME_SERIES + else: + return TickMode.DEFAULT + + def setTickMode(self, tickMode): + if tickMode == TickMode.DEFAULT: + self._getBackend().setXAxisTimeSeries(False) + elif tickMode == TickMode.TIME_SERIES: + self._getBackend().setXAxisTimeSeries(True) + else: + raise ValueError("Unexpected TickMode: {}".format(tickMode)) + def _internalSetCurrentLabel(self, label): - self._plot._backend.setGraphXLabel(label) + self._getBackend().setGraphXLabel(label) def _internalGetLimits(self): - return self._plot._backend.getGraphXLimits() + return self._getBackend().getGraphXLimits() def _internalSetLimits(self, xmin, xmax): - self._plot._backend.setGraphXLimits(xmin, xmax) + self._getBackend().setGraphXLimits(xmin, xmax) def _internalSetLogarithmic(self, flag): - self._plot._backend.setXAxisLogarithmic(flag) + self._getBackend().setXAxisLogarithmic(flag) def _setLimitsConstraints(self, minPos=None, maxPos=None): - constrains = self._plot._getViewConstraints() + constrains = self._getPlot()._getViewConstraints() updated = constrains.update(xMin=minPos, xMax=maxPos) return updated def _setRangeConstraints(self, minRange=None, maxRange=None): - constrains = self._plot._getViewConstraints() + constrains = self._getPlot()._getViewConstraints() updated = constrains.update(minXRange=minRange, maxXRange=maxRange) return updated @@ -338,16 +428,16 @@ class YAxis(Axis): # specialised implementations (prefixel by '_internal') def _internalSetCurrentLabel(self, label): - self._plot._backend.setGraphYLabel(label, axis='left') + self._getBackend().setGraphYLabel(label, axis='left') def _internalGetLimits(self): - return self._plot._backend.getGraphYLimits(axis='left') + return self._getBackend().getGraphYLimits(axis='left') def _internalSetLimits(self, ymin, ymax): - self._plot._backend.setGraphYLimits(ymin, ymax, axis='left') + self._getBackend().setGraphYLimits(ymin, ymax, axis='left') def _internalSetLogarithmic(self, flag): - self._plot._backend.setYAxisLogarithmic(flag) + self._getBackend().setYAxisLogarithmic(flag) def setInverted(self, flag=True): """Set the axis orientation. @@ -358,8 +448,8 @@ class YAxis(Axis): False for Y axis going from bottom to top """ flag = bool(flag) - self._plot._backend.setYAxisInverted(flag) - self._plot._setDirtyPlot() + self._getBackend().setYAxisInverted(flag) + self._getPlot()._setDirtyPlot() self.sigInvertedChanged.emit(flag) def isInverted(self): @@ -368,15 +458,15 @@ class YAxis(Axis): :rtype: bool """ - return self._plot._backend.isYAxisInverted() + return self._getBackend().isYAxisInverted() def _setLimitsConstraints(self, minPos=None, maxPos=None): - constrains = self._plot._getViewConstraints() + constrains = self._getPlot()._getViewConstraints() updated = constrains.update(yMin=minPos, yMax=maxPos) return updated def _setRangeConstraints(self, minRange=None, maxRange=None): - constrains = self._plot._getViewConstraints() + constrains = self._getPlot()._getViewConstraints() updated = constrains.update(minYRange=minRange, maxYRange=maxRange) return updated @@ -419,13 +509,13 @@ class YRightAxis(Axis): return self.__mainAxis.sigAutoScaleChanged def _internalSetCurrentLabel(self, label): - self._plot._backend.setGraphYLabel(label, axis='right') + self._getBackend().setGraphYLabel(label, axis='right') def _internalGetLimits(self): - return self._plot._backend.getGraphYLimits(axis='right') + return self._getBackend().getGraphYLimits(axis='right') def _internalSetLimits(self, ymin, ymax): - self._plot._backend.setGraphYLimits(ymin, ymax, axis='right') + self._getBackend().setGraphYLimits(ymin, ymax, axis='right') def setInverted(self, flag=True): """Set the Y axis orientation. diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py index ba57e85..535b0a9 100644 --- a/silx/gui/plot/items/complex.py +++ b/silx/gui/plot/items/complex.py @@ -29,7 +29,7 @@ from __future__ import absolute_import __authors__ = ["Vincent Favre-Nicolin", "T. Vincent"] __license__ = "MIT" -__date__ = "19/01/2018" +__date__ = "14/06/2018" import logging @@ -37,7 +37,7 @@ import numpy from silx.third_party import enum -from ..Colormap import Colormap +from ...colors import Colormap from .core import ColormapMixIn, ItemChangedType from .image import ImageBase @@ -229,7 +229,7 @@ class ImageComplexData(ImageBase, ColormapMixIn): def setColormap(self, colormap, mode=None): """Set the colormap for this specific mode. - :param ~silx.gui.plot.Colormap.Colormap colormap: The colormap + :param ~silx.gui.colors.Colormap colormap: The colormap :param Mode mode: If specified, set the colormap of this specific mode. Default: current mode. @@ -249,7 +249,7 @@ class ImageComplexData(ImageBase, ColormapMixIn): :param Mode mode: If specified, get the colormap of this specific mode. Default: current mode. - :rtype: ~silx.gui.plot.Colormap.Colormap + :rtype: ~silx.gui.colors.Colormap """ if mode is None: mode = self.getVisualizationMode() diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py index bcb6dd1..4ed0914 100644 --- a/silx/gui/plot/items/core.py +++ b/silx/gui/plot/items/core.py @@ -27,18 +27,19 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "27/06/2017" +__date__ = "14/06/2018" import collections from copy import deepcopy import logging +import warnings import weakref import numpy from silx.third_party import six, enum from ... import qt -from .. import Colors -from ..Colormap import Colormap +from ... import colors +from ...colors import Colormap _logger = logging.getLogger(__name__) @@ -409,7 +410,7 @@ class ColormapMixIn(ItemMixInBase): def setColormap(self, colormap): """Set the colormap of this image - :param silx.gui.plot.Colormap.Colormap colormap: colormap description + :param silx.gui.colors.Colormap colormap: colormap description """ if isinstance(colormap, dict): colormap = Colormap._fromDict(colormap) @@ -619,17 +620,17 @@ class ColorMixIn(ItemMixInBase): :param color: color(s) to be used :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or - one of the predefined color names defined in Colors.py + one of the predefined color names defined in colors.py :param bool copy: True (Default) to get a copy, False to use internal representation (do not modify!) """ if isinstance(color, six.string_types): - color = Colors.rgba(color) + color = colors.rgba(color) else: color = numpy.array(color, copy=copy) # TODO more checks + improve color array support if color.ndim == 1: # Single RGBA color - color = Colors.rgba(color) + color = colors.rgba(color) else: # Array of colors assert color.ndim == 2 @@ -767,7 +768,10 @@ class Points(Item, SymbolMixIn, AlphaMixIn): error = numpy.ravel(error) # Supports error being scalar, N or 2xN array - errorClipped = (value - numpy.atleast_2d(error)[0]) <= 0 + valueMinusError = value - numpy.atleast_2d(error)[0] + errorClipped = numpy.isnan(valueMinusError) + mask = numpy.logical_not(errorClipped) + errorClipped[mask] = valueMinusError[mask] <= 0 if numpy.any(errorClipped): # Need filtering @@ -805,10 +809,20 @@ class Points(Item, SymbolMixIn, AlphaMixIn): """ assert xPositive or yPositive if (xPositive, yPositive) not in self._clippedCache: - x = self.getXData(copy=False) - y = self.getYData(copy=False) - xclipped = (x <= 0) if xPositive else False - yclipped = (y <= 0) if yPositive else False + xclipped, yclipped = False, False + + if xPositive: + x = self.getXData(copy=False) + with warnings.catch_warnings(): # Ignore NaN warnings + warnings.simplefilter('ignore', category=RuntimeWarning) + xclipped = x <= 0 + + if yPositive: + y = self.getYData(copy=False) + with warnings.catch_warnings(): # Ignore NaN warnings + warnings.simplefilter('ignore', category=RuntimeWarning) + yclipped = y <= 0 + self._clippedCache[(xPositive, yPositive)] = \ numpy.logical_or(xclipped, yclipped) return self._clippedCache[(xPositive, yPositive)] diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py index 0ba475d..50ad86d 100644 --- a/silx/gui/plot/items/curve.py +++ b/silx/gui/plot/items/curve.py @@ -27,13 +27,13 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "06/03/2017" +__date__ = "24/04/2018" import logging import numpy -from .. import Colors +from ... import colors from .core import (Points, LabelsMixIn, ColorMixIn, YAxisMixIn, FillMixIn, LineMixIn, ItemChangedType) @@ -170,9 +170,9 @@ class Curve(Points, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, LineMixIn): :param color: color(s) to be used for highlight :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or - one of the predefined color names defined in Colors.py + one of the predefined color names defined in colors.py """ - color = Colors.rgba(color) + color = colors.rgba(color) if color != self._highlightColor: self._highlightColor = color self._updated(ItemChangedType.HIGHLIGHTED_COLOR) diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py index ad89677..3545345 100644 --- a/silx/gui/plot/items/histogram.py +++ b/silx/gui/plot/items/histogram.py @@ -29,7 +29,6 @@ __authors__ = ["H. Payno", "T. Vincent"] __license__ = "MIT" __date__ = "27/06/2017" - import logging import numpy @@ -37,7 +36,6 @@ import numpy from .core import (Item, AlphaMixIn, ColorMixIn, FillMixIn, LineMixIn, YAxisMixIn, ItemChangedType) - _logger = logging.getLogger(__name__) @@ -290,5 +288,40 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, self._histogram = histogram self._edges = edges + self._alignement = align self._updated(ItemChangedType.DATA) + + def getAlignment(self): + """ + + :return: histogram alignement. Value in ('center', 'left', 'right'). + """ + return self._alignement + + def _revertComputeEdges(self, x, histogramType): + """Compute the edges from a set of xs and a rule to generate the edges + + :param x: the x value of the curve to transform into an histogram + :param histogramType: the type of histogram we wan't to generate. + This define the way to center the histogram values compared to the + curve value. Possible values can be:: + + - 'left' + - 'right' + - 'center' + + :return: the edges for the given x and the histogramType + """ + # for now we consider that the spaces between xs are constant + edges = x.copy() + if histogramType is 'left': + return edges[1:] + if histogramType is 'center': + edges = (edges[1:] + edges[:-1]) / 2.0 + if histogramType is 'right': + width = 1 + if len(x) > 1: + width = x[-1] + x[-2] + edges = edges[:-1] + return edges diff --git a/silx/gui/plot/items/roi.py b/silx/gui/plot/items/roi.py new file mode 100644 index 0000000..f55ef91 --- /dev/null +++ b/silx/gui/plot/items/roi.py @@ -0,0 +1,1416 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides ROI item for the :class:`~silx.gui.plot.PlotWidget`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import functools +import itertools +import logging +import collections +import numpy + +from ....utils.weakref import WeakList +from ... import qt +from .. import items +from ...colors import rgba + + +logger = logging.getLogger(__name__) + + +class RegionOfInterest(qt.QObject): + """Object describing a region of interest in a plot. + + :param QObject parent: + The RegionOfInterestManager that created this object + """ + + _kind = None + """Label for this kind of ROI. + + Should be setted by inherited classes to custom the ROI manager widget. + """ + + sigRegionChanged = qt.Signal() + """Signal emitted everytime the shape or position of the ROI changes""" + + def __init__(self, parent=None): + # Avoid circular dependancy + from ..tools import roi as roi_tools + assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager) + super(RegionOfInterest, self).__init__(parent) + self._color = rgba('red') + self._items = WeakList() + self._editAnchors = WeakList() + self._points = None + self._label = '' + self._labelItem = None + self._editable = False + + def __del__(self): + # Clean-up plot items + self._removePlotItems() + + def setParent(self, parent): + """Set the parent of the RegionOfInterest + + :param Union[None,RegionOfInterestManager] parent: + """ + # Avoid circular dependancy + from ..tools import roi as roi_tools + if (parent is not None and not isinstance(parent, roi_tools.RegionOfInterestManager)): + raise ValueError('Unsupported parent') + + self._removePlotItems() + super(RegionOfInterest, self).setParent(parent) + self._createPlotItems() + + @classmethod + def _getKind(cls): + """Return an human readable kind of ROI + + :rtype: str + """ + return cls._kind + + def getColor(self): + """Returns the color of this ROI + + :rtype: QColor + """ + return qt.QColor.fromRgbF(*self._color) + + def _getAnchorColor(self, color): + """Returns the anchor color from the base ROI color + + :param Union[numpy.array,Tuple,List]: color + :rtype: Union[numpy.array,Tuple,List] + """ + return color[:3] + (0.5,) + + def setColor(self, color): + """Set the color used for this ROI. + + :param color: The color to use for ROI shape as + either a color name, a QColor, a list of uint8 or float in [0, 1]. + """ + color = rgba(color) + if color != self._color: + self._color = color + + # Update color of shape items in the plot + rgbaColor = rgba(color) + for item in list(self._items): + if isinstance(item, items.ColorMixIn): + item.setColor(rgbaColor) + item = self._getLabelItem() + if isinstance(item, items.ColorMixIn): + item.setColor(rgbaColor) + + rgbaColor = self._getAnchorColor(rgbaColor) + for item in list(self._editAnchors): + if isinstance(item, items.ColorMixIn): + item.setColor(rgbaColor) + + def getLabel(self): + """Returns the label displayed for this ROI. + + :rtype: str + """ + return self._label + + def setLabel(self, label): + """Set the label displayed with this ROI. + + :param str label: The text label to display + """ + label = str(label) + if label != self._label: + self._label = label + self._updateLabelItem(label) + + def isEditable(self): + """Returns whether the ROI is editable by the user or not. + + :rtype: bool + """ + return self._editable + + def setEditable(self, editable): + """Set whether the ROI can be changed interactively. + + :param bool editable: True to allow edition by the user, + False to disable. + """ + editable = bool(editable) + if self._editable != editable: + self._editable = editable + # Recreate plot items + # This can be avoided once marker.setDraggable is public + self._createPlotItems() + + def _getControlPoints(self): + """Returns the current ROI control points. + + It returns an empty tuple if there is currently no ROI. + + :return: Array of (x, y) position in plot coordinates + :rtype: numpy.ndarray + """ + return None if self._points is None else numpy.array(self._points) + + @classmethod + def showFirstInteractionShape(cls): + """Returns True if the shape created by the first interaction and + managed by the plot have to be visible. + + :rtype: bool + """ + return True + + @classmethod + def getFirstInteractionShape(cls): + """Returns the shape kind which will be used by the very first + interaction with the plot. + + This interactions are hardcoded inside the plot + + :rtype: str + """ + return cls._plotShape + + def setFirstShapePoints(self, points): + """"Initialize the ROI using the points from the first interaction. + + This interaction is constains by the plot API and only supports few + shapes. + """ + points = self._createControlPointsFromFirstShape(points) + self._setControlPoints(points) + + def _createControlPointsFromFirstShape(self, points): + """Returns the list of control points from the very first shape + provided. + + This shape is provided by the plot interaction and constained by the + class of the ROI itself. + """ + return points + + def _setControlPoints(self, points): + """Set this ROI control points. + + :param points: Iterable of (x, y) control points + """ + points = numpy.array(points) + + nbPointsChanged = (self._points is None or + points.shape != self._points.shape) + + if nbPointsChanged or not numpy.all(numpy.equal(points, self._points)): + self._points = points + + self._updateShape() + if self._items and not nbPointsChanged: # Update plot items + item = self._getLabelItem() + if item is not None: + markerPos = self._getLabelPosition() + item.setPosition(*markerPos) + + if self._editAnchors: # Update anchors + for anchor, point in zip(self._editAnchors, points): + old = anchor.blockSignals(True) + anchor.setPosition(*point) + anchor.blockSignals(old) + + else: # No items or new point added + # re-create plot items + self._createPlotItems() + + self.sigRegionChanged.emit() + + def _updateShape(self): + """Called when shape must be updated. + + Must be reimplemented if a shape item have to be updated. + """ + return + + def _getLabelPosition(self): + """Compute position of the label + + :return: (x, y) position of the marker + """ + return None + + def _createPlotItems(self): + """Create items displaying the ROI in the plot. + + It first removes any existing plot items. + """ + roiManager = self.parent() + if roiManager is None: + return + plot = roiManager.parent() + + self._removePlotItems() + + legendPrefix = "__RegionOfInterest-%d__" % id(self) + itemIndex = 0 + + controlPoints = self._getControlPoints() + + if self._labelItem is None: + self._labelItem = self._createLabelItem() + if self._labelItem is not None: + self._labelItem._setLegend(legendPrefix + "label") + plot._add(self._labelItem) + + self._items = WeakList() + plotItems = self._createShapeItems(controlPoints) + for item in plotItems: + item._setLegend(legendPrefix + str(itemIndex)) + plot._add(item) + self._items.append(item) + itemIndex += 1 + + self._editAnchors = WeakList() + if self.isEditable(): + plotItems = self._createAnchorItems(controlPoints) + color = rgba(self.getColor()) + color = self._getAnchorColor(color) + for index, item in enumerate(plotItems): + item._setLegend(legendPrefix + str(itemIndex)) + item.setColor(color) + plot._add(item) + item.sigItemChanged.connect(functools.partial( + self._controlPointAnchorChanged, index)) + self._editAnchors.append(item) + itemIndex += 1 + + def _updateLabelItem(self, label): + """Update the marker displaying the label. + + Inherite this method to custom the way the ROI display the label. + + :param str label: The new label to use + """ + item = self._getLabelItem() + if item is not None: + item.setText(label) + + def _createLabelItem(self): + """Returns a created marker which will be used to dipslay the label of + this ROI. + + Inherite this method to return nothing if no new items have to be + created, or your own marker. + + :rtype: Union[None,Marker] + """ + # Add label marker + markerPos = self._getLabelPosition() + marker = items.Marker() + marker.setPosition(*markerPos) + marker.setText(self.getLabel()) + marker.setColor(rgba(self.getColor())) + marker.setSymbol('') + marker._setDraggable(False) + return marker + + def _getLabelItem(self): + """Returns the marker displaying the label of this ROI. + + Inherite this method to choose your own item. In case this item is also + a control point. + """ + return self._labelItem + + def _createShapeItems(self, points): + """Create shape items from the current control points. + + :rtype: List[PlotItem] + """ + return [] + + def _createAnchorItems(self, points): + """Create anchor items from the current control points. + + :rtype: List[Marker] + """ + return [] + + def _controlPointAnchorChanged(self, index, event): + """Handle update of position of an edition anchor + + :param int index: Index of the anchor + :param ItemChangedType event: Event type + """ + if event == items.ItemChangedType.POSITION: + anchor = self._editAnchors[index] + previous = self._points[index].copy() + current = anchor.getPosition() + self._controlPointAnchorPositionChanged(index, current, previous) + + def _controlPointAnchorPositionChanged(self, index, current, previous): + """Called when an anchor is manually edited. + + This function have to be inherited to change the behaviours of the + control points. This function have to call :meth:`_getControlPoints` to + reach the previous state of the control points. Updated the positions + of the changed control points. Then call :meth:`_setControlPoints` to + update the anchors and send signals. + """ + points = self._getControlPoints() + points[index] = current + self._setControlPoints(points) + + def _removePlotItems(self): + """Remove items from their plot.""" + for item in itertools.chain(list(self._items), + list(self._editAnchors)): + + plot = item.getPlot() + if plot is not None: + plot._remove(item) + self._items = WeakList() + self._editAnchors = WeakList() + + if self._labelItem is not None: + item = self._labelItem + plot = item.getPlot() + if plot is not None: + plot._remove(item) + self._labelItem = None + + def __str__(self): + """Returns parameters of the ROI as a string.""" + points = self._getControlPoints() + params = '; '.join('(%f; %f)' % (pt[0], pt[1]) for pt in points) + return "%s(%s)" % (self.__class__.__name__, params) + + +class PointROI(RegionOfInterest): + """A ROI identifying a point in a 2D plot.""" + + _kind = "Point" + """Label for this kind of ROI""" + + _plotShape = "point" + """Plot shape which is used for the first interaction""" + + def getPosition(self): + """Returns the position of this ROI + + :rtype: numpy.ndarray + """ + return self._points[0].copy() + + def setPosition(self, pos): + """Set the position of this ROI + + :param numpy.ndarray pos: 2d-coordinate of this point + """ + controlPoints = numpy.array([pos]) + self._setControlPoints(controlPoints) + + def _createLabelItem(self): + return None + + def _updateLabelItem(self, label): + if self.isEditable(): + item = self._editAnchors[0] + else: + item = self._items[0] + item.setText(label) + + def _createShapeItems(self, points): + if self.isEditable(): + return [] + marker = items.Marker() + marker.setPosition(points[0][0], points[0][1]) + marker.setText(self.getLabel()) + marker.setColor(rgba(self.getColor())) + marker._setDraggable(False) + return [marker] + + def _createAnchorItems(self, points): + marker = items.Marker() + marker.setPosition(points[0][0], points[0][1]) + marker.setText(self.getLabel()) + marker._setDraggable(self.isEditable()) + return [marker] + + def __str__(self): + points = self._getControlPoints() + params = '%f %f' % (points[0, 0], points[0, 1]) + return "%s(%s)" % (self.__class__.__name__, params) + + +class LineROI(RegionOfInterest): + """A ROI identifying a line in a 2D plot. + + This ROI provides 1 anchor for each boundary of the line, plus an center + in the center to translate the full ROI. + """ + + _kind = "Line" + """Label for this kind of ROI""" + + _plotShape = "line" + """Plot shape which is used for the first interaction""" + + def _createControlPointsFromFirstShape(self, points): + center = numpy.mean(points, axis=0) + controlPoints = numpy.array([points[0], points[1], center]) + return controlPoints + + def setEndPoints(self, startPoint, endPoint): + """Set this line location using the endding points + + :param numpy.ndarray startPoint: Staring bounding point of the line + :param numpy.ndarray endPoint: Endding bounding point of the line + """ + assert(startPoint.shape == (2,) and endPoint.shape == (2,)) + shapePoints = numpy.array([startPoint, endPoint]) + controlPoints = self._createControlPointsFromFirstShape(shapePoints) + self._setControlPoints(controlPoints) + + def getEndPoints(self): + """Returns bounding points of this ROI. + + :rtype: Tuple(numpy.ndarray,numpy.ndarray) + """ + startPoint = self._points[0].copy() + endPoint = self._points[1].copy() + return (startPoint, endPoint) + + def _getLabelPosition(self): + points = self._getControlPoints() + return points[-1] + + def _updateShape(self): + if len(self._items) == 0: + return + shape = self._items[0] + points = self._getControlPoints() + points = self._getShapeFromControlPoints(points) + shape.setPoints(points) + + def _getShapeFromControlPoints(self, points): + # Remove the center from the control points + return points[0:2] + + def _createShapeItems(self, points): + shapePoints = self._getShapeFromControlPoints(points) + item = items.Shape("polylines") + item.setPoints(shapePoints) + item.setColor(rgba(self.getColor())) + item.setFill(False) + item.setOverlay(True) + return [item] + + def _createAnchorItems(self, points): + anchors = [] + for point in points[0:-1]: + anchor = items.Marker() + anchor.setPosition(*point) + anchor.setText('') + anchor.setSymbol('s') + anchor._setDraggable(True) + anchors.append(anchor) + + # Add an anchor to the center of the rectangle + center = numpy.mean(points, axis=0) + anchor = items.Marker() + anchor.setPosition(*center) + anchor.setText('') + anchor.setSymbol('+') + anchor._setDraggable(True) + anchors.append(anchor) + + return anchors + + def _controlPointAnchorPositionChanged(self, index, current, previous): + if index == len(self._editAnchors) - 1: + # It is the center anchor + points = self._getControlPoints() + center = numpy.mean(points[0:-1], axis=0) + offset = current - previous + points[-1] = current + points[0:-1] = points[0:-1] + offset + self._setControlPoints(points) + else: + # Update the center + points = self._getControlPoints() + points[index] = current + center = numpy.mean(points[0:-1], axis=0) + points[-1] = center + self._setControlPoints(points) + + def __str__(self): + points = self._getControlPoints() + params = points[0][0], points[0][1], points[1][0], points[1][1] + params = 'start: %f %f; end: %f %f' % params + return "%s(%s)" % (self.__class__.__name__, params) + + +class HorizontalLineROI(RegionOfInterest): + """A ROI identifying an horizontal line in a 2D plot.""" + + _kind = "HLine" + """Label for this kind of ROI""" + + _plotShape = "hline" + """Plot shape which is used for the first interaction""" + + def _createControlPointsFromFirstShape(self, points): + points = numpy.array([(float('nan'), points[0, 1])], + dtype=numpy.float64) + return points + + def getPosition(self): + """Returns the position of this line if the horizontal axis + + :rtype: float + """ + return self._points[0, 1] + + def setPosition(self, pos): + """Set the position of this ROI + + :param float pos: Horizontal position of this line + """ + controlPoints = numpy.array([[float('nan'), pos]]) + self._setControlPoints(controlPoints) + + def _createLabelItem(self): + return None + + def _updateLabelItem(self, label): + if self.isEditable(): + item = self._editAnchors[0] + else: + item = self._items[0] + item.setText(label) + + def _updateShape(self): + if not self.isEditable(): + if len(self._items) > 0: + controlPoints = self._getControlPoints() + item = self._items[0] + item.setPosition(*controlPoints[0]) + + def _createShapeItems(self, points): + if self.isEditable(): + return [] + marker = items.YMarker() + marker.setPosition(points[0][0], points[0][1]) + marker.setText(self.getLabel()) + marker.setColor(rgba(self.getColor())) + marker._setDraggable(False) + return [marker] + + def _createAnchorItems(self, points): + marker = items.YMarker() + marker.setPosition(points[0][0], points[0][1]) + marker.setText(self.getLabel()) + marker._setDraggable(self.isEditable()) + return [marker] + + def __str__(self): + points = self._getControlPoints() + params = 'y: %f' % points[0, 1] + return "%s(%s)" % (self.__class__.__name__, params) + + +class VerticalLineROI(RegionOfInterest): + """A ROI identifying a vertical line in a 2D plot.""" + + _kind = "VLine" + """Label for this kind of ROI""" + + _plotShape = "vline" + """Plot shape which is used for the first interaction""" + + def _createControlPointsFromFirstShape(self, points): + points = numpy.array([(points[0, 0], float('nan'))], + dtype=numpy.float64) + return points + + def getPosition(self): + """Returns the position of this line if the horizontal axis + + :rtype: float + """ + return self._points[0, 0] + + def setPosition(self, pos): + """Set the position of this ROI + + :param float pos: Horizontal position of this line + """ + controlPoints = numpy.array([[pos, float('nan')]]) + self._setControlPoints(controlPoints) + + def _createLabelItem(self): + return None + + def _updateLabelItem(self, label): + if self.isEditable(): + item = self._editAnchors[0] + else: + item = self._items[0] + item.setText(label) + + def _updateShape(self): + if not self.isEditable(): + if len(self._items) > 0: + controlPoints = self._getControlPoints() + item = self._items[0] + item.setPosition(*controlPoints[0]) + + def _createShapeItems(self, points): + if self.isEditable(): + return [] + marker = items.XMarker() + marker.setPosition(points[0][0], points[0][1]) + marker.setText(self.getLabel()) + marker.setColor(rgba(self.getColor())) + marker._setDraggable(False) + return [marker] + + def _createAnchorItems(self, points): + marker = items.XMarker() + marker.setPosition(points[0][0], points[0][1]) + marker.setText(self.getLabel()) + marker._setDraggable(self.isEditable()) + return [marker] + + def __str__(self): + points = self._getControlPoints() + params = 'x: %f' % points[0, 0] + return "%s(%s)" % (self.__class__.__name__, params) + + +class RectangleROI(RegionOfInterest): + """A ROI identifying a rectangle in a 2D plot. + + This ROI provides 1 anchor for each corner, plus an anchor in the + center to translate the full ROI. + """ + + _kind = "Rectangle" + """Label for this kind of ROI""" + + _plotShape = "rectangle" + """Plot shape which is used for the first interaction""" + + def _createControlPointsFromFirstShape(self, points): + point0 = points[0] + point1 = points[1] + + # 4 corners + controlPoints = numpy.array([ + point0[0], point0[1], + point0[0], point1[1], + point1[0], point1[1], + point1[0], point0[1], + ]) + # Central + center = numpy.mean(points, axis=0) + controlPoints = numpy.append(controlPoints, center) + controlPoints.shape = -1, 2 + return controlPoints + + def getCenter(self): + """Returns the central point of this rectangle + + :rtype: numpy.ndarray([float,float]) + """ + return numpy.mean(self._points, axis=0) + + def getOrigin(self): + """Returns the corner point with the smaller coordinates + + :rtype: numpy.ndarray([float,float]) + """ + return numpy.min(self._points, axis=0) + + def getSize(self): + """Returns the size of this rectangle + + :rtype: numpy.ndarray([float,float]) + """ + minPoint = numpy.min(self._points, axis=0) + maxPoint = numpy.max(self._points, axis=0) + return maxPoint - minPoint + + def setOrigin(self, position): + """Set the origin position of this ROI + + :param numpy.ndarray position: Location of the smaller corner of the ROI + """ + size = self.getSize() + self.setGeometry(origin=position, size=size) + + def setSize(self, size): + """Set the size of this ROI + + :param numpy.ndarray size: Size of the center of the ROI + """ + origin = self.getOrigin() + self.setGeometry(origin=origin, size=size) + + def setCenter(self, position): + """Set the size of this ROI + + :param numpy.ndarray position: Location of the center of the ROI + """ + size = self.getSize() + self.setGeometry(center=position, size=size) + + def setGeometry(self, origin=None, size=None, center=None): + """Set the geometry of the ROI + """ + if origin is not None: + origin = numpy.array(origin) + size = numpy.array(size) + points = numpy.array([origin, origin + size]) + controlPoints = self._createControlPointsFromFirstShape(points) + elif center is not None: + center = numpy.array(center) + size = numpy.array(size) + points = numpy.array([center - size * 0.5, center + size * 0.5]) + controlPoints = self._createControlPointsFromFirstShape(points) + else: + raise ValueError("Origin or cengter expected") + self._setControlPoints(controlPoints) + + def _getLabelPosition(self): + points = self._getControlPoints() + return points.min(axis=0) + + def _updateShape(self): + if len(self._items) == 0: + return + shape = self._items[0] + points = self._getControlPoints() + points = self._getShapeFromControlPoints(points) + shape.setPoints(points) + + def _getShapeFromControlPoints(self, points): + minPoint = points.min(axis=0) + maxPoint = points.max(axis=0) + return numpy.array([minPoint, maxPoint]) + + def _createShapeItems(self, points): + shapePoints = self._getShapeFromControlPoints(points) + item = items.Shape("rectangle") + item.setPoints(shapePoints) + item.setColor(rgba(self.getColor())) + item.setFill(False) + item.setOverlay(True) + return [item] + + def _createAnchorItems(self, points): + # Remove the center control point + points = points[0:-1] + + anchors = [] + for point in points: + anchor = items.Marker() + anchor.setPosition(*point) + anchor.setText('') + anchor.setSymbol('s') + anchor._setDraggable(True) + anchors.append(anchor) + + # Add an anchor to the center of the rectangle + center = numpy.mean(points, axis=0) + anchor = items.Marker() + anchor.setPosition(*center) + anchor.setText('') + anchor.setSymbol('+') + anchor._setDraggable(True) + anchors.append(anchor) + + return anchors + + def _controlPointAnchorPositionChanged(self, index, current, previous): + if index == len(self._editAnchors) - 1: + # It is the center anchor + points = self._getControlPoints() + center = numpy.mean(points[0:-1], axis=0) + offset = current - previous + points[-1] = current + points[0:-1] = points[0:-1] + offset + self._setControlPoints(points) + else: + # Fix other corners + constrains = [(1, 3), (0, 2), (3, 1), (2, 0)] + constrains = constrains[index] + points = self._getControlPoints() + points[index] = current + points[constrains[0]][0] = current[0] + points[constrains[1]][1] = current[1] + # Update the center + center = numpy.mean(points[0:-1], axis=0) + points[-1] = center + self._setControlPoints(points) + + def __str__(self): + origin = self.getOrigin() + w, h = self.getSize() + params = origin[0], origin[1], w, h + params = 'origin: %f %f; width: %f; height: %f' % params + return "%s(%s)" % (self.__class__.__name__, params) + + +class PolygonROI(RegionOfInterest): + """A ROI identifying a closed polygon in a 2D plot. + + This ROI provides 1 anchor for each point of the polygon. + """ + + _kind = "Polygon" + """Label for this kind of ROI""" + + _plotShape = "polygon" + """Plot shape which is used for the first interaction""" + + def getPoints(self): + """Returns the list of the points of this polygon. + + :rtype: numpy.ndarray + """ + return self._points.copy() + + def setPoints(self, points): + """Set the position of this ROI + + :param numpy.ndarray pos: 2d-coordinate of this point + """ + assert(len(points.shape) == 2 and points.shape[1] == 2) + if len(points) > 0: + controlPoints = numpy.array(points) + else: + controlPoints = numpy.empty((0, 2)) + self._setControlPoints(controlPoints) + + def _getLabelPosition(self): + points = self._getControlPoints() + if len(points) == 0: + # FIXME: we should return none, this polygon have no location + return numpy.array([0, 0]) + return points[numpy.argmin(points[:, 1])] + + def _updateShape(self): + if len(self._items) == 0: + return + shape = self._items[0] + points = self._getControlPoints() + shape.setPoints(points) + + def _createShapeItems(self, points): + if len(points) == 0: + return [] + else: + item = items.Shape("polygon") + item.setPoints(points) + item.setColor(rgba(self.getColor())) + item.setFill(False) + item.setOverlay(True) + return [item] + + def _createAnchorItems(self, points): + anchors = [] + for point in points: + anchor = items.Marker() + anchor.setPosition(*point) + anchor.setText('') + anchor.setSymbol('s') + anchor._setDraggable(True) + anchors.append(anchor) + return anchors + + def __str__(self): + points = self._getControlPoints() + params = '; '.join('%f %f' % (pt[0], pt[1]) for pt in points) + return "%s(%s)" % (self.__class__.__name__, params) + + +class ArcROI(RegionOfInterest): + """A ROI identifying an arc of a circle with a width. + + This ROI provides 3 anchors to control the curvature, 1 anchor to control + the weigth, and 1 anchor to translate the shape. + """ + + _kind = "Arc" + """Label for this kind of ROI""" + + _plotShape = "line" + """Plot shape which is used for the first interaction""" + + _ArcGeometry = collections.namedtuple('ArcGeometry', ['center', + 'startPoint', 'endPoint', + 'radius', 'weight', + 'startAngle', 'endAngle']) + + def __init__(self, parent=None): + RegionOfInterest.__init__(self, parent=parent) + self._geometry = None + + def _getInternalGeometry(self): + """Returns the object storing the internal geometry of this ROI. + + This geometry is derived from the control points and cached for + efficiency. Calling :meth:`_setControlPoints` invalidate the cache. + """ + if self._geometry is None: + controlPoints = self._getControlPoints() + self._geometry = self._createGeometryFromControlPoint(controlPoints) + return self._geometry + + @classmethod + def showFirstInteractionShape(cls): + return False + + def _getLabelPosition(self): + points = self._getControlPoints() + return points.min(axis=0) + + def _updateShape(self): + if len(self._items) == 0: + return + shape = self._items[0] + points = self._getControlPoints() + points = self._getShapeFromControlPoints(points) + shape.setPoints(points) + + def _controlPointAnchorPositionChanged(self, index, current, previous): + controlPoints = self._getControlPoints() + currentWeigth = numpy.linalg.norm(controlPoints[3] - controlPoints[1]) * 2 + + if index in [0, 2]: + # Moving start or end will maintain the same curvature + # Then we have to custom the curvature control point + startPoint = controlPoints[0] + endPoint = controlPoints[2] + center = (startPoint + endPoint) * 0.5 + normal = (endPoint - startPoint) + normal = numpy.array((normal[1], -normal[0])) + distance = numpy.linalg.norm(normal) + # Compute the coeficient which have to be constrained + if distance != 0: + normal /= distance + midVector = controlPoints[1] - center + constainedCoef = numpy.dot(midVector, normal) / distance + else: + constainedCoef = 1.0 + + # Compute the location of the curvature point + controlPoints[index] = current + startPoint = controlPoints[0] + endPoint = controlPoints[2] + center = (startPoint + endPoint) * 0.5 + normal = (endPoint - startPoint) + normal = numpy.array((normal[1], -normal[0])) + distance = numpy.linalg.norm(normal) + if distance != 0: + # BTW we dont need to divide by the distance here + # Cause we compute normal * distance after all + normal /= distance + midPoint = center + normal * constainedCoef * distance + controlPoints[1] = midPoint + + # The weight have to be fixed + self._updateWeightControlPoint(controlPoints, currentWeigth) + self._setControlPoints(controlPoints) + + elif index == 1: + # The weight have to be fixed + controlPoints[index] = current + self._updateWeightControlPoint(controlPoints, currentWeigth) + self._setControlPoints(controlPoints) + else: + super(ArcROI, self)._controlPointAnchorPositionChanged(index, current, previous) + + def _updateWeightControlPoint(self, controlPoints, weigth): + startPoint = controlPoints[0] + midPoint = controlPoints[1] + endPoint = controlPoints[2] + normal = (endPoint - startPoint) + normal = numpy.array((normal[1], -normal[0])) + distance = numpy.linalg.norm(normal) + if distance != 0: + normal /= distance + controlPoints[3] = midPoint + normal * weigth * 0.5 + + def _createGeometryFromControlPoint(self, controlPoints): + """Returns the geometry of the object""" + weigth = numpy.linalg.norm(controlPoints[3] - controlPoints[1]) * 2 + if numpy.allclose(controlPoints[0], controlPoints[2]): + # Special arc: It's a closed circle + center = (controlPoints[0] + controlPoints[1]) * 0.5 + radius = numpy.linalg.norm(controlPoints[0] - center) + v = controlPoints[0] - center + startAngle = numpy.angle(complex(v[0], v[1])) + endAngle = startAngle + numpy.pi * 2.0 + return self._ArcGeometry(center, controlPoints[0], controlPoints[2], + radius, weigth, startAngle, endAngle) + + elif numpy.linalg.norm( + numpy.cross(controlPoints[1] - controlPoints[0], + controlPoints[2] - controlPoints[0])) < 1e-5: + # Degenerated arc, it's a rectangle + return self._ArcGeometry(None, controlPoints[0], controlPoints[2], + None, weigth, None, None) + else: + center, radius = self._circleEquation(*controlPoints[:3]) + v = controlPoints[0] - center + startAngle = numpy.angle(complex(v[0], v[1])) + v = controlPoints[1] - center + midAngle = numpy.angle(complex(v[0], v[1])) + v = controlPoints[2] - center + endAngle = numpy.angle(complex(v[0], v[1])) + # Is it clockwise or anticlockwise + if (midAngle - startAngle + 2 * numpy.pi) % (2 * numpy.pi) <= numpy.pi: + if endAngle < startAngle: + endAngle += 2 * numpy.pi + else: + if endAngle > startAngle: + endAngle -= 2 * numpy.pi + + return self._ArcGeometry(center, controlPoints[0], controlPoints[2], + radius, weigth, startAngle, endAngle) + + def _isCircle(self, geometry): + """Returns True if the geometry is a closed circle""" + delta = numpy.abs(geometry.endAngle - geometry.startAngle) + return numpy.isclose(delta, numpy.pi * 2) + + def _getShapeFromControlPoints(self, controlPoints): + geometry = self._createGeometryFromControlPoint(controlPoints) + if geometry.center is None: + # It is not an arc + # but we can display it as an the intermediat shape + normal = (geometry.endPoint - geometry.startPoint) + normal = numpy.array((normal[1], -normal[0])) + distance = numpy.linalg.norm(normal) + if distance != 0: + normal /= distance + points = numpy.array([ + geometry.startPoint + normal * geometry.weight * 0.5, + geometry.endPoint + normal * geometry.weight * 0.5, + geometry.endPoint - normal * geometry.weight * 0.5, + geometry.startPoint - normal * geometry.weight * 0.5]) + else: + innerRadius = geometry.radius - geometry.weight * 0.5 + outerRadius = geometry.radius + geometry.weight * 0.5 + + if numpy.isnan(geometry.startAngle): + # Degenerated, it's a point + # At least 2 points are expected + return numpy.array([geometry.startPoint, geometry.startPoint]) + + delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1 + if geometry.startAngle == geometry.endAngle: + # Degenerated, it's a line (single radius) + angle = geometry.startAngle + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points = [] + points.append(geometry.center + direction * innerRadius) + points.append(geometry.center + direction * outerRadius) + return numpy.array(points) + + angles = numpy.arange(geometry.startAngle, geometry.endAngle, delta) + if angles[-1] != geometry.endAngle: + angles = numpy.append(angles, geometry.endAngle) + + isCircle = self._isCircle(geometry) + + if isCircle: + if innerRadius <= 0: + # It's a circle + points = [] + numpy.append(angles, angles[-1]) + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.append(geometry.center + direction * outerRadius) + else: + # It's a donut + points = [] + # NOTE: NaN value allow to create 2 separated circle shapes + # using a single plot item. It's a kind of cheat + points.append(numpy.array([float("nan"), float("nan")])) + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.insert(0, geometry.center + direction * innerRadius) + points.append(geometry.center + direction * outerRadius) + points.append(numpy.array([float("nan"), float("nan")])) + else: + if innerRadius <= 0: + # It's a part of camembert + points = [] + points.append(geometry.center) + points.append(geometry.startPoint) + delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1 + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.append(geometry.center + direction * outerRadius) + points.append(geometry.endPoint) + points.append(geometry.center) + else: + # It's a part of donut + points = [] + points.append(geometry.startPoint) + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.insert(0, geometry.center + direction * innerRadius) + points.append(geometry.center + direction * outerRadius) + points.insert(0, geometry.endPoint) + points.append(geometry.endPoint) + points = numpy.array(points) + + return points + + def _setControlPoints(self, points): + # Invalidate the geometry + self._geometry = None + RegionOfInterest._setControlPoints(self, points) + + def getGeometry(self): + """Returns a tuple containing the geometry of this ROI + + It is a symetric fonction of :meth:`setGeometry`. + + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + + :rtype: Tuple[numpy.ndarray,float,float,float,float] + :raise ValueError: In case the ROI can't be representaed as section of + a circle + """ + geometry = self._getInternalGeometry() + if geometry.center is None: + raise ValueError("This ROI can't be represented as a section of circle") + return geometry.center, self.getInnerRadius(), self.getOuterRadius(), geometry.startAngle, geometry.endAngle + + def isClosed(self): + """Returns true if the arc is a closed shape, like a circle or a donut. + + :rtype: bool + """ + geometry = self._getInternalGeometry() + return self._isCircle(geometry) + + def getCenter(self): + """Returns the center of the circle used to draw arcs of this ROI. + + This center is usually outside the the shape itself. + + :rtype: numpy.ndarray + """ + geometry = self._getInternalGeometry() + return geometry.center + + def getStartAngle(self): + """Returns the angle of the start of the section of this ROI (in radian). + + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + + :rtype: float + """ + geometry = self._getInternalGeometry() + return geometry.startAngle + + def getEndAngle(self): + """Returns the angle of the end of the section of this ROI (in radian). + + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + + :rtype: float + """ + geometry = self._getInternalGeometry() + return geometry.endAngle + + def getInnerRadius(self): + """Returns the radius of the smaller arc used to draw this ROI. + + :rtype: float + """ + geometry = self._getInternalGeometry() + radius = geometry.radius - geometry.weight * 0.5 + if radius < 0: + radius = 0 + return radius + + def getOuterRadius(self): + """Returns the radius of the bigger arc used to draw this ROI. + + :rtype: float + """ + geometry = self._getInternalGeometry() + radius = geometry.radius + geometry.weight * 0.5 + return radius + + def setGeometry(self, center, innerRadius, outerRadius, startAngle, endAngle): + """ + Set the geometry of this arc. + + :param numpy.ndarray center: Center of the circle. + :param float innerRadius: Radius of the smaller arc of the section. + :param float outerRadius: Weight of the bigger arc of the section. + It have to be bigger than `innerRadius` + :param float startAngle: Location of the start of the section (in radian) + :param float endAngle: Location of the end of the section (in radian). + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + """ + assert(innerRadius <= outerRadius) + assert(numpy.abs(startAngle - endAngle) <= 2 * numpy.pi) + center = numpy.array(center) + radius = (innerRadius + outerRadius) * 0.5 + weight = outerRadius - innerRadius + geometry = self._ArcGeometry(center, None, None, radius, weight, startAngle, endAngle) + controlPoints = self._createControlPointsFromGeometry(geometry) + self._setControlPoints(controlPoints) + + def _createControlPointsFromGeometry(self, geometry): + if geometry.startPoint or geometry.endPoint: + # Duplication with the angles + raise NotImplementedError("This general case is not implemented") + + angle = geometry.startAngle + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + startPoint = geometry.center + direction * geometry.radius + + angle = geometry.endAngle + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + endPoint = geometry.center + direction * geometry.radius + + angle = (geometry.startAngle + geometry.endAngle) * 0.5 + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + curvaturePoint = geometry.center + direction * geometry.radius + weightPoint = curvaturePoint + direction * geometry.weight * 0.5 + + return numpy.array([startPoint, curvaturePoint, endPoint, weightPoint]) + + def _createControlPointsFromFirstShape(self, points): + # The first shape is a line + point0 = points[0] + point1 = points[1] + + # Compute a non colineate point for the curvature + center = (point1 + point0) * 0.5 + normal = point1 - center + normal = numpy.array((normal[1], -normal[0])) + defaultCurvature = numpy.pi / 5.0 + defaultWeight = 0.20 # percentage + curvaturePoint = center - normal * defaultCurvature + weightPoint = center - normal * defaultCurvature * (1.0 + defaultWeight) + + # 3 corners + controlPoints = numpy.array([ + point0, + curvaturePoint, + point1, + weightPoint + ]) + return controlPoints + + def _createShapeItems(self, points): + shapePoints = self._getShapeFromControlPoints(points) + item = items.Shape("polygon") + item.setPoints(shapePoints) + item.setColor(rgba(self.getColor())) + item.setFill(False) + item.setOverlay(True) + return [item] + + def _createAnchorItems(self, points): + anchors = [] + symbols = ['o', 'o', 'o', 's'] + + for index, point in enumerate(points): + if index in [1, 3]: + constraint = self._arcCurvatureMarkerConstraint + else: + constraint = None + anchor = items.Marker() + anchor.setPosition(*point) + anchor.setText('') + anchor.setSymbol(symbols[index]) + anchor._setDraggable(True) + if constraint is not None: + anchor._setConstraint(constraint) + anchors.append(anchor) + + return anchors + + def _arcCurvatureMarkerConstraint(self, x, y): + """Curvature marker remains on "mediatrice" """ + start = self._points[0] + end = self._points[2] + midPoint = (start + end) / 2. + normal = (end - start) + normal = numpy.array((normal[1], -normal[0])) + distance = numpy.linalg.norm(normal) + if distance != 0: + normal /= distance + v = numpy.dot(normal, (numpy.array((x, y)) - midPoint)) + x, y = midPoint + v * normal + return x, y + + @staticmethod + def _circleEquation(pt1, pt2, pt3): + """Circle equation from 3 (x, y) points + + :return: Position of the center of the circle and the radius + :rtype: Tuple[Tuple[float,float],float] + """ + x, y, z = complex(*pt1), complex(*pt2), complex(*pt3) + w = z - x + w /= y - x + c = (x - y) * (w - abs(w) ** 2) / 2j / w.imag - x + return ((-c.real, -c.imag), abs(c + x)) + + def __str__(self): + try: + center, innerRadius, outerRadius, startAngle, endAngle = self.getGeometry() + params = center[0], center[1], innerRadius, outerRadius, startAngle, endAngle + params = 'center: %f %f; radius: %f %f; angles: %f %f' % params + except ValueError: + params = "invalid" + return "%s(%s)" % (self.__class__.__name__, params) diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py index 98ed473..72b8496 100644 --- a/silx/gui/plot/items/scatter.py +++ b/silx/gui/plot/items/scatter.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -42,6 +42,10 @@ _logger = logging.getLogger(__name__) class Scatter(Points, ColormapMixIn): """Description of a scatter""" + + _DEFAULT_SELECTABLE = True + """Default selectable state for scatter plots""" + _DEFAULT_SYMBOL = 'o' """Default symbol of the scatter plots""" diff --git a/silx/gui/plot/matplotlib/Colormap.py b/silx/gui/plot/matplotlib/Colormap.py index d035605..772a473 100644 --- a/silx/gui/plot/matplotlib/Colormap.py +++ b/silx/gui/plot/matplotlib/Colormap.py @@ -1,6 +1,6 @@ # coding: utf-8 # /*########################################################################## -# Copyright (C) 2017 European Synchrotron Radiation Facility +# Copyright (C) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,8 @@ from matplotlib.colors import ListedColormap import matplotlib.colors import matplotlib.cm import silx.resources +from silx.utils.deprecation import deprecated + _logger = logging.getLogger(__name__) @@ -177,6 +179,8 @@ def getScalarMappable(colormap, data=None): return matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) +@deprecated(replacement='silx.colors.Colormap.applyToData', + since_version='0.8.0') def applyColormapToData(data, colormap): """Apply a colormap to the data and returns the RGBA image diff --git a/silx/gui/plot/matplotlib/__init__.py b/silx/gui/plot/matplotlib/__init__.py index 384d049..a4dc235 100644 --- a/silx/gui/plot/matplotlib/__init__.py +++ b/silx/gui/plot/matplotlib/__init__.py @@ -22,6 +22,9 @@ # THE SOFTWARE. # # ###########################################################################*/ + +from __future__ import absolute_import + """This module inits matplotlib and setups the backend to use. It MUST be imported prior to any other import of matplotlib. @@ -32,7 +35,7 @@ to the used backend. __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "04/05/2017" +__date__ = "02/05/2018" import sys @@ -41,31 +44,54 @@ import logging _logger = logging.getLogger(__name__) -if 'matplotlib' in sys.modules: - _logger.warning( - 'matplotlib already loaded, setting its backend may not work') - +_matplotlib_already_loaded = 'matplotlib' in sys.modules +"""If true, matplotlib was already loaded""" +import matplotlib from ... import qt -import matplotlib + +def _configure(backend, backend_qt4=None, backend_qt5=None, check=False): + """Configure matplotlib using a specific backend. + + It initialize `matplotlib.rcParams` using the requested backend, or check + if it is already configured as requested. + + :param bool check: If true, the function only check that matplotlib + is already initialized as request. If not a warning is emitted. + If `check` is false, matplotlib is initialized. + """ + if check: + valid = matplotlib.rcParams['backend'] == backend + if backend_qt4 is not None: + valid = valid and matplotlib.rcParams['backend.qt4'] == backend_qt4 + if backend_qt5 is not None: + valid = valid and matplotlib.rcParams['backend.qt5'] == backend_qt5 + + if not valid: + _logger.warning('matplotlib already loaded, setting its backend may not work') + else: + matplotlib.rcParams['backend'] = backend + if backend_qt4 is not None: + matplotlib.rcParams['backend.qt4'] = backend_qt4 + if backend_qt5 is not None: + matplotlib.rcParams['backend.qt5'] = backend_qt5 + if qt.BINDING == 'PySide': - matplotlib.rcParams['backend'] = 'Qt4Agg' - matplotlib.rcParams['backend.qt4'] = 'PySide' + _configure('Qt4Agg', backend_qt4='PySide', check=_matplotlib_already_loaded) import matplotlib.backends.backend_qt4agg as backend elif qt.BINDING == 'PyQt4': - matplotlib.rcParams['backend'] = 'Qt4Agg' + _configure('Qt4Agg', check=_matplotlib_already_loaded) import matplotlib.backends.backend_qt4agg as backend elif qt.BINDING == 'PySide2': - matplotlib.rcParams['backend'] = 'Qt5Agg' - matplotlib.rcParams['backend.qt5'] = 'PySide2' + _configure('Qt5Agg', backend_qt5="PySide2", check=_matplotlib_already_loaded) import matplotlib.backends.backend_qt5agg as backend elif qt.BINDING == 'PyQt5': - matplotlib.rcParams['backend'] = 'Qt5Agg' + _configure('Qt5Agg', check=_matplotlib_already_loaded) import matplotlib.backends.backend_qt5agg as backend else: diff --git a/silx/gui/plot/setup.py b/silx/gui/plot/setup.py index 205c5fa..e0b2c91 100644 --- a/silx/gui/plot/setup.py +++ b/silx/gui/plot/setup.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,10 +35,14 @@ def configuration(parent_package='', top_path=None): config.add_subpackage('_utils') config.add_subpackage('utils') config.add_subpackage('matplotlib') + config.add_subpackage('stats') config.add_subpackage('backends') config.add_subpackage('backends.glutils') config.add_subpackage('items') config.add_subpackage('test') + config.add_subpackage('tools') + config.add_subpackage('tools.profile') + config.add_subpackage('tools.test') config.add_subpackage('actions') return config diff --git a/silx/gui/plot/stats/__init__.py b/silx/gui/plot/stats/__init__.py new file mode 100644 index 0000000..04a5327 --- /dev/null +++ b/silx/gui/plot/stats/__init__.py @@ -0,0 +1,33 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "07/03/2018" + + +from .stats import * diff --git a/silx/gui/plot/stats/stats.py b/silx/gui/plot/stats/stats.py new file mode 100644 index 0000000..a753989 --- /dev/null +++ b/silx/gui/plot/stats/stats.py @@ -0,0 +1,491 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`Scatter` item of the :class:`Plot`. +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "06/06/2018" + + +import numpy +from silx.gui.plot.items.curve import Curve as CurveItem +from silx.gui.plot.items.image import ImageBase as ImageItem +from silx.gui.plot.items.scatter import Scatter as ScatterItem +from silx.gui.plot.items.histogram import Histogram as HistogramItem +from silx.math.combo import min_max +from collections import OrderedDict +import logging + +logger = logging.getLogger(__name__) + + +class Stats(OrderedDict): + """Class to define a set of statistic relative to a dataset + (image, curve...). + + The goal of this class is to avoid multiple recalculation of some + basic operations such as filtering data area where the statistics has to + be apply. + Min and max are also stored because they can be used several time. + + :param List statslist: List of the :class:`Stat` object to be computed. + """ + def __init__(self, statslist=None): + OrderedDict.__init__(self) + _statslist = statslist if not None else [] + if statslist is not None: + for stat in _statslist: + self.add(stat) + + def calculate(self, item, plot, onlimits): + """ + Call all :class:`Stat` object registred and return the result of the + computation. + + :param item: the item for which we want statistics + :param plot: plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + :return dict: dictionary with :class:`Stat` name as ket and result + of the calculation as value + """ + res = {} + if isinstance(item, CurveItem): + context = _CurveContext(item, plot, onlimits) + elif isinstance(item, ImageItem): + context = _ImageContext(item, plot, onlimits) + elif isinstance(item, ScatterItem): + context = _ScatterContext(item, plot, onlimits) + elif isinstance(item, HistogramItem): + context = _HistogramContext(item, plot, onlimits) + else: + raise ValueError('Item type not managed') + for statName, stat in list(self.items()): + if context.kind not in stat.compatibleKinds: + logger.debug('kind %s not managed by statistic %s' + % (context.kind, stat.name)) + res[statName] = None + else: + res[statName] = stat.calculate(context) + return res + + def __setitem__(self, key, value): + assert isinstance(value, StatBase) + OrderedDict.__setitem__(self, key, value) + + def add(self, stat): + self.__setitem__(key=stat.name, value=stat) + + +class _StatsContext(object): + """ + The context is designed to be a simple buffer and avoid repetition of + calculations that can appear during stats evaluation. + + .. warning:: this class gives access to the data to be used for computation + . It deal with filtering data visible by the user on plot. + The filtering is a simple data sub-sampling. No interpolation + is made to fit data to boundaries. + + :param item: the item for which we want to compute the context + :param str kind: the kind of the item + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + """ + def __init__(self, item, kind, plot, onlimits): + assert item + assert plot + assert type(onlimits) is bool + self.kind = kind + self.min = None + self.max = None + self.data = None + self.values = None + self.createContext(item, plot, onlimits) + + def createContext(self, item, plot, onlimits): + raise NotImplementedError("Base class") + + +class _CurveContext(_StatsContext): + """ + StatsContext for :class:`Curve` + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + """ + def __init__(self, item, plot, onlimits): + _StatsContext.__init__(self, kind='curve', item=item, + plot=plot, onlimits=onlimits) + + def createContext(self, item, plot, onlimits): + xData, yData = item.getData(copy=True)[0:2] + + if onlimits: + minX, maxX = plot.getXAxis().getLimits() + yData = yData[(minX <= xData) & (xData <= maxX)] + xData = xData[(minX <= xData) & (xData <= maxX)] + + self.xData = xData + self.yData = yData + if len(yData) > 0: + self.min, self.max = min_max(yData) + else: + self.min, self.max = None, None + self.data = (xData, yData) + self.values = yData + + +class _HistogramContext(_StatsContext): + """ + StatsContext for :class:`Curve` + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + """ + def __init__(self, item, plot, onlimits): + _StatsContext.__init__(self, kind='histogram', item=item, + plot=plot, onlimits=onlimits) + + def createContext(self, item, plot, onlimits): + xData, edges = item.getData(copy=True)[0:2] + yData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment()) + if onlimits: + minX, maxX = plot.getXAxis().getLimits() + yData = yData[(minX <= xData) & (xData <= maxX)] + xData = xData[(minX <= xData) & (xData <= maxX)] + + self.xData = xData + self.yData = yData + if len(yData) > 0: + self.min, self.max = min_max(yData) + else: + self.min, self.max = None, None + self.data = (xData, yData) + self.values = yData + + +class _ScatterContext(_StatsContext): + """ + StatsContext for :class:`Scatter` + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + """ + def __init__(self, item, plot, onlimits): + _StatsContext.__init__(self, kind='scatter', item=item, plot=plot, + onlimits=onlimits) + + def createContext(self, item, plot, onlimits): + xData, yData, valueData, xerror, yerror = item.getData(copy=True) + assert plot + if onlimits: + minX, maxX = plot.getXAxis().getLimits() + minY, maxY = plot.getYAxis().getLimits() + # filter on X axis + valueData = valueData[(minX <= xData) & (xData <= maxX)] + yData = yData[(minX <= xData) & (xData <= maxX)] + xData = xData[(minX <= xData) & (xData <= maxX)] + # filter on Y axis + valueData = valueData[(minY <= yData) & (yData <= maxY)] + xData = xData[(minY <= yData) & (yData <= maxY)] + yData = yData[(minY <= yData) & (yData <= maxY)] + if len(valueData) > 0: + self.min, self.max = min_max(valueData) + else: + self.min, self.max = None, None + self.data = (xData, yData, valueData) + self.values = valueData + + +class _ImageContext(_StatsContext): + """ + StatsContext for :class:`ImageBase` + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + """ + def __init__(self, item, plot, onlimits): + _StatsContext.__init__(self, kind='image', item=item, + plot=plot, onlimits=onlimits) + + def createContext(self, item, plot, onlimits): + self.origin = item.getOrigin() + self.scale = item.getScale() + self.data = item.getData() + + if onlimits: + minX, maxX = plot.getXAxis().getLimits() + minY, maxY = plot.getYAxis().getLimits() + + XMinBound = int((minX - self.origin[0]) / self.scale[0]) + YMinBound = int((minY - self.origin[1]) / self.scale[1]) + XMaxBound = int((maxX - self.origin[0]) / self.scale[0]) + YMaxBound = int((maxY - self.origin[1]) / self.scale[1]) + + XMinBound = max(XMinBound, 0) + YMinBound = max(YMinBound, 0) + + if XMaxBound <= XMinBound or YMaxBound <= YMinBound: + return self.noDataSelected() + data = item.getData() + self.data = data[YMinBound:YMaxBound + 1, XMinBound:XMaxBound + 1] + else: + self.data = item.getData() + + if self.data.size > 0: + self.min, self.max = min_max(self.data) + else: + self.min, self.max = None, None + self.values = self.data + + +BASIC_COMPATIBLE_KINDS = { + 'curve': CurveItem, + 'image': ImageItem, + 'scatter': ScatterItem, + 'histogram': HistogramItem, +} + + +class StatBase(object): + """ + Base class for defining a statistic. + + :param str name: the name of the statistic. Must be unique. + :param compatibleKinds: the kind of items (curve, scatter...) for which + the statistic apply. + :rtype: List or tuple + """ + def __init__(self, name, compatibleKinds=BASIC_COMPATIBLE_KINDS, description=None): + self.name = name + self.compatibleKinds = compatibleKinds + self.description = description + + def calculate(self, context): + """ + compute the statistic for the given :class:`StatsContext` + + :param context: + :return dict: key is stat name, statistic computed is the dict value + """ + raise NotImplementedError('Base class') + + def getToolTip(self, kind): + """ + If necessary add a tooltip for a stat kind + + :param str kinf: the kind of item the statistic is compute for. + :return: tooltip or None if no tooltip + """ + return None + + +class Stat(StatBase): + """ + Create a StatBase class based on a function pointer. + + :param str name: name of the statistic. Used as id + :param fct: function which should have as unique mandatory parameter the + data. Should be able to adapt to all `kinds` defined as + compatible + :param tuple kinds: the compatible item kinds of the function (curve, + image...) + """ + def __init__(self, name, fct, kinds=BASIC_COMPATIBLE_KINDS): + StatBase.__init__(self, name, kinds) + self._fct = fct + + def calculate(self, context): + if context.kind in self.compatibleKinds: + return self._fct(context.values) + else: + raise ValueError('Kind %s not managed by %s' + '' % (context.kind, self.name)) + + +class StatMin(StatBase): + """ + Compute the minimal value on data + """ + def __init__(self): + StatBase.__init__(self, name='min') + + def calculate(self, context): + return context.min + + +class StatMax(StatBase): + """ + Compute the maximal value on data + """ + def __init__(self): + StatBase.__init__(self, name='max') + + def calculate(self, context): + return context.max + + +class StatDelta(StatBase): + """ + Compute the delta between minimal and maximal on data + """ + def __init__(self): + StatBase.__init__(self, name='delta') + + def calculate(self, context): + return context.max - context.min + + +class StatCoordMin(StatBase): + """ + Compute the first coordinates of the data minimal value + """ + def __init__(self): + StatBase.__init__(self, name='coords min') + + def calculate(self, context): + if context.kind in ('curve', 'histogram'): + return context.xData[numpy.argmin(context.yData)] + elif context.kind == 'scatter': + xData, yData, valueData = context.data + return (xData[numpy.argmin(valueData)], + yData[numpy.argmin(valueData)]) + elif context.kind == 'image': + scaleX, scaleY = context.scale + originX, originY = context.origin + index1D = numpy.argmin(context.data) + ySize = (context.data.shape[1]) + x = index1D % context.data.shape[1] + y = (index1D - x) / ySize + x = x * scaleX + originX + y = y * scaleY + originY + return (x, y) + else: + raise ValueError('kind not managed') + + def getToolTip(self, kind): + if kind in ('scatter', 'image'): + return '(x, y)' + else: + return None + +class StatCoordMax(StatBase): + """ + Compute the first coordinates of the data minimal value + """ + def __init__(self): + StatBase.__init__(self, name='coords max') + + def calculate(self, context): + if context.kind in ('curve', 'histogram'): + return context.xData[numpy.argmax(context.yData)] + elif context.kind == 'scatter': + xData, yData, valueData = context.data + return (xData[numpy.argmax(valueData)], + yData[numpy.argmax(valueData)]) + elif context.kind == 'image': + scaleX, scaleY = context.scale + originX, originY = context.origin + index1D = numpy.argmax(context.data) + ySize = (context.data.shape[1]) + x = index1D % context.data.shape[1] + y = (index1D - x) / ySize + x = x * scaleX + originX + y = y * scaleY + originY + return (x, y) + else: + raise ValueError('kind not managed') + + def getToolTip(self, kind): + if kind in ('scatter', 'image'): + return '(x, y)' + else: + return None + +class StatCOM(StatBase): + """ + Compute data center of mass + """ + def __init__(self): + StatBase.__init__(self, name='COM', description='Center of mass') + + def calculate(self, context): + if context.kind in ('curve', 'histogram'): + xData, yData = context.data + deno = numpy.sum(yData).astype(numpy.float32) + if deno == 0.: + return numpy.nan + else: + return numpy.sum(xData * yData).astype(numpy.float32) / deno + elif context.kind == 'scatter': + xData, yData, values = context.data + deno = numpy.sum(values).astype(numpy.float32) + if deno == 0.: + return numpy.nan, numpy.nan + else: + xcom = numpy.sum(xData * values).astype(numpy.float32) / deno + ycom = numpy.sum(yData * values).astype(numpy.float32) / deno + return (xcom, ycom) + elif context.kind == 'image': + yData = numpy.sum(context.data, axis=1) + xData = numpy.sum(context.data, axis=0) + dataXRange = range(context.data.shape[1]) + dataYRange = range(context.data.shape[0]) + xScale, yScale = context.scale + xOrigin, yOrigin = context.origin + + denoY = numpy.sum(yData) + if denoY == 0.: + ycom = numpy.nan + else: + ycom = numpy.sum(yData * dataYRange) / denoY + ycom = ycom * yScale + yOrigin + + denoX = numpy.sum(xData) + if denoX == 0.: + xcom = numpy.nan + else: + xcom = numpy.sum(xData * dataXRange) / denoX + xcom = xcom * xScale + xOrigin + return (xcom, ycom) + else: + raise ValueError('kind not managed') + + def getToolTip(self, kind): + if kind in ('scatter', 'image'): + return '(x, y)' + else: + return None diff --git a/silx/gui/plot/stats/statshandler.py b/silx/gui/plot/stats/statshandler.py new file mode 100644 index 0000000..0a62b31 --- /dev/null +++ b/silx/gui/plot/stats/statshandler.py @@ -0,0 +1,190 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "05/06/2018" + + +import logging + +from silx.gui import qt +from silx.gui.plot import stats as statsmdl + +logger = logging.getLogger(__name__) + + +class _FloatItem(qt.QTableWidgetItem): + """Simple QTableWidgetItem allowing ordering on floats""" + + def __init__(self, type=qt.QTableWidgetItem.Type): + qt.QTableWidgetItem.__init__(self, type=type) + + def __lt__(self, other): + return float(self.text()) < float(other.text()) + + +class StatFormatter(object): + """ + Class used to apply format on :class:`Stat` + + :param formatter: the formatter. Defined as str.format() + :param qItemClass: the class inheriting from :class:`QTableWidgetItem` + which will be used to display the result of the + statistic computation. + """ + DEFAULT_FORMATTER = '{0:.3f}' + + def __init__(self, formatter=DEFAULT_FORMATTER, qItemClass=_FloatItem): + self.formatter = formatter + self.tabWidgetItemClass = qItemClass + + def format(self, val): + if self.formatter is None or val is None: + return str(val) + else: + return self.formatter.format(val) + + +class StatsHandler(object): + """ + Give + create: + + * Stats object which will manage the statistic computation + * Associate formatter and :class:`Stat` + + :param statFormatters: Stat and optional formatter. + If elements are given as a tuple, elements + should be (:class:`Stat`, formatter). + Otherwise should be :class:`Stat` elements. + :rtype: List or tuple + """ + + def __init__(self, statFormatters): + self.stats = statsmdl.Stats() + self.formatters = {} + for elmt in statFormatters: + helper = _StatHelper(elmt) + self.add(stat=helper.stat, formatter=helper.statFormatter) + + def add(self, stat, formatter=None): + assert isinstance(stat, statsmdl.StatBase) + self.stats.add(stat) + _formatter = formatter + if type(_formatter) is str: + _formatter = StatFormatter(formatter=_formatter) + self.formatters[stat.name] = _formatter + + def format(self, name, val): + """ + Apply the format for the `name` statistic and the given value + :param name: the name of the associated statistic + :param val: value before formatting + :return: formatted value + """ + if name not in self.formatters: + logger.warning("statistic %s haven't been registred" % name) + return val + else: + if self.formatters[name] is None: + return str(val) + else: + if isinstance(val, (tuple, list)): + res = [] + [res.append(self.formatters[name].format(_val)) for _val in val] + return ', '.join(res) + else: + return self.formatters[name].format(val) + + def calculate(self, item, plot, onlimits): + """ + compute all statistic registred and return the list of formatted + statistics result. + + :param item: item for which we want to compute statistics + :param plot: plot containing the item + :param onlimits: True if we want to compute statistics on visible data + only + :return: list of formatted statistics (as str) + :rtype: dict + """ + res = self.stats.calculate(item, plot, onlimits) + for resName, resValue in list(res.items()): + res[resName] = self.format(resName, res[resName]) + return res + + +class _StatHelper(object): + """ + Helper class to generated the requested StatBase instance and the + associated StatFormatter + """ + def __init__(self, arg): + self.statFormatter = None + self.stat = None + + if isinstance(arg, statsmdl.StatBase): + self.stat = arg + else: + assert len(arg) > 0 + if isinstance(arg[0], statsmdl.StatBase): + self.dealWithStatAndFormatter(arg) + else: + _arg = arg + if isinstance(arg[0], tuple): + _arg = arg[0] + if len(arg) > 1: + self.statFormatter = arg[1] + self.createStatInstanceAndFormatter(_arg) + + def dealWithStatAndFormatter(self, arg): + assert isinstance(arg[0], statsmdl.StatBase) + self.stat = arg[0] + if len(arg) > 2: + raise ValueError('To many argument with %s. At most one ' + 'argument can be associated with the ' + 'BaseStat (the `StatFormatter`') + if len(arg) is 2: + assert isinstance(arg[1], (StatFormatter, type(None), str)) + self.statFormatter = arg[1] + + def createStatInstanceAndFormatter(self, arg): + if type(arg[0]) is not str: + raise ValueError('first element of the tuple should be a string' + ' or a StatBase instance') + if len(arg) is 1: + raise ValueError('A function should be associated with the' + 'stat name') + if len(arg) > 3: + raise ValueError('Two much argument given for defining statistic.' + 'Take at most three arguments (name, function, ' + 'kinds)') + if len(arg) is 2: + self.stat = statsmdl.Stat(name=arg[0], fct=arg[1]) + else: + self.stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2]) diff --git a/silx/gui/plot/test/__init__.py b/silx/gui/plot/test/__init__.py index 154a70a..1428bad 100644 --- a/silx/gui/plot/test/__init__.py +++ b/silx/gui/plot/test/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,24 +24,21 @@ # ###########################################################################*/ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "28/11/2017" +__date__ = "24/04/2018" import unittest from .._utils import test from . import testColorBar -from . import testColormap -from . import testColormapDialog -from . import testColors from . import testCurvesROIWidget +from . import testStats from . import testAlphaSlider from . import testInteraction from . import testLegendSelector from . import testMaskToolsWidget from . import testScatterMaskToolsWidget from . import testPlotInteraction -from . import testPlotTools from . import testPlotWidgetNoBackend from . import testPlotWidget from . import testPlotWindow @@ -53,16 +50,21 @@ from . import testLimitConstraints from . import testComplexImageView from . import testImageView from . import testSaveAction +from . import testScatterView +from . import testPixelIntensityHistoAction def suite(): + # Lazy-loading to avoid cyclic reference + from ..tools import test as testTools + test_suite = unittest.TestSuite() test_suite.addTests( [test.suite(), + testTools.suite(), testColorBar.suite(), - testColors.suite(), - testColormapDialog.suite(), testCurvesROIWidget.suite(), + testStats.suite(), testAlphaSlider.suite(), testInteraction.suite(), testLegendSelector.suite(), @@ -70,16 +72,17 @@ def suite(): testScatterMaskToolsWidget.suite(), testPlotInteraction.suite(), testPlotWidgetNoBackend.suite(), - testPlotTools.suite(), testPlotWidget.suite(), testPlotWindow.suite(), testProfile.suite(), testStackView.suite(), - testColormap.suite(), testItem.suite(), testUtilsAxis.suite(), testLimitConstraints.suite(), testComplexImageView.suite(), testImageView.suite(), - testSaveAction.suite()]) + testSaveAction.suite(), + testScatterView.suite(), + testPixelIntensityHistoAction.suite() + ]) return test_suite diff --git a/silx/gui/plot/test/testColorBar.py b/silx/gui/plot/test/testColorBar.py index 80ae6a8..0d1c952 100644 --- a/silx/gui/plot/test/testColorBar.py +++ b/silx/gui/plot/test/testColorBar.py @@ -26,13 +26,13 @@ __authors__ = ["H. Payno"] __license__ = "MIT" -__date__ = "11/04/2017" +__date__ = "24/04/2018" import unittest from silx.gui.test.utils import TestCaseQt from silx.gui.plot.ColorBar import _ColorScale from silx.gui.plot.ColorBar import ColorBarWidget -from silx.gui.plot.Colormap import Colormap +from silx.gui.colors import Colormap from silx.gui.plot import Plot2D from silx.gui import qt import numpy @@ -64,7 +64,7 @@ class TestColorScale(TestCaseQt): vmin=0.0, vmax=1.0) self.colorScaleWidget.setColormap(self.colorMapLin1) - + self.assertTrue( self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25) self.assertTrue( @@ -77,7 +77,7 @@ class TestColorScale(TestCaseQt): vmin=-10, vmax=0) self.colorScaleWidget.setColormap(self.colorMapLin2) - + self.assertTrue( self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5) self.assertTrue( @@ -98,7 +98,7 @@ class TestColorScale(TestCaseQt): val = self.colorScaleWidget.getValueFromRelativePosition(0.5) self.assertTrue(val == 10.0) - + val = self.colorScaleWidget.getValueFromRelativePosition(0.0) self.assertTrue(val == 1.0) @@ -225,7 +225,7 @@ class TestColorBarWidget(TestCaseQt): self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 30) # if data is positive - data[data<1] = data.max() + data[data < 1] = data.max() self.plot.addImage(data=data, colormap=colormapLog, legend='toto', diff --git a/silx/gui/plot/test/testColors.py b/silx/gui/plot/test/testColors.py deleted file mode 100644 index 4d617eb..0000000 --- a/silx/gui/plot/test/testColors.py +++ /dev/null @@ -1,94 +0,0 @@ -# coding: utf-8 -# /*########################################################################## -# -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# ###########################################################################*/ -"""Basic tests for Colors""" - -__authors__ = ["T. Vincent"] -__license__ = "MIT" -__date__ = "17/01/2018" - - -import numpy - -import unittest -from silx.utils.testutils import ParametricTestCase - -from silx.gui.plot import Colors -from silx.gui.plot.Colormap import Colormap - -class TestRGBA(ParametricTestCase): - """Basic tests of rgba function""" - - def testRGBA(self): - """"Test rgba function with accepted values""" - tests = { # name: (colors, expected values) - 'blue': ('blue', (0., 0., 1., 1.)), - '#010203': ('#010203', (1. / 255., 2. / 255., 3. / 255., 1.)), - '#01020304': ('#01020304', (1. / 255., 2. / 255., 3. / 255., 4. / 255.)), - '3 x uint8': (numpy.array((1, 255, 0), dtype=numpy.uint8), - (1 / 255., 1., 0., 1.)), - '4 x uint8': (numpy.array((1, 255, 0, 1), dtype=numpy.uint8), - (1 / 255., 1., 0., 1 / 255.)), - '3 x float overflow': ((3., 0.5, 1.), (1., 0.5, 1., 1.)), - } - - for name, test in tests.items(): - color, expected = test - with self.subTest(msg=name): - result = Colors.rgba(color) - self.assertEqual(result, expected) - - -class TestApplyColormapToData(ParametricTestCase): - """Tests of applyColormapToData function""" - - def testApplyColormapToData(self): - """Simple test of applyColormapToData function""" - colormap = Colormap(name='gray', normalization='linear', - vmin=0, vmax=255) - - size = 10 - expected = numpy.empty((size, 4), dtype='uint8') - expected[:, 0] = numpy.arange(size, dtype='uint8') - expected[:, 1] = expected[:, 0] - expected[:, 2] = expected[:, 0] - expected[:, 3] = 255 - - for dtype in ('uint8', 'int32', 'float32', 'float64'): - with self.subTest(dtype=dtype): - array = numpy.arange(size, dtype=dtype) - result = colormap.applyToData(data=array) - self.assertTrue(numpy.all(numpy.equal(result, expected))) - - -def suite(): - test_suite = unittest.TestSuite() - for testClass in (TestRGBA, TestApplyColormapToData): - test_suite.addTest( - unittest.defaultTestLoader.loadTestsFromTestCase(testClass)) - return test_suite - - -if __name__ == '__main__': - unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py index 0fd2456..7a2e3d1 100644 --- a/silx/gui/plot/test/testCurvesROIWidget.py +++ b/silx/gui/plot/test/testCurvesROIWidget.py @@ -24,7 +24,7 @@ # ###########################################################################*/ """Basic tests for CurvesROIWidget""" -__authors__ = ["T. Vincent", "P. Knobel"] +__authors__ = ["T. Vincent", "P. Knobel", "H. Payno"] __license__ = "MIT" __date__ = "16/11/2017" @@ -32,9 +32,8 @@ __date__ = "16/11/2017" import logging import os.path import unittest - +from collections import OrderedDict import numpy - from silx.gui import qt from silx.test.utils import temp_dir from silx.gui.test.utils import TestCaseQt @@ -153,6 +152,25 @@ class TestCurvesROIWidget(TestCaseQt): self.assertEqual(output["negative"]["rawcounts"], y[selection].sum(), "Calculation failed on negative X coordinates") + def testDeferedInit(self): + x = numpy.arange(100.) + y = numpy.arange(100.) + self.plot.addCurve(x=x, y=y, legend="name", replace="True") + roisDefs = OrderedDict([ + ["range1", + OrderedDict([["from", 20], ["to", 200], ["type", "energy"]])], + ["range2", + OrderedDict([["from", 300], ["to", 500], ["type", "energy"]])] + ]) + + roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget + self.assertFalse(roiWidget._isInit) + self.plot.getCurvesRoiDockWidget().setRois(roisDefs) + self.assertTrue(len(roiWidget.getRois()) is len(roisDefs)) + self.plot.getCurvesRoiDockWidget().setVisible(True) + self.assertTrue(len(roiWidget.getRois()) is len(roisDefs)) + + def suite(): test_suite = unittest.TestSuite() for TestClass in (TestCurvesROIWidget,): diff --git a/silx/gui/plot/test/testImageView.py b/silx/gui/plot/test/testImageView.py index 641d438..5059a0b 100644 --- a/silx/gui/plot/test/testImageView.py +++ b/silx/gui/plot/test/testImageView.py @@ -26,7 +26,7 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "22/09/2017" +__date__ = "24/04/2018" import unittest @@ -36,7 +36,7 @@ from silx.gui import qt from silx.gui.test.utils import TestCaseQt from silx.gui.plot import ImageView -from silx.gui.plot.Colormap import Colormap +from silx.gui.colors import Colormap class TestImageView(TestCaseQt): diff --git a/silx/gui/plot/test/testLimitConstraints.py b/silx/gui/plot/test/testLimitConstraints.py index 94aae76..5e7e0b1 100644 --- a/silx/gui/plot/test/testLimitConstraints.py +++ b/silx/gui/plot/test/testLimitConstraints.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -44,9 +44,9 @@ class TestLimitConstaints(unittest.TestCase): def testApi(self): """Test availability of the API""" - self.plot.getXAxis().setLimitsConstraints(minPos=1, maxPos=1) + self.plot.getXAxis().setLimitsConstraints(minPos=1, maxPos=10) self.plot.getXAxis().setRangeConstraints(minRange=1, maxRange=1) - self.plot.getYAxis().setLimitsConstraints(minPos=1, maxPos=1) + self.plot.getYAxis().setLimitsConstraints(minPos=1, maxPos=10) self.plot.getYAxis().setRangeConstraints(minRange=1, maxRange=1) def testXMinMax(self): diff --git a/silx/gui/plot/test/testPixelIntensityHistoAction.py b/silx/gui/plot/test/testPixelIntensityHistoAction.py new file mode 100644 index 0000000..987e5b2 --- /dev/null +++ b/silx/gui/plot/test/testPixelIntensityHistoAction.py @@ -0,0 +1,104 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PixelIntensitiesHistoAction""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/03/2018" + + +import numpy +import unittest + +from silx.utils.testutils import ParametricTestCase +from silx.gui.test.utils import TestCaseQt, getQToolButtonFromAction +from silx.gui import qt +from silx.gui.plot import Plot2D + + +class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase): + """Tests for PixelIntensitiesHistoAction widget.""" + + def setUp(self): + super(TestPixelIntensitiesHisto, self).setUp() + self.image = numpy.random.rand(100, 100) + self.plotImage = Plot2D() + self.plotImage.getIntensityHistogramAction().setVisible(True) + + def tearDown(self): + del self.plotImage + super(TestPixelIntensitiesHisto, self).tearDown() + + def testShowAndHide(self): + """Simple test that the plot is showing and hiding when activating the + action""" + self.plotImage.addImage(self.image, origin=(0, 0), legend='sino') + self.plotImage.show() + + histoAction = self.plotImage.getIntensityHistogramAction() + + # test the pixel intensity diagram is showing + button = getQToolButtonFromAction(histoAction) + self.assertIsNot(button, None) + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + self.assertTrue(histoAction.getHistogramPlotWidget().isVisible()) + + # test the pixel intensity diagram is hiding + self.qapp.setActiveWindow(self.plotImage) + self.qapp.processEvents() + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + self.assertFalse(histoAction.getHistogramPlotWidget().isVisible()) + + def testImageFormatInput(self): + """Test multiple type as image input""" + typesToTest = [numpy.uint8, numpy.int8, numpy.int16, numpy.int32, + numpy.float32, numpy.float64] + self.plotImage.addImage(self.image, origin=(0, 0), legend='sino') + self.plotImage.show() + button = getQToolButtonFromAction( + self.plotImage.getIntensityHistogramAction()) + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + for typeToTest in typesToTest: + with self.subTest(typeToTest=typeToTest): + self.plotImage.addImage(self.image.astype(typeToTest), + origin=(0, 0), legend='sino') + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase( + TestPixelIntensitiesHisto)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py index 72617e5..dac6580 100644 --- a/silx/gui/plot/test/testPlotWidget.py +++ b/silx/gui/plot/test/testPlotWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,22 +26,24 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "26/01/2018" +__date__ = "24/04/2018" import unittest import logging import numpy -from silx.utils.testutils import ParametricTestCase +from silx.utils.testutils import ParametricTestCase, parameterize from silx.gui.test.utils import SignalListener from silx.gui.test.utils import TestCaseQt from silx.utils import testutils from silx.utils import deprecation +from silx.test.utils import test_options + from silx.gui import qt from silx.gui.plot import PlotWidget -from silx.gui.plot.Colormap import Colormap +from silx.gui.colors import Colormap from .utils import PlotWidgetTestCase @@ -188,7 +190,7 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): self.plot.addImage(rgb, legend="rgb", origin=(0, 0), scale=(10, 10), - replace=False, resetzoom=False) + resetzoom=False) rgba = numpy.array( (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)), @@ -197,7 +199,7 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): self.plot.addImage(rgba, legend="rgba", origin=(5, 5), scale=(10, 10), - replace=False, resetzoom=False) + resetzoom=False) self.plot.resetZoom() @@ -212,7 +214,7 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): colors=((0., 0., 0.), (1., 0., 0.), (0., 1., 0.), (0., 0., 1.))) self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap, - replace=False, resetzoom=False) + resetzoom=False) colormap = Colormap(name=None, normalization=Colormap.LINEAR, @@ -224,7 +226,7 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): dtype=numpy.uint8)) self.plot.addImage(DATA_2D, legend="image 2", colormap=colormap, origin=(DATA_2D.shape[0], 0), - replace=False, resetzoom=False) + resetzoom=False) self.plot.resetZoom() def testImageOriginScale(self): @@ -614,13 +616,13 @@ class TestPlotActiveCurveImage(PlotWidgetTestCase): self.plot.getYAxis().setLabel('YLabel') # labels changed as active curve - self.plot.addImage(numpy.arange(100).reshape(10, 10), replace=False, + self.plot.addImage(numpy.arange(100).reshape(10, 10), legend='1', xlabel='x1', ylabel='y1') self.assertEqual(self.plot.getXAxis().getLabel(), 'x1') self.assertEqual(self.plot.getYAxis().getLabel(), 'y1') # labels not changed as not active curve - self.plot.addImage(numpy.arange(100).reshape(10, 10), replace=False, + self.plot.addImage(numpy.arange(100).reshape(10, 10), legend='2') self.assertEqual(self.plot.getXAxis().getLabel(), 'x1') self.assertEqual(self.plot.getYAxis().getLabel(), 'y1') @@ -660,9 +662,13 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): xData = numpy.arange(1, 10) yData = xData ** 2 + def __init__(self, methodName='runTest', backend=None): + unittest.TestCase.__init__(self, methodName) + self.__backend = backend + def setUp(self): super(TestPlotAxes, self).setUp() - self.plot = PlotWidget() + self.plot = PlotWidget(backend=self.__backend) # It is not needed to display the plot # It saves a lot of time # self.plot.show() @@ -721,7 +727,7 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): if getter is not None: self.assertEqual(getter(), expected) - @testutils.test_logging(deprecation.depreclog.name, warning=2) + @testutils.test_logging(deprecation.depreclog.name) def testOldPlotAxis_Logarithmic(self): """Test silx API prior to silx 0.6""" x = self.plot.getXAxis() @@ -760,7 +766,7 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(self.plot.isYAxisLogarithmic(), False) self.assertEqual(listener.arguments(callIndex=-1), ("y", False)) - @testutils.test_logging(deprecation.depreclog.name, warning=2) + @testutils.test_logging(deprecation.depreclog.name) def testOldPlotAxis_AutoScale(self): """Test silx API prior to silx 0.6""" x = self.plot.getXAxis() @@ -799,7 +805,7 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(self.plot.isYAxisAutoScale(), True) self.assertEqual(listener.arguments(callIndex=-1), ("y", True)) - @testutils.test_logging(deprecation.depreclog.name, warning=1) + @testutils.test_logging(deprecation.depreclog.name) def testOldPlotAxis_Inverted(self): """Test silx API prior to silx 0.6""" x = self.plot.getXAxis() @@ -1162,7 +1168,7 @@ class TestPlotImageLog(PlotWidgetTestCase): vmax=None) self.plot.addImage(DATA_2D, legend="image 1", origin=(1., 1.), scale=(1., 1.), - replace=False, resetzoom=False, colormap=colormap) + resetzoom=False, colormap=colormap) self.plot.resetZoom() def testPlotColormapGrayLogY(self): @@ -1175,7 +1181,7 @@ class TestPlotImageLog(PlotWidgetTestCase): vmax=None) self.plot.addImage(DATA_2D, legend="image 1", origin=(1., 1.), scale=(1., 1.), - replace=False, resetzoom=False, colormap=colormap) + resetzoom=False, colormap=colormap) self.plot.resetZoom() def testPlotColormapGrayLogXY(self): @@ -1189,7 +1195,7 @@ class TestPlotImageLog(PlotWidgetTestCase): vmax=None) self.plot.addImage(DATA_2D, legend="image 1", origin=(1., 1.), scale=(1., 1.), - replace=False, resetzoom=False, colormap=colormap) + resetzoom=False, colormap=colormap) self.plot.resetZoom() def testPlotRgbRgbaLogXY(self): @@ -1204,7 +1210,7 @@ class TestPlotImageLog(PlotWidgetTestCase): self.plot.addImage(rgb, legend="rgb", origin=(1, 1), scale=(10, 10), - replace=False, resetzoom=False) + resetzoom=False) rgba = numpy.array( (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)), @@ -1213,7 +1219,7 @@ class TestPlotImageLog(PlotWidgetTestCase): self.plot.addImage(rgba, legend="rgba", origin=(5., 5.), scale=(10., 10.), - replace=False, resetzoom=False) + resetzoom=False) self.plot.resetZoom() @@ -1355,19 +1361,22 @@ class TestPlotItemLog(PlotWidgetTestCase): def suite(): + testClasses = (TestPlotWidget, TestPlotImage, TestPlotCurve, + TestPlotMarker, TestPlotItem, TestPlotAxes, + TestPlotEmptyLog, TestPlotCurveLog, TestPlotImageLog, + TestPlotMarkerLog, TestPlotItemLog) + test_suite = unittest.TestSuite() - loadTests = unittest.defaultTestLoader.loadTestsFromTestCase - test_suite.addTest(loadTests(TestPlotWidget)) - test_suite.addTest(loadTests(TestPlotImage)) - test_suite.addTest(loadTests(TestPlotCurve)) - test_suite.addTest(loadTests(TestPlotMarker)) - test_suite.addTest(loadTests(TestPlotItem)) - test_suite.addTest(loadTests(TestPlotAxes)) - test_suite.addTest(loadTests(TestPlotEmptyLog)) - test_suite.addTest(loadTests(TestPlotCurveLog)) - test_suite.addTest(loadTests(TestPlotImageLog)) - test_suite.addTest(loadTests(TestPlotMarkerLog)) - test_suite.addTest(loadTests(TestPlotItemLog)) + + # Tests with matplotlib + for testClass in testClasses: + test_suite.addTest(parameterize(testClass, backend=None)) + + if test_options.WITH_GL_TEST: + # Tests with OpenGL backend + for testClass in testClasses: + test_suite.addTest(parameterize(testClass, backend='gl')) + return test_suite diff --git a/silx/gui/plot/test/testPlotWidgetNoBackend.py b/silx/gui/plot/test/testPlotWidgetNoBackend.py index 0d0ddc4..cd7cbb3 100644 --- a/silx/gui/plot/test/testPlotWidgetNoBackend.py +++ b/silx/gui/plot/test/testPlotWidgetNoBackend.py @@ -460,8 +460,8 @@ class TestPlotGetCurveImage(unittest.TestCase): image = plot.getImage() self.assertIsNone(image) - plot.addImage(((0, 1), (2, 3)), legend='image 0', replace=False) - plot.addImage(((0, 1), (2, 3)), legend='image 1', replace=False) + plot.addImage(((0, 1), (2, 3)), legend='image 0') + plot.addImage(((0, 1), (2, 3)), legend='image 1') # Active image active = plot.getActiveImage() @@ -470,7 +470,7 @@ class TestPlotGetCurveImage(unittest.TestCase): self.assertEqual(image.getLegend(), 'image 0') # No active image - plot.addImage(((0, 1), (2, 3)), legend='image 2', replace=False) + plot.addImage(((0, 1), (2, 3)), legend='image 2') plot.setActiveImage(None) active = plot.getActiveImage() self.assertIsNone(active) @@ -496,7 +496,7 @@ class TestPlotGetCurveImage(unittest.TestCase): image = numpy.arange(10).astype(numpy.float32) image.shape = 5, 2 - plot.addImage(image, legend='image 0', info=["Hi!"], replace=False) + plot.addImage(image, legend='image 0', info=["Hi!"]) # Active image data, legend, info, something, params = plot.getActiveImage() @@ -515,8 +515,8 @@ class TestPlotGetCurveImage(unittest.TestCase): # 2 images data = numpy.arange(100).reshape(10, 10) - plot.addImage(data, legend='1', replace=False) - plot.addImage(data, origin=(10, 10), legend='2', replace=False) + plot.addImage(data, legend='1') + plot.addImage(data, origin=(10, 10), legend='2') images = plot.getAllImages(just_legend=True) self.assertEqual(list(images), ['1', '2']) images = plot.getAllImages(just_legend=False) diff --git a/silx/gui/plot/test/testSaveAction.py b/silx/gui/plot/test/testSaveAction.py index 4dfe373..85669bf 100644 --- a/silx/gui/plot/test/testSaveAction.py +++ b/silx/gui/plot/test/testSaveAction.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,11 +33,13 @@ import unittest import tempfile import os +from silx.gui.plot.test.utils import PlotWidgetTestCase + from silx.gui.plot import PlotWidget from silx.gui.plot.actions.io import SaveAction -class TestSaveAction(unittest.TestCase): +class TestSaveActionSaveCurvesAsSpec(unittest.TestCase): def setUp(self): self.plot = PlotWidget(backend='none') @@ -63,8 +65,9 @@ class TestSaveAction(unittest.TestCase): ylabel="curve2 Y") self.plot.addCurve([3, 1], [7, 6], "curve with no labels") - self.saveAction._saveCurves(self.out_fname, - SaveAction.ALL_CURVES_FILTERS[0]) # "All curves as SpecFile (*.dat)" + self.saveAction._saveCurves(self.plot, + self.out_fname, + SaveAction.DEFAULT_ALL_CURVES_FILTERS[0]) # "All curves as SpecFile (*.dat)" with open(self.out_fname, "rb") as f: file_content = f.read() @@ -86,10 +89,35 @@ class TestSaveAction(unittest.TestCase): self.assertIn("#L graph x label graph y label", file_content) +class TestSaveActionExtension(PlotWidgetTestCase): + """Test SaveAction file filter API""" + + def _dummySaveFunction(self, plot, filename, nameFilter): + pass + + def testFileFilterAPI(self): + """Test addition/update of a file filter""" + saveAction = SaveAction(plot=self.plot, parent=self.plot) + + # Add a new file filter + nameFilter = 'Dummy file (*.dummy)' + saveAction.setFileFilter('all', nameFilter, self._dummySaveFunction) + self.assertTrue(nameFilter in saveAction.getFileFilters('all')) + self.assertEqual(saveAction.getFileFilters('all')[nameFilter], + self._dummySaveFunction) + + # Update an existing file filter + nameFilter = SaveAction.IMAGE_FILTER_EDF + saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction) + self.assertEqual(saveAction.getFileFilters('image')[nameFilter], + self._dummySaveFunction) + + def suite(): test_suite = unittest.TestSuite() - test_suite.addTest( - unittest.defaultTestLoader.loadTestsFromTestCase(TestSaveAction)) + for cls in (TestSaveActionSaveCurvesAsSpec, TestSaveActionExtension): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(cls)) return test_suite diff --git a/silx/gui/plot/test/testScatterView.py b/silx/gui/plot/test/testScatterView.py new file mode 100644 index 0000000..40fdac6 --- /dev/null +++ b/silx/gui/plot/test/testScatterView.py @@ -0,0 +1,115 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for ScatterView""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2018" + + +import unittest + +import numpy + +from silx.gui.plot.items import Axis, Scatter +from silx.gui.plot import ScatterView +from silx.gui.plot.test.utils import PlotWidgetTestCase + + +class TestScatterView(PlotWidgetTestCase): + """Test of ScatterView widget""" + + def _createPlot(self): + return ScatterView() + + def test(self): + """Simple tests""" + x = numpy.arange(100) + y = numpy.arange(100) + value = numpy.arange(100) + self.plot.setData(x, y, value) + self.qapp.processEvents() + + data = self.plot.getData() + self.assertEqual(len(data), 5) + self.assertTrue(numpy.all(numpy.equal(x, data[0]))) + self.assertTrue(numpy.all(numpy.equal(y, data[1]))) + self.assertTrue(numpy.all(numpy.equal(value, data[2]))) + self.assertIsNone(data[3]) # xerror + self.assertIsNone(data[4]) # yerror + + # Test access to scatter item + self.assertIsInstance(self.plot.getScatterItem(), Scatter) + + # Test toolbar actions + + action = self.plot.getScatterToolBar().getXAxisLogarithmicAction() + action.trigger() + self.qapp.processEvents() + + maskAction = self.plot.getScatterToolBar().actions()[-1] + maskAction.trigger() + self.qapp.processEvents() + + # Test proxy API + + self.plot.resetZoom() + self.qapp.processEvents() + + scale = self.plot.getXAxis().getScale() + self.assertEqual(scale, Axis.LOGARITHMIC) + + scale = self.plot.getYAxis().getScale() + self.assertEqual(scale, Axis.LINEAR) + + title = 'Test ScatterView' + self.plot.setGraphTitle(title) + self.assertEqual(self.plot.getGraphTitle(), title) + + self.qapp.processEvents() + + # Reset scatter data + + self.plot.setData(None, None, None) + self.qapp.processEvents() + + data = self.plot.getData() + self.assertEqual(len(data), 5) + self.assertEqual(len(data[0]), 0) # x + self.assertEqual(len(data[1]), 0) # y + self.assertEqual(len(data[2]), 0) # value + self.assertIsNone(data[3]) # xerror + self.assertIsNone(data[4]) # yerror + + +def suite(): + test_suite = unittest.TestSuite() + loadTests = unittest.defaultTestLoader.loadTestsFromTestCase + test_suite.addTest(loadTests(TestScatterView)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testStackView.py b/silx/gui/plot/test/testStackView.py index 8d2a0ee..3dcea36 100644 --- a/silx/gui/plot/test/testStackView.py +++ b/silx/gui/plot/test/testStackView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,7 +32,7 @@ __date__ = "20/03/2017" import unittest import numpy -from silx.gui.test.utils import TestCaseQt +from silx.gui.test.utils import TestCaseQt, SignalListener from silx.gui import qt from silx.gui.plot import StackView @@ -187,6 +187,17 @@ class TestStackView(TestCaseQt): "beau sirop, mi-sirop, siroté, gagne-sirop, sirop-grelot," " passe-montagne, sirop au bon goût.") + def testStackFrameNumber(self): + self.stackview.setStack(self.mystack) + self.assertEqual(self.stackview.getFrameNumber(), 0) + + listener = SignalListener() + self.stackview.sigFrameChanged.connect(listener) + + self.stackview.setFrameNumber(1) + self.assertEqual(self.stackview.getFrameNumber(), 1) + self.assertEqual(listener.arguments(), [(1,)]) + class TestStackViewMainWindow(TestCaseQt): """Base class for tests of StackView.""" diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py new file mode 100644 index 0000000..123eb89 --- /dev/null +++ b/silx/gui/plot/test/testStats.py @@ -0,0 +1,561 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for CurvesROIWidget""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "07/03/2018" + + +from silx.gui import qt +from silx.gui.plot.stats import stats +from silx.gui.plot import StatsWidget +from silx.gui.plot.stats import statshandler +from silx.gui.test.utils import TestCaseQt +from silx.gui.plot import Plot1D, Plot2D +import unittest +import logging +import numpy + +_logger = logging.getLogger(__name__) + + +class TestStats(TestCaseQt): + """ + Test :class:`BaseClass` class and inheriting classes + """ + def setUp(self): + TestCaseQt.setUp(self) + self.createCurveContext() + self.createImageContext() + self.createScatterContext() + + def tearDown(self): + self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot1d.close() + self.plot2d.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot2d.close() + self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.scatterPlot.close() + + def createCurveContext(self): + self.plot1d = Plot1D() + x = range(20) + y = range(20) + self.plot1d.addCurve(x, y, legend='curve0') + + self.curveContext = stats._CurveContext( + item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False) + + def createScatterContext(self): + self.scatterPlot = Plot2D() + lgd = 'scatter plot' + self.xScatterData = numpy.array([0, 1, 2, 20, 50, 60, 36]) + self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18]) + self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5]) + self.scatterPlot.addScatter(self.xScatterData, self.yScatterData, + self.valuesScatterData, legend=lgd) + self.scatterContext = stats._ScatterContext( + item=self.scatterPlot.getScatter(lgd), + plot=self.scatterPlot, + onlimits=False + ) + + def createImageContext(self): + self.plot2d = Plot2D() + self._imgLgd = 'test image' + self.imageData = numpy.arange(32*128).reshape(32, 128) + self.plot2d.addImage(data=self.imageData, + legend=self._imgLgd, replace=False) + self.imageContext = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False + ) + + def getBasicStats(self): + return { + 'min': stats.StatMin(), + 'minCoords': stats.StatCoordMin(), + 'max': stats.StatMax(), + 'maxCoords': stats.StatCoordMax(), + 'std': stats.Stat(name='std', fct=numpy.std), + 'mean': stats.Stat(name='mean', fct=numpy.mean), + 'com': stats.StatCOM() + } + + def testBasicStatsCurve(self): + """Test result for simple stats on a curve""" + _stats = self.getBasicStats() + xData = yData = numpy.array(range(20)) + self.assertTrue(_stats['min'].calculate(self.curveContext) == 0) + self.assertTrue(_stats['max'].calculate(self.curveContext) == 19) + self.assertTrue(_stats['minCoords'].calculate(self.curveContext) == [0]) + self.assertTrue(_stats['maxCoords'].calculate(self.curveContext) == [19]) + self.assertTrue(_stats['std'].calculate(self.curveContext) == numpy.std(yData)) + self.assertTrue(_stats['mean'].calculate(self.curveContext) == numpy.mean(yData)) + com = numpy.sum(xData * yData) / numpy.sum(yData) + self.assertTrue(_stats['com'].calculate(self.curveContext) == com) + + def testBasicStatsImage(self): + """Test result for simple stats on an image""" + _stats = self.getBasicStats() + self.assertTrue(_stats['min'].calculate(self.imageContext) == 0) + self.assertTrue(_stats['max'].calculate(self.imageContext) == 128 * 32 - 1) + self.assertTrue(_stats['minCoords'].calculate(self.imageContext) == (0, 0)) + self.assertTrue(_stats['maxCoords'].calculate(self.imageContext) == (127, 31)) + self.assertTrue(_stats['std'].calculate(self.imageContext) == numpy.std(self.imageData)) + self.assertTrue(_stats['mean'].calculate(self.imageContext) == numpy.mean(self.imageData)) + + yData = numpy.sum(self.imageData, axis=1) + xData = numpy.sum(self.imageData, axis=0) + dataXRange = range(self.imageData.shape[1]) + dataYRange = range(self.imageData.shape[0]) + + ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData) + xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData) + + self.assertTrue(_stats['com'].calculate(self.imageContext) == (xcom, ycom)) + + def testStatsImageAdv(self): + """Test that scale and origin are taking into account for images""" + + image2Data = numpy.arange(32 * 128).reshape(32, 128) + self.plot2d.addImage(data=image2Data, legend=self._imgLgd, + replace=True, origin=(100, 10), scale=(2, 0.5)) + image2Context = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False + ) + _stats = self.getBasicStats() + self.assertTrue(_stats['min'].calculate(image2Context) == 0) + self.assertTrue( + _stats['max'].calculate(image2Context) == 128 * 32 - 1) + self.assertTrue( + _stats['minCoords'].calculate(image2Context) == (100, 10)) + self.assertTrue( + _stats['maxCoords'].calculate(image2Context) == (127*2. + 100, + 31 * 0.5 + 10) + ) + self.assertTrue( + _stats['std'].calculate(image2Context) == numpy.std( + self.imageData)) + self.assertTrue( + _stats['mean'].calculate(image2Context) == numpy.mean( + self.imageData)) + + yData = numpy.sum(self.imageData, axis=1) + xData = numpy.sum(self.imageData, axis=0) + dataXRange = range(self.imageData.shape[1]) + dataYRange = range(self.imageData.shape[0]) + + ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData) + ycom = (ycom * 0.5) + 10 + xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData) + xcom = (xcom * 2.) + 100 + self.assertTrue( + _stats['com'].calculate(image2Context) == (xcom, ycom)) + + def testBasicStatsScatter(self): + """Test result for simple stats on a scatter""" + _stats = self.getBasicStats() + self.assertTrue(_stats['min'].calculate(self.scatterContext) == 5) + self.assertTrue(_stats['max'].calculate(self.scatterContext) == 90) + self.assertTrue(_stats['minCoords'].calculate(self.scatterContext) == (0, 2)) + self.assertTrue(_stats['maxCoords'].calculate(self.scatterContext) == (50, 69)) + self.assertTrue(_stats['std'].calculate(self.scatterContext) == numpy.std(self.valuesScatterData)) + self.assertTrue(_stats['mean'].calculate(self.scatterContext) == numpy.mean(self.valuesScatterData)) + + comx = numpy.sum(self.xScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum( + self.valuesScatterData).astype(numpy.float32) + comy = numpy.sum(self.yScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum( + self.valuesScatterData).astype(numpy.float32) + self.assertTrue(numpy.all( + numpy.equal(_stats['com'].calculate(self.scatterContext), + (comx, comy))) + ) + + def testKindNotManagedByStat(self): + """Make sure an exception is raised if we try to execute calculate + of the base class""" + b = stats.StatBase(name='toto', compatibleKinds='curve') + with self.assertRaises(NotImplementedError): + b.calculate(self.imageContext) + + def testKindNotManagedByContext(self): + """ + Make sure an error is raised if we try to calculate a statistic with + a context not managed + """ + myStat = stats.Stat(name='toto', fct=numpy.std, kinds=('curve')) + myStat.calculate(self.curveContext) + with self.assertRaises(ValueError): + myStat.calculate(self.scatterContext) + with self.assertRaises(ValueError): + myStat.calculate(self.imageContext) + + def testOnLimits(self): + stat = stats.StatMin() + + self.plot1d.getXAxis().setLimitsConstraints(minPos=2, maxPos=5) + curveContextOnLimits = stats._CurveContext( + item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=True) + self.assertTrue(stat.calculate(curveContextOnLimits) == 2) + + self.plot2d.getXAxis().setLimitsConstraints(minPos=32) + imageContextOnLimits = stats._ImageContext( + item=self.plot2d.getImage('test image'), + plot=self.plot2d, + onlimits=True) + self.assertTrue(stat.calculate(imageContextOnLimits) == 32) + + self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40) + scatterContextOnLimits = stats._ScatterContext( + item=self.scatterPlot.getScatter('scatter plot'), + plot=self.scatterPlot, + onlimits=True) + self.assertTrue(stat.calculate(scatterContextOnLimits) == 20) + + +class TestStatsFormatter(TestCaseQt): + """Simple test to check usage of the :class:`StatsFormatter`""" + def setUp(self): + self.plot1d = Plot1D() + x = range(20) + y = range(20) + self.plot1d.addCurve(x, y, legend='curve0') + + self.curveContext = stats._CurveContext( + item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False) + + self.stat = stats.StatMin() + + def tearDown(self): + self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot1d.close() + + def testEmptyFormatter(self): + """Make sure a formatter with no formatter definition will return a + simple cast to str""" + emptyFormatter = statshandler.StatFormatter() + self.assertTrue( + emptyFormatter.format(self.stat.calculate(self.curveContext)) == '0.000') + + def testSettedFormatter(self): + """Make sure a formatter with no formatter definition will return a + simple cast to str""" + formatter= statshandler.StatFormatter(formatter='{0:.3f}') + self.assertTrue( + formatter.format(self.stat.calculate(self.curveContext)) == '0.000') + + +class TestStatsHandler(unittest.TestCase): + """Make sure the StatHandler is correctly making the link between + :class:`StatBase` and :class:`StatFormatter` and checking the API is valid + """ + def setUp(self): + self.plot1d = Plot1D() + x = range(20) + y = range(20) + self.plot1d.addCurve(x, y, legend='curve0') + self.curveItem = self.plot1d.getCurve('curve0') + + self.stat = stats.StatMin() + + def tearDown(self): + self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot1d.close() + + def testConstructor(self): + """Make sure the constructor can deal will all possible arguments: + + * tuple of :class:`StatBase` derivated classes + * tuple of tuples (:class:`StatBase`, :class:`StatFormatter`) + * tuple of tuples (str, pointer to function, kind) + """ + handler0 = statshandler.StatsHandler( + (stats.StatMin(), stats.StatMax()) + ) + + res = handler0.calculate(item=self.curveItem, plot=self.plot1d, + onlimits=False) + self.assertTrue('min' in res) + self.assertTrue(res['min'] == '0') + self.assertTrue('max' in res) + self.assertTrue(res['max'] == '19') + + handler1 = statshandler.StatsHandler( + ( + (stats.StatMin(), statshandler.StatFormatter(formatter=None)), + (stats.StatMax(), statshandler.StatFormatter()) + ) + ) + + res = handler1.calculate(item=self.curveItem, plot=self.plot1d, + onlimits=False) + self.assertTrue('min' in res) + self.assertTrue(res['min'] == '0') + self.assertTrue('max' in res) + self.assertTrue(res['max'] == '19.000') + + handler2 = statshandler.StatsHandler( + ( + (stats.StatMin(), None), + (stats.StatMax(), statshandler.StatFormatter()) + )) + + res = handler2.calculate(item=self.curveItem, plot=self.plot1d, + onlimits=False) + self.assertTrue('min' in res) + self.assertTrue(res['min'] == '0') + self.assertTrue('max' in res) + self.assertTrue(res['max'] == '19.000') + + handler3 = statshandler.StatsHandler(( + (('amin', numpy.argmin), statshandler.StatFormatter()), + ('amax', numpy.argmax) + )) + + res = handler3.calculate(item=self.curveItem, plot=self.plot1d, + onlimits=False) + self.assertTrue('amin' in res) + self.assertTrue(res['amin'] == '0.000') + self.assertTrue('amax' in res) + self.assertTrue(res['amax'] == '19') + + with self.assertRaises(ValueError): + statshandler.StatsHandler(('name')) + + +class TestStatsWidgetWithCurves(TestCaseQt): + """Basic test for StatsWidget with curves""" + def setUp(self): + TestCaseQt.setUp(self) + self.plot = Plot1D() + x = range(20) + y = range(20) + self.plot.addCurve(x, y, legend='curve0') + y = range(12, 32) + self.plot.addCurve(x, y, legend='curve1') + y = range(-2, 18) + self.plot.addCurve(x, y, legend='curve2') + self.widget = StatsWidget.StatsTable(plot=self.plot) + + mystats = statshandler.StatsHandler(( + stats.StatMin(), + (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + stats.StatMax(), + (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + stats.StatDelta(), + ('std', numpy.std), + ('mean', numpy.mean), + stats.StatCOM() + )) + + self.widget.setStats(mystats) + + def tearDown(self): + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + self.plot = None + TestCaseQt.tearDown(self) + + def testInit(self): + """Make sure all the curves are registred on initialization""" + self.assertTrue(self.widget.rowCount() is 3) + + def testRemoveCurve(self): + """Make sure the Curves stats take into account the curve removal from + plot""" + self.plot.removeCurve('curve2') + self.assertTrue(self.widget.rowCount() is 2) + for iRow in range(2): + self.assertTrue(self.widget.item(iRow, 0).text() in ('curve0', 'curve1')) + + self.plot.removeCurve('curve0') + self.assertTrue(self.widget.rowCount() is 1) + self.plot.removeCurve('curve1') + self.assertTrue(self.widget.rowCount() is 0) + + def testAddCurve(self): + """Make sure the Curves stats take into account the add curve action""" + self.plot.addCurve(legend='curve3', x=range(10), y=range(10)) + self.assertTrue(self.widget.rowCount() is 4) + + def testUpdateCurveFrmAddCurve(self): + """Make sure the stats of the cuve will be removed after updating a + curve""" + self.plot.addCurve(legend='curve0', x=range(10), y=range(10)) + self.assertTrue(self.widget.rowCount() is 3) + itemMax = self.widget._getItem(name='max', legend='curve0', + kind='curve', indexTable=None) + self.assertTrue(itemMax.text() == '9') + + def testUpdateCurveFrmCurveObj(self): + self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) + self.assertTrue(self.widget.rowCount() is 3) + itemMax = self.widget._getItem(name='max', legend='curve0', + kind='curve', indexTable=None) + self.assertTrue(itemMax.text() == '3') + + def testSetAnotherPlot(self): + plot2 = Plot1D() + plot2.addCurve(x=range(26), y=range(26), legend='new curve') + self.widget.setPlot(plot2) + self.assertTrue(self.widget.rowCount() is 1) + self.qapp.processEvents() + plot2.setAttribute(qt.Qt.WA_DeleteOnClose) + plot2.close() + plot2 = None + + +class TestStatsWidgetWithImages(TestCaseQt): + """Basic test for StatsWidget with images""" + def setUp(self): + TestCaseQt.setUp(self) + self.plot = Plot2D() + + self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128), + legend='test image', replace=False) + + self.widget = StatsWidget.StatsTable(plot=self.plot) + + mystats = statshandler.StatsHandler(( + (stats.StatMin(), statshandler.StatFormatter()), + (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + (stats.StatMax(), statshandler.StatFormatter()), + (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + (stats.StatDelta(), statshandler.StatFormatter()), + ('std', numpy.std), + ('mean', numpy.mean), + (stats.StatCOM(), statshandler.StatFormatter(None)) + )) + + self.widget.setStats(mystats) + + def tearDown(self): + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + self.plot = None + TestCaseQt.tearDown(self) + + def test(self): + columnsIndex = self.widget._columns_index + itemLegend = self.widget._lgdAndKindToItems[('test image', 'image')]['legend'] + itemMin = self.widget.item(itemLegend.row(), columnsIndex['min']) + itemMax = self.widget.item(itemLegend.row(), columnsIndex['max']) + itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta']) + itemCoordsMin = self.widget.item(itemLegend.row(), + columnsIndex['coords min']) + itemCoordsMax = self.widget.item(itemLegend.row(), + columnsIndex['coords max']) + max = (128 * 128) - 1 + self.assertTrue(itemMin.text() == '0.000') + self.assertTrue(itemMax.text() == '{0:.3f}'.format(max)) + self.assertTrue(itemDelta.text() == '{0:.3f}'.format(max)) + self.assertTrue(itemCoordsMin.text() == '0.0, 0.0') + self.assertTrue(itemCoordsMax.text() == '127.0, 127.0') + + +class TestStatsWidgetWithScatters(TestCaseQt): + def setUp(self): + TestCaseQt.setUp(self) + self.scatterPlot = Plot2D() + self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60], + [2, 3, 4, 26, 69, 6], + [5, 6, 7, 10, 90, 20], + legend='scatter plot') + self.widget = StatsWidget.StatsTable(plot=self.scatterPlot) + + mystats = statshandler.StatsHandler(( + stats.StatMin(), + (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + stats.StatMax(), + (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), + stats.StatDelta(), + ('std', numpy.std), + ('mean', numpy.mean), + stats.StatCOM() + )) + + self.widget.setStats(mystats) + + def tearDown(self): + self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.scatterPlot.close() + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + self.scatterPlot = None + TestCaseQt.tearDown(self) + + def testStats(self): + columnsIndex = self.widget._columns_index + itemLegend = self.widget._lgdAndKindToItems[('scatter plot', 'scatter')]['legend'] + itemMin = self.widget.item(itemLegend.row(), columnsIndex['min']) + itemMax = self.widget.item(itemLegend.row(), columnsIndex['max']) + itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta']) + itemCoordsMin = self.widget.item(itemLegend.row(), + columnsIndex['coords min']) + itemCoordsMax = self.widget.item(itemLegend.row(), + columnsIndex['coords max']) + self.assertTrue(itemMin.text() == '5') + self.assertTrue(itemMax.text() == '90') + self.assertTrue(itemDelta.text() == '85') + self.assertTrue(itemCoordsMin.text() == '0, 2') + self.assertTrue(itemCoordsMax.text() == '50, 69') + + +class TestEmptyStatsWidget(TestCaseQt): + def test(self): + widget = StatsWidget.StatsWidget() + widget.show() + + +def suite(): + test_suite = unittest.TestSuite() + for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters, + TestStatsWidgetWithImages, TestStatsWidgetWithCurves, + TestStatsFormatter, TestEmptyStatsWidget): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/utils.py b/silx/gui/plot/test/utils.py index ec9bc7c..efba39c 100644 --- a/silx/gui/plot/test/utils.py +++ b/silx/gui/plot/test/utils.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -48,11 +48,12 @@ class PlotWidgetTestCase(TestCaseQt): __screenshot_already_taken = False - def __init__(self, methodName='runTest'): + def __init__(self, methodName='runTest', backend=None): TestCaseQt.__init__(self, methodName=methodName) + self.__backend = backend def _createPlot(self): - return PlotWidget() + return PlotWidget(backend=self.__backend) def setUp(self): super(PlotWidgetTestCase, self).setUp() diff --git a/silx/gui/plot/tools/LimitsToolBar.py b/silx/gui/plot/tools/LimitsToolBar.py new file mode 100644 index 0000000..fc192a6 --- /dev/null +++ b/silx/gui/plot/tools/LimitsToolBar.py @@ -0,0 +1,131 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A toolbar to display and edit limits of a PlotWidget +""" + + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "16/10/2017" + + +from ... import qt +from ...widgets.FloatEdit import FloatEdit + + +class LimitsToolBar(qt.QToolBar): + """QToolBar displaying and controlling the limits of a :class:`PlotWidget`. + + To run the following sample code, a QApplication must be initialized. + First, create a PlotWindow: + + >>> from silx.gui.plot import PlotWindow + >>> plot = PlotWindow() # Create a PlotWindow to add the toolbar to + + Then, create the LimitsToolBar and add it to the PlotWindow. + + >>> from silx.gui import qt + >>> from silx.gui.plot.tools import LimitsToolBar + + >>> toolbar = LimitsToolBar(plot=plot) # Create the toolbar + >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolbar) # Add it to the plot + >>> plot.show() # To display the PlotWindow with the limits toolbar + + :param parent: See :class:`QToolBar`. + :param plot: :class:`PlotWidget` instance on which to operate. + :param str title: See :class:`QToolBar`. + """ + + def __init__(self, parent=None, plot=None, title='Limits'): + super(LimitsToolBar, self).__init__(title, parent) + assert plot is not None + self._plot = plot + self._plot.sigPlotSignal.connect(self._plotWidgetSlot) + + self._initWidgets() + + @property + def plot(self): + """The :class:`PlotWidget` the toolbar is attached to.""" + return self._plot + + def _initWidgets(self): + """Create and init Toolbar widgets.""" + xMin, xMax = self.plot.getXAxis().getLimits() + yMin, yMax = self.plot.getYAxis().getLimits() + + self.addWidget(qt.QLabel('Limits: ')) + self.addWidget(qt.QLabel(' X: ')) + self._xMinFloatEdit = FloatEdit(self, xMin) + self._xMinFloatEdit.editingFinished[()].connect( + self._xFloatEditChanged) + self.addWidget(self._xMinFloatEdit) + + self._xMaxFloatEdit = FloatEdit(self, xMax) + self._xMaxFloatEdit.editingFinished[()].connect( + self._xFloatEditChanged) + self.addWidget(self._xMaxFloatEdit) + + self.addWidget(qt.QLabel(' Y: ')) + self._yMinFloatEdit = FloatEdit(self, yMin) + self._yMinFloatEdit.editingFinished[()].connect( + self._yFloatEditChanged) + self.addWidget(self._yMinFloatEdit) + + self._yMaxFloatEdit = FloatEdit(self, yMax) + self._yMaxFloatEdit.editingFinished[()].connect( + self._yFloatEditChanged) + self.addWidget(self._yMaxFloatEdit) + + def _plotWidgetSlot(self, event): + """Listen to :class:`PlotWidget` events.""" + if event['event'] not in ('limitsChanged',): + return + + xMin, xMax = self.plot.getXAxis().getLimits() + yMin, yMax = self.plot.getYAxis().getLimits() + + self._xMinFloatEdit.setValue(xMin) + self._xMaxFloatEdit.setValue(xMax) + self._yMinFloatEdit.setValue(yMin) + self._yMaxFloatEdit.setValue(yMax) + + def _xFloatEditChanged(self): + """Handle X limits changed from the GUI.""" + xMin, xMax = self._xMinFloatEdit.value(), self._xMaxFloatEdit.value() + if xMax < xMin: + xMin, xMax = xMax, xMin + + self.plot.getXAxis().setLimits(xMin, xMax) + + def _yFloatEditChanged(self): + """Handle Y limits changed from the GUI.""" + yMin, yMax = self._yMinFloatEdit.value(), self._yMaxFloatEdit.value() + if yMax < yMin: + yMin, yMax = yMax, yMin + + self.plot.getYAxis().setLimits(yMin, yMax) diff --git a/silx/gui/plot/tools/PositionInfo.py b/silx/gui/plot/tools/PositionInfo.py new file mode 100644 index 0000000..83b61bd --- /dev/null +++ b/silx/gui/plot/tools/PositionInfo.py @@ -0,0 +1,347 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a widget displaying mouse coordinates in a PlotWidget. + +It can be configured to provide more information. +""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "16/10/2017" + + +import logging +import numbers +import traceback +import weakref + +import numpy + +from ....utils.deprecation import deprecated +from ... import qt +from .. import items + + +_logger = logging.getLogger(__name__) + + +# PositionInfo ################################################################ + +class PositionInfo(qt.QWidget): + """QWidget displaying coords converted from data coords of the mouse. + + Provide this widget with a list of couple: + + - A name to display before the data + - A function that takes (x, y) as arguments and returns something that + gets converted to a string. + If the result is a float it is converted with '%.7g' format. + + To run the following sample code, a QApplication must be initialized. + First, create a PlotWindow and add a QToolBar where to place the + PositionInfo widget. + + >>> from silx.gui.plot import PlotWindow + >>> from silx.gui import qt + + >>> plot = PlotWindow() # Create a PlotWindow to add the widget to + >>> toolBar = qt.QToolBar() # Create a toolbar to place the widget in + >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) # Add it to plot + + Then, create the PositionInfo widget and add it to the toolbar. + The PositionInfo widget is created with a list of converters, here + to display polar coordinates of the mouse position. + + >>> import numpy + >>> from silx.gui.plot.tools import PositionInfo + + >>> position = PositionInfo(plot=plot, converters=[ + ... ('Radius', lambda x, y: numpy.sqrt(x*x + y*y)), + ... ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))]) + >>> toolBar.addWidget(position) # Add the widget to the toolbar + <...> + >>> plot.show() # To display the PlotWindow with the position widget + + :param plot: The PlotWidget this widget is displaying data coords from. + :param converters: + List of 2-tuple: name to display and conversion function from (x, y) + in data coords to displayed value. + If None, the default, it displays X and Y. + :param parent: Parent widget + """ + + SNAP_THRESHOLD_DIST = 5 + + def __init__(self, parent=None, plot=None, converters=None): + assert plot is not None + self._plotRef = weakref.ref(plot) + self._snappingMode = self.SNAPPING_DISABLED + + super(PositionInfo, self).__init__(parent) + + if converters is None: + converters = (('X', lambda x, y: x), ('Y', lambda x, y: y)) + + self._fields = [] # To store (QLineEdit, name, function (x, y)->v) + + # Create a new layout with new widgets + layout = qt.QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + # layout.setSpacing(0) + + # Create all QLabel and store them with the corresponding converter + for name, func in converters: + layout.addWidget(qt.QLabel('<b>' + name + ':</b>')) + + contentWidget = qt.QLabel() + contentWidget.setText('------') + contentWidget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse) + contentWidget.setFixedWidth( + contentWidget.fontMetrics().width('##############')) + layout.addWidget(contentWidget) + self._fields.append((contentWidget, name, func)) + + layout.addStretch(1) + self.setLayout(layout) + + # Connect to Plot events + plot.sigPlotSignal.connect(self._plotEvent) + + def getPlotWidget(self): + """Returns the PlotWidget this widget is attached to or None. + + :rtype: Union[~silx.gui.plot.PlotWidget,None] + """ + return self._plotRef() + + @property + @deprecated(replacement='getPlotWidget', since_version='0.8.0') + def plot(self): + return self.getPlotWidget() + + def getConverters(self): + """Return the list of converters as 2-tuple (name, function).""" + return [(name, func) for _label, name, func in self._fields] + + def _plotEvent(self, event): + """Handle events from the Plot. + + :param dict event: Plot event + """ + if event['event'] == 'mouseMoved': + x, y = event['x'], event['y'] + xPixel, yPixel = event['xpixel'], event['ypixel'] + self._updateStatusBar(x, y, xPixel, yPixel) + + def updateInfo(self): + """Update displayed information""" + plot = self.getPlotWidget() + if plot is None: + _logger.error("Trying to update PositionInfo " + "while PlotWidget no longer exists") + return + + widget = plot.getWidgetHandle() + position = widget.mapFromGlobal(qt.QCursor.pos()) + xPixel, yPixel = position.x(), position.y() + dataPos = plot.pixelToData(xPixel, yPixel, check=True) + if dataPos is not None: # Inside plot area + x, y = dataPos + self._updateStatusBar(x, y, xPixel, yPixel) + + def _updateStatusBar(self, x, y, xPixel, yPixel): + """Update information from the status bar using the definitions. + + :param float x: Position-x in data + :param float y: Position-y in data + :param float xPixel: Position-x in pixels + :param float yPixel: Position-y in pixels + """ + plot = self.getPlotWidget() + if plot is None: + return + + styleSheet = "color: rgb(0, 0, 0);" # Default style + xData, yData = x, y + + snappingMode = self.getSnappingMode() + + # Snapping when crosshair either not requested or active + if (snappingMode & (self.SNAPPING_CURVE | self.SNAPPING_SCATTER) and + (not (snappingMode & self.SNAPPING_CROSSHAIR) or + plot.getGraphCursor())): + styleSheet = "color: rgb(255, 0, 0);" # Style far from item + + if snappingMode & self.SNAPPING_ACTIVE_ONLY: + selectedItems = [] + + if snappingMode & self.SNAPPING_CURVE: + activeCurve = plot.getActiveCurve() + if activeCurve: + selectedItems.append(activeCurve) + + if snappingMode & self.SNAPPING_SCATTER: + activeScatter = plot._getActiveItem(kind='scatter') + if activeScatter: + selectedItems.append(activeScatter) + + else: + kinds = [] + if snappingMode & self.SNAPPING_CURVE: + kinds.append('curve') + if snappingMode & self.SNAPPING_SCATTER: + kinds.append('scatter') + selectedItems = plot._getItems(kind=kinds) + + # Compute distance threshold + if qt.BINDING in ('PyQt5', 'PySide2'): + window = plot.window() + windowHandle = window.windowHandle() + if windowHandle is not None: + ratio = windowHandle.devicePixelRatio() + else: + ratio = qt.QGuiApplication.primaryScreen().devicePixelRatio() + else: + ratio = 1. + + # Baseline squared distance threshold + distInPixels = (self.SNAP_THRESHOLD_DIST * ratio)**2 + + for item in selectedItems: + if (snappingMode & self.SNAPPING_SYMBOLS_ONLY and + not item.getSymbol()): + # Only handled if item symbols are visible + continue + + xArray = item.getXData(copy=False) + yArray = item.getYData(copy=False) + closestIndex = numpy.argmin( + pow(xArray - x, 2) + pow(yArray - y, 2)) + + xClosest = xArray[closestIndex] + yClosest = yArray[closestIndex] + + if isinstance(item, items.YAxisMixIn): + axis = item.getYAxis() + else: + axis = 'left' + + closestInPixels = plot.dataToPixel( + xClosest, yClosest, axis=axis) + if closestInPixels is not None: + curveDistInPixels = ( + (closestInPixels[0] - xPixel)**2 + + (closestInPixels[1] - yPixel)**2) + + if curveDistInPixels <= distInPixels: + # Update label style sheet + styleSheet = "color: rgb(0, 0, 0);" + + # if close enough, snap to data point coord + xData, yData = xClosest, yClosest + distInPixels = curveDistInPixels + + for label, name, func in self._fields: + label.setStyleSheet(styleSheet) + + try: + value = func(xData, yData) + text = self.valueToString(value) + label.setText(text) + except: + label.setText('Error') + _logger.error( + "Error while converting coordinates (%f, %f)" + "with converter '%s'" % (xPixel, yPixel, name)) + _logger.error(traceback.format_exc()) + + def valueToString(self, value): + if isinstance(value, (tuple, list)): + value = [self.valueToString(v) for v in value] + return ", ".join(value) + elif isinstance(value, numbers.Real): + # Use this for floats and int + return '%.7g' % value + else: + # Fallback for other types + return str(value) + + # Snapping mode + + SNAPPING_DISABLED = 0 + """No snapping occurs""" + + SNAPPING_CROSSHAIR = 1 << 0 + """Snapping only enabled when crosshair cursor is enabled""" + + SNAPPING_ACTIVE_ONLY = 1 << 1 + """Snapping only enabled for active item""" + + SNAPPING_SYMBOLS_ONLY = 1 << 2 + """Snapping only when symbols are visible""" + + SNAPPING_CURVE = 1 << 3 + """Snapping on curves""" + + SNAPPING_SCATTER = 1 << 4 + """Snapping on scatter""" + + def setSnappingMode(self, mode): + """Set the snapping mode. + + The mode is a mask. + + :param int mode: The mode to use + """ + if mode != self._snappingMode: + self._snappingMode = mode + self.updateInfo() + + def getSnappingMode(self): + """Returns the snapping mode as a mask + + :rtype: int + """ + return self._snappingMode + + _SNAPPING_LEGACY = (SNAPPING_CROSSHAIR | + SNAPPING_ACTIVE_ONLY | + SNAPPING_SYMBOLS_ONLY | + SNAPPING_CURVE | + SNAPPING_SCATTER) + """Legacy snapping mode""" + + @property + @deprecated(replacement="getSnappingMode", since_version="0.8") + def autoSnapToActiveCurve(self): + return self.getSnappingMode() == self._SNAPPING_LEGACY + + @autoSnapToActiveCurve.setter + @deprecated(replacement="setSnappingMode", since_version="0.8") + def autoSnapToActiveCurve(self, flag): + self.setSnappingMode( + self._SNAPPING_LEGACY if flag else self.SNAPPING_DISABLED) diff --git a/silx/gui/plot/tools/__init__.py b/silx/gui/plot/tools/__init__.py new file mode 100644 index 0000000..09f468c --- /dev/null +++ b/silx/gui/plot/tools/__init__.py @@ -0,0 +1,50 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a set of widgets working with :class:`PlotWidget`. + +It provides some QToolBar and QWidget: + +- :class:`InteractiveModeToolBar` +- :class:`OutputToolBar` +- :class:`ImageToolBar` +- :class:`CurveToolBar` +- :class:`LimitsToolBar` +- :class:`PositionInfo` + +It also provides a :mod:`~silx.gui.plot.tools.roi` module to handle +interactive region of interest on a :class:`~silx.gui.plot.PlotWidget`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "01/03/2018" + + +from .toolbars import InteractiveModeToolBar # noqa +from .toolbars import OutputToolBar # noqa +from .toolbars import ImageToolBar, CurveToolBar, ScatterToolBar # noqa + +from .LimitsToolBar import LimitsToolBar # noqa +from .PositionInfo import PositionInfo # noqa diff --git a/silx/gui/plot/tools/profile/ImageProfileToolBar.py b/silx/gui/plot/tools/profile/ImageProfileToolBar.py new file mode 100644 index 0000000..207a2e2 --- /dev/null +++ b/silx/gui/plot/tools/profile/ImageProfileToolBar.py @@ -0,0 +1,271 @@ +# TODO quick & dirty proof of concept + +import numpy + +from silx.gui.plot.tools.profile.ScatterProfileToolBar import _BaseProfileToolBar +from .. import items +from ...colors import cursorColorForColormap +from ....image.bilinear import BilinearImage + + +def _alignedPartialProfile(data, rowRange, colRange, axis): + """Mean of a rectangular region (ROI) of a stack of images + along a given axis. + + Returned values and all parameters are in image coordinates. + + :param numpy.ndarray data: 3D volume (stack of 2D images) + The first dimension is the image index. + :param rowRange: [min, max[ of ROI rows (upper bound excluded). + :type rowRange: 2-tuple of int (min, max) with min < max + :param colRange: [min, max[ of ROI columns (upper bound excluded). + :type colRange: 2-tuple of int (min, max) with min < max + :param int axis: The axis along which to take the profile of the ROI. + 0: Sum rows along columns. + 1: Sum columns along rows. + :return: Profile image along the ROI as the mean of the intersection + of the ROI and the image. + """ + assert axis in (0, 1) + assert len(data.shape) == 3 + assert rowRange[0] < rowRange[1] + assert colRange[0] < colRange[1] + + nimages, height, width = data.shape + + # Range aligned with the integration direction + profileRange = colRange if axis == 0 else rowRange + + profileLength = abs(profileRange[1] - profileRange[0]) + + # Subset of the image to use as intersection of ROI and image + rowStart = min(max(0, rowRange[0]), height) + rowEnd = min(max(0, rowRange[1]), height) + colStart = min(max(0, colRange[0]), width) + colEnd = min(max(0, colRange[1]), width) + + imgProfile = numpy.mean(data[:, rowStart:rowEnd, colStart:colEnd], + axis=axis + 1, dtype=numpy.float32) + + # Profile including out of bound area + profile = numpy.zeros((nimages, profileLength), dtype=numpy.float32) + + # Place imgProfile in full profile + offset = - min(0, profileRange[0]) + profile[:, offset:offset + imgProfile.shape[1]] = imgProfile + + return profile + + +def createProfile(points, data, origin, scale, lineWidth): + """Create the profile line for the the given image. + + :param points: Coords of profile end points: (x0, y0, x1, y1) + :param numpy.ndarray data: the 2D image or the 3D stack of images + on which we compute the profile. + :param origin: (ox, oy) the offset from origin + :type origin: 2-tuple of float + :param scale: (sx, sy) the scale to use + :type scale: 2-tuple of float + :param int lineWidth: width of the profile line + :return: `profile, area`, where: + - profile is a 2D array of the profiles of the stack of images. + For a single image, the profile is a curve, so this parameter + has a shape *(1, len(curve))* + - area is a tuple of two 1D arrays with 4 values each. They represent + the effective ROI area corners in plot coords. + + :rtype: tuple(ndarray, (ndarray, ndarray), str, str) + """ + if data is None or points is None or lineWidth is None: + raise ValueError("createProfile called with invalid arguments") + + # force 3D data (stack of images) + if len(data.shape) == 2: + data3D = data.reshape((1,) + data.shape) + elif len(data.shape) == 3: + data3D = data + + roiWidth = max(1, lineWidth) + x0, y0, x1, y1 = points + + # Convert start and end points in image coords as (row, col) + startPt = ((y0 - origin[1]) / scale[1], + (x0 - origin[0]) / scale[0]) + endPt = ((y1 - origin[1]) / scale[1], + (x1 - origin[0]) / scale[0]) + + if (int(startPt[0]) == int(endPt[0]) or + int(startPt[1]) == int(endPt[1])): + # Profile is aligned with one of the axes + + # Convert to int + startPt = int(startPt[0]), int(startPt[1]) + endPt = int(endPt[0]), int(endPt[1]) + + # Ensure startPt <= endPt + if startPt[0] > endPt[0] or startPt[1] > endPt[1]: + startPt, endPt = endPt, startPt + + if startPt[0] == endPt[0]: # Row aligned + rowRange = (int(startPt[0] + 0.5 - 0.5 * roiWidth), + int(startPt[0] + 0.5 + 0.5 * roiWidth)) + colRange = startPt[1], endPt[1] + 1 + profile = _alignedPartialProfile(data3D, + rowRange, colRange, + axis=0) + + else: # Column aligned + rowRange = startPt[0], endPt[0] + 1 + colRange = (int(startPt[1] + 0.5 - 0.5 * roiWidth), + int(startPt[1] + 0.5 + 0.5 * roiWidth)) + profile = _alignedPartialProfile(data3D, + rowRange, colRange, + axis=1) + + # Convert ranges to plot coords to draw ROI area + area = ( + numpy.array( + (colRange[0], colRange[1], colRange[1], colRange[0]), + dtype=numpy.float32) * scale[0] + origin[0], + numpy.array( + (rowRange[0], rowRange[0], rowRange[1], rowRange[1]), + dtype=numpy.float32) * scale[1] + origin[1]) + + else: # General case: use bilinear interpolation + + # Ensure startPt <= endPt + if (startPt[1] > endPt[1] or ( + startPt[1] == endPt[1] and startPt[0] > endPt[0])): + startPt, endPt = endPt, startPt + + profile = [] + for slice_idx in range(data3D.shape[0]): + bilinear = BilinearImage(data3D[slice_idx, :, :]) + + profile.append(bilinear.profile_line( + (startPt[0] - 0.5, startPt[1] - 0.5), + (endPt[0] - 0.5, endPt[1] - 0.5), + roiWidth)) + profile = numpy.array(profile) + + # Extend ROI with half a pixel on each end, and + # Convert back to plot coords (x, y) + length = numpy.sqrt((endPt[0] - startPt[0]) ** 2 + + (endPt[1] - startPt[1]) ** 2) + dRow = (endPt[0] - startPt[0]) / length + dCol = (endPt[1] - startPt[1]) / length + + # Extend ROI with half a pixel on each end + startPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol + endPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol + + # Rotate deltas by 90 degrees to apply line width + dRow, dCol = dCol, -dRow + + area = ( + numpy.array((startPt[1] - 0.5 * roiWidth * dCol, + startPt[1] + 0.5 * roiWidth * dCol, + endPt[1] + 0.5 * roiWidth * dCol, + endPt[1] - 0.5 * roiWidth * dCol), + dtype=numpy.float32) * scale[0] + origin[0], + numpy.array((startPt[0] - 0.5 * roiWidth * dRow, + startPt[0] + 0.5 * roiWidth * dRow, + endPt[0] + 0.5 * roiWidth * dRow, + endPt[0] - 0.5 * roiWidth * dRow), + dtype=numpy.float32) * scale[1] + origin[1]) + + xProfile = numpy.arange(len(profile[0]), dtype=numpy.float64) + + return (xProfile, profile[0]), area + + +class ImageProfileToolBar(_BaseProfileToolBar): + + def __init__(self, parent=None, plot=None, title='Image Profile'): + super(ImageProfileToolBar, self).__init__(parent, plot, title) + plot.sigActiveImageChanged.connect(self.__activeImageChanged) + + roiManager = self._getRoiManager() + if roiManager is None: + _logger.error( + "Error during scatter profile toolbar initialisation") + else: + roiManager.sigInteractiveModeStarted.connect( + self.__interactionStarted) + roiManager.sigInteractiveModeFinished.connect( + self.__interactionFinished) + if roiManager.isStarted(): + self.__interactionStarted(roiManager.getRegionOfInterestKind()) + + def __interactionStarted(self, kind): + """Handle start of ROI interaction""" + plot = self.getPlotWidget() + if plot is None: + return + + plot.sigActiveImageChanged.connect(self.__activeImageChanged) + + image = plot.getActiveImage() + legend = None if image is None else image.getLegend() + self.__activeImageChanged(None, legend) + + def __interactionFinished(self, rois): + """Handle end of ROI interaction""" + plot = self.getPlotWidget() + if plot is None: + return + + plot.sigActiveImageChanged.disconnect(self.__activeImageChanged) + + image = plot.getActiveImage() + legend = None if image is None else image.getLegend() + self.__activeImageChanged(legend, None) + + def __activeImageChanged(self, previous, legend): + """Handle active image change: toggle enabled toolbar, update curve""" + plot = self.getPlotWidget() + if plot is None: + return + + activeImage = plot.getActiveImage() + if activeImage is None: + self.setEnabled(False) + else: + # Disable for empty image + self.setEnabled(activeImage.getData(copy=False).size > 0) + + # Update default profile color + if isinstance(activeImage, items.ColormapMixIn): + self.setColor(cursorColorForColormap( + activeImage.getColormap()['name'])) # TODO change thsi + else: + self.setColor('black') + + self.updateProfile() + + def computeProfile(self, x0, y0, x1, y1): + """Compute corresponding profile + + :param float x0: Profile start point X coord + :param float y0: Profile start point Y coord + :param float x1: Profile end point X coord + :param float y1: Profile end point Y coord + :return: (x, y) profile data or None + """ + plot = self.getPlotWidget() + if plot is None: + return None + + image = plot.getActiveImage() + if image is None: + return None + + profile, area = createProfile( + points=(x0, y0, x1, y1), + data=image.getData(copy=False), + origin=image.getOrigin(), + scale=image.getScale(), + lineWidth=1) # TODO + + return profile
\ No newline at end of file diff --git a/silx/gui/plot/tools/profile/ScatterProfileToolBar.py b/silx/gui/plot/tools/profile/ScatterProfileToolBar.py new file mode 100644 index 0000000..fd21515 --- /dev/null +++ b/silx/gui/plot/tools/profile/ScatterProfileToolBar.py @@ -0,0 +1,431 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module profile tools for scatter plots. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import logging +import threading +import time + +import numpy + +try: + from scipy.interpolate import LinearNDInterpolator +except ImportError: + LinearNDInterpolator = None + + # Fallback using local Delaunay and matplotlib interpolator + from silx.third_party.scipy_spatial import Delaunay + import matplotlib.tri + +from ._BaseProfileToolBar import _BaseProfileToolBar +from .... import qt +from ... import items + + +_logger = logging.getLogger(__name__) + + +# TODO support log scale + + +class _InterpolatorInitThread(qt.QThread): + """Thread building a scatter interpolator + + This works in greedy mode in that the signal is only emitted + when no other request is pending + """ + + sigInterpolatorReady = qt.Signal(object) + """Signal emitted whenever an interpolator is ready + + It provides a 3-tuple (points, values, interpolator) + """ + + _RUNNING_THREADS_TO_DELETE = [] + """Store reference of no more used threads but still running""" + + def __init__(self): + super(_InterpolatorInitThread, self).__init__() + self._lock = threading.RLock() + self._pendingData = None + self._firstFallbackRun = True + + def discard(self, obj=None): + """Wait for pending thread to complete and delete then + + Connect this to the destroyed signal of widget using this thread + """ + if self.isRunning(): + self.cancel() + self._RUNNING_THREADS_TO_DELETE.append(self) # Keep a reference + self.finished.connect(self.__finished) + + def __finished(self): + """Handle finished signal of threads to delete""" + try: + self._RUNNING_THREADS_TO_DELETE.remove(self) + except ValueError: + _logger.warning('Finished thread no longer in reference list') + + def request(self, points, values): + """Request new initialisation of interpolator + + :param numpy.ndarray points: Point coordinates (N, D) + :param numpy.ndarray values: Values the N points (1D array) + """ + with self._lock: + # Possibly replace already pending data + self._pendingData = points, values + + if not self.isRunning(): + self.start() + + def cancel(self): + """Cancel any running/pending requests""" + with self._lock: + self._pendingData = 'cancelled' + + def run(self): + """Run the init of the scatter interpolator""" + if LinearNDInterpolator is None: + self.run_matplotlib() + else: + self.run_scipy() + + def run_matplotlib(self): + """Run the init of the scatter interpolator""" + if self._firstFallbackRun: + self._firstFallbackRun = False + _logger.warning( + "scipy.spatial.LinearNDInterpolator not available: " + "Scatter plot interpolator initialisation can freeze the GUI.") + + while True: + with self._lock: + data = self._pendingData + self._pendingData = None + + if data in (None, 'cancelled'): + return + + points, values = data + + startTime = time.time() + try: + delaunay = Delaunay(points) + except: + _logger.warning( + "Cannot triangulate scatter data") + else: + with self._lock: + data = self._pendingData + + if data is not None: # Break point + _logger.info('Interpolator discarded after %f s', + time.time() - startTime) + else: + + x, y = points.T + triangulation = matplotlib.tri.Triangulation( + x, y, triangles=delaunay.simplices) + + interpolator = matplotlib.tri.LinearTriInterpolator( + triangulation, values) + + with self._lock: + data = self._pendingData + + if data is not None: + _logger.info('Interpolator discarded after %f s', + time.time() - startTime) + else: + # No other processing requested: emit the signal + _logger.info("Interpolator initialised in %f s", + time.time() - startTime) + + # Wrap interpolator to have same API as scipy's one + def wrapper(points): + return interpolator(*points.T) + + self.sigInterpolatorReady.emit( + (points, values, wrapper)) + + def run_scipy(self): + """Run the init of the scatter interpolator""" + while True: + with self._lock: + data = self._pendingData + self._pendingData = None + + if data in (None, 'cancelled'): + return + + points, values = data + + startTime = time.time() + try: + interpolator = LinearNDInterpolator(points, values) + except: + _logger.warning( + "Cannot initialise scatter profile interpolator") + else: + with self._lock: + data = self._pendingData + + if data is not None: # Break point + _logger.info('Interpolator discarded after %f s', + time.time() - startTime) + else: + # First call takes a while, do it here + interpolator([(0., 0.)]) + + with self._lock: + data = self._pendingData + + if data is not None: + _logger.info('Interpolator discarded after %f s', + time.time() - startTime) + else: + # No other processing requested: emit the signal + _logger.info("Interpolator initialised in %f s", + time.time() - startTime) + self.sigInterpolatorReady.emit( + (points, values, interpolator)) + + +class ScatterProfileToolBar(_BaseProfileToolBar): + """QToolBar providing scatter plot profiling tools + + :param parent: See :class:`QToolBar`. + :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate. + :param str title: See :class:`QToolBar`. + """ + + def __init__(self, parent=None, plot=None, title='Scatter Profile'): + super(ScatterProfileToolBar, self).__init__(parent, plot, title) + + self.__nPoints = 1024 + self.__interpolator = None + self.__interpolatorCache = None # points, values, interpolator + + self.__initThread = _InterpolatorInitThread() + self.destroyed.connect(self.__initThread.discard) + self.__initThread.sigInterpolatorReady.connect( + self.__interpolatorReady) + + roiManager = self._getRoiManager() + if roiManager is None: + _logger.error( + "Error during scatter profile toolbar initialisation") + else: + roiManager.sigInteractiveModeStarted.connect( + self.__interactionStarted) + roiManager.sigInteractiveModeFinished.connect( + self.__interactionFinished) + if roiManager.isStarted(): + self.__interactionStarted(roiManager.getCurrentInteractionModeRoiClass()) + + def __interactionStarted(self, roiClass): + """Handle start of ROI interaction""" + plot = self.getPlotWidget() + if plot is None: + return + + plot.sigActiveScatterChanged.connect(self.__activeScatterChanged) + + scatter = plot._getActiveItem(kind='scatter') + legend = None if scatter is None else scatter.getLegend() + self.__activeScatterChanged(None, legend) + + def __interactionFinished(self): + """Handle end of ROI interaction""" + plot = self.getPlotWidget() + if plot is None: + return + + plot.sigActiveScatterChanged.disconnect(self.__activeScatterChanged) + + scatter = plot._getActiveItem(kind='scatter') + legend = None if scatter is None else scatter.getLegend() + self.__activeScatterChanged(legend, None) + + def __activeScatterChanged(self, previous, legend): + """Handle change of active scatter + + :param Union[str,None] previous: + :param Union[str,None] legend: + """ + self.__initThread.cancel() + + # Reset interpolator + self.__interpolator = None + + plot = self.getPlotWidget() + if plot is None: + _logger.error("Associated PlotWidget no longer exists") + + else: + if previous is not None: # Disconnect signal + scatter = plot.getScatter(previous) + if scatter is not None: + scatter.sigItemChanged.disconnect( + self.__scatterItemChanged) + + if legend is not None: + scatter = plot.getScatter(legend) + if scatter is None: + _logger.error("Cannot retrieve active scatter") + + else: + scatter.sigItemChanged.connect(self.__scatterItemChanged) + points = numpy.transpose(numpy.array(( + scatter.getXData(copy=False), + scatter.getYData(copy=False)))) + values = scatter.getValueData(copy=False) + + self.__updateInterpolator(points, values) + + # Refresh profile + self.updateProfile() + + def __scatterItemChanged(self, event): + """Handle update of active scatter plot item + + :param ItemChangedType event: + """ + if event == items.ItemChangedType.DATA: + self.__interpolator = None + scatter = self.sender() + if scatter is None: + _logger.error("Cannot retrieve updated scatter item") + + else: + points = numpy.transpose(numpy.array(( + scatter.getXData(copy=False), + scatter.getYData(copy=False)))) + values = scatter.getValueData(copy=False) + + self.__updateInterpolator(points, values) + + # Handle interpolator init thread + + def __updateInterpolator(self, points, values): + """Update used interpolator with new data""" + if (self.__interpolatorCache is not None and + len(points) == len(self.__interpolatorCache[0]) and + numpy.all(numpy.equal(self.__interpolatorCache[0], points)) and + numpy.all(numpy.equal(self.__interpolatorCache[1], values))): + # Reuse previous interpolator + _logger.info( + 'Scatter changed: Reuse previous interpolator') + self.__interpolator = self.__interpolatorCache[2] + + else: + # Interpolator needs update: Start background processing + _logger.info( + 'Scatter changed: Rebuild interpolator') + self.__interpolator = None + self.__interpolatorCache = None + self.__initThread.request(points, values) + + def __interpolatorReady(self, data): + """Handle end of init interpolator thread""" + points, values, interpolator = data + self.__interpolator = interpolator + self.__interpolatorCache = None if interpolator is None else data + self.updateProfile() + + def hasPendingOperations(self): + return self.__initThread.isRunning() + + # Number of points + + def getNPoints(self): + """Returns the number of points of the profiles + + :rtype: int + """ + return self.__nPoints + + def setNPoints(self, npoints): + """Set the number of points of the profiles + + :param int npoints: + """ + npoints = int(npoints) + if npoints < 1: + raise ValueError("Unsupported number of points: %d" % npoints) + else: + self.__nPoints = npoints + + # Overridden methods + + def computeProfileTitle(self, x0, y0, x1, y1): + """Compute corresponding plot title + + :param float x0: Profile start point X coord + :param float y0: Profile start point Y coord + :param float x1: Profile end point X coord + :param float y1: Profile end point Y coord + :return: Title to use + :rtype: str + """ + if self.hasPendingOperations(): + return 'Pre-processing data...' + + else: + return super(ScatterProfileToolBar, self).computeProfileTitle( + x0, y0, x1, y1) + + def computeProfile(self, x0, y0, x1, y1): + """Compute corresponding profile + + :param float x0: Profile start point X coord + :param float y0: Profile start point Y coord + :param float x1: Profile end point X coord + :param float y1: Profile end point Y coord + :return: (points, values) profile data or None + """ + if self.__interpolator is None: + return None + + nPoints = self.getNPoints() + + points = numpy.transpose(( + numpy.linspace(x0, x1, nPoints, endpoint=True), + numpy.linspace(y0, y1, nPoints, endpoint=True))) + + values = self.__interpolator(points) + + if not numpy.any(numpy.isfinite(values)): + return None # Profile outside convex hull + + return points, values diff --git a/silx/gui/plot/tools/profile/_BaseProfileToolBar.py b/silx/gui/plot/tools/profile/_BaseProfileToolBar.py new file mode 100644 index 0000000..6d9d6d4 --- /dev/null +++ b/silx/gui/plot/tools/profile/_BaseProfileToolBar.py @@ -0,0 +1,430 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the base class for profile toolbars.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import logging +import weakref + +import numpy + +from silx.utils.weakref import WeakMethodProxy +from silx.gui import qt, icons, colors +from silx.gui.plot import PlotWidget, items +from silx.gui.plot.ProfileMainWindow import ProfileMainWindow +from silx.gui.plot.tools.roi import RegionOfInterestManager +from silx.gui.plot.items import roi as roi_items + + +_logger = logging.getLogger(__name__) + + +class _BaseProfileToolBar(qt.QToolBar): + """Base class for QToolBar plot profiling tools + + :param parent: See :class:`QToolBar`. + :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate. + :param str title: See :class:`QToolBar`. + """ + + sigProfileChanged = qt.Signal() + """Signal emitted when the profile has changed""" + + def __init__(self, parent=None, plot=None, title=''): + super(_BaseProfileToolBar, self).__init__(title, parent) + + self.__profile = None + self.__profileTitle = '' + + assert isinstance(plot, PlotWidget) + self._plotRef = weakref.ref( + plot, WeakMethodProxy(self.__plotDestroyed)) + + self._profileWindow = None + + # Set-up interaction manager + roiManager = RegionOfInterestManager(plot) + self._roiManagerRef = weakref.ref(roiManager) + + roiManager.sigInteractiveModeFinished.connect(self.__interactionFinished) + roiManager.sigRoiChanged.connect(self.updateProfile) + roiManager.sigRoiAdded.connect(self.__roiAdded) + + # Add interactive mode actions + for kind, icon, tooltip in ( + (roi_items.HorizontalLineROI, 'shape-horizontal', + 'Enables horizontal line profile selection mode'), + (roi_items.VerticalLineROI, 'shape-vertical', + 'Enables vertical line profile selection mode'), + (roi_items.LineROI, 'shape-diagonal', + 'Enables line profile selection mode')): + action = roiManager.getInteractionModeAction(kind) + action.setIcon(icons.getQIcon(icon)) + action.setToolTip(tooltip) + self.addAction(action) + + # Add clear action + action = qt.QAction(icons.getQIcon('profile-clear'), + 'Clear Profile', self) + action.setToolTip('Clear the profile') + action.setCheckable(False) + action.triggered.connect(self.clearProfile) + self.addAction(action) + + # Initialize color + self._color = None + self.setColor('red') + + # Listen to plot limits changed + plot.getXAxis().sigLimitsChanged.connect(self.updateProfile) + plot.getYAxis().sigLimitsChanged.connect(self.updateProfile) + + # Listen to plot scale + plot.getXAxis().sigScaleChanged.connect(self.__plotAxisScaleChanged) + plot.getYAxis().sigScaleChanged.connect(self.__plotAxisScaleChanged) + + self.setDefaultProfileWindowEnabled(True) + + def getProfilePoints(self, copy=True): + """Returns the profile sampling points as (x, y) or None + + :param bool copy: True to get a copy, + False to get internal arrays (do not modify) + :rtype: Union[numpy.ndarray,None] + """ + if self.__profile is None: + return None + else: + return numpy.array(self.__profile[0], copy=copy) + + def getProfileValues(self, copy=True): + """Returns the values of the profile or None + + :param bool copy: True to get a copy, + False to get internal arrays (do not modify) + :rtype: Union[numpy.ndarray,None] + """ + if self.__profile is None: + return None + else: + return numpy.array(self.__profile[1], copy=copy) + + def getProfileTitle(self): + """Returns the profile title + + :rtype: str + """ + return self.__profileTitle + + # Handle plot reference + + def __plotDestroyed(self, ref): + """Handle finalization of PlotWidget + + :param ref: weakref to the plot + """ + self._plotRef = None + self.setEnabled(False) # Profile is pointless + for action in self.actions(): # TODO useful? + self.removeAction(action) + + def getPlotWidget(self): + """The :class:`~silx.gui.plot.PlotWidget` associated to the toolbar. + + :rtype: Union[~silx.gui.plot.PlotWidget,None] + """ + return None if self._plotRef is None else self._plotRef() + + def _getRoiManager(self): + """Returns the used ROI manager + + :rtype: RegionOfInterestManager + """ + return self._roiManagerRef() + + # Profile Plot + + def isDefaultProfileWindowEnabled(self): + """Returns True if the default floating profile window is used + + :rtype: bool + """ + return self.getDefaultProfileWindow() is not None + + def setDefaultProfileWindowEnabled(self, enabled): + """Set whether to use or not the default floating profile window. + + :param bool enabled: True to use, False to disable + """ + if self.isDefaultProfileWindowEnabled() != enabled: + if enabled: + self._profileWindow = ProfileMainWindow(self) + self._profileWindow.sigClose.connect(self.clearProfile) + self.sigProfileChanged.connect(self.__updateDefaultProfilePlot) + + else: + self.sigProfileChanged.disconnect(self.__updateDefaultProfilePlot) + self._profileWindow.sigClose.disconnect(self.clearProfile) + self._profileWindow.close() + self._profileWindow = None + + def getDefaultProfileWindow(self): + """Returns the default floating profile window if in use else None. + + See :meth:`isDefaultProfileWindowEnabled` + + :rtype: Union[ProfileMainWindow,None] + """ + return self._profileWindow + + def __updateDefaultProfilePlot(self): + """Update the plot of the default profile window""" + profileWindow = self.getDefaultProfileWindow() + if profileWindow is None: + return + + profilePlot = profileWindow.getPlot() + if profilePlot is None: + return + + profilePlot.clear() + profilePlot.setGraphTitle(self.getProfileTitle()) + + points = self.getProfilePoints(copy=False) + values = self.getProfileValues(copy=False) + + if points is not None and values is not None: + if (numpy.abs(points[-1, 0] - points[0, 0]) > + numpy.abs(points[-1, 1] - points[0, 1])): + xProfile = points[:, 0] + profilePlot.getXAxis().setLabel('X') + else: + xProfile = points[:, 1] + profilePlot.getXAxis().setLabel('Y') + + profilePlot.addCurve( + xProfile, values, legend='Profile', color=self._color) + + self._showDefaultProfileWindow() + + def _showDefaultProfileWindow(self): + """If profile window was created by this toolbar, + try to avoid overlapping with the toolbar's parent window. + """ + profileWindow = self.getDefaultProfileWindow() + roiManager = self._getRoiManager() + if profileWindow is None or roiManager is None: + return + + if roiManager.isStarted() and not profileWindow.isVisible(): + profileWindow.show() + profileWindow.raise_() + + window = self.window() + winGeom = window.frameGeometry() + qapp = qt.QApplication.instance() + desktop = qapp.desktop() + screenGeom = desktop.availableGeometry(self) + spaceOnLeftSide = winGeom.left() + spaceOnRightSide = screenGeom.width() - winGeom.right() + + frameGeometry = profileWindow.frameGeometry() + profileWindowWidth = frameGeometry.width() + if profileWindowWidth < spaceOnRightSide: + # Place profile on the right + profileWindow.move(winGeom.right(), winGeom.top()) + elif profileWindowWidth < spaceOnLeftSide: + # Place profile on the left + profileWindow.move( + max(0, winGeom.left() - profileWindowWidth), winGeom.top()) + + # Handle plot in log scale + + def __plotAxisScaleChanged(self, scale): + """Handle change of axis scale in the plot widget""" + plot = self.getPlotWidget() + if plot is None: + return + + xScale = plot.getXAxis().getScale() + yScale = plot.getYAxis().getScale() + + if xScale == items.Axis.LINEAR and yScale == items.Axis.LINEAR: + self.setEnabled(True) + + else: + roiManager = self._getRoiManager() + if roiManager is not None: + roiManager.stop() # Stop interactive mode + + self.clearProfile() + self.setEnabled(False) + + # Profile color + + def getColor(self): + """Returns the color used for the profile and ROI + + :rtype: QColor + """ + return qt.QColor.fromRgbF(*self._color) + + def setColor(self, color): + """Set the color to use for ROI and profile. + + :param color: + Either a color name, a QColor, a list of uint8 or float in [0, 1]. + """ + self._color = colors.rgba(color) + roiManager = self._getRoiManager() + if roiManager is not None: + roiManager.setColor(self._color) + for roi in roiManager.getRois(): + roi.setColor(self._color) + self.updateProfile() + + # Handle ROI manager + + def __interactionFinished(self): + """Handle end of interactive mode""" + self.clearProfile() + + profileWindow = self.getDefaultProfileWindow() + if profileWindow is not None: + profileWindow.hide() + + def __roiAdded(self, roi): + """Handle new ROI""" + roi.setLabel('Profile') + roi.setEditable(True) + + # Remove any other ROI + roiManager = self._getRoiManager() + if roiManager is not None: + for regionOfInterest in list(roiManager.getRois()): + if regionOfInterest is not roi: + roiManager.removeRoi(regionOfInterest) + + def computeProfile(self, x0, y0, x1, y1): + """Compute corresponding profile + + Override in subclass to compute profile + + :param float x0: Profile start point X coord + :param float y0: Profile start point Y coord + :param float x1: Profile end point X coord + :param float y1: Profile end point Y coord + :return: (points, values) profile data or None + """ + return None + + def computeProfileTitle(self, x0, y0, x1, y1): + """Compute corresponding plot title + + This can be overridden to change title behavior. + + :param float x0: Profile start point X coord + :param float y0: Profile start point Y coord + :param float x1: Profile end point X coord + :param float y1: Profile end point Y coord + :return: Title to use + :rtype: str + """ + if x0 == x1: + title = 'X = %g; Y = [%g, %g]' % (x0, y0, y1) + elif y0 == y1: + title = 'Y = %g; X = [%g, %g]' % (y0, x0, x1) + else: + m = (y1 - y0) / (x1 - x0) + b = y0 - m * x0 + title = 'Y = %g * X %+g' % (m, b) + + return title + + def updateProfile(self): + """Update profile according to current ROI""" + roiManager = self._getRoiManager() + if roiManager is None: + roi = None + else: + rois = roiManager.getRois() + roi = None if len(rois) == 0 else rois[0] + + if roi is None: + self._setProfile(profile=None, title='') + return + + # Get end points + if isinstance(roi, roi_items.LineROI): + points = roi.getEndPoints() + x0, y0 = points[0] + x1, y1 = points[1] + elif isinstance(roi, (roi_items.VerticalLineROI, roi_items.HorizontalLineROI)): + plot = self.getPlotWidget() + if plot is None: + self._setProfile(profile=None, title='') + return + + elif isinstance(roi, roi_items.HorizontalLineROI): + x0, x1 = plot.getXAxis().getLimits() + y0 = y1 = roi.getPosition() + + elif isinstance(roi, roi_items.VerticalLineROI): + x0 = x1 = roi.getPosition() + y0, y1 = plot.getYAxis().getLimits() + + else: + raise RuntimeError('Unsupported ROI for profile: {}'.format(roi.__class__)) + + if x1 < x0 or (x1 == x0 and y1 < y0): + # Invert points + x0, y0, x1, y1 = x1, y1, x0, y0 + + profile = self.computeProfile(x0, y0, x1, y1) + title = self.computeProfileTitle(x0, y0, x1, y1) + self._setProfile(profile=profile, title=title) + + def _setProfile(self, profile=None, title=''): + """Set profile data and emit signal. + + :param profile: points and profile values + :param str title: + """ + self.__profile = profile + self.__profileTitle = title + + self.sigProfileChanged.emit() + + def clearProfile(self): + """Clear the current line ROI and associated profile""" + roiManager = self._getRoiManager() + if roiManager is not None: + roiManager.clear() + + self._setProfile(profile=None, title='') diff --git a/silx/gui/plot/tools/profile/__init__.py b/silx/gui/plot/tools/profile/__init__.py new file mode 100644 index 0000000..d91191e --- /dev/null +++ b/silx/gui/plot/tools/profile/__init__.py @@ -0,0 +1,38 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides tools to get profiles on plot data. + +It provides: + +- :class:`ScatterProfileToolBar`: a QToolBar to handle profile on scatter data + +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "07/06/2018" + + +from .ScatterProfileToolBar import ScatterProfileToolBar # noqa diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py new file mode 100644 index 0000000..d58c041 --- /dev/null +++ b/silx/gui/plot/tools/roi.py @@ -0,0 +1,934 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides ROI interaction for :class:`~silx.gui.plot.PlotWidget`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import collections +import functools +import logging +import time +import weakref + +import numpy + +from ....third_party import enum +from ....utils.weakref import WeakMethodProxy +from ... import qt, icons +from .. import PlotWidget +from ..items import roi as roi_items + +from ...colors import rgba + + +logger = logging.getLogger(__name__) + + +class RegionOfInterestManager(qt.QObject): + """Class handling ROI interaction on a PlotWidget. + + It supports the multiple ROIs: points, rectangles, polygons, + lines, horizontal and vertical lines. + + See ``plotInteractiveImageROI.py`` sample code (:ref:`sample-code`). + + :param silx.gui.plot.PlotWidget parent: + The plot widget in which to control the ROIs. + """ + + sigRoiAdded = qt.Signal(roi_items.RegionOfInterest) + """Signal emitted when a new ROI has been added. + + It provides the newly add :class:`RegionOfInterest` object. + """ + + sigRoiAboutToBeRemoved = qt.Signal(roi_items.RegionOfInterest) + """Signal emitted just before a ROI is removed. + + It provides the :class:`RegionOfInterest` object that is about to be removed. + """ + + sigRoiChanged = qt.Signal() + """Signal emitted whenever the ROIs have changed.""" + + sigInteractiveModeStarted = qt.Signal(object) + """Signal emitted when switching to ROI drawing interactive mode. + + It provides the class of the ROI which will be created by the interactive + mode. + """ + + sigInteractiveModeFinished = qt.Signal() + """Signal emitted when leaving and interactive ROI drawing. + + It provides the list of ROIs. + """ + + _MODE_ACTIONS_PARAMS = collections.OrderedDict() + # Interactive mode: (icon name, text) + _MODE_ACTIONS_PARAMS[roi_items.PointROI] = 'add-shape-point', 'Add point markers' + _MODE_ACTIONS_PARAMS[roi_items.RectangleROI] = 'add-shape-rectangle', 'Add rectangle ROI' + _MODE_ACTIONS_PARAMS[roi_items.PolygonROI] = 'add-shape-polygon', 'Add polygon ROI' + _MODE_ACTIONS_PARAMS[roi_items.LineROI] = 'add-shape-diagonal', 'Add line ROI' + _MODE_ACTIONS_PARAMS[roi_items.HorizontalLineROI] = 'add-shape-horizontal', 'Add horizontal line ROI' + _MODE_ACTIONS_PARAMS[roi_items.VerticalLineROI] = 'add-shape-vertical', 'Add vertical line ROI' + _MODE_ACTIONS_PARAMS[roi_items.ArcROI] = 'add-shape-arc', 'Add arc ROI' + + def __init__(self, parent): + assert isinstance(parent, PlotWidget) + super(RegionOfInterestManager, self).__init__(parent) + self._rois = [] # List of ROIs + self._drawnROI = None # New ROI being currently drawn + + self._roiClass = None + self._color = rgba('red') + + self._label = "__RegionOfInterestManager__%d" % id(self) + + self._eventLoop = None + + self._modeActions = {} + + parent.sigInteractiveModeChanged.connect( + self._plotInteractiveModeChanged) + + @classmethod + def getSupportedRoiClasses(cls): + """Returns the default available ROI classes + + :rtype: List[class] + """ + return tuple(cls._MODE_ACTIONS_PARAMS.keys()) + + # Associated QActions + + def getInteractionModeAction(self, roiClass): + """Returns the QAction corresponding to a kind of ROI + + The QAction allows to enable the corresponding drawing + interactive mode. + + :param str roiClass: The ROI class which will be crated by this action. + :rtype: QAction + :raise ValueError: If kind is not supported + """ + if not issubclass(roiClass, roi_items.RegionOfInterest): + raise ValueError('Unsupported ROI class %s' % roiClass) + + action = self._modeActions.get(roiClass, None) + if action is None: # Lazy-loading + if roiClass in self._MODE_ACTIONS_PARAMS: + iconName, text = self._MODE_ACTIONS_PARAMS[roiClass] + else: + iconName = "add-shape-unknown" + name = roiClass._getKind() + if name is None: + name = roiClass.__name__ + text = 'Add %s' % name + action = qt.QAction(self) + action.setIcon(icons.getQIcon(iconName)) + action.setText(text) + action.setCheckable(True) + action.setChecked(self.getCurrentInteractionModeRoiClass() is roiClass) + action.setToolTip(text) + + action.triggered[bool].connect(functools.partial( + WeakMethodProxy(self._modeActionTriggered), roiClass=roiClass)) + self._modeActions[roiClass] = action + return action + + def _modeActionTriggered(self, checked, roiClass): + """Handle mode actions being checked by the user + + :param bool checked: + :param str kind: Corresponding shape kind + """ + if checked: + self.start(roiClass) + else: # Keep action checked + action = self.sender() + action.setChecked(True) + + def _updateModeActions(self): + """Check/Uncheck action corresponding to current mode""" + for roiClass, action in self._modeActions.items(): + action.setChecked(roiClass == self.getCurrentInteractionModeRoiClass()) + + # PlotWidget eventFilter and listeners + + def _plotInteractiveModeChanged(self, source): + """Handle change of interactive mode in the plot""" + if source is not self: + self.__roiInteractiveModeEnded() + + else: # Check the corresponding action + self._updateModeActions() + + # Handle ROI interaction + + def _handleInteraction(self, event): + """Handle mouse interaction for ROI addition""" + roiClass = self.getCurrentInteractionModeRoiClass() + if roiClass is None: + return # Should not happen + + kind = roiClass.getFirstInteractionShape() + if kind == 'point': + if event['event'] == 'mouseClicked' and event['button'] == 'left': + points = numpy.array([(event['x'], event['y'])], + dtype=numpy.float64) + self.createRoi(roiClass, points=points) + + else: # other shapes + if (event['event'] in ('drawingProgress', 'drawingFinished') and + event['parameters']['label'] == self._label): + points = numpy.array((event['xdata'], event['ydata']), + dtype=numpy.float64).T + + if self._drawnROI is None: # Create new ROI + self._drawnROI = self.createRoi(roiClass, points=points) + else: + self._drawnROI.setFirstShapePoints(points) + + if event['event'] == 'drawingFinished': + if kind == 'polygon' and len(points) > 1: + self._drawnROI.setFirstShapePoints(points[:-1]) + self._drawnROI = None # Stop drawing + + # RegionOfInterest API + + def getRois(self): + """Returns the list of ROIs. + + It returns an empty tuple if there is currently no ROI. + + :return: Tuple of arrays of objects describing the ROIs + :rtype: List[RegionOfInterest] + """ + return tuple(self._rois) + + def clear(self): + """Reset current ROIs + + :return: True if ROIs were reset. + :rtype: bool + """ + if self.getRois(): # Something to reset + for roi in self._rois: + roi.sigRegionChanged.disconnect( + self._regionOfInterestChanged) + roi.setParent(None) + self._rois = [] + self._roisUpdated() + return True + + else: + return False + + def _regionOfInterestChanged(self): + """Handle ROI object changed""" + self.sigRoiChanged.emit() + + def createRoi(self, roiClass, points, label='', index=None): + """Create a new ROI and add it to list of ROIs. + + :param class roiClass: The class of the ROI to create + :param numpy.ndarray points: The first shape used to create the ROI + :param str label: The label to display along with the ROI. + :param int index: The position where to insert the ROI. + By default it is appended to the end of the list. + :return: The created ROI object + :rtype: roi_items.RegionOfInterest + :raise RuntimeError: When ROI cannot be added because the maximum + number of ROIs has been reached. + """ + roi = roiClass(parent=None) + roi.setLabel(str(label)) + roi.setFirstShapePoints(points) + + self.addRoi(roi, index) + return roi + + def addRoi(self, roi, index=None, useManagerColor=True): + """Add the ROI to the list of ROIs. + + :param roi_items.RegionOfInterest roi: The ROI to add + :param int index: The position where to insert the ROI, + By default it is appended to the end of the list of ROIs + :raise RuntimeError: When ROI cannot be added because the maximum + number of ROIs has been reached. + """ + plot = self.parent() + if plot is None: + raise RuntimeError( + 'Cannot add ROI: PlotWidget no more available') + + roi.setParent(self) + + if useManagerColor: + roi.setColor(self.getColor()) + + roi.sigRegionChanged.connect(self._regionOfInterestChanged) + + if index is None: + self._rois.append(roi) + else: + self._rois.insert(index, roi) + self.sigRoiAdded.emit(roi) + self._roisUpdated() + + def removeRoi(self, roi): + """Remove a ROI from the list of ROIs. + + :param roi_items.RegionOfInterest roi: The ROI to remove + :raise ValueError: When ROI does not belong to this object + """ + if not (isinstance(roi, roi_items.RegionOfInterest) and + roi.parent() is self and + roi in self._rois): + raise ValueError( + 'RegionOfInterest does not belong to this instance') + + self.sigRoiAboutToBeRemoved.emit(roi) + + self._rois.remove(roi) + roi.sigRegionChanged.disconnect(self._regionOfInterestChanged) + roi.setParent(None) + self._roisUpdated() + + def _roisUpdated(self): + """Handle update of the ROI list""" + self.sigRoiChanged.emit() + + # RegionOfInterest parameters + + def getColor(self): + """Return the default color of created ROIs + + :rtype: QColor + """ + return qt.QColor.fromRgbF(*self._color) + + def setColor(self, color): + """Set the default color to use when creating ROIs. + + Existing ROIs are not affected. + + :param color: The color to use for displaying ROIs as + either a color name, a QColor, a list of uint8 or float in [0, 1]. + """ + self._color = rgba(color) + + # Control ROI + + def getCurrentInteractionModeRoiClass(self): + """Returns the current ROI class used by the interactive drawing mode. + + Returns None if the ROI manager is not in an interactive mode. + + :rtype: Union[class,None] + """ + return self._roiClass + + def isStarted(self): + """Returns True if an interactive ROI drawing mode is active. + + :rtype: bool + """ + return self._roiClass is not None + + def start(self, roiClass): + """Start an interactive ROI drawing mode. + + :param class roiClass: The ROI class to create. It have to inherite from + `roi_items.RegionOfInterest`. + :return: True if interactive ROI drawing was started, False otherwise + :rtype: bool + :raise ValueError: If roiClass is not supported + """ + self.stop() + + if not issubclass(roiClass, roi_items.RegionOfInterest): + raise ValueError('Unsupported ROI class %s' % roiClass) + + plot = self.parent() + if plot is None: + return False + + self._roiClass = roiClass + firstInteractionShapeKind = roiClass.getFirstInteractionShape() + + if firstInteractionShapeKind == 'point': + plot.setInteractiveMode(mode='select', source=self) + else: + if roiClass.showFirstInteractionShape(): + color = rgba(self.getColor()) + else: + color = None + plot.setInteractiveMode(mode='select-draw', + source=self, + shape=firstInteractionShapeKind, + color=color, + label=self._label) + + plot.sigPlotSignal.connect(self._handleInteraction) + + self.sigInteractiveModeStarted.emit(roiClass) + + return True + + def __roiInteractiveModeEnded(self): + """Handle end of ROI draw interactive mode""" + if self.isStarted(): + self._roiClass = None + + if self._drawnROI is not None: + # Cancel ROI create + self.removeRoi(self._drawnROI) + self._drawnROI = None + + plot = self.parent() + if plot is not None: + plot.sigPlotSignal.disconnect(self._handleInteraction) + + self._updateModeActions() + + self.sigInteractiveModeFinished.emit() + + def stop(self): + """Stop interactive ROI drawing mode. + + :return: True if an interactive ROI drawing mode was actually stopped + :rtype: bool + """ + if not self.isStarted(): + return False + + plot = self.parent() + if plot is not None: + # This leads to call __roiInteractiveModeEnded through + # interactive mode changed signal + plot.setInteractiveMode(mode='zoom', source=None) + else: # Fallback + self.__roiInteractiveModeEnded() + + return True + + def exec_(self, roiClass): + """Block until :meth:`quit` is called. + + :param class kind: The class of the ROI which have to be created. + See `silx.gui.plot.items.roi`. + :return: The list of ROIs + :rtype: tuple + """ + self.start(roiClass) + + plot = self.parent() + plot.show() + plot.raise_() + + self._eventLoop = qt.QEventLoop() + self._eventLoop.exec_() + self._eventLoop = None + + self.stop() + + rois = self.getRois() + self.clear() + return rois + + def quit(self): + """Stop a blocking :meth:`exec_` and call :meth:`stop`""" + if self._eventLoop is not None: + self._eventLoop.quit() + self._eventLoop = None + self.stop() + + +class InteractiveRegionOfInterestManager(RegionOfInterestManager): + """RegionOfInterestManager with features for use from interpreter. + + It is meant to be used through the :meth:`exec_`. + It provides some messages to display in a status bar and + different modes to end blocking calls to :meth:`exec_`. + + :param parent: See QObject + """ + + sigMessageChanged = qt.Signal(str) + """Signal emitted when a new message should be displayed to the user + + It provides the message as a str. + """ + + def __init__(self, parent): + super(InteractiveRegionOfInterestManager, self).__init__(parent) + self._maxROI = None + self.__timeoutEndTime = None + self.__message = '' + self.__validationMode = self.ValidationMode.ENTER + self.__execClass = None + + self.sigRoiAdded.connect(self.__added) + self.sigRoiAboutToBeRemoved.connect(self.__aboutToBeRemoved) + self.sigInteractiveModeStarted.connect(self.__started) + self.sigInteractiveModeFinished.connect(self.__finished) + + # Max ROI + + def getMaxRois(self): + """Returns the maximum number of ROIs or None if no limit. + + :rtype: Union[int,None] + """ + return self._maxROI + + def setMaxRois(self, max_): + """Set the maximum number of ROIs. + + :param Union[int,None] max_: The max limit or None for no limit. + :raise ValueError: If there is more ROIs than max value + """ + if max_ is not None: + max_ = int(max_) + if max_ <= 0: + raise ValueError('Max limit must be strictly positive') + + if len(self.getRois()) > max_: + raise ValueError( + 'Cannot set max limit: Already too many ROIs') + + self._maxROI = max_ + + def isMaxRois(self): + """Returns True if the maximum number of ROIs is reached. + + :rtype: bool + """ + max_ = self.getMaxRois() + return max_ is not None and len(self.getRois()) >= max_ + + # Validation mode + + @enum.unique + class ValidationMode(enum.Enum): + """Mode of validation to leave blocking :meth:`exec_`""" + + AUTO = 'auto' + """Automatically ends the interactive mode once + the user terminates the last ROI shape.""" + + ENTER = 'enter' + """Ends the interactive mode when the *Enter* key is pressed.""" + + AUTO_ENTER = 'auto_enter' + """Ends the interactive mode when reaching max ROIs or + when the *Enter* key is pressed. + """ + + NONE = 'none' + """Do not provide the user a way to end the interactive mode. + + The end of :meth:`exec_` is done through :meth:`quit` or timeout. + """ + + def getValidationMode(self): + """Returns the interactive mode validation in use. + + :rtype: ValidationMode + """ + return self.__validationMode + + def setValidationMode(self, mode): + """Set the way to perform interactive mode validation. + + See :class:`ValidationMode` enumeration for the supported + validation modes. + + :param ValidationMode mode: The interactive mode validation to use. + """ + assert isinstance(mode, self.ValidationMode) + if mode != self.__validationMode: + self.__validationMode = mode + + if self.isExec(): + if (self.isMaxRois() and self.getValidationMode() in + (self.ValidationMode.AUTO, + self.ValidationMode.AUTO_ENTER)): + self.quit() + + self.__updateMessage() + + def eventFilter(self, obj, event): + if event.type() == qt.QEvent.Hide: + self.quit() + + if event.type() == qt.QEvent.KeyPress: + key = event.key() + if (key in (qt.Qt.Key_Return, qt.Qt.Key_Enter) and + self.getValidationMode() in ( + self.ValidationMode.ENTER, + self.ValidationMode.AUTO_ENTER)): + # Stop on return key pressed + self.quit() + return True # Stop further handling of this keys + + if (key in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or ( + key == qt.Qt.Key_Z and + event.modifiers() & qt.Qt.ControlModifier)): + rois = self.getRois() + if rois: # Something to undo + self.removeRoi(rois[-1]) + # Stop further handling of keys if something was undone + return True + + return super(InteractiveRegionOfInterestManager, self).eventFilter(obj, event) + + # Message API + + def getMessage(self): + """Returns the current status message. + + This message is meant to be displayed in a status bar. + + :rtype: str + """ + if self.__timeoutEndTime is None: + return self.__message + else: + remaining = self.__timeoutEndTime - time.time() + return self.__message + (' - %d seconds remaining' % + max(1, int(remaining))) + + # Listen to ROI updates + + def __added(self, *args, **kwargs): + """Handle new ROI added""" + max_ = self.getMaxRois() + if max_ is not None: + # When reaching max number of ROIs, redo last one + while len(self.getRois()) > max_: + self.removeRoi(self.getRois()[-2]) + + self.__updateMessage() + if (self.isMaxRois() and + self.getValidationMode() in (self.ValidationMode.AUTO, + self.ValidationMode.AUTO_ENTER)): + self.quit() + + def __aboutToBeRemoved(self, *args, **kwargs): + """Handle removal of a ROI""" + # RegionOfInterest not removed yet + self.__updateMessage(nbrois=len(self.getRois()) - 1) + + def __started(self, roiKind): + """Handle interactive mode started""" + self.__updateMessage() + + def __finished(self): + """Handle interactive mode finished""" + self.__updateMessage() + + def __updateMessage(self, nbrois=None): + """Update message""" + if not self.isExec(): + message = 'Done' + + elif not self.isStarted(): + message = 'Use %s ROI edition mode' % self.__execClass + + else: + if nbrois is None: + nbrois = len(self.getRois()) + + kind = self.__execClass._getKind() + max_ = self.getMaxRois() + + if max_ is None: + message = 'Select %ss (%d selected)' % (kind, nbrois) + + elif max_ <= 1: + message = 'Select a %s' % kind + else: + message = 'Select %d/%d %ss' % (nbrois, max_, kind) + + if (self.getValidationMode() == self.ValidationMode.ENTER and + self.isMaxRois()): + message += ' - Press Enter to confirm' + + if message != self.__message: + self.__message = message + # Use getMessage to add timeout message + self.sigMessageChanged.emit(self.getMessage()) + + # Handle blocking call + + def __timeoutUpdate(self): + """Handle update of timeout""" + if (self.__timeoutEndTime is not None and + (self.__timeoutEndTime - time.time()) > 0): + self.sigMessageChanged.emit(self.getMessage()) + else: # Stop interactive mode and message timer + timer = self.sender() + if timer is not None: + timer.stop() + self.__timeoutEndTime = None + self.quit() + + def isExec(self): + """Returns True if :meth:`exec_` is currently running. + + :rtype: bool""" + return self.__execClass is not None + + def exec_(self, roiClass, timeout=0): + """Block until ROI selection is done or timeout is elapsed. + + :meth:`quit` also ends this blocking call. + + :param class roiClass: The class of the ROI which have to be created. + See `silx.gui.plot.items.roi`. + :param int timeout: Maximum duration in seconds to block. + Default: No timeout + :return: The list of ROIs + :rtype: List[RegionOfInterest] + """ + plot = self.parent() + if plot is None: + return + + self.__execClass = roiClass + + plot.installEventFilter(self) + + if timeout > 0: + self.__timeoutEndTime = time.time() + timeout + timer = qt.QTimer(self) + timer.timeout.connect(self.__timeoutUpdate) + timer.start(1000) + + rois = super(InteractiveRegionOfInterestManager, self).exec_(roiClass) + + timer.stop() + self.__timeoutEndTime = None + + else: + rois = super(InteractiveRegionOfInterestManager, self).exec_(roiClass) + + plot.removeEventFilter(self) + + self.__execClass = None + self.__updateMessage() + + return rois + + +class _DeleteRegionOfInterestToolButton(qt.QToolButton): + """Tool button deleting a ROI object + + :param parent: See QWidget + :param RegionOfInterest roi: The ROI to delete + """ + + def __init__(self, parent, roi): + super(_DeleteRegionOfInterestToolButton, self).__init__(parent) + self.setIcon(icons.getQIcon('remove')) + self.setToolTip("Remove this ROI") + self.__roiRef = roi if roi is None else weakref.ref(roi) + self.clicked.connect(self.__clicked) + + def __clicked(self, checked): + """Handle button clicked""" + roi = None if self.__roiRef is None else self.__roiRef() + if roi is not None: + manager = roi.parent() + if manager is not None: + manager.removeRoi(roi) + self.__roiRef = None + + +class RegionOfInterestTableWidget(qt.QTableWidget): + """Widget displaying the ROIs of a :class:`RegionOfInterestManager`""" + + def __init__(self, parent=None): + super(RegionOfInterestTableWidget, self).__init__(parent) + self._roiManagerRef = None + + self.setColumnCount(5) + self.setHorizontalHeaderLabels( + ['Label', 'Edit', 'Kind', 'Coordinates', '']) + + horizontalHeader = self.horizontalHeader() + horizontalHeader.setDefaultAlignment(qt.Qt.AlignLeft) + if hasattr(horizontalHeader, 'setResizeMode'): # Qt 4 + setSectionResizeMode = horizontalHeader.setResizeMode + else: # Qt5 + setSectionResizeMode = horizontalHeader.setSectionResizeMode + + setSectionResizeMode(0, qt.QHeaderView.Interactive) + setSectionResizeMode(1, qt.QHeaderView.ResizeToContents) + setSectionResizeMode(2, qt.QHeaderView.ResizeToContents) + setSectionResizeMode(3, qt.QHeaderView.Stretch) + setSectionResizeMode(4, qt.QHeaderView.ResizeToContents) + + verticalHeader = self.verticalHeader() + verticalHeader.setVisible(False) + + self.setSelectionMode(qt.QAbstractItemView.NoSelection) + self.setFocusPolicy(qt.Qt.NoFocus) + + self.itemChanged.connect(self.__itemChanged) + + @staticmethod + def __itemChanged(item): + """Handle item updates""" + column = item.column() + roi = item.data(qt.Qt.UserRole) + if column == 0: + roi.setLabel(item.text()) + elif column == 1: + roi.setEditable( + item.checkState() == qt.Qt.Checked) + elif column in (2, 3, 4): + pass # TODO + else: + logger.error('Unhandled column %d', column) + + def setRegionOfInterestManager(self, manager): + """Set the :class:`RegionOfInterestManager` object to sync with + + :param RegionOfInterestManager manager: + """ + assert manager is None or isinstance(manager, RegionOfInterestManager) + + previousManager = self.getRegionOfInterestManager() + + if previousManager is not None: + previousManager.sigRoiChanged.disconnect(self._sync) + self.setRowCount(0) + + self._roiManagerRef = weakref.ref(manager) + + self._sync() + + if manager is not None: + manager.sigRoiChanged.connect(self._sync) + + def _getReadableRoiDescription(self, roi): + """Returns modelisation of a ROI as a readable sequence of values. + + :rtype: str + """ + text = str(roi) + try: + # Extract the params from syntax "CLASSNAME(PARAMS)" + elements = text.split("(", 1) + if len(elements) != 2: + return text + result = elements[1] + result = result.strip() + if not result.endswith(")"): + return text + result = result[0:-1] + # Capitalize each words + result = result.title() + return result + except Exception: + logger.debug("Backtrace", exc_info=True) + return text + + def _sync(self): + """Update widget content according to ROI manger""" + manager = self.getRegionOfInterestManager() + + if manager is None: + self.setRowCount(0) + return + + rois = manager.getRois() + + self.setRowCount(len(rois)) + for index, roi in enumerate(rois): + baseFlags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled + + # Label + label = roi.getLabel() + item = qt.QTableWidgetItem(label) + item.setFlags(baseFlags | qt.Qt.ItemIsEditable) + item.setData(qt.Qt.UserRole, roi) + self.setItem(index, 0, item) + + # Editable + item = qt.QTableWidgetItem() + item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable) + item.setData(qt.Qt.UserRole, roi) + item.setCheckState( + qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked) + self.setItem(index, 1, item) + item.setTextAlignment(qt.Qt.AlignCenter) + item.setText(None) + + # Kind + label = roi._getKind() + if label is None: + # Default value if kind is not overrided + label = roi.__class__.__name__ + item = qt.QTableWidgetItem(label.capitalize()) + item.setFlags(baseFlags) + self.setItem(index, 2, item) + + item = qt.QTableWidgetItem() + item.setFlags(baseFlags) + + # Coordinates + text = self._getReadableRoiDescription(roi) + item.setText(text) + self.setItem(index, 3, item) + + # Delete + delBtn = _DeleteRegionOfInterestToolButton(None, roi) + widget = qt.QWidget(self) + layout = qt.QHBoxLayout() + layout.setContentsMargins(2, 2, 2, 2) + layout.setSpacing(0) + widget.setLayout(layout) + layout.addStretch(1) + layout.addWidget(delBtn) + layout.addStretch(1) + self.setCellWidget(index, 4, widget) + + def getRegionOfInterestManager(self): + """Returns the :class:`RegionOfInterestManager` this widget supervise. + + It returns None if not sync with an :class:`RegionOfInterestManager`. + + :rtype: RegionOfInterestManager + """ + return None if self._roiManagerRef is None else self._roiManagerRef() diff --git a/silx/gui/plot/tools/test/__init__.py b/silx/gui/plot/tools/test/__init__.py new file mode 100644 index 0000000..79301ab --- /dev/null +++ b/silx/gui/plot/tools/test/__init__.py @@ -0,0 +1,48 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "26/03/2018" + + +import unittest + +from . import testROI +from . import testTools +from . import testScatterProfileToolBar + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTests( + [testROI.suite(), + testTools.suite(), + testScatterProfileToolBar.suite(), + ]) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/tools/test/testROI.py b/silx/gui/plot/tools/test/testROI.py new file mode 100644 index 0000000..5032036 --- /dev/null +++ b/silx/gui/plot/tools/test/testROI.py @@ -0,0 +1,456 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import unittest +import numpy.testing + +from silx.gui import qt +from silx.utils.testutils import ParametricTestCase +from silx.gui.test.utils import TestCaseQt, SignalListener +from silx.gui.plot import PlotWindow +import silx.gui.plot.items.roi as roi_items +from silx.gui.plot.tools import roi + + +class TestRoiItems(TestCaseQt): + + def testLine_geometry(self): + item = roi_items.LineROI() + startPoint = numpy.array([1, 2]) + endPoint = numpy.array([3, 4]) + item.setEndPoints(startPoint, endPoint) + numpy.testing.assert_allclose(item.getEndPoints()[0], startPoint) + numpy.testing.assert_allclose(item.getEndPoints()[1], endPoint) + + def testHLine_geometry(self): + item = roi_items.HorizontalLineROI() + item.setPosition(15) + self.assertEqual(item.getPosition(), 15) + + def testVLine_geometry(self): + item = roi_items.VerticalLineROI() + item.setPosition(15) + self.assertEqual(item.getPosition(), 15) + + def testPoint_geometry(self): + point = numpy.array([1, 2]) + item = roi_items.VerticalLineROI() + item.setPosition(point) + numpy.testing.assert_allclose(item.getPosition(), point) + + def testRectangle_originGeometry(self): + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + center = numpy.array([5, 10]) + item = roi_items.RectangleROI() + item.setGeometry(origin=origin, size=size) + numpy.testing.assert_allclose(item.getOrigin(), origin) + numpy.testing.assert_allclose(item.getSize(), size) + numpy.testing.assert_allclose(item.getCenter(), center) + + def testRectangle_centerGeometry(self): + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + center = numpy.array([5, 10]) + item = roi_items.RectangleROI() + item.setGeometry(center=center, size=size) + numpy.testing.assert_allclose(item.getOrigin(), origin) + numpy.testing.assert_allclose(item.getSize(), size) + numpy.testing.assert_allclose(item.getCenter(), center) + + def testRectangle_setCenterGeometry(self): + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + item = roi_items.RectangleROI() + item.setGeometry(origin=origin, size=size) + newCenter = numpy.array([0, 0]) + item.setCenter(newCenter) + expectedOrigin = numpy.array([-5, -10]) + numpy.testing.assert_allclose(item.getOrigin(), expectedOrigin) + numpy.testing.assert_allclose(item.getCenter(), newCenter) + numpy.testing.assert_allclose(item.getSize(), size) + + def testRectangle_setOriginGeometry(self): + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + item = roi_items.RectangleROI() + item.setGeometry(origin=origin, size=size) + newOrigin = numpy.array([10, 10]) + item.setOrigin(newOrigin) + expectedCenter = numpy.array([15, 20]) + numpy.testing.assert_allclose(item.getOrigin(), newOrigin) + numpy.testing.assert_allclose(item.getCenter(), expectedCenter) + numpy.testing.assert_allclose(item.getSize(), size) + + def testPolygon_emptyGeometry(self): + points = numpy.empty((0, 2)) + item = roi_items.PolygonROI() + item.setPoints(points) + numpy.testing.assert_allclose(item.getPoints(), points) + + def testPolygon_geometry(self): + points = numpy.array([[10, 10], [12, 10], [50, 1]]) + item = roi_items.PolygonROI() + item.setPoints(points) + numpy.testing.assert_allclose(item.getPoints(), points) + + def testArc_getToSetGeometry(self): + """Test that we can use getGeometry as input to setGeometry""" + item = roi_items.ArcROI() + item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]])) + item.setGeometry(*item.getGeometry()) + + def testArc_degenerated_point(self): + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0 + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + + def testArc_degenerated_line(self): + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + + def testArc_special_circle(self): + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, 3 * numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + numpy.testing.assert_allclose(item.getCenter(), center) + self.assertAlmostEqual(item.getInnerRadius(), innerRadius) + self.assertAlmostEqual(item.getOuterRadius(), outerRadius) + self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0) + self.assertAlmostEqual(item.isClosed(), True) + + def testArc_special_donut(self): + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + numpy.testing.assert_allclose(item.getCenter(), center) + self.assertAlmostEqual(item.getInnerRadius(), innerRadius) + self.assertAlmostEqual(item.getOuterRadius(), outerRadius) + self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0) + self.assertAlmostEqual(item.isClosed(), True) + + def testArc_clockwiseGeometry(self): + """Test that we can use getGeometry as input to setGeometry""" + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + numpy.testing.assert_allclose(item.getCenter(), center) + self.assertAlmostEqual(item.getInnerRadius(), innerRadius) + self.assertAlmostEqual(item.getOuterRadius(), outerRadius) + self.assertAlmostEqual(item.getStartAngle(), startAngle) + self.assertAlmostEqual(item.getEndAngle(), endAngle) + self.assertAlmostEqual(item.isClosed(), False) + + def testArc_anticlockwiseGeometry(self): + """Test that we can use getGeometry as input to setGeometry""" + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, -numpy.pi * 0.5 + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + numpy.testing.assert_allclose(item.getCenter(), center) + self.assertAlmostEqual(item.getInnerRadius(), innerRadius) + self.assertAlmostEqual(item.getOuterRadius(), outerRadius) + self.assertAlmostEqual(item.getStartAngle(), startAngle) + self.assertAlmostEqual(item.getEndAngle(), endAngle) + self.assertAlmostEqual(item.isClosed(), False) + + +class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase): + """Tests for RegionOfInterestManager class""" + + def setUp(self): + super(TestRegionOfInterestManager, self).setUp() + self.plot = PlotWindow() + + self.roiTableWidget = roi.RegionOfInterestTableWidget() + dock = qt.QDockWidget() + dock.setWidget(self.roiTableWidget) + self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, dock) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + del self.roiTableWidget + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestRegionOfInterestManager, self).tearDown() + + def test(self): + """Test ROI of different shapes""" + tests = ( # shape, points=[list of (x, y), list of (x, y)] + (roi_items.PointROI, numpy.array(([(10., 15.)], [(20., 25.)]))), + (roi_items.RectangleROI, + numpy.array((((1., 10.), (11., 20.)), + ((2., 3.), (12., 13.))))), + (roi_items.PolygonROI, + numpy.array((((0., 1.), (0., 10.), (10., 0.)), + ((5., 6.), (5., 16.), (15., 6.))))), + (roi_items.LineROI, + numpy.array((((10., 20.), (10., 30.)), + ((30., 40.), (30., 50.))))), + (roi_items.HorizontalLineROI, + numpy.array((((10., 20.), (10., 30.)), + ((30., 40.), (30., 50.))))), + (roi_items.VerticalLineROI, + numpy.array((((10., 20.), (10., 30.)), + ((30., 40.), (30., 50.))))), + ) + + for roiClass, points in tests: + with self.subTest(roiClass=roiClass): + manager = roi.RegionOfInterestManager(self.plot) + self.roiTableWidget.setRegionOfInterestManager(manager) + manager.start(roiClass) + + self.assertEqual(manager.getRois(), ()) + + finishListener = SignalListener() + manager.sigInteractiveModeFinished.connect(finishListener) + + changedListener = SignalListener() + manager.sigRoiChanged.connect(changedListener) + + # Add a point + manager.createRoi(roiClass, points[0]) + self.qapp.processEvents() + self.assertTrue(len(manager.getRois()), 1) + self.assertEqual(changedListener.callCount(), 1) + + # Remove it + manager.removeRoi(manager.getRois()[0]) + self.assertEqual(manager.getRois(), ()) + self.assertEqual(changedListener.callCount(), 2) + + # Add two point + manager.createRoi(roiClass, points[0]) + self.qapp.processEvents() + manager.createRoi(roiClass, points[1]) + self.qapp.processEvents() + self.assertTrue(len(manager.getRois()), 2) + self.assertEqual(changedListener.callCount(), 4) + + # Reset it + result = manager.clear() + self.assertTrue(result) + self.assertEqual(manager.getRois(), ()) + self.assertEqual(changedListener.callCount(), 5) + + changedListener.clear() + + # Add two point + manager.createRoi(roiClass, points[0]) + self.qapp.processEvents() + manager.createRoi(roiClass, points[1]) + self.qapp.processEvents() + self.assertTrue(len(manager.getRois()), 2) + self.assertEqual(changedListener.callCount(), 2) + + # stop + result = manager.stop() + self.assertTrue(result) + self.assertTrue(len(manager.getRois()), 1) + self.qapp.processEvents() + self.assertEqual(finishListener.callCount(), 1) + + manager.clear() + + def testRoiDisplay(self): + rois = [] + + # Line + item = roi_items.LineROI() + startPoint = numpy.array([1, 2]) + endPoint = numpy.array([3, 4]) + item.setEndPoints(startPoint, endPoint) + rois.append(item) + # Horizontal line + item = roi_items.HorizontalLineROI() + item.setPosition(15) + rois.append(item) + # Vertical line + item = roi_items.VerticalLineROI() + item.setPosition(15) + rois.append(item) + # Point + item = roi_items.PointROI() + point = numpy.array([1, 2]) + item.setPosition(point) + rois.append(item) + # Rectangle + item = roi_items.RectangleROI() + origin = numpy.array([0, 0]) + size = numpy.array([10, 20]) + item.setGeometry(origin=origin, size=size) + rois.append(item) + # Polygon + item = roi_items.PolygonROI() + points = numpy.array([[10, 10], [12, 10], [50, 1]]) + item.setPoints(points) + rois.append(item) + # Degenerated polygon: No points + item = roi_items.PolygonROI() + points = numpy.empty((0, 2)) + item.setPoints(points) + rois.append(item) + # Degenerated polygon: A single point + item = roi_items.PolygonROI() + points = numpy.array([[5, 10]]) + item.setPoints(points) + rois.append(item) + # Degenerated arc: it's a point + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0 + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + rois.append(item) + # Degenerated arc: it's a line + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + rois.append(item) + # Special arc: it's a donut + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + rois.append(item) + # Arc + item = roi_items.ArcROI() + center = numpy.array([10, 20]) + innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi + item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle) + rois.append(item) + + manager = roi.RegionOfInterestManager(self.plot) + self.roiTableWidget.setRegionOfInterestManager(manager) + for item in rois: + with self.subTest(roi=str(item)): + manager.addRoi(item) + self.qapp.processEvents() + item.setEditable(True) + self.qapp.processEvents() + item.setEditable(False) + self.qapp.processEvents() + manager.removeRoi(item) + self.qapp.processEvents() + + def testMaxROI(self): + """Test Max ROI""" + origin1 = numpy.array([1., 10.]) + size1 = numpy.array([10., 10.]) + origin2 = numpy.array([2., 3.]) + size2 = numpy.array([10., 10.]) + + manager = roi.InteractiveRegionOfInterestManager(self.plot) + self.roiTableWidget.setRegionOfInterestManager(manager) + self.assertEqual(manager.getRois(), ()) + + changedListener = SignalListener() + manager.sigRoiChanged.connect(changedListener) + + # Add two point + item = roi_items.RectangleROI() + item.setGeometry(origin=origin1, size=size1) + manager.addRoi(item) + item = roi_items.RectangleROI() + item.setGeometry(origin=origin2, size=size2) + manager.addRoi(item) + self.qapp.processEvents() + self.assertEqual(changedListener.callCount(), 2) + self.assertEqual(len(manager.getRois()), 2) + + # Try to set max ROI to 1 while there is 2 ROIs + with self.assertRaises(ValueError): + manager.setMaxRois(1) + + manager.clear() + self.assertEqual(len(manager.getRois()), 0) + self.assertEqual(changedListener.callCount(), 3) + + # Set max limit to 1 + manager.setMaxRois(1) + + # Add a point + item = roi_items.RectangleROI() + item.setGeometry(origin=origin1, size=size1) + manager.addRoi(item) + self.qapp.processEvents() + self.assertEqual(changedListener.callCount(), 4) + + # Add a 2nd point while max ROI is 1 + item = roi_items.RectangleROI() + item.setGeometry(origin=origin1, size=size1) + manager.addRoi(item) + self.qapp.processEvents() + self.assertEqual(changedListener.callCount(), 6) + self.assertEqual(len(manager.getRois()), 1) + + def testChangeInteractionMode(self): + """Test change of interaction mode""" + manager = roi.RegionOfInterestManager(self.plot) + self.roiTableWidget.setRegionOfInterestManager(manager) + manager.start(roi_items.PointROI) + + interactiveModeToolBar = self.plot.getInteractiveModeToolBar() + panAction = interactiveModeToolBar.getPanModeAction() + + for roiClass in manager.getSupportedRoiClasses(): + with self.subTest(roiClass=roiClass): + # Change to pan mode + panAction.trigger() + + # Change to interactive ROI mode + action = manager.getInteractionModeAction(roiClass) + action.trigger() + + self.assertEqual(roiClass, manager.getCurrentInteractionModeRoiClass()) + + manager.clear() + + +def suite(): + test_suite = unittest.TestSuite() + loadTests = unittest.defaultTestLoader.loadTestsFromTestCase + test_suite.addTest(loadTests(TestRoiItems)) + test_suite.addTest(loadTests(TestRegionOfInterestManager)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/silx/gui/plot/tools/test/testScatterProfileToolBar.py new file mode 100644 index 0000000..16972f9 --- /dev/null +++ b/silx/gui/plot/tools/test/testScatterProfileToolBar.py @@ -0,0 +1,216 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import unittest +import numpy + +from silx.gui import qt +from silx.utils.testutils import ParametricTestCase +from silx.gui.test.utils import TestCaseQt +from silx.gui.plot import PlotWindow +from silx.gui.plot.tools import profile +import silx.gui.plot.items.roi as roi_items + + +class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase): + """Tests for ScatterProfileToolBar class""" + + def setUp(self): + super(TestScatterProfileToolBar, self).setUp() + self.plot = PlotWindow() + + self.profile = profile.ScatterProfileToolBar(plot=self.plot) + + self.plot.addToolBar(self.profile) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + del self.profile + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestScatterProfileToolBar, self).tearDown() + + def testNoProfile(self): + """Test ScatterProfileToolBar without profile""" + self.assertEqual(self.profile.getPlotWidget(), self.plot) + + # Add a scatter plot + self.plot.addScatter( + x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.)) + self.plot.resetZoom(dataMargins=(.1, .1, .1, .1)) + self.qapp.processEvents() + + # Check that there is no profile + self.assertIsNone(self.profile.getProfileValues()) + self.assertIsNone(self.profile.getProfilePoints()) + + def testHorizontalProfile(self): + """Test ScatterProfileToolBar horizontal profile""" + nPoints = 8 + self.profile.setNPoints(nPoints) + self.assertEqual(self.profile.getNPoints(), nPoints) + + # Add a scatter plot + self.plot.addScatter( + x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.)) + self.plot.resetZoom(dataMargins=(.1, .1, .1, .1)) + self.qapp.processEvents() + + # Activate Horizontal profile + hlineAction = self.profile.actions()[0] + hlineAction.trigger() + self.qapp.processEvents() + + # Set a ROI profile + roi = roi_items.HorizontalLineROI() + roi.setPosition(0.5) + self.profile._getRoiManager().addRoi(roi) + + # Wait for async interpolator init + for _ in range(10): + self.qWait(200) + if not self.profile.hasPendingOperations(): + break + + self.assertIsNotNone(self.profile.getProfileValues()) + points = self.profile.getProfilePoints() + self.assertEqual(len(points), nPoints) + + # Check that profile has same limits than Plot + xLimits = self.plot.getXAxis().getLimits() + self.assertEqual(points[0, 0], xLimits[0]) + self.assertEqual(points[-1, 0], xLimits[1]) + + # Clear the profile + clearAction = self.profile.actions()[-1] + clearAction.trigger() + self.qapp.processEvents() + + self.assertIsNone(self.profile.getProfileValues()) + self.assertIsNone(self.profile.getProfilePoints()) + self.assertEqual(self.profile.getProfileTitle(), '') + + def testVerticalProfile(self): + """Test ScatterProfileToolBar vertical profile""" + nPoints = 8 + self.profile.setNPoints(nPoints) + self.assertEqual(self.profile.getNPoints(), nPoints) + + # Add a scatter plot + self.plot.addScatter( + x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.)) + self.plot.resetZoom(dataMargins=(.1, .1, .1, .1)) + self.qapp.processEvents() + + # Activate vertical profile + vlineAction = self.profile.actions()[1] + vlineAction.trigger() + self.qapp.processEvents() + + # Set a ROI profile + roi = roi_items.VerticalLineROI() + roi.setPosition(0.5) + self.profile._getRoiManager().addRoi(roi) + + # Wait for async interpolator init + for _ in range(10): + self.qWait(200) + if not self.profile.hasPendingOperations(): + break + + self.assertIsNotNone(self.profile.getProfileValues()) + points = self.profile.getProfilePoints() + self.assertEqual(len(points), nPoints) + + # Check that profile has same limits than Plot + yLimits = self.plot.getYAxis().getLimits() + self.assertEqual(points[0, 1], yLimits[0]) + self.assertEqual(points[-1, 1], yLimits[1]) + + # Check that profile limits are updated when changing limits + self.plot.getYAxis().setLimits(yLimits[0] + 1, yLimits[1] + 10) + self.qapp.processEvents() + yLimits = self.plot.getYAxis().getLimits() + points = self.profile.getProfilePoints() + self.assertEqual(points[0, 1], yLimits[0]) + self.assertEqual(points[-1, 1], yLimits[1]) + + # Clear the plot + self.plot.clear() + self.qapp.processEvents() + self.assertIsNone(self.profile.getProfileValues()) + self.assertIsNone(self.profile.getProfilePoints()) + + def testLineProfile(self): + """Test ScatterProfileToolBar line profile""" + nPoints = 8 + self.profile.setNPoints(nPoints) + self.assertEqual(self.profile.getNPoints(), nPoints) + + # Activate line profile + lineAction = self.profile.actions()[2] + lineAction.trigger() + self.qapp.processEvents() + + # Add a scatter plot + self.plot.addScatter( + x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.)) + self.plot.resetZoom(dataMargins=(.1, .1, .1, .1)) + self.qapp.processEvents() + + # Set a ROI profile + roi = roi_items.LineROI() + roi.setEndPoints(numpy.array([0., 0.]), numpy.array([1., 1.])) + self.profile._getRoiManager().addRoi(roi) + + # Wait for async interpolator init + for _ in range(10): + self.qWait(200) + if not self.profile.hasPendingOperations(): + break + + self.assertIsNotNone(self.profile.getProfileValues()) + points = self.profile.getProfilePoints() + self.assertEqual(len(points), nPoints) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase( + TestScatterProfileToolBar)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testPlotTools.py b/silx/gui/plot/tools/test/testTools.py index 3d5849f..810b933 100644 --- a/silx/gui/plot/test/testPlotTools.py +++ b/silx/gui/plot/tools/test/testTools.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,22 +22,23 @@ # THE SOFTWARE. # # ###########################################################################*/ -"""Basic tests for PlotTools""" +"""Basic tests for silx.gui.plot.tools package""" __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "17/01/2018" +__date__ = "02/03/2018" -import numpy +import functools import unittest +import numpy -from silx.utils.testutils import ParametricTestCase, TestLogging -from silx.gui.test.utils import ( - qWaitForWindowExposedAndActivate, TestCaseQt, getQToolButtonFromAction) +from silx.utils.testutils import TestLogging +from silx.gui.test.utils import qWaitForWindowExposedAndActivate from silx.gui import qt -from silx.gui.plot import Plot2D, PlotWindow, PlotTools -from .utils import PlotWidgetTestCase +from silx.gui.plot import PlotWindow +from silx.gui.plot import tools +from silx.gui.plot.test.utils import PlotWidgetTestCase # Makes sure a QApplication exists @@ -99,7 +100,7 @@ class TestPositionInfo(PlotWidgetTestCase): for index, name in enumerate(converterNames): self.assertEqual(converters[index][0], name) - with TestLogging(PlotTools.__name__, **kwargs): + with TestLogging(tools.__name__, **kwargs): # Move mouse to center center = self.plot.size() / 2 self.mouseMove(self.plot, pos=(center.width(), center.height())) @@ -108,7 +109,7 @@ class TestPositionInfo(PlotWidgetTestCase): def testDefaultConverters(self): """Test PositionInfo with default converters""" - positionWidget = PlotTools.PositionInfo(plot=self.plot) + positionWidget = tools.PositionInfo(plot=self.plot) self._test(positionWidget, ('X', 'Y')) def testCustomConverters(self): @@ -118,8 +119,8 @@ class TestPositionInfo(PlotWidgetTestCase): ('Radius', lambda x, y: numpy.sqrt(x * x + y * y)), ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x))) ] - positionWidget = PlotTools.PositionInfo(plot=self.plot, - converters=converters) + positionWidget = tools.PositionInfo(plot=self.plot, + converters=converters) self._test(positionWidget, ('Coords', 'Radius', 'Angle')) def testFailingConverters(self): @@ -127,70 +128,44 @@ class TestPositionInfo(PlotWidgetTestCase): def raiseException(x, y): raise RuntimeError() - positionWidget = PlotTools.PositionInfo( + positionWidget = tools.PositionInfo( plot=self.plot, converters=[('Exception', raiseException)]) self._test(positionWidget, ['Exception'], error=2) + def testUpdate(self): + """Test :meth:`PositionInfo.updateInfo`""" + calls = [] -class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase): - """Tests for ProfileToolBar widget.""" + def update(calls, x, y): # Get number of calls + calls.append((x, y)) + return len(calls) - def setUp(self): - super(TestPixelIntensitiesHisto, self).setUp() - self.image = numpy.random.rand(100, 100) - self.plotImage = Plot2D() - self.plotImage.getIntensityHistogramAction().setVisible(True) + positionWidget = tools.PositionInfo( + plot=self.plot, + converters=[('Call count', functools.partial(update, calls))]) - def tearDown(self): - del self.plotImage - super(TestPixelIntensitiesHisto, self).tearDown() - - def testShowAndHide(self): - """Simple test that the plot is showing and hiding when activating the - action""" - self.plotImage.addImage(self.image, origin=(0, 0), legend='sino') - self.plotImage.show() - - histoAction = self.plotImage.getIntensityHistogramAction() - - # test the pixel intensity diagram is showing - button = getQToolButtonFromAction(histoAction) - self.assertIsNot(button, None) - self.mouseMove(button) - self.mouseClick(button, qt.Qt.LeftButton) - self.qapp.processEvents() - self.assertTrue(histoAction.getHistogramPlotWidget().isVisible()) + positionWidget.updateInfo() + self.assertEqual(len(calls), 1) - # test the pixel intensity diagram is hiding - self.qapp.setActiveWindow(self.plotImage) - self.qapp.processEvents() - self.mouseMove(button) - self.mouseClick(button, qt.Qt.LeftButton) - self.qapp.processEvents() - self.assertFalse(histoAction.getHistogramPlotWidget().isVisible()) - - def testImageFormatInput(self): - """Test multiple type as image input""" - typesToTest = [numpy.uint8, numpy.int8, numpy.int16, numpy.int32, - numpy.float32, numpy.float64] - self.plotImage.addImage(self.image, origin=(0, 0), legend='sino') - self.plotImage.show() - button = getQToolButtonFromAction( - self.plotImage.getIntensityHistogramAction()) - self.mouseMove(button) - self.mouseClick(button, qt.Qt.LeftButton) - self.qapp.processEvents() - for typeToTest in typesToTest: - with self.subTest(typeToTest=typeToTest): - self.plotImage.addImage(self.image.astype(typeToTest), - origin=(0, 0), legend='sino') + +class TestPlotToolsToolbars(PlotWidgetTestCase): + """Tests toolbars from silx.gui.plot.tools""" + + def test(self): + """"Add all toolbars""" + for tbClass in (tools.InteractiveModeToolBar, + tools.ImageToolBar, + tools.CurveToolBar, + tools.OutputToolBar): + tb = tbClass(parent=self.plot, plot=self.plot) + self.plot.addToolBar(tb) def suite(): test_suite = unittest.TestSuite() # test_suite.addTest(positionInfoTestSuite) - for testClass in (TestPositionInfo, TestPixelIntensitiesHisto): + for testClass in (TestPositionInfo, TestPlotToolsToolbars): test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( testClass)) return test_suite diff --git a/silx/gui/plot/tools/toolbars.py b/silx/gui/plot/tools/toolbars.py new file mode 100644 index 0000000..28fb7f9 --- /dev/null +++ b/silx/gui/plot/tools/toolbars.py @@ -0,0 +1,356 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides toolbars that work with :class:`PlotWidget`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "01/03/2018" + + +from ... import qt +from .. import actions +from ..PlotWidget import PlotWidget +from .. import PlotToolButtons + + +class InteractiveModeToolBar(qt.QToolBar): + """Toolbar with interactive mode actions + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Plot Interaction'): + super(InteractiveModeToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._zoomModeAction = actions.mode.ZoomModeAction( + parent=self, plot=plot) + self.addAction(self._zoomModeAction) + + self._panModeAction = actions.mode.PanModeAction( + parent=self, plot=plot) + self.addAction(self._panModeAction) + + def getZoomModeAction(self): + """Returns the zoom mode QAction. + + :rtype: PlotAction + """ + return self._zoomModeAction + + def getPanModeAction(self): + """Returns the pan mode QAction + + :rtype: PlotAction + """ + return self._panModeAction + + +class OutputToolBar(qt.QToolBar): + """Toolbar providing icons to copy, save and print a PlotWidget + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Plot Output'): + super(OutputToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._copyAction = actions.io.CopyAction(parent=self, plot=plot) + self.addAction(self._copyAction) + + self._saveAction = actions.io.SaveAction(parent=self, plot=plot) + self.addAction(self._saveAction) + + self._printAction = actions.io.PrintAction(parent=self, plot=plot) + self.addAction(self._printAction) + + def getCopyAction(self): + """Returns the QAction performing copy to clipboard of the PlotWidget + + :rtype: PlotAction + """ + return self._copyAction + + def getSaveAction(self): + """Returns the QAction performing save to file of the PlotWidget + + :rtype: PlotAction + """ + return self._saveAction + + def getPrintAction(self): + """Returns the QAction performing printing of the PlotWidget + + :rtype: PlotAction + """ + return self._printAction + + +class ImageToolBar(qt.QToolBar): + """Toolbar providing PlotAction suited when displaying images + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Image'): + super(ImageToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._resetZoomAction = actions.control.ResetZoomAction( + parent=self, plot=plot) + self.addAction(self._resetZoomAction) + + self._colormapAction = actions.control.ColormapAction( + parent=self, plot=plot) + self.addAction(self._colormapAction) + + self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton( + parent=self, plot=plot) + self.addWidget(self._keepDataAspectRatioButton) + + self._yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton( + parent=self, plot=plot) + self.addWidget(self._yAxisInvertedButton) + + def getResetZoomAction(self): + """Returns the QAction to reset the zoom. + + :rtype: PlotAction + """ + return self._resetZoomAction + + def getColormapAction(self): + """Returns the QAction to control the colormap. + + :rtype: PlotAction + """ + return self._colormapAction + + def getKeepDataAspectRatioButton(self): + """Returns the QToolButton controlling data aspect ratio. + + :rtype: QToolButton + """ + return self._keepDataAspectRatioButton + + def getYAxisInvertedButton(self): + """Returns the QToolButton controlling Y axis orientation. + + :rtype: QToolButton + """ + return self._yAxisInvertedButton + + +class CurveToolBar(qt.QToolBar): + """Toolbar providing PlotAction suited when displaying curves + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Image'): + super(CurveToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._resetZoomAction = actions.control.ResetZoomAction( + parent=self, plot=plot) + self.addAction(self._resetZoomAction) + + self._xAxisAutoScaleAction = actions.control.XAxisAutoScaleAction( + parent=self, plot=plot) + self.addAction(self._xAxisAutoScaleAction) + + self._yAxisAutoScaleAction = actions.control.YAxisAutoScaleAction( + parent=self, plot=plot) + self.addAction(self._yAxisAutoScaleAction) + + self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction( + parent=self, plot=plot) + self.addAction(self._xAxisLogarithmicAction) + + self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction( + parent=self, plot=plot) + self.addAction(self._yAxisLogarithmicAction) + + self._gridAction = actions.control.GridAction( + parent=self, plot=plot) + self.addAction(self._gridAction) + + self._curveStyleAction = actions.control.CurveStyleAction( + parent=self, plot=plot) + self.addAction(self._curveStyleAction) + + def getResetZoomAction(self): + """Returns the QAction to reset the zoom. + + :rtype: PlotAction + """ + return self._resetZoomAction + + def getXAxisAutoScaleAction(self): + """Returns the QAction to toggle X axis autoscale. + + :rtype: PlotAction + """ + return self._xAxisAutoScaleAction + + def getYAxisAutoScaleAction(self): + """Returns the QAction to toggle Y axis autoscale. + + :rtype: PlotAction + """ + return self._yAxisAutoScaleAction + + def getXAxisLogarithmicAction(self): + """Returns the QAction to toggle X axis log/linear scale. + + :rtype: PlotAction + """ + return self._xAxisLogarithmicAction + + def getYAxisLogarithmicAction(self): + """Returns the QAction to toggle Y axis log/linear scale. + + :rtype: PlotAction + """ + return self._yAxisLogarithmicAction + + def getGridAction(self): + """Returns the action to toggle the plot grid. + + :rtype: PlotAction + """ + return self._gridAction + + def getCurveStyleAction(self): + """Returns the QAction to change the style of all curves. + + :rtype: PlotAction + """ + return self._curveStyleAction + + +class ScatterToolBar(qt.QToolBar): + """Toolbar providing PlotAction suited when displaying scatter plot + + :param parent: See :class:`QWidget` + :param silx.gui.plot.PlotWidget plot: PlotWidget to control + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, plot=None, title='Scatter Tools'): + super(ScatterToolBar, self).__init__(title, parent) + + assert isinstance(plot, PlotWidget) + + self._resetZoomAction = actions.control.ResetZoomAction( + parent=self, plot=plot) + self.addAction(self._resetZoomAction) + + self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction( + parent=self, plot=plot) + self.addAction(self._xAxisLogarithmicAction) + + self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction( + parent=self, plot=plot) + self.addAction(self._yAxisLogarithmicAction) + + self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton( + parent=self, plot=plot) + self.addWidget(self._keepDataAspectRatioButton) + + self._gridAction = actions.control.GridAction( + parent=self, plot=plot) + self.addAction(self._gridAction) + + self._colormapAction = actions.control.ColormapAction( + parent=self, plot=plot) + self.addAction(self._colormapAction) + + self._symbolToolButton = PlotToolButtons.SymbolToolButton( + parent=self, plot=plot) + self.addWidget(self._symbolToolButton) + + def getResetZoomAction(self): + """Returns the QAction to reset the zoom. + + :rtype: PlotAction + """ + return self._resetZoomAction + + def getXAxisLogarithmicAction(self): + """Returns the QAction to toggle X axis log/linear scale. + + :rtype: PlotAction + """ + return self._xAxisLogarithmicAction + + def getYAxisLogarithmicAction(self): + """Returns the QAction to toggle Y axis log/linear scale. + + :rtype: PlotAction + """ + return self._yAxisLogarithmicAction + + def getGridAction(self): + """Returns the action to toggle the plot grid. + + :rtype: PlotAction + """ + return self._gridAction + + def getColormapAction(self): + """Returns the QAction to control the colormap. + + :rtype: PlotAction + """ + return self._colormapAction + + def getSymbolToolButton(self): + """Returns the QToolButton controlling symbol size and marker. + + :rtype: SymbolToolButton + """ + return self._symbolToolButton + + def getKeepDataAspectRatioButton(self): + """Returns the QToolButton controlling data aspect ratio. + + :rtype: QToolButton + """ + return self._keepDataAspectRatioButton diff --git a/silx/gui/plot/utils/axis.py b/silx/gui/plot/utils/axis.py index 80e1dc4..fae50b4 100644 --- a/silx/gui/plot/utils/axis.py +++ b/silx/gui/plot/utils/axis.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -65,18 +65,14 @@ class SyncAxes(object): """ object.__init__(self) self.__locked = False - self.__axes = [] + self.__axisRefs = [] self.__syncLimits = syncLimits self.__syncScale = syncScale self.__syncDirection = syncDirection self.__callbacks = None - qtCallback = silxWeakref.WeakMethodProxy(self.__deleteAxisQt) for axis in axes: - ref = weakref.ref(axis) - self.__axes.append(ref) - callback = functools.partial(qtCallback, ref) - axis.destroyed.connect(callback) + self.__axisRefs.append(weakref.ref(axis)) self.start() @@ -91,9 +87,13 @@ class SyncAxes(object): raise RuntimeError("Axes already synchronized") self.__callbacks = {} + axes = self.__getAxes() + if len(axes) == 0: + raise RuntimeError('No axis to synchronize') + # register callback for further sync - for refAxis in self.__axes: - axis = refAxis() + for axis in axes: + refAxis = weakref.ref(axis) callbacks = [] if self.__syncLimits: # the weakref is needed to be able ignore self references @@ -120,8 +120,8 @@ class SyncAxes(object): self.__callbacks[refAxis] = callbacks # sync the current state - refMainAxis = self.__axes[0] - mainAxis = refMainAxis() + mainAxis = axes[0] + refMainAxis = weakref.ref(mainAxis) if self.__syncLimits: self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits()) if self.__syncScale: @@ -129,23 +129,16 @@ class SyncAxes(object): if self.__syncDirection: self.__axisInvertedChanged(refMainAxis, mainAxis.isInverted()) - def __deleteAxis(self, ref): - _logger.debug("Delete axes ref %s", ref) - self.__axes.remove(ref) - del self.__callbacks[ref] - - def __deleteAxisQt(self, ref, qobject): - self.__deleteAxis(ref) - def stop(self): """Stop the synchronization of the axes""" if self.__callbacks is None: raise RuntimeError("Axes not synchronized") for ref, callbacks in self.__callbacks.items(): - axes = ref() - for sigName, callback in callbacks: - sig = getattr(axes, sigName) - sig.disconnect(callback) + axis = ref() + if axis is not None: + for sigName, callback in callbacks: + sig = getattr(axis, sigName) + sig.disconnect(callback) self.__callbacks = None def __del__(self): @@ -154,6 +147,14 @@ class SyncAxes(object): if self.__callbacks is not None: self.stop() + def __getAxes(self): + """Returns list of existing axes. + + :rtype: List[Axis] + """ + axes = [ref() for ref in self.__axisRefs] + return [axis for axis in axes if axis is not None] + @contextmanager def __inhibitSignals(self): self.__locked = True @@ -161,8 +162,7 @@ class SyncAxes(object): self.__locked = False def __otherAxes(self, changedAxis): - for axis in self.__axes: - axis = axis() + for axis in self.__getAxes(): if axis is changedAxis: continue yield axis diff --git a/silx/gui/plot3d/Plot3DWidget.py b/silx/gui/plot3d/Plot3DWidget.py index 15e2356..53ff895 100644 --- a/silx/gui/plot3d/Plot3DWidget.py +++ b/silx/gui/plot3d/Plot3DWidget.py @@ -28,15 +28,15 @@ from __future__ import absolute_import __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "26/01/2017" +__date__ = "24/04/2018" import logging from silx.gui import qt -from silx.gui.plot.Colors import rgba +from silx.gui.colors import rgba from . import actions -from .._utils import convertArrayToQImage +from ..utils._image import convertArrayToQImage from .. import _glutils as glu from .scene import interaction, primitives, transform diff --git a/silx/gui/plot3d/SFViewParamTree.py b/silx/gui/plot3d/SFViewParamTree.py index 314e5a1..bb81465 100644 --- a/silx/gui/plot3d/SFViewParamTree.py +++ b/silx/gui/plot3d/SFViewParamTree.py @@ -30,7 +30,7 @@ from __future__ import absolute_import __authors__ = ["D. N."] __license__ = "MIT" -__date__ = "02/10/2017" +__date__ = "24/04/2018" import logging import sys @@ -40,7 +40,7 @@ import numpy from silx.gui import qt from silx.gui.icons import getQIcon -from silx.gui.plot.Colormap import Colormap +from silx.gui.colors import Colormap from silx.gui.widgets.FloatEdit import FloatEdit from .ScalarFieldView import Isosurface @@ -1024,13 +1024,13 @@ class IsoSurfaceAddRemoveWidget(qt.QWidget): layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) - addBtn = qt.QToolButton() + addBtn = qt.QToolButton(self) addBtn.setText('+') addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly) layout.addWidget(addBtn) addBtn.clicked.connect(self.__addClicked) - removeBtn = qt.QToolButton() + removeBtn = qt.QToolButton(self) removeBtn.setText('-') removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly) layout.addWidget(removeBtn) diff --git a/silx/gui/plot3d/ScalarFieldView.py b/silx/gui/plot3d/ScalarFieldView.py index a41999b..e5e680c 100644 --- a/silx/gui/plot3d/ScalarFieldView.py +++ b/silx/gui/plot3d/ScalarFieldView.py @@ -32,7 +32,7 @@ from __future__ import absolute_import __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "10/01/2017" +__date__ = "14/06/2018" import re import logging @@ -42,8 +42,8 @@ from collections import deque import numpy from silx.gui import qt, icons -from silx.gui.plot.Colors import rgba -from silx.gui.plot.Colormap import Colormap +from silx.gui.colors import rgba +from silx.gui.colors import Colormap from silx.math.marchingcubes import MarchingCubes from silx.math.combo import min_max @@ -643,7 +643,7 @@ class CutPlane(qt.QObject): """Returns the colormap set by :meth:`setColormap`. :return: The colormap - :rtype: ~silx.gui.plot.Colormap.Colormap + :rtype: ~silx.gui.colors.Colormap """ return self._colormap @@ -660,7 +660,7 @@ class CutPlane(qt.QObject): :param name: Name of the colormap in 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. Or Colormap object. - :type name: str or ~silx.gui.plot.Colormap.Colormap + :type name: str or ~silx.gui.colors.Colormap :param str norm: Colormap mapping: 'linear' or 'log'. :param float vmin: The minimum value of the range or None for autoscale :param float vmax: The maximum value of the range or None for autoscale diff --git a/silx/gui/plot3d/SceneWidget.py b/silx/gui/plot3d/SceneWidget.py index 4e75515..f005dec 100644 --- a/silx/gui/plot3d/SceneWidget.py +++ b/silx/gui/plot3d/SceneWidget.py @@ -28,14 +28,14 @@ from __future__ import absolute_import __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "26/10/2017" +__date__ = "24/04/2018" import numpy import weakref from silx.third_party import enum from .. import qt -from ..plot.Colors import rgba +from ..colors import rgba from .Plot3DWidget import Plot3DWidget from . import items @@ -567,7 +567,7 @@ class SceneWidget(Plot3DWidget): def clearItems(self): """Remove all item from :class:`SceneWidget`.""" - return self.getSceneGroup().clear() + return self.getSceneGroup().clearItems() # Colors diff --git a/silx/gui/plot3d/SceneWindow.py b/silx/gui/plot3d/SceneWindow.py index 5121a17..56fb21f 100644 --- a/silx/gui/plot3d/SceneWindow.py +++ b/silx/gui/plot3d/SceneWindow.py @@ -163,6 +163,13 @@ class SceneWindow(qt.QMainWindow): """ return self._sceneWidget + def getGroupResetWidget(self): + """Returns the :class:`GroupPropertiesWidget` of this window. + + :rtype: GroupPropertiesWidget + """ + return self._sceneGroupResetWidget + def getParamTreeView(self): """Returns the :class:`ParamTreeView` of this window. @@ -173,20 +180,20 @@ class SceneWindow(qt.QMainWindow): def getInteractiveModeToolBar(self): """Returns the interactive mode toolbar. - :rtype: InteractiveModeToolBar + :rtype: ~silx.gui.plot3d.tools.InteractiveModeToolBar """ return self._interactiveModeToolBar def getViewpointToolBar(self): """Returns the viewpoint toolbar. - :rtype: ViewpointToolBar + :rtype: ~silx.gui.plot3d.tools.ViewpointToolBar """ return self._viewpointToolBar def getOutputToolBar(self): """Returns the output toolbar. - :rtype: OutputToolBar + :rtype: ~silx.gui.plot3d.tools.OutputToolBar """ return self._outputToolBar diff --git a/silx/gui/plot3d/_model/items.py b/silx/gui/plot3d/_model/items.py index 7009ea1..02485fe 100644 --- a/silx/gui/plot3d/_model/items.py +++ b/silx/gui/plot3d/_model/items.py @@ -30,18 +30,19 @@ from __future__ import absolute_import, division __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "11/01/2018" +__date__ = "24/04/2018" import functools +import logging import weakref import numpy from silx.third_party import six -from ..._utils import convertArrayToQImage -from ...plot.Colormap import preferredColormaps +from ...utils._image import convertArrayToQImage +from ...colors import preferredColormaps from ... import qt, icons from .. import items from ..items.volume import Isosurface, CutPlane @@ -50,6 +51,9 @@ from ..items.volume import Isosurface, CutPlane from .core import AngleDegreeRow, BaseRow, ColorProxyRow, ProxyRow, StaticRow +_logger = logging.getLogger(__name__) + + class _DirectionalLightProxy(qt.QObject): """Proxy to handle directional light with angles rather than vector. """ @@ -472,7 +476,11 @@ class DataItem3DTransformRow(StaticRow): """ item = self.item() if item is not None: - item.setScale(scale.x(), scale.y(), scale.z()) + sx, sy, sz = scale.x(), scale.y(), scale.z() + if sx == 0. or sy == 0. or sz == 0.: + _logger.warning('Cannot set scale to 0: ignored') + else: + item.setScale(scale.x(), scale.y(), scale.z()) class GroupItemRow(Item3DRow): @@ -519,7 +527,7 @@ class GroupItemRow(Item3DRow): # Find item for row in self.children(): - if row.item() is item: + if isinstance(row, Item3DRow) and row.item() is item: self.removeRow(row) break # Got it else: @@ -764,7 +772,7 @@ class ColormapRow(_ColormapBaseProxyRow): def _getName(self): """Proxy for :meth:`Colormap.getName`""" - if self._colormap is not None: + if self._colormap is not None and self._colormap.getName() is not None: return self._colormap.getName().title() else: return '' diff --git a/silx/gui/plot3d/actions/io.py b/silx/gui/plot3d/actions/io.py index 5126000..f30abeb 100644 --- a/silx/gui/plot3d/actions/io.py +++ b/silx/gui/plot3d/actions/io.py @@ -39,12 +39,11 @@ import os import numpy -from silx.gui import qt -from silx.gui.plot.actions.io import PrintAction as _PrintAction +from silx.gui import qt, printer from silx.gui.icons import getQIcon from .Plot3DAction import Plot3DAction from ..utils import mng -from ..._utils import convertQImageToArray +from ...utils._image import convertQImageToArray _logger = logging.getLogger(__name__) @@ -157,11 +156,7 @@ class PrintAction(Plot3DAction): :rtype: QPrinter """ - # TODO This is a hack to sync with silx plot PrintAction - # This needs to be centralized - if _PrintAction._printer is None: - _PrintAction._printer = qt.QPrinter() - return _PrintAction._printer + return printer.getDefaultPrinter() def _triggered(self, checked=False): plot3d = self.getPlot3DWidget() diff --git a/silx/gui/plot3d/items/__init__.py b/silx/gui/plot3d/items/__init__.py index b50ea5a..b2a9dab 100644 --- a/silx/gui/plot3d/items/__init__.py +++ b/silx/gui/plot3d/items/__init__.py @@ -38,6 +38,6 @@ from .mixins import (ColormapMixIn, InterpolationMixIn, # noqa PlaneMixIn, SymbolMixIn) # noqa from .clipplane import ClipPlane # noqa from .image import ImageData, ImageRgba # noqa -from .mesh import Mesh # noqa +from .mesh import Mesh, Box, Cylinder, Hexagon # noqa from .scatter import Scatter2D, Scatter3D # noqa from .volume import ScalarField3D # noqa diff --git a/silx/gui/plot3d/items/mesh.py b/silx/gui/plot3d/items/mesh.py index 8535728..12a3941 100644 --- a/silx/gui/plot3d/items/mesh.py +++ b/silx/gui/plot3d/items/mesh.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,7 @@ import numpy from ..scene import primitives from .core import DataItem3D, ItemChangedType +from ..scene.transform import Rotate class Mesh(DataItem3D): @@ -143,3 +144,396 @@ class Mesh(DataItem3D): :rtype: str """ return self._mesh.drawMode + + +class _CylindricalVolume(DataItem3D): + """Class that represents a volume with a rotational symmetry along z + + :param parent: The View widget this item belongs to. + """ + + def __init__(self, parent=None): + DataItem3D.__init__(self, parent=parent) + self._mesh = None + + def _setData(self, position, radius, height, angles, color, flatFaces, + rotation): + """Set volume geometry data. + + :param numpy.ndarray position: + Center position (x, y, z) of each volume as (N, 3) array. + :param float radius: External radius ot the volume. + :param float height: Height of the volume(s). + :param numpy.ndarray angles: Angles of the edges. + :param numpy.array color: RGB color of the volume(s). + :param bool flatFaces: + If the volume as flat faces or not. Used for normals calculation. + """ + + self._getScenePrimitive().children = [] # Remove any previous mesh + + if position is None or len(position) == 0: + self._mesh = 0 + else: + volume = numpy.empty(shape=(len(angles) - 1, 12, 3), + dtype=numpy.float32) + normal = numpy.empty(shape=(len(angles) - 1, 12, 3), + dtype=numpy.float32) + + for i in range(0, len(angles) - 1): + """ + c6 + /\ + / \ + / \ + c4|------|c5 + | \ | + | \ | + | \ | + | \ | + c2|------|c3 + \ / + \ / + \/ + c1 + """ + c1 = numpy.array([0, 0, -height/2]) + c1 = rotation.transformPoint(c1) + c2 = numpy.array([radius * numpy.cos(angles[i]), + radius * numpy.sin(angles[i]), + -height/2]) + c2 = rotation.transformPoint(c2) + c3 = numpy.array([radius * numpy.cos(angles[i+1]), + radius * numpy.sin(angles[i+1]), + -height/2]) + c3 = rotation.transformPoint(c3) + c4 = numpy.array([radius * numpy.cos(angles[i]), + radius * numpy.sin(angles[i]), + height/2]) + c4 = rotation.transformPoint(c4) + c5 = numpy.array([radius * numpy.cos(angles[i+1]), + radius * numpy.sin(angles[i+1]), + height/2]) + c5 = rotation.transformPoint(c5) + c6 = numpy.array([0, 0, height/2]) + c6 = rotation.transformPoint(c6) + + volume[i] = numpy.array([c1, c3, c2, + c2, c3, c4, + c3, c5, c4, + c4, c5, c6]) + if flatFaces: + normal[i] = numpy.array([numpy.cross(c3-c1, c2-c1), # c1 + numpy.cross(c2-c3, c1-c3), # c3 + numpy.cross(c1-c2, c3-c2), # c2 + numpy.cross(c3-c2, c4-c2), # c2 + numpy.cross(c4-c3, c2-c3), # c3 + numpy.cross(c2-c4, c3-c4), # c4 + numpy.cross(c5-c3, c4-c3), # c3 + numpy.cross(c4-c5, c3-c5), # c5 + numpy.cross(c3-c4, c5-c4), # c4 + numpy.cross(c5-c4, c6-c4), # c4 + numpy.cross(c6-c5, c5-c5), # c5 + numpy.cross(c4-c6, c5-c6)]) # c6 + else: + normal[i] = numpy.array([numpy.cross(c3-c1, c2-c1), + numpy.cross(c2-c3, c1-c3), + numpy.cross(c1-c2, c3-c2), + c2-c1, c3-c1, c4-c6, # c2 c2 c4 + c3-c1, c5-c6, c4-c6, # c3 c5 c4 + numpy.cross(c5-c4, c6-c4), + numpy.cross(c6-c5, c5-c5), + numpy.cross(c4-c6, c5-c6)]) + + # Multiplication according to the number of positions + vertices = numpy.tile(volume.reshape(-1, 3), (len(position), 1))\ + .reshape((-1, 3)) + normals = numpy.tile(normal.reshape(-1, 3), (len(position), 1))\ + .reshape((-1, 3)) + + # Translations + numpy.add(vertices, numpy.tile(position, (1, (len(angles)-1) * 12)) + .reshape((-1, 3)), out=vertices) + + # Colors + if numpy.ndim(color) == 2: + color = numpy.tile(color, (1, 12 * (len(angles) - 1)))\ + .reshape(-1, 3) + + self._mesh = primitives.Mesh3D( + vertices, color, normals, mode='triangles', copy=False) + self._getScenePrimitive().children.append(self._mesh) + + self.sigItemChanged.emit(ItemChangedType.DATA) + + +class Box(_CylindricalVolume): + """Description of a box. + + Can be used to draw one box or many similar boxes. + + :param parent: The View widget this item belongs to. + """ + + def __init__(self, parent=None): + super(Box, self).__init__(parent) + self.position = None + self.size = None + self.color = None + self.rotation = None + self.setData() + + def setData(self, size=(1, 1, 1), color=(1, 1, 1), + position=(0, 0, 0), rotation=(0, (0, 0, 0))): + """ + Set Box geometry data. + + :param numpy.array size: Size (dx, dy, dz) of the box(es). + :param numpy.array color: RGB color of the box(es). + :param numpy.ndarray position: + Center position (x, y, z) of each box as a (N, 3) array. + :param tuple(float, array) rotation: + Angle (in degrees) and axis of rotation. + If (0, (0, 0, 0)) (default), the hexagonal faces are on + xy plane and a side face is aligned with x axis. + """ + self.position = numpy.atleast_2d(numpy.array(position, copy=True)) + self.size = numpy.array(size, copy=True) + self.color = numpy.array(color, copy=True) + self.rotation = Rotate(rotation[0], + rotation[1][0], rotation[1][1], rotation[1][2]) + + assert (numpy.ndim(self.color) == 1 or + len(self.color) == len(self.position)) + + diagonal = numpy.sqrt(self.size[0]**2 + self.size[1]**2) + alpha = 2 * numpy.arcsin(self.size[1] / diagonal) + beta = 2 * numpy.arcsin(self.size[0] / diagonal) + angles = numpy.array([0, + alpha, + alpha + beta, + alpha + beta + alpha, + 2 * numpy.pi]) + numpy.subtract(angles, 0.5 * alpha, out=angles) + self._setData(self.position, + numpy.sqrt(self.size[0]**2 + self.size[1]**2)/2, + self.size[2], + angles, + self.color, + True, + self.rotation) + + def getPosition(self, copy=True): + """Get box(es) position(s). + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: Position of the box(es) as a (N, 3) array. + :rtype: numpy.ndarray + """ + return numpy.array(self.position, copy=copy) + + def getSize(self): + """Get box(es) size. + + :return: Size (dx, dy, dz) of the box(es). + :rtype: numpy.ndarray + """ + return numpy.array(self.size, copy=True) + + def getColor(self, copy=True): + """Get box(es) color. + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: RGB color of the box(es). + :rtype: numpy.ndarray + """ + return numpy.array(self.color, copy=copy) + + +class Cylinder(_CylindricalVolume): + """Description of a cylinder. + + Can be used to draw one cylinder or many similar cylinders. + + :param parent: The View widget this item belongs to. + """ + + def __init__(self, parent=None): + super(Cylinder, self).__init__(parent) + self.position = None + self.radius = None + self.height = None + self.color = None + self.nbFaces = 0 + self.rotation = None + self.setData() + + def setData(self, radius=1, height=1, color=(1, 1, 1), nbFaces=20, + position=(0, 0, 0), rotation=(0, (0, 0, 0))): + """ + Set the cylinder geometry data + + :param float radius: Radius of the cylinder(s). + :param float height: Height of the cylinder(s). + :param numpy.array color: RGB color of the cylinder(s). + :param int nbFaces: + Number of faces for cylinder approximation (default 20). + :param numpy.ndarray position: + Center position (x, y, z) of each cylinder as a (N, 3) array. + :param tuple(float, array) rotation: + Angle (in degrees) and axis of rotation. + If (0, (0, 0, 0)) (default), the hexagonal faces are on + xy plane and a side face is aligned with x axis. + """ + self.position = numpy.atleast_2d(numpy.array(position, copy=True)) + self.radius = float(radius) + self.height = float(height) + self.color = numpy.array(color, copy=True) + self.nbFaces = int(nbFaces) + self.rotation = Rotate(rotation[0], + rotation[1][0], rotation[1][1], rotation[1][2]) + + assert (numpy.ndim(self.color) == 1 or + len(self.color) == len(self.position)) + + angles = numpy.linspace(0, 2*numpy.pi, self.nbFaces + 1) + self._setData(self.position, + self.radius, + self.height, + angles, + self.color, + False, + self.rotation) + + def getPosition(self, copy=True): + """Get cylinder(s) position(s). + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: Position(s) of the cylinder(s) as a (N, 3) array. + :rtype: numpy.ndarray + """ + return numpy.array(self.position, copy=copy) + + def getRadius(self): + """Get cylinder(s) radius. + + :return: Radius of the cylinder(s). + :rtype: float + """ + return self.radius + + def getHeight(self): + """Get cylinder(s) height. + + :return: Height of the cylinder(s). + :rtype: float + """ + return self.height + + def getColor(self, copy=True): + """Get cylinder(s) color. + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: RGB color of the cylinder(s). + :rtype: numpy.ndarray + """ + return numpy.array(self.color, copy=copy) + + +class Hexagon(_CylindricalVolume): + """Description of a uniform hexagonal prism. + + Can be used to draw one hexagonal prim or many similar hexagonal + prisms. + + :param parent: The View widget this item belongs to. + """ + + def __init__(self, parent=None): + super(Hexagon, self).__init__(parent) + self.position = None + self.radius = 0 + self.height = 0 + self.color = None + self.rotation = None + self.setData() + + def setData(self, radius=1, height=1, color=(1, 1, 1), + position=(0, 0, 0), rotation=(0, (0, 0, 0))): + """ + Set the uniform hexagonal prism geometry data + + :param float radius: External radius of the hexagonal prism + :param float height: Height of the hexagonal prism + :param numpy.array color: RGB color of the prism(s) + :param numpy.ndarray position: + Center position (x, y, z) of each prism as a (N, 3) array + :param tuple(float, array) rotation: + Angle (in degrees) and axis of rotation. + If (0, (0, 0, 0)) (default), the hexagonal faces are on + xy plane and a side face is aligned with x axis. + """ + self.position = numpy.atleast_2d(numpy.array(position, copy=True)) + self.radius = float(radius) + self.height = float(height) + self.color = numpy.array(color, copy=True) + self.rotation = Rotate(rotation[0], rotation[1][0], rotation[1][1], + rotation[1][2]) + + assert (numpy.ndim(self.color) == 1 or + len(self.color) == len(self.position)) + + angles = numpy.linspace(0, 2*numpy.pi, 7) + self._setData(self.position, + self.radius, + self.height, + angles, + self.color, + True, + self.rotation) + + def getPosition(self, copy=True): + """Get hexagonal prim(s) position(s). + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: Position(s) of hexagonal prism(s) as a (N, 3) array. + :rtype: numpy.ndarray + """ + return numpy.array(self.position, copy=copy) + + def getRadius(self): + """Get hexagonal prism(s) radius. + + :return: Radius of hexagon(s). + :rtype: float + """ + return self.radius + + def getHeight(self): + """Get hexagonal prism(s) height. + + :return: Height of hexagonal prism(s). + :rtype: float + """ + return self.height + + def getColor(self, copy=True): + """Get hexagonal prism(s) color. + + :param bool copy: + True (default) to get a copy, + False to get internal representation (do not modify!). + :return: RGB color of the hexagonal prism(s). + :rtype: numpy.ndarray + """ + return numpy.array(self.color, copy=copy) diff --git a/silx/gui/plot3d/items/mixins.py b/silx/gui/plot3d/items/mixins.py index 41ad3c3..8e96441 100644 --- a/silx/gui/plot3d/items/mixins.py +++ b/silx/gui/plot3d/items/mixins.py @@ -27,7 +27,7 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "15/11/2017" +__date__ = "24/04/2018" import collections @@ -38,7 +38,7 @@ from silx.math.combo import min_max from ...plot.items.core import ItemMixInBase from ...plot.items.core import ColormapMixIn as _ColormapMixIn from ...plot.items.core import SymbolMixIn as _SymbolMixIn -from ...plot.Colors import rgba +from ...colors import rgba from ..scene import primitives from .core import Item3DChangedType, ItemChangedType diff --git a/silx/gui/plot3d/items/volume.py b/silx/gui/plot3d/items/volume.py index a1f40f7..a7b5923 100644 --- a/silx/gui/plot3d/items/volume.py +++ b/silx/gui/plot3d/items/volume.py @@ -29,7 +29,7 @@ from __future__ import absolute_import __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "15/11/2017" +__date__ = "24/04/2018" import logging import time @@ -39,7 +39,7 @@ from silx.math.combo import min_max from silx.math.marchingcubes import MarchingCubes from ... import qt -from ...plot.Colors import rgba +from ...colors import rgba from ..scene import cutplane, primitives, transform diff --git a/silx/gui/plot3d/scene/primitives.py b/silx/gui/plot3d/scene/primitives.py index abf7dd4..af00b6d 100644 --- a/silx/gui/plot3d/scene/primitives.py +++ b/silx/gui/plot3d/scene/primitives.py @@ -27,7 +27,7 @@ from __future__ import absolute_import, division, unicode_literals __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "25/07/2016" +__date__ = "24/04/2018" import collections @@ -38,7 +38,7 @@ import string import numpy -from silx.gui.plot.Colors import rgba +from silx.gui.colors import rgba from ... import _glutils from ..._glutils import gl @@ -1246,12 +1246,12 @@ class _Points(Geometry): $clippingCall(vCameraPosition); float alpha = alphaSymbol(gl_PointCoord, vSize); - if (alpha == 0.0) { - discard; - } gl_FragColor = $valueToColorCall(vValue); gl_FragColor.a *= alpha; + if (gl_FragColor.a == 0.0) { + discard; + } } """)) diff --git a/silx/gui/plot3d/scene/text.py b/silx/gui/plot3d/scene/text.py index 903fc21..c2983d5 100644 --- a/silx/gui/plot3d/scene/text.py +++ b/silx/gui/plot3d/scene/text.py @@ -28,13 +28,13 @@ from __future__ import absolute_import, division, unicode_literals __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "17/10/2016" +__date__ = "24/04/2018" import logging import numpy -from silx.gui.plot.Colors import rgba +from silx.gui.colors import rgba from ... import _glutils from ..._glutils import gl diff --git a/silx/gui/plot3d/scene/utils.py b/silx/gui/plot3d/scene/utils.py index 04abd04..3752289 100644 --- a/silx/gui/plot3d/scene/utils.py +++ b/silx/gui/plot3d/scene/utils.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# Copyright (c) 2015-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -227,12 +227,7 @@ def trianglesNormal(positions): positions[:, 2] - positions[:, 0]) # Normalize normals - if numpy.version.version < '1.8.0': - # Debian 7 support: numpy.linalg.norm has no axis argument - norms = numpy.array(tuple(numpy.linalg.norm(vec) for vec in normals), - dtype=normals.dtype) - else: - norms = numpy.linalg.norm(normals, axis=1) + norms = numpy.linalg.norm(normals, axis=1) norms[norms == 0] = 1 return normals / norms.reshape(-1, 1) diff --git a/silx/gui/plot3d/scene/viewport.py b/silx/gui/plot3d/scene/viewport.py index 0cacbf0..41aa999 100644 --- a/silx/gui/plot3d/scene/viewport.py +++ b/silx/gui/plot3d/scene/viewport.py @@ -33,12 +33,12 @@ from __future__ import absolute_import, division, unicode_literals __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "25/07/2016" +__date__ = "24/04/2018" import numpy -from silx.gui.plot.Colors import rgba +from silx.gui.colors import rgba from ..._glutils import gl diff --git a/silx/gui/plot3d/tools/GroupPropertiesWidget.py b/silx/gui/plot3d/tools/GroupPropertiesWidget.py index 30e11de..5b0bcdb 100644 --- a/silx/gui/plot3d/tools/GroupPropertiesWidget.py +++ b/silx/gui/plot3d/tools/GroupPropertiesWidget.py @@ -28,11 +28,11 @@ from __future__ import absolute_import __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "11/01/2018" +__date__ = "24/04/2018" from ....gui import qt -from ....gui.plot.Colormap import Colormap -from ....gui.plot.ColormapDialog import ColormapDialog +from ....gui.colors import Colormap +from ....gui.dialog.ColormapDialog import ColormapDialog from ..items import SymbolMixIn, ColormapMixIn diff --git a/silx/gui/printer.py b/silx/gui/printer.py new file mode 100644 index 0000000..761fa0f --- /dev/null +++ b/silx/gui/printer.py @@ -0,0 +1,62 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a singleton QPrinter used by default by silx widgets. +""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "01/03/2018" + + +from . import qt + + +_printer = None +"""Shared QPrinter instance""" + + +def getDefaultPrinter(): + """Returns the default printer. + + This allows reusing the same QPrinter across widgets. + + :return: QPrinter + """ + global _printer + if _printer is None: + _printer = qt.QPrinter() + return _printer + + +def setDefaultPrinter(printer): + """Set the printer used by default by silx widgets. + + :param QPrinter printer: + """ + assert isinstance(printer, qt.QPrinter) + global _printer + _printer = printer diff --git a/silx/gui/qt/__init__.py b/silx/gui/qt/__init__.py index 44daa94..f7bc916 100644 --- a/silx/gui/qt/__init__.py +++ b/silx/gui/qt/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,12 +24,13 @@ # ###########################################################################*/ """Common wrapper over Python Qt bindings: -- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_, -- `PyQt4 <http://pyqt.sourceforge.net/Docs/PyQt4/>`_ or -- `PySide <http://www.pyside.org>`_. +- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_ +- `PyQt4 <http://pyqt.sourceforge.net/Docs/PyQt4/>`_ +- `PySide <http://www.pyside.org>`_ +- `PySide2 <https://wiki.qt.io/PySide2>`_ If a Qt binding is already loaded, it will use it, otherwise the different -Qt bindings are tried in this order: PyQt4, PySide, PyQt5. +Qt bindings are tried in this order: PyQt5, PyQt4, PySide, PySide2. The name of the loaded Qt binding is stored in the BINDING variable. @@ -50,7 +51,6 @@ see `qtpy <https://pypi.python.org/pypi/QtPy/>`_ which provides the namespace of PyQt5 over PyQt4 and PySide. """ -import sys from ._qt import * # noqa from ._utils import * # noqa diff --git a/silx/gui/qt/_qt.py b/silx/gui/qt/_qt.py index a54ea67..6bf7d93 100644 --- a/silx/gui/qt/_qt.py +++ b/silx/gui/qt/_qt.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -40,7 +40,7 @@ provides the namespace of PyQt5 over PyQt4, PySide and PySide2. __authors__ = ["V.A. Sole"] __license__ = "MIT" -__date__ = "11/10/2017" +__date__ = "23/05/2018" import logging @@ -78,27 +78,27 @@ elif 'PyQt4.QtCore' in sys.modules: else: # Then try Qt bindings try: - import PyQt4 # noqa + import PyQt5 # noqa except ImportError: try: - import PySide # noqa + import PyQt4 # noqa except ImportError: try: - import PyQt5 # noqa + import PySide # noqa except ImportError: try: import PySide2 # noqa except ImportError: raise ImportError( - 'No Qt wrapper found. Install PyQt4, PyQt5 or PySide2.') + 'No Qt wrapper found. Install PyQt5, PyQt4 or PySide2.') else: BINDING = 'PySide2' else: - BINDING = 'PyQt5' + BINDING = 'PySide' else: - BINDING = 'PySide' + BINDING = 'PyQt4' else: - BINDING = 'PyQt4' + BINDING = 'PyQt5' if BINDING == 'PyQt4': @@ -255,10 +255,10 @@ def exceptionHandler(type_, value, trace): The script/application willing to use it should implement code similar to: .. code-block:: python - + if __name__ == "__main__": sys.excepthook = qt.exceptionHandler - + """ _logger.error("%s %s %s", type_, value, ''.join(traceback.format_tb(trace))) msg = QMessageBox() @@ -268,4 +268,3 @@ def exceptionHandler(type_, value, trace): msg.setDetailedText(("%s " % value) + ''.join(traceback.format_tb(trace))) msg.raise_() msg.exec_() - diff --git a/silx/gui/setup.py b/silx/gui/setup.py index 8e8c796..6eb87ae 100644 --- a/silx/gui/setup.py +++ b/silx/gui/setup.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -42,6 +42,8 @@ def configuration(parent_package='', top_path=None): config.add_subpackage('plot3d') config.add_subpackage('data') config.add_subpackage('dialog') + config.add_subpackage('utils') + config.add_subpackage('utils.test') return config diff --git a/silx/gui/test/__init__.py b/silx/gui/test/__init__.py index 0d0805f..8a9a949 100644 --- a/silx/gui/test/__init__.py +++ b/silx/gui/test/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,15 +24,15 @@ # ###########################################################################*/ __authors__ = ["T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "28/11/2017" +__date__ = "24/04/2018" import logging import os import sys import unittest -from silx.test.utils import test_options +from silx.test.utils import test_options _logger = logging.getLogger(__name__) @@ -73,12 +73,14 @@ def suite(): from ..widgets import test as test_widgets from ..data import test as test_data from ..dialog import test as test_dialog + from ..utils import test as test_utils + from . import test_qt # Console tests disabled due to corruption of python environment # (see issue #538 on github) # from . import test_console from . import test_icons - from . import test_utils + from . import test_colors try: from ..plot3d.test import suite as test_plot3d_suite @@ -102,6 +104,7 @@ def suite(): test_suite.addTest(test_widgets.suite()) # test_suite.addTest(test_console.suite()) # see issue #538 on github test_suite.addTest(test_icons.suite()) + test_suite.addTest(test_colors.suite()) test_suite.addTest(test_data.suite()) test_suite.addTest(test_utils.suite()) test_suite.addTest(test_plot3d_suite()) diff --git a/silx/gui/plot/test/testColormap.py b/silx/gui/test/test_colors.py index 4888a7c..d7c205e 100644 --- a/silx/gui/plot/test/testColormap.py +++ b/silx/gui/test/test_colors.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# Copyright (c) 2015-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,16 +29,62 @@ from __future__ import absolute_import __authors__ = ["H.Payno"] __license__ = "MIT" -__date__ = "17/01/2018" +__date__ = "24/04/2018" import unittest import numpy from silx.utils.testutils import ParametricTestCase -from silx.gui.plot.Colormap import Colormap -from silx.gui.plot.Colormap import preferredColormaps, setPreferredColormaps +from silx.gui import colors +from silx.gui.colors import Colormap +from silx.gui.colors import preferredColormaps, setPreferredColormaps from silx.utils.exceptions import NotEditableError +class TestRGBA(ParametricTestCase): + """Basic tests of rgba function""" + + def testRGBA(self): + """"Test rgba function with accepted values""" + tests = { # name: (colors, expected values) + 'blue': ('blue', (0., 0., 1., 1.)), + '#010203': ('#010203', (1. / 255., 2. / 255., 3. / 255., 1.)), + '#01020304': ('#01020304', (1. / 255., 2. / 255., 3. / 255., 4. / 255.)), + '3 x uint8': (numpy.array((1, 255, 0), dtype=numpy.uint8), + (1 / 255., 1., 0., 1.)), + '4 x uint8': (numpy.array((1, 255, 0, 1), dtype=numpy.uint8), + (1 / 255., 1., 0., 1 / 255.)), + '3 x float overflow': ((3., 0.5, 1.), (1., 0.5, 1., 1.)), + } + + for name, test in tests.items(): + color, expected = test + with self.subTest(msg=name): + result = colors.rgba(color) + self.assertEqual(result, expected) + + +class TestApplyColormapToData(ParametricTestCase): + """Tests of applyColormapToData function""" + + def testApplyColormapToData(self): + """Simple test of applyColormapToData function""" + colormap = Colormap(name='gray', normalization='linear', + vmin=0, vmax=255) + + size = 10 + expected = numpy.empty((size, 4), dtype='uint8') + expected[:, 0] = numpy.arange(size, dtype='uint8') + expected[:, 1] = expected[:, 0] + expected[:, 2] = expected[:, 0] + expected[:, 3] = 255 + + for dtype in ('uint8', 'int32', 'float32', 'float64'): + with self.subTest(dtype=dtype): + array = numpy.arange(size, dtype=dtype) + result = colormap.applyToData(data=array) + self.assertTrue(numpy.all(numpy.equal(result, expected))) + + class TestDictAPI(unittest.TestCase): """Make sure the old dictionary API is working """ @@ -280,7 +326,7 @@ class TestObjectAPI(ParametricTestCase): """Test getNColors method""" # specific LUT colormap = Colormap(name=None, - colors=((0, 0, 0), (1, 1, 1)), + colors=((0., 0., 0.), (1., 1., 1.)), vmin=1000, vmax=2000) colors = colormap.getNColors() @@ -289,7 +335,7 @@ class TestObjectAPI(ParametricTestCase): ((0, 0, 0, 255), (255, 255, 255, 255))))) def testEditableMode(self): - """Make sure the colormap will raise NotEditableError when try to + """Make sure the colormap will raise NotEditableError when try to change a colormap not editable""" colormap = Colormap() colormap.setEditable(False) @@ -342,10 +388,12 @@ class TestPreferredColormaps(unittest.TestCase): def suite(): test_suite = unittest.TestSuite() - for ui in (TestDictAPI, TestObjectAPI, TestPreferredColormaps): - test_suite.addTest( - unittest.defaultTestLoader.loadTestsFromTestCase(ui)) - + loadTests = unittest.defaultTestLoader.loadTestsFromTestCase + test_suite.addTest(loadTests(TestApplyColormapToData)) + test_suite.addTest(loadTests(TestRGBA)) + test_suite.addTest(loadTests(TestDictAPI)) + test_suite.addTest(loadTests(TestObjectAPI)) + test_suite.addTest(loadTests(TestPreferredColormaps)) return test_suite diff --git a/silx/gui/utils/__init__.py b/silx/gui/utils/__init__.py new file mode 100644 index 0000000..51c4fac --- /dev/null +++ b/silx/gui/utils/__init__.py @@ -0,0 +1,29 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Miscellaneous helpers for Qt""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "09/03/2018" diff --git a/silx/gui/_utils.py b/silx/gui/utils/_image.py index d91a572..260aac3 100644 --- a/silx/gui/_utils.py +++ b/silx/gui/utils/_image.py @@ -27,7 +27,6 @@ It provides: - conversion between numpy and QImage: :func:`convertArrayToQImage`, :func:`convertQImageToArray` -- Execution of function in Qt main thread: :func:`submitToQtMainThread` """ from __future__ import division @@ -41,9 +40,7 @@ __date__ = "16/01/2017" import sys import numpy -from silx.third_party.concurrent_futures import Future - -from . import qt +from .. import qt def convertArrayToQImage(image): @@ -105,71 +102,3 @@ def convertQImageToArray(image): array = array.reshape(image.height(), -1)[:, :image.width() * 3] array.shape = image.height(), image.width(), 3 return array - - -class _QtExecutor(qt.QObject): - """Executor of tasks in Qt main thread""" - - __sigSubmit = qt.Signal(Future, object, tuple, dict) - """Signal used to run tasks.""" - - def __init__(self): - super(_QtExecutor, self).__init__(parent=None) - - # Makes sure the executor lives in the main thread - app = qt.QApplication.instance() - assert app is not None - mainThread = app.thread() - if self.thread() != mainThread: - self.moveToThread(mainThread) - - self.__sigSubmit.connect(self.__run) - - def submit(self, fn, *args, **kwargs): - """Submit fn(*args, **kwargs) to Qt main thread - - :param callable fn: Function to call in main thread - :return: Future object to retrieve result - :rtype: concurrent.future.Future - """ - future = Future() - self.__sigSubmit.emit(future, fn, args, kwargs) - return future - - def __run(self, future, fn, args, kwargs): - """Run task in Qt main thread - - :param concurrent.future.Future future: - :param callable fn: Function to run - :param tuple args: Arguments - :param dict kwargs: Keyword arguments - """ - if not future.set_running_or_notify_cancel(): - return - - try: - result = fn(*args, **kwargs) - except BaseException as e: - future.set_exception(e) - else: - future.set_result(result) - - -_executor = None -"""QObject running the tasks in main thread""" - - -def submitToQtMainThread(fn, *args, **kwargs): - """Run fn(*args, **kwargs) in Qt's main thread. - - If not called from the main thread, this is run asynchronously. - - :param callable fn: Function to call in main thread. - :return: A future object to retrieve the result - :rtype: concurrent.future.Future - """ - global _executor - if _executor is None: # Lazy-loading - _executor = _QtExecutor() - - return _executor.submit(fn, *args, **kwargs) diff --git a/silx/gui/utils/concurrent.py b/silx/gui/utils/concurrent.py new file mode 100644 index 0000000..48fff91 --- /dev/null +++ b/silx/gui/utils/concurrent.py @@ -0,0 +1,103 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module allows to run a function in Qt main thread from another thread +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "09/03/2018" + + +from silx.third_party.concurrent_futures import Future + +from .. import qt + + +class _QtExecutor(qt.QObject): + """Executor of tasks in Qt main thread""" + + __sigSubmit = qt.Signal(Future, object, tuple, dict) + """Signal used to run tasks.""" + + def __init__(self): + super(_QtExecutor, self).__init__(parent=None) + + # Makes sure the executor lives in the main thread + app = qt.QApplication.instance() + assert app is not None + mainThread = app.thread() + if self.thread() != mainThread: + self.moveToThread(mainThread) + + self.__sigSubmit.connect(self.__run) + + def submit(self, fn, *args, **kwargs): + """Submit fn(*args, **kwargs) to Qt main thread + + :param callable fn: Function to call in main thread + :return: Future object to retrieve result + :rtype: concurrent.future.Future + """ + future = Future() + self.__sigSubmit.emit(future, fn, args, kwargs) + return future + + def __run(self, future, fn, args, kwargs): + """Run task in Qt main thread + + :param concurrent.future.Future future: + :param callable fn: Function to run + :param tuple args: Arguments + :param dict kwargs: Keyword arguments + """ + if not future.set_running_or_notify_cancel(): + return + + try: + result = fn(*args, **kwargs) + except BaseException as e: + future.set_exception(e) + else: + future.set_result(result) + + +_executor = None +"""QObject running the tasks in main thread""" + + +def submitToQtMainThread(fn, *args, **kwargs): + """Run fn(args, kwargs) in Qt's main thread. + + If not called from the main thread, this is run asynchronously. + + :param callable fn: Function to call in main thread. + :return: A future object to retrieve the result + :rtype: concurrent.future.Future + """ + global _executor + if _executor is None: # Lazy-loading + _executor = _QtExecutor() + + return _executor.submit(fn, *args, **kwargs) diff --git a/silx/gui/utils/test/__init__.py b/silx/gui/utils/test/__init__.py new file mode 100644 index 0000000..9e50170 --- /dev/null +++ b/silx/gui/utils/test/__init__.py @@ -0,0 +1,48 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""silx.gui.utils tests""" + + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "24/04/2018" + + +import unittest + +from . import test_async +from . import test_image + + +def suite(): + """Test suite for module silx.image.test""" + test_suite = unittest.TestSuite() + test_suite.addTest(test_async.suite()) + test_suite.addTest(test_image.suite()) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/test/test_utils.py b/silx/gui/utils/test/test_async.py index b1cdf0f..fd32a3f 100644 --- a/silx/gui/test/test_utils.py +++ b/silx/gui/utils/test/test_async.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,50 +22,22 @@ # THE SOFTWARE. # # ###########################################################################*/ -"""Test of utils module.""" +"""Test of async module.""" __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "16/01/2017" +__date__ = "09/03/2018" import threading import unittest -import numpy from silx.third_party.concurrent_futures import wait from silx.gui import qt from silx.gui.test.utils import TestCaseQt -from silx.gui import _utils - - -class TestQImageConversion(TestCaseQt): - """Tests conversion of QImage to/from numpy array.""" - - def testConvertArrayToQImage(self): - """Test conversion of numpy array to QImage""" - image = numpy.ones((3, 3, 3), dtype=numpy.uint8) - qimage = _utils.convertArrayToQImage(image) - - self.assertEqual(qimage.height(), image.shape[0]) - self.assertEqual(qimage.width(), image.shape[1]) - self.assertEqual(qimage.format(), qt.QImage.Format_RGB888) - - color = qt.QColor(1, 1, 1).rgb() - self.assertEqual(qimage.pixel(1, 1), color) - - def testConvertQImageToArray(self): - """Test conversion of QImage to numpy array""" - qimage = qt.QImage(3, 3, qt.QImage.Format_RGB888) - qimage.fill(0x010101) - image = _utils.convertQImageToArray(qimage) - - self.assertEqual(qimage.height(), image.shape[0]) - self.assertEqual(qimage.width(), image.shape[1]) - self.assertEqual(image.shape[2], 3) - self.assertTrue(numpy.all(numpy.equal(image, 1))) +from silx.gui.utils import concurrent class TestSubmitToQtThread(TestCaseQt): @@ -73,7 +45,7 @@ class TestSubmitToQtThread(TestCaseQt): def setUp(self): # Reset executor to test lazy-loading in different conditions - _utils._executor = None + concurrent._executor = None super(TestSubmitToQtThread, self).setUp() def _task(self, value1, value2): @@ -85,12 +57,12 @@ class TestSubmitToQtThread(TestCaseQt): def testFromMainThread(self): """Call submitToQtMainThread from the main thread""" value1, value2 = 0, 1 - future = _utils.submitToQtMainThread(self._task, value1, value2=value2) + future = concurrent.submitToQtMainThread(self._task, value1, value2=value2) self.assertTrue(future.done()) self.assertEqual(future.result(1), (value1, value2)) self.assertIsNone(future.exception(1)) - future = _utils.submitToQtMainThread(self._taskWithException) + future = concurrent.submitToQtMainThread(self._taskWithException) self.assertTrue(future.done()) with self.assertRaises(RuntimeError): future.result(1) @@ -99,7 +71,7 @@ class TestSubmitToQtThread(TestCaseQt): def _threadedTest(self): """Function run in a thread for the tests""" value1, value2 = 0, 1 - future = _utils.submitToQtMainThread(self._task, value1, value2=value2) + future = concurrent.submitToQtMainThread(self._task, value1, value2=value2) wait([future], 3) @@ -107,7 +79,7 @@ class TestSubmitToQtThread(TestCaseQt): self.assertEqual(future.result(1), (value1, value2)) self.assertIsNone(future.exception(1)) - future = _utils.submitToQtMainThread(self._taskWithException) + future = concurrent.submitToQtMainThread(self._taskWithException) wait([future], 3) @@ -156,8 +128,6 @@ class TestSubmitToQtThread(TestCaseQt): def suite(): test_suite = unittest.TestSuite() test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( - TestQImageConversion)) - test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( TestSubmitToQtThread)) return test_suite diff --git a/silx/gui/utils/test/test_image.py b/silx/gui/utils/test/test_image.py new file mode 100644 index 0000000..7cba1b0 --- /dev/null +++ b/silx/gui/utils/test/test_image.py @@ -0,0 +1,74 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test of utils module.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "16/01/2017" + +import numpy +import unittest + +from silx.gui import qt +from silx.gui.test.utils import TestCaseQt +from silx.gui.utils import _image + + +class TestQImageConversion(TestCaseQt): + """Tests conversion of QImage to/from numpy array.""" + + def testConvertArrayToQImage(self): + """Test conversion of numpy array to QImage""" + image = numpy.ones((3, 3, 3), dtype=numpy.uint8) + qimage = _image.convertArrayToQImage(image) + + self.assertEqual(qimage.height(), image.shape[0]) + self.assertEqual(qimage.width(), image.shape[1]) + self.assertEqual(qimage.format(), qt.QImage.Format_RGB888) + + color = qt.QColor(1, 1, 1).rgb() + self.assertEqual(qimage.pixel(1, 1), color) + + def testConvertQImageToArray(self): + """Test conversion of QImage to numpy array""" + qimage = qt.QImage(3, 3, qt.QImage.Format_RGB888) + qimage.fill(0x010101) + image = _image.convertQImageToArray(qimage) + + self.assertEqual(qimage.height(), image.shape[0]) + self.assertEqual(qimage.width(), image.shape[1]) + self.assertEqual(image.shape[2], 3) + self.assertTrue(numpy.all(numpy.equal(image, 1))) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( + TestQImageConversion)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/widgets/BoxLayoutDockWidget.py b/silx/gui/widgets/BoxLayoutDockWidget.py new file mode 100644 index 0000000..3d2b853 --- /dev/null +++ b/silx/gui/widgets/BoxLayoutDockWidget.py @@ -0,0 +1,90 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A QDockWidget that update the layout direction of its widget +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2018" + + +from .. import qt + + +class BoxLayoutDockWidget(qt.QDockWidget): + """QDockWidget adjusting its child widget QBoxLayout direction. + + The child widget layout direction is set according to dock widget area. + The child widget MUST use a QBoxLayout + + :param parent: See :class:`QDockWidget` + :param flags: See :class:`QDockWidget` + """ + + def __init__(self, parent=None, flags=qt.Qt.Widget): + super(BoxLayoutDockWidget, self).__init__(parent, flags) + self._currentArea = qt.Qt.NoDockWidgetArea + self.dockLocationChanged.connect(self._dockLocationChanged) + self.topLevelChanged.connect(self._topLevelChanged) + + def setWidget(self, widget): + """Set the widget of this QDockWidget + + See :meth:`QDockWidget.setWidget` + """ + super(BoxLayoutDockWidget, self).setWidget(widget) + # Update widget's layout direction + self._dockLocationChanged(self._currentArea) + + def _dockLocationChanged(self, area): + self._currentArea = area + + widget = self.widget() + if widget is not None: + layout = widget.layout() + if isinstance(layout, qt.QBoxLayout): + if area in (qt.Qt.LeftDockWidgetArea, qt.Qt.RightDockWidgetArea): + direction = qt.QBoxLayout.TopToBottom + else: + direction = qt.QBoxLayout.LeftToRight + layout.setDirection(direction) + self.resize(widget.minimumSize()) + self.adjustSize() + + def _topLevelChanged(self, topLevel): + widget = self.widget() + if widget is not None and topLevel: + layout = widget.layout() + if isinstance(layout, qt.QBoxLayout): + layout.setDirection(qt.QBoxLayout.LeftToRight) + self.resize(widget.minimumSize()) + self.adjustSize() + + def showEvent(self, event): + """Make sure this dock widget is raised when it is shown. + + This is useful for tabbed dock widgets. + """ + self.raise_() diff --git a/silx/gui/widgets/FrameBrowser.py b/silx/gui/widgets/FrameBrowser.py index a8c0349..b4f88fc 100644 --- a/silx/gui/widgets/FrameBrowser.py +++ b/silx/gui/widgets/FrameBrowser.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,6 +33,7 @@ """ from silx.gui import qt from silx.gui import icons +from silx.utils import deprecation __authors__ = ["V.A. Sole", "P. Knobel"] __license__ = "MIT" @@ -50,7 +51,9 @@ class FrameBrowser(qt.QWidget): :param QWidget parent: Parent widget :param int n: Number of frames. This will set the range of frame indices to 0--n-1. - If None, the range is initialized to the default QSlider range (0--99).""" + If None, the range is initialized to the default QSlider range (0--99). + """ + sigIndexChanged = qt.pyqtSignal(object) def __init__(self, parent=None, n=None): @@ -123,25 +126,19 @@ class FrameBrowser(qt.QWidget): def _firstClicked(self): """Select first/lowest frame number""" - self._lineEdit.setText("%d" % self._lineEdit.validator().bottom()) - self._textChangedSlot() + self.setValue(self.getRange()[0]) def _previousClicked(self): """Select previous frame number""" - if self._index > self._lineEdit.validator().bottom(): - self._lineEdit.setText("%d" % (self._index - 1)) - self._textChangedSlot() + self.setValue(self.getValue() - 1) def _nextClicked(self): """Select next frame number""" - if self._index < (self._lineEdit.validator().top()): - self._lineEdit.setText("%d" % (self._index + 1)) - self._textChangedSlot() + self.setValue(self.getValue() + 1) def _lastClicked(self): """Select last/highest frame number""" - self._lineEdit.setText("%d" % self._lineEdit.validator().top()) - self._textChangedSlot() + self.setValue(self.getRange()[1]) def _textChangedSlot(self): """Select frame number typed in the line edit widget""" @@ -161,17 +158,17 @@ class FrameBrowser(qt.QWidget): self._index = new_value self.sigIndexChanged.emit(ddict) - def setRange(self, first, last): - """Set minimum and maximum frame indices - Initialize the frame index to *first*. - Update the label text to *" limits: first, last"* + def getRange(self): + """Returns frame range - :param int first: Minimum frame index - :param int last: Maximum frame index""" - return self.setLimits(first, last) + :return: (first_index, last_index) + """ + validator = self.lineEdit().validator() + return validator.bottom(), validator.top() - def setLimits(self, first, last): + def setRange(self, first, last): """Set minimum and maximum frame indices. + Initialize the frame index to *first*. Update the label text to *" limits: first, last"* @@ -181,34 +178,52 @@ class FrameBrowser(qt.QWidget): top = max(first, last) self._lineEdit.validator().setTop(top) self._lineEdit.validator().setBottom(bottom) - self._index = bottom - self._lineEdit.setText("%d" % self._index) + self.setValue(bottom) + + # Update limits self._label.setText(" limits: %d, %d " % (bottom, top)) + @deprecation.deprecated(replacement="FrameBrowser.setRange", + since_version="0.8") + def setLimits(self, first, last): + return self.setRange(first, last) + def setNFrames(self, nframes): """Set minimum=0 and maximum=nframes-1 frame numbers. + Initialize the frame index to 0. Update the label text to *"1 of nframes"* :param int nframes: Number of frames""" - bottom = 0 top = nframes - 1 - self._lineEdit.validator().setTop(top) - self._lineEdit.validator().setBottom(bottom) - self._index = bottom - self._lineEdit.setText("%d" % self._index) + self.setRange(0, top) # display 1-based index in label - self._label.setText(" %d of %d " % (self._index + 1, top + 1)) + self._label.setText(" of %d " % top) + @deprecation.deprecated(replacement="FrameBrowser.getValue", + since_version="0.8") def getCurrentIndex(self): - """Get 0-based frame index - """ + return self._index + + def getValue(self): + """Return current frame index""" return self._index def setValue(self, value): """Set 0-based frame index + Value is clipped to current range. + :param int value: Frame number""" + bottom = self.lineEdit().validator().bottom() + top = self.lineEdit().validator().top() + value = int(value) + + if value < bottom: + value = bottom + elif value > top: + value = top + self._lineEdit.setText("%d" % value) self._textChangedSlot() diff --git a/silx/gui/widgets/PrintGeometryDialog.py b/silx/gui/widgets/PrintGeometryDialog.py index 0613ce0..db0f3b3 100644 --- a/silx/gui/widgets/PrintGeometryDialog.py +++ b/silx/gui/widgets/PrintGeometryDialog.py @@ -40,7 +40,7 @@ class PrintGeometryWidget(qt.QWidget): self.mainLayout = qt.QGridLayout(self) self.mainLayout.setContentsMargins(0, 0, 0, 0) self.mainLayout.setSpacing(2) - hbox = qt.QWidget() + hbox = qt.QWidget(self) hboxLayout = qt.QHBoxLayout(hbox) hboxLayout.setContentsMargins(0, 0, 0, 0) hboxLayout.setSpacing(2) diff --git a/silx/gui/widgets/PrintPreview.py b/silx/gui/widgets/PrintPreview.py index 2b4c433..78d1bd7 100644 --- a/silx/gui/widgets/PrintPreview.py +++ b/silx/gui/widgets/PrintPreview.py @@ -31,7 +31,7 @@ The user can interactively move and resize the items. """ import sys import logging -from silx.gui import qt +from silx.gui import qt, printer __authors__ = ["V.A. Sole", "P. Knobel"] @@ -387,7 +387,7 @@ class PrintPreviewDialog(qt.QDialog): *None*. """ if self.printer is None: - self.printer = qt.QPrinter() + self.printer = printer.getDefaultPrinter() if self.printDialog is None: self.printDialog = qt.QPrintDialog(self.printer, self) if self.printDialog.exec_(): diff --git a/silx/gui/widgets/test/__init__.py b/silx/gui/widgets/test/__init__.py index 7affc20..5e62393 100644 --- a/silx/gui/widgets/test/__init__.py +++ b/silx/gui/widgets/test/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,8 @@ from . import test_tablewidget from . import test_threadpoolpushbutton from . import test_hierarchicaltableview from . import test_printpreview +from . import test_framebrowser +from . import test_boxlayoutdockwidget __authors__ = ["V. Valls", "P. Knobel"] __license__ = "MIT" @@ -43,5 +45,7 @@ def suite(): test_periodictable.suite(), test_printpreview.suite(), test_hierarchicaltableview.suite(), + test_framebrowser.suite(), + test_boxlayoutdockwidget.suite(), ]) return test_suite diff --git a/silx/gui/widgets/test/test_boxlayoutdockwidget.py b/silx/gui/widgets/test/test_boxlayoutdockwidget.py new file mode 100644 index 0000000..0df262b --- /dev/null +++ b/silx/gui/widgets/test/test_boxlayoutdockwidget.py @@ -0,0 +1,83 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests for BoxLayoutDockWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2018" + +import unittest + +from silx.gui.widgets.BoxLayoutDockWidget import BoxLayoutDockWidget +from silx.gui import qt +from silx.gui.test.utils import TestCaseQt + + +class TestBoxLayoutDockWidget(TestCaseQt): + """Tests for BoxLayoutDockWidget""" + + def setUp(self): + """Create and show a main window""" + self.window = qt.QMainWindow() + self.qWaitForWindowExposed(self.window) + + def tearDown(self): + """Delete main window""" + self.window.setAttribute(qt.Qt.WA_DeleteOnClose) + self.window.close() + del self.window + self.qapp.processEvents() + + def test(self): + """Test update of layout direction according to dock area""" + # Create a widget with a QBoxLayout + layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight) + layout.addWidget(qt.QLabel('First')) + layout.addWidget(qt.QLabel('Second')) + widget = qt.QWidget() + widget.setLayout(layout) + + # Add it to a BoxLayoutDockWidget + dock = BoxLayoutDockWidget() + dock.setWidget(widget) + + self.window.addDockWidget(qt.Qt.BottomDockWidgetArea, dock) + self.qapp.processEvents() + self.assertEqual(layout.direction(), qt.QBoxLayout.LeftToRight) + + self.window.addDockWidget(qt.Qt.LeftDockWidgetArea, dock) + self.qapp.processEvents() + self.assertEqual(layout.direction(), qt.QBoxLayout.TopToBottom) + + +def suite(): + loader = unittest.defaultTestLoader.loadTestsFromTestCase + test_suite = unittest.TestSuite() + test_suite.addTest(loader(TestBoxLayoutDockWidget)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/widgets/test/test_framebrowser.py b/silx/gui/widgets/test/test_framebrowser.py new file mode 100644 index 0000000..9988d16 --- /dev/null +++ b/silx/gui/widgets/test/test_framebrowser.py @@ -0,0 +1,73 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "23/03/2018" + + +import unittest + +from silx.gui.test.utils import TestCaseQt +from silx.gui.widgets.FrameBrowser import FrameBrowser + + +class TestFrameBrowser(TestCaseQt): + """Test for FrameBrowser""" + + def test(self): + """Test FrameBrowser""" + widget = FrameBrowser() + widget.show() + self.qWaitForWindowExposed(widget) + + nFrames = 20 + widget.setNFrames(nFrames) + self.assertEqual(widget.getRange(), (0, nFrames - 1)) + self.assertEqual(widget.getValue(), 0) + + range_ = -100, 100 + widget.setRange(*range_) + self.assertEqual(widget.getRange(), range_) + self.assertEqual(widget.getValue(), range_[0]) + + widget.setValue(0) + self.assertEqual(widget.getValue(), 0) + + widget.setValue(range_[1] + 100) + self.assertEqual(widget.getValue(), range_[1]) + + widget.setValue(range_[0] - 100) + self.assertEqual(widget.getValue(), range_[0]) + + +def suite(): + loader = unittest.defaultTestLoader.loadTestsFromTestCase + test_suite = unittest.TestSuite() + test_suite.addTest(loader(TestFrameBrowser)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') |