diff options
Diffstat (limited to 'silx/gui/plot/items/scatter.py')
-rw-r--r-- | silx/gui/plot/items/scatter.py | 286 |
1 files changed, 236 insertions, 50 deletions
diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py index 50cc694..5e7d65b 100644 --- a/silx/gui/plot/items/scatter.py +++ b/silx/gui/plot/items/scatter.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2019 European Synchrotron Radiation Facility +# Copyright (c) 2017-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -43,6 +43,7 @@ from concurrent.futures import ThreadPoolExecutor, CancelledError from ....utils.proxy import docstring from ....math.combo import min_max +from ....math.histogram import Histogramnd from ....utils.weakref import WeakList from .._utils.delaunay import delaunay from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn @@ -142,12 +143,13 @@ def is_monotonic(array): :rtype: int """ diff = numpy.diff(numpy.ravel(array)) - if numpy.all(diff >= 0): - return 1 - elif numpy.all(diff <= 0): - return -1 - else: - return 0 + with numpy.errstate(invalid='ignore'): + if numpy.all(diff >= 0): + return 1 + elif numpy.all(diff <= 0): + return -1 + else: + return 0 def _guess_grid(x, y): @@ -264,6 +266,10 @@ _RegularGridInfo = namedtuple( '_RegularGridInfo', ['bounds', 'origin', 'scale', 'shape', 'order']) +_HistogramInfo = namedtuple( + '_HistogramInfo', ['mean', 'count', 'sum', 'origin', 'scale', 'shape']) + + class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): """Description of a scatter""" @@ -275,6 +281,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): ScatterVisualizationMixIn.Visualization.SOLID, ScatterVisualizationMixIn.Visualization.REGULAR_GRID, ScatterVisualizationMixIn.Visualization.IRREGULAR_GRID, + ScatterVisualizationMixIn.Visualization.BINNED_STATISTIC, ) """Overrides supported Visualizations""" @@ -293,17 +300,53 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): # Cache triangles: x, y, indices self.__cacheTriangles = None, None, None - # Cache regular grid info + # Cache regular grid and histogram info self.__cacheRegularGridInfo = None + self.__cacheHistogramInfo = None + + def _updateColormappedData(self): + """Update the colormapped data, to be called when changed""" + if self.getVisualization() is self.Visualization.BINNED_STATISTIC: + histoInfo = self.__getHistogramInfo() + if histoInfo is None: + data = None + else: + data = getattr( + histoInfo, + self.getVisualizationParameter( + self.VisualizationParameter.BINNED_STATISTIC_FUNCTION)) + else: + data = self.getValueData(copy=False) + self._setColormappedData(data, copy=False) + + @docstring(ScatterVisualizationMixIn) + def setVisualization(self, mode): + previous = self.getVisualization() + if super().setVisualization(mode): + if (bool(mode is self.Visualization.BINNED_STATISTIC) ^ + bool(previous is self.Visualization.BINNED_STATISTIC)): + self._updateColormappedData() + return True + else: + return False @docstring(ScatterVisualizationMixIn) def setVisualizationParameter(self, parameter, value): - changed = super(Scatter, self).setVisualizationParameter(parameter, value) - if changed and parameter in (self.VisualizationParameter.GRID_BOUNDS, - self.VisualizationParameter.GRID_MAJOR_ORDER, - self.VisualizationParameter.GRID_SHAPE): - self.__cacheRegularGridInfo = None - return changed + if super(Scatter, self).setVisualizationParameter(parameter, value): + if parameter in (self.VisualizationParameter.GRID_BOUNDS, + self.VisualizationParameter.GRID_MAJOR_ORDER, + self.VisualizationParameter.GRID_SHAPE): + self.__cacheRegularGridInfo = None + + if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE, + self.VisualizationParameter.BINNED_STATISTIC_FUNCTION): + if parameter == self.VisualizationParameter.BINNED_STATISTIC_SHAPE: + self.__cacheHistogramInfo = None # Clean-up cache + if self.getVisualization() is self.Visualization.BINNED_STATISTIC: + self._updateColormappedData() + return True + else: + return False @docstring(ScatterVisualizationMixIn) def getCurrentVisualizationParameter(self, parameter): @@ -323,6 +366,10 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): grid = self.__getRegularGridInfo() return None if grid is None else grid.shape + elif parameter is self.VisualizationParameter.BINNED_STATISTIC_SHAPE: + info = self.__getHistogramInfo() + return None if info is None else info.shape + else: raise NotImplementedError() @@ -345,6 +392,18 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): if order is None: order = guess[0] + nbpoints = len(self.getXData(copy=False)) + if nbpoints > shape[0] * shape[1]: + # More data points that provided grid shape: enlarge grid + _logger.warning( + "More data points than provided grid shape size: extends grid") + dim0, dim1 = shape + if order == 'row': # keep dim1, enlarge dim0 + dim0 = nbpoints // dim1 + (1 if nbpoints % dim1 else 0) + else: # keep dim0, enlarge dim1 + dim1 = nbpoints // dim0 + (1 if nbpoints % dim0 else 0) + shape = dim0, dim1 + bounds = self.getVisualizationParameter( self.VisualizationParameter.GRID_BOUNDS) if bounds is None: @@ -372,6 +431,47 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): return self.__cacheRegularGridInfo + def __getHistogramInfo(self): + """Get histogram info""" + if self.__cacheHistogramInfo is None: + shape = self.getVisualizationParameter( + self.VisualizationParameter.BINNED_STATISTIC_SHAPE) + if shape is None: + shape = 100, 100 # TODO compute auto shape + + x, y, values = self.getData(copy=False)[:3] + if len(x) == 0: # No histogram + return None + + if not numpy.issubdtype(x.dtype, numpy.floating): + x = x.astype(numpy.float64) + if not numpy.issubdtype(y.dtype, numpy.floating): + y = y.astype(numpy.float64) + if not numpy.issubdtype(values.dtype, numpy.floating): + values = values.astype(numpy.float64) + + ranges = (tuple(min_max(y, finite=True)), + tuple(min_max(x, finite=True))) + points = numpy.transpose(numpy.array((y, x))) + counts, sums, bin_edges = Histogramnd( + points, + histo_range=ranges, + n_bins=shape, + weights=values) + yEdges, xEdges = bin_edges + origin = xEdges[0], yEdges[0] + scale = ((xEdges[-1] - xEdges[0]) / (len(xEdges) - 1), + (yEdges[-1] - yEdges[0]) / (len(yEdges) - 1)) + + with numpy.errstate(divide='ignore', invalid='ignore'): + histo = sums / counts + + self.__cacheHistogramInfo = _HistogramInfo( + mean=histo, count=counts, sum=sums, + origin=origin, scale=scale, shape=shape) + + return self.__cacheHistogramInfo + def _addBackendRenderer(self, backend): """Update backend renderer""" # Filter-out values <= 0 @@ -386,28 +486,47 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): if len(xFiltered) == 0: return None # No data to display, do not add renderer to backend + visualization = self.getVisualization() + + if visualization is self.Visualization.BINNED_STATISTIC: + plot = self.getPlot() + if (plot is None or + plot.getXAxis().getScale() != Axis.LINEAR or + plot.getYAxis().getScale() != Axis.LINEAR): + # Those visualizations are not available with log scaled axes + return None + + histoInfo = self.__getHistogramInfo() + if histoInfo is None: + return None + data = getattr(histoInfo, self.getVisualizationParameter( + self.VisualizationParameter.BINNED_STATISTIC_FUNCTION)) + + return backend.addImage( + data=data, + origin=histoInfo.origin, + scale=histoInfo.scale, + colormap=self.getColormap(), + alpha=self.getAlpha()) + # Compute colors cmap = self.getColormap() - rgbacolors = cmap.applyToData(self._value) + rgbacolors = cmap.applyToData(self) if self.__alpha is not None: rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8) - # Apply mask to colors - rgbacolors = rgbacolors[mask] - visualization = self.getVisualization() if visualization is self.Visualization.POINTS: return backend.addCurve(xFiltered, yFiltered, - color=rgbacolors, + color=rgbacolors[mask], symbol=self.getSymbol(), linewidth=0, linestyle="", yaxis='left', xerror=xerror, yerror=yerror, - z=self.getZValue(), fill=False, alpha=self.getAlpha(), symbolsize=self.getSymbolSize(), @@ -432,8 +551,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): return backend.addTriangles(xFiltered, yFiltered, triangles, - color=rgbacolors, - z=self.getZValue(), + color=rgbacolors[mask], alpha=self.getAlpha()) elif visualization is self.Visualization.REGULAR_GRID: @@ -461,7 +579,6 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): data=image, origin=gridInfo.origin, scale=gridInfo.scale, - z=self.getZValue(), colormap=None, alpha=self.getAlpha()) @@ -474,31 +591,89 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): if shape is None: # No shape, no display return None - # clip shape to fully filled lines - if len(xFiltered) != numpy.prod(shape): - if gridInfo.order == 'row': - shape = len(xFiltered) // shape[1], shape[1] + nbpoints = len(xFiltered) + if nbpoints == 1: + # single point, render as a square points + return backend.addCurve(xFiltered, yFiltered, + color=rgbacolors[mask], + symbol='s', + linewidth=0, + linestyle="", + yaxis='left', + xerror=None, + yerror=None, + fill=False, + alpha=self.getAlpha(), + symbolsize=7, + baseline=None) + + # Make shape include all points + gridOrder = gridInfo.order + if nbpoints != numpy.prod(shape): + if gridOrder == 'row': + shape = int(numpy.ceil(nbpoints / shape[1])), shape[1] else: # column-major order - shape = shape[0], len(xFiltered) // shape[0] - if shape[0] < 2 or shape[1] < 2: # Not enough points - return None - - nbpoints = numpy.prod(shape) - if gridInfo.order == 'row': - points = numpy.transpose((xFiltered[:nbpoints], yFiltered[:nbpoints])) - points = points.reshape(shape[0], shape[1], 2) + shape = shape[0], int(numpy.ceil(nbpoints / shape[0])) + + if shape[0] < 2 or shape[1] < 2: # Single line, at least 2 points + points = numpy.ones((2, nbpoints, 2), dtype=numpy.float64) + # Use row/column major depending on shape, not on info value + gridOrder = 'row' if shape[0] == 1 else 'column' + + if gridOrder == 'row': + points[0, :, 0] = xFiltered + points[0, :, 1] = yFiltered + else: # column-major order + points[0, :, 0] = yFiltered + points[0, :, 1] = xFiltered + + # Add a second line that will be clipped in the end + points[1, :-1] = points[0, :-1] + numpy.cross( + points[0, 1:] - points[0, :-1], (0., 0., 1.))[:, :2] + points[1, -1] = points[0, -1] + numpy.cross( + points[0, -1] - points[0, -2], (0., 0., 1.))[:2] + + points.shape = 2, nbpoints, 2 # Use same shape for both orders + coords, indices = _quadrilateral_grid_as_triangles(points) + + elif gridOrder == 'row': # row-major order + if nbpoints != numpy.prod(shape): + points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64) + points[:nbpoints, 0] = xFiltered + points[:nbpoints, 1] = yFiltered + # Index of last element of last fully filled row + index = (nbpoints // shape[1]) * shape[1] + points[nbpoints:, 0] = xFiltered[index - (numpy.prod(shape) - nbpoints):index] + points[nbpoints:, 1] = yFiltered[-1] + else: + points = numpy.transpose((xFiltered, yFiltered)) + points.shape = shape[0], shape[1], 2 else: # column-major order - points = numpy.transpose((yFiltered[:nbpoints], xFiltered[:nbpoints])) - points = points.reshape(shape[1], shape[0], 2) + if nbpoints != numpy.prod(shape): + points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64) + points[:nbpoints, 0] = yFiltered + points[:nbpoints, 1] = xFiltered + # Index of last element of last fully filled column + index = (nbpoints // shape[0]) * shape[0] + points[nbpoints:, 0] = yFiltered[index - (numpy.prod(shape) - nbpoints):index] + points[nbpoints:, 1] = xFiltered[-1] + else: + points = numpy.transpose((yFiltered, xFiltered)) + points.shape = shape[1], shape[0], 2 coords, indices = _quadrilateral_grid_as_triangles(points) - if gridInfo.order == 'row': + # Remove unused extra triangles + coords = coords[:4*nbpoints] + indices = indices[:2*nbpoints] + + if gridOrder == 'row': x, y = coords[:, 0], coords[:, 1] else: # column-major order y, x = coords[:, 0], coords[:, 1] + rgbacolors = rgbacolors[mask] # Filter-out not finite points gridcolors = numpy.empty( (4 * nbpoints, rgbacolors.shape[-1]), dtype=rgbacolors.dtype) for first in range(4): @@ -508,8 +683,8 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): y, indices, color=gridcolors, - z=self.getZValue(), alpha=self.getAlpha()) + else: _logger.error("Unhandled visualization %s", visualization) return None @@ -528,23 +703,15 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): elif visualization is self.Visualization.REGULAR_GRID: # Specific handling of picking for the regular grid mode - plot = self.getPlot() - if plot is None: - return None - - dataPos = plot.pixelToData(x, y) - if dataPos is None: + picked = result.getIndices(copy=False) + if picked is None: return None + row, column = picked[0][0], picked[1][0] gridInfo = self.__getRegularGridInfo() if gridInfo is None: return None - origin = gridInfo.origin - scale = gridInfo.scale - column = int((dataPos[0] - origin[0]) / scale[0]) - row = int((dataPos[1] - origin[1]) / scale[1]) - if gridInfo.order == 'row': index = row * gridInfo.shape[1] + column else: @@ -554,6 +721,23 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): result = PickingResult(self, (index,)) + elif visualization is self.Visualization.BINNED_STATISTIC: + picked = result.getIndices(copy=False) + if picked is None or len(picked) == 0 or len(picked[0]) == 0: + return None + row, col = picked[0][0], picked[1][0] + histoInfo = self.__getHistogramInfo() + if histoInfo is None: + return None + sx, sy = histoInfo.scale + ox, oy = histoInfo.origin + xdata = self.getXData(copy=False) + ydata = self.getYData(copy=False) + indices = numpy.nonzero(numpy.logical_and( + numpy.logical_and(xdata >= ox + sx * col, xdata < ox + sx * (col + 1)), + numpy.logical_and(ydata >= oy + sy * row, ydata < oy + sy * (row + 1))))[0] + result = None if len(indices) == 0 else PickingResult(self, indices) + return result def __getExecutor(self): @@ -750,8 +934,10 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): # Data changed, this needs update self.__cacheRegularGridInfo = None + self.__cacheHistogramInfo = None self._value = value + self._updateColormappedData() if alpha is not None: # Make sure alpha is an array of float in [0, 1] |