diff options
Diffstat (limited to 'silx/gui/plot/items/core.py')
-rw-r--r-- | silx/gui/plot/items/core.py | 839 |
1 files changed, 839 insertions, 0 deletions
diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py new file mode 100644 index 0000000..72bfd9a --- /dev/null +++ b/silx/gui/plot/items/core.py @@ -0,0 +1,839 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the base class for items of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "26/04/2017" + +from copy import deepcopy +import logging +import weakref +import numpy +from silx.third_party import six + +from .. import Colors + + + +_logger = logging.getLogger(__name__) + + +class Item(object): + """Description of an item of the plot""" + + _DEFAULT_Z_LAYER = 0 + """Default layer for overlay rendering""" + + _DEFAULT_LEGEND = '' + """Default legend of items""" + + _DEFAULT_SELECTABLE = False + """Default selectable state of items""" + + def __init__(self): + self._dirty = True + self._plotRef = None + self._visible = True + self._legend = self._DEFAULT_LEGEND + self._selectable = self._DEFAULT_SELECTABLE + self._z = self._DEFAULT_Z_LAYER + self._info = None + self._xlabel = None + self._ylabel = None + + self._backendRenderer = None + + def getPlot(self): + """Returns Plot this item belongs to. + + :rtype: Plot or None + """ + return None if self._plotRef is None else self._plotRef() + + def _setPlot(self, plot): + """Set the plot this item belongs to. + + WARNING: This should only be called from the Plot. + + :param Plot plot: The Plot instance. + """ + if plot is not None and self._plotRef is not None: + raise RuntimeError('Trying to add a node at two places.') + self._plotRef = None if plot is None else weakref.ref(plot) + self._updated() + + def getBounds(self): # TODO return a Bounds object rather than a tuple + """Returns the bounding box of this item in data coordinates + + :returns: (xmin, xmax, ymin, ymax) or None + :rtype: 4-tuple of float or None + """ + return self._getBounds() + + def _getBounds(self): + """:meth:`getBounds` implementation to override by sub-class""" + return None + + def isVisible(self): + """True if item is visible, False otherwise + + :rtype: bool + """ + return self._visible + + def setVisible(self, visible): + """Set visibility of item. + + :param bool visible: True to display it, False otherwise + """ + visible = bool(visible) + if visible != self._visible: + self._visible = visible + # When visibility has changed, always mark as dirty + self._updated(checkVisibility=False) + + def isOverlay(self): + """Return true if item is drawn as an overlay. + + :rtype: bool + """ + return False + + def getLegend(self): + """Returns the legend of this item (str)""" + return self._legend + + def _setLegend(self, legend): + """Set the legend. + + This is private as it is used by the plot as an identifier + + :param str legend: Item legend + """ + legend = str(legend) if legend is not None else self._DEFAULT_LEGEND + self._legend = legend + + def isSelectable(self): + """Returns true if item is selectable (bool)""" + return self._selectable + + def _setSelectable(self, selectable): # TODO support update + """Set whether item is selectable or not. + + This is private for now as change is not handled. + + :param bool selectable: True to make item selectable + """ + self._selectable = bool(selectable) + + def getZValue(self): + """Returns the layer on which to draw this item (int)""" + return self._z + + def setZValue(self, z): + z = int(z) if z is not None else self._DEFAULT_Z_LAYER + if z != self._z: + self._z = z + self._updated() + + def getInfo(self, copy=True): + """Returns the info associated to this item + + :param bool copy: True to get a deepcopy, False otherwise. + """ + return deepcopy(self._info) if copy else self._info + + def setInfo(self, info, copy=True): + if copy: + info = deepcopy(info) + self._info = info + + def _updated(self, checkVisibility=True): + """Mark the item as dirty (i.e., needing update). + + This also triggers Plot.replot. + + :param bool checkVisibility: True to only mark as dirty if visible, + False to always mark as dirty. + """ + if not checkVisibility or self.isVisible(): + if not self._dirty: + self._dirty = True + # TODO: send event instead of explicit call + plot = self.getPlot() + if plot is not None: + plot._itemRequiresUpdate(self) + + def _update(self, backend): + """Called by Plot to update the backend for this item. + + This is meant to be called asynchronously from _updated. + This optimizes the number of call to _update. + + :param backend: The backend to update + """ + if self._dirty: + # Remove previous renderer from backend if any + self._removeBackendRenderer(backend) + + # If not visible, do not add renderer to backend + if self.isVisible(): + self._backendRenderer = self._addBackendRenderer(backend) + + self._dirty = False + + def _addBackendRenderer(self, backend): + """Override in subclass to add specific backend renderer. + + :param BackendBase backend: The backend to update + :return: The renderer handle to store or None if no renderer in backend + """ + return None + + def _removeBackendRenderer(self, backend): + """Override in subclass to remove specific backend renderer. + + :param BackendBase backend: The backend to update + """ + if self._backendRenderer is not None: + backend.remove(self._backendRenderer) + self._backendRenderer = None + + +# Mix-in classes ############################################################## + +class LabelsMixIn(object): + """Mix-in class for items with x and y labels + + Setters are private, otherwise it needs to check the plot + current active curve and access the internal current labels. + """ + + def __init__(self): + self._xlabel = None + self._ylabel = None + + def getXLabel(self): + """Return the X axis label associated to this curve + + :rtype: str or None + """ + return self._xlabel + + def _setXLabel(self, label): + """Set the X axis label associated with this curve + + :param str label: The X axis label + """ + self._xlabel = str(label) + + def getYLabel(self): + """Return the Y axis label associated to this curve + + :rtype: str or None + """ + return self._ylabel + + def _setYLabel(self, label): + """Set the Y axis label associated with this curve + + :param str label: The Y axis label + """ + self._ylabel = str(label) + + +class DraggableMixIn(object): + """Mix-in class for draggable items""" + + def __init__(self): + self._draggable = False + + def isDraggable(self): + """Returns true if image is draggable + + :rtype: bool + """ + return self._draggable + + def _setDraggable(self, draggable): # TODO support update + """Set if image is draggable or not. + + This is private for not as it does not support update. + + :param bool draggable: + """ + self._draggable = bool(draggable) + + +class ColormapMixIn(object): + """Mix-in class for items with colormap""" + + _DEFAULT_COLORMAP = {'name': 'gray', 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0} + """Default colormap of the item""" + + def __init__(self): + self._colormap = self._DEFAULT_COLORMAP + + def getColormap(self): + """Return the used colormap""" + return self._colormap.copy() + + def setColormap(self, colormap): + """Set the colormap of this image + + :param dict colormap: colormap description + """ + self._colormap = colormap.copy() + # TODO colormap comparison + colormap object and events on modification + self._updated() + + +class SymbolMixIn(object): + """Mix-in class for items with symbol type""" + + _DEFAULT_SYMBOL = '' + """Default marker of the item""" + + _DEFAULT_SYMBOL_SIZE = 6.0 + """Default marker size of the item""" + + def __init__(self): + self._symbol = self._DEFAULT_SYMBOL + self._symbol_size = self._DEFAULT_SYMBOL_SIZE + + def getSymbol(self): + """Return the point marker type. + + Marker type:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + + :rtype: str + """ + return self._symbol + + def setSymbol(self, symbol): + """Set the marker type + + See :meth:`getSymbol`. + + :param str symbol: Marker type + """ + assert symbol in ('o', '.', ',', '+', 'x', 'd', 's', '', None) + if symbol is None: + symbol = self._DEFAULT_SYMBOL + if symbol != self._symbol: + self._symbol = symbol + self._updated() + + def getSymbolSize(self): + """Return the point marker size in points. + + :rtype: float + """ + return self._symbol_size + + def setSymbolSize(self, size): + """Set the point marker size in points. + + See :meth:`getSymbolSize`. + + :param str symbol: Marker type + """ + if size is None: + size = self._DEFAULT_SYMBOL_SIZE + if size != self._symbol_size: + self._symbol_size = size + self._updated() + + +class LineMixIn(object): + """Mix-in class for item with line""" + + _DEFAULT_LINEWIDTH = 1. + """Default line width""" + + _DEFAULT_LINESTYLE = '-' + """Default line style""" + + def __init__(self): + self._linewidth = self._DEFAULT_LINEWIDTH + self._linestyle = self._DEFAULT_LINESTYLE + + def getLineWidth(self): + """Return the curve line width in pixels (int)""" + return self._linewidth + + def setLineWidth(self, width): + """Set the width in pixel of the curve line + + See :meth:`getLineWidth`. + + :param float width: Width in pixels + """ + width = float(width) + if width != self._linewidth: + self._linewidth = width + self._updated() + + def getLineStyle(self): + """Return the type of the line + + Type of line:: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + + :rtype: str + """ + return self._linestyle + + def setLineStyle(self, style): + """Set the style of the curve line. + + See :meth:`getLineStyle`. + + :param str style: Line style + """ + style = str(style) + assert style in ('', ' ', '-', '--', '-.', ':', None) + if style is None: + style = self._DEFAULT_LINESTYLE + if style != self._linestyle: + self._linestyle = style + self._updated() + + +class ColorMixIn(object): + """Mix-in class for item with color""" + + _DEFAULT_COLOR = (0., 0., 0., 1.) + """Default color of the item""" + + def __init__(self): + self._color = self._DEFAULT_COLOR + + def getColor(self): + """Returns the RGBA color of the item + + :rtype: 4-tuple of float in [0, 1] + """ + return self._color + + def setColor(self, color, copy=True): + """Set item color + + :param color: color(s) to be used + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in Colors.py + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + if isinstance(color, six.string_types): + color = Colors.rgba(color) + else: + color = numpy.array(color, copy=copy) + # TODO more checks + improve color array support + if color.ndim == 1: # Single RGBA color + color = Colors.rgba(color) + else: # Array of colors + assert color.ndim == 2 + + self._color = color + self._updated() + + +class YAxisMixIn(object): + """Mix-in class for item with yaxis""" + + _DEFAULT_YAXIS = 'left' + """Default Y axis the item belongs to""" + + def __init__(self): + self._yaxis = self._DEFAULT_YAXIS + + def getYAxis(self): + """Returns the Y axis this curve belongs to. + + Either 'left' or 'right'. + + :rtype: str + """ + return self._yaxis + + def setYAxis(self, yaxis): + """Set the Y axis this curve belongs to. + + :param str yaxis: 'left' or 'right' + """ + yaxis = str(yaxis) + assert yaxis in ('left', 'right') + if yaxis != self._yaxis: + self._yaxis = yaxis + self._updated() + + +class FillMixIn(object): + """Mix-in class for item with fill""" + + def __init__(self): + self._fill = False + + def isFill(self): + """Returns whether the item is filled or not. + + :rtype: bool + """ + return self._fill + + def setFill(self, fill): + """Set whether to fill the item or not. + + :param bool fill: + """ + fill = bool(fill) + if fill != self._fill: + self._fill = fill + self._updated() + + +class AlphaMixIn(object): + """Mix-in class for item with opacity""" + + def __init__(self): + self._alpha = 1. + + def getAlpha(self): + """Returns the opacity of the item + + :rtype: float in [0, 1.] + """ + return self._alpha + + def setAlpha(self, alpha): + """Set the opacity of the item + + .. note:: + + If the colormap already has some transparency, this alpha + adds additional transparency. The alpha channel of the colormap + is multiplied by this value. + + :param alpha: Opacity of the item, between 0 (full transparency) + and 1. (full opacity) + :type alpha: float + """ + alpha = float(alpha) + alpha = max(0., min(alpha, 1.)) # Clip alpha to [0., 1.] range + if alpha != self._alpha: + self._alpha = alpha + self._updated() + + +class Points(Item, SymbolMixIn, AlphaMixIn): + """Base class for :class:`Curve` and :class:`Scatter`""" + # note: _logFilterData must be overloaded if you overload + # getData to change its signature + + _DEFAULT_Z_LAYER = 1 + """Default overlay layer for points, + on top of images.""" + + def __init__(self): + Item.__init__(self) + SymbolMixIn.__init__(self) + AlphaMixIn.__init__(self) + self._x = () + self._y = () + self._xerror = None + self._yerror = None + + # Store filtered data for x > 0 and/or y > 0 + self._filteredCache = {} + self._clippedCache = {} + + # Store bounds depending on axes filtering >0: + # key is (isXPositiveFilter, isYPositiveFilter) + self._boundsCache = {} + + @staticmethod + def _logFilterError(value, error): + """Filter/convert error values if they go <= 0. + + Replace error leading to negative values by nan + + :param numpy.ndarray value: 1D array of values + :param numpy.ndarray error: + Array of errors: scalar, N, Nx1 or 2xN or None. + :return: Filtered error so error bars are never negative + """ + if error is not None: + # Convert Nx1 to N + if error.ndim == 2 and error.shape[1] == 1 and len(value) != 1: + error = numpy.ravel(error) + + # Supports error being scalar, N or 2xN array + errorClipped = (value - numpy.atleast_2d(error)[0]) <= 0 + + if numpy.any(errorClipped): # Need filtering + + # expand errorbars to 2xN + if error.size == 1: # Scalar + error = numpy.full( + (2, len(value)), error, dtype=numpy.float) + + elif error.ndim == 1: # N array + newError = numpy.empty((2, len(value)), + dtype=numpy.float) + newError[0, :] = error + newError[1, :] = error + error = newError + + elif error.size == 2 * len(value): # 2xN array + error = numpy.array( + error, copy=True, dtype=numpy.float) + + else: + _logger.error("Unhandled error array") + return error + + error[0, errorClipped] = numpy.nan + + return error + + def _getClippingBoolArray(self, xPositive, yPositive): + """Compute a boolean array to filter out points with negative + coordinates on log axes. + + :param bool xPositive: True to filter arrays according to X coords. + :param bool yPositive: True to filter arrays according to Y coords. + :rtype: boolean numpy.ndarray + """ + assert xPositive or yPositive + if (xPositive, yPositive) not in self._clippedCache: + x = self.getXData(copy=False) + y = self.getYData(copy=False) + xclipped = (x <= 0) if xPositive else False + yclipped = (y <= 0) if yPositive else False + self._clippedCache[(xPositive, yPositive)] = \ + numpy.logical_or(xclipped, yclipped) + return self._clippedCache[(xPositive, yPositive)] + + def _logFilterData(self, xPositive, yPositive): + """Filter out values with x or y <= 0 on log axes + + :param bool xPositive: True to filter arrays according to X coords. + :param bool yPositive: True to filter arrays according to Y coords. + :return: The filter arrays or unchanged object if filtering not needed + :rtype: (x, y, xerror, yerror) + """ + x = self.getXData(copy=False) + y = self.getYData(copy=False) + xerror = self.getXErrorData(copy=False) + yerror = self.getYErrorData(copy=False) + + if xPositive or yPositive: + clipped = self._getClippingBoolArray(xPositive, yPositive) + + if numpy.any(clipped): + # copy to keep original array and convert to float + x = numpy.array(x, copy=True, dtype=numpy.float) + x[clipped] = numpy.nan + y = numpy.array(y, copy=True, dtype=numpy.float) + y[clipped] = numpy.nan + + if xPositive and xerror is not None: + xerror = self._logFilterError(x, xerror) + + if yPositive and yerror is not None: + yerror = self._logFilterError(y, yerror) + + return x, y, xerror, yerror + + def _getBounds(self): + if self.getXData(copy=False).size == 0: # Empty data + return None + + plot = self.getPlot() + if plot is not None: + xPositive = plot.isXAxisLogarithmic() + yPositive = plot.isYAxisLogarithmic() + else: + xPositive = False + yPositive = False + + # TODO bounds do not take error bars into account + if (xPositive, yPositive) not in self._boundsCache: + # use the getData class method because instance method can be + # overloaded to return additional arrays + data = Points.getData(self, copy=False, + displayed=True) + if len(data) == 5: + # hack to avoid duplicating caching mechanism in Scatter + # (happens when cached data is used, caching done using + # Scatter._logFilterData) + x, y, xerror, yerror = data[0], data[1], data[3], data[4] + else: + x, y, xerror, yerror = data + + self._boundsCache[(xPositive, yPositive)] = ( + numpy.nanmin(x), + numpy.nanmax(x), + numpy.nanmin(y), + numpy.nanmax(y) + ) + return self._boundsCache[(xPositive, yPositive)] + + def _getCachedData(self): + """Return cached filtered data if applicable, + i.e. if any axis is in log scale. + Return None if caching is not applicable.""" + plot = self.getPlot() + if plot is not None: + xPositive = plot.isXAxisLogarithmic() + yPositive = plot.isYAxisLogarithmic() + if xPositive or yPositive: + # At least one axis has log scale, filter data + if (xPositive, yPositive) not in self._filteredCache: + self._filteredCache[(xPositive, yPositive)] = \ + self._logFilterData(xPositive, yPositive) + return self._filteredCache[(xPositive, yPositive)] + return None + + def getData(self, copy=True, displayed=False): + """Returns the x, y values of the curve points and xerror, yerror + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :param bool displayed: True to only get curve points that are displayed + in the plot. Default: False + Note: If plot has log scale, negative points + are not displayed. + :returns: (x, y, xerror, yerror) + :rtype: 4-tuple of numpy.ndarray + """ + if displayed: # filter data according to plot state + cached_data = self._getCachedData() + if cached_data is not None: + return cached_data + + return (self.getXData(copy), + self.getYData(copy), + self.getXErrorData(copy), + self.getYErrorData(copy)) + + def getXData(self, copy=True): + """Returns the x coordinates of the data points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._x, copy=copy) + + def getYData(self, copy=True): + """Returns the y coordinates of the data points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._y, copy=copy) + + def getXErrorData(self, copy=True): + """Returns the x error of the points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray or None + """ + if self._xerror is None: + return None + else: + return numpy.array(self._xerror, copy=copy) + + def getYErrorData(self, copy=True): + """Returns the y error of the points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray or None + """ + if self._yerror is None: + return None + else: + return numpy.array(self._yerror, copy=copy) + + def setData(self, x, y, xerror=None, yerror=None, copy=True): + """Set the data of the curve. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + :param numpy.ndarray y: The data corresponding to the y coordinates. + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values. + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + x = numpy.array(x, copy=copy) + y = numpy.array(y, copy=copy) + assert len(x) == len(y) + assert x.ndim == y.ndim == 1 + + if xerror is not None: + xerror = numpy.array(xerror, copy=copy) + if yerror is not None: + yerror = numpy.array(yerror, copy=copy) + # TODO checks on xerror, yerror + self._x, self._y = x, y + self._xerror, self._yerror = xerror, yerror + + self._boundsCache = {} # Reset cached bounds + self._filteredCache = {} # Reset cached filtered data + self._clippedCache = {} # Reset cached clipped bool array + + self._updated() + # TODO hackish data range implementation + if self.isVisible(): + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() |