summaryrefslogtreecommitdiff
path: root/silx/gui/plot/items/scatter.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/gui/plot/items/scatter.py')
-rw-r--r--silx/gui/plot/items/scatter.py286
1 files changed, 236 insertions, 50 deletions
diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py
index 50cc694..5e7d65b 100644
--- a/silx/gui/plot/items/scatter.py
+++ b/silx/gui/plot/items/scatter.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -43,6 +43,7 @@ from concurrent.futures import ThreadPoolExecutor, CancelledError
from ....utils.proxy import docstring
from ....math.combo import min_max
+from ....math.histogram import Histogramnd
from ....utils.weakref import WeakList
from .._utils.delaunay import delaunay
from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn
@@ -142,12 +143,13 @@ def is_monotonic(array):
:rtype: int
"""
diff = numpy.diff(numpy.ravel(array))
- if numpy.all(diff >= 0):
- return 1
- elif numpy.all(diff <= 0):
- return -1
- else:
- return 0
+ with numpy.errstate(invalid='ignore'):
+ if numpy.all(diff >= 0):
+ return 1
+ elif numpy.all(diff <= 0):
+ return -1
+ else:
+ return 0
def _guess_grid(x, y):
@@ -264,6 +266,10 @@ _RegularGridInfo = namedtuple(
'_RegularGridInfo', ['bounds', 'origin', 'scale', 'shape', 'order'])
+_HistogramInfo = namedtuple(
+ '_HistogramInfo', ['mean', 'count', 'sum', 'origin', 'scale', 'shape'])
+
+
class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
"""Description of a scatter"""
@@ -275,6 +281,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
ScatterVisualizationMixIn.Visualization.SOLID,
ScatterVisualizationMixIn.Visualization.REGULAR_GRID,
ScatterVisualizationMixIn.Visualization.IRREGULAR_GRID,
+ ScatterVisualizationMixIn.Visualization.BINNED_STATISTIC,
)
"""Overrides supported Visualizations"""
@@ -293,17 +300,53 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
# Cache triangles: x, y, indices
self.__cacheTriangles = None, None, None
- # Cache regular grid info
+ # Cache regular grid and histogram info
self.__cacheRegularGridInfo = None
+ self.__cacheHistogramInfo = None
+
+ def _updateColormappedData(self):
+ """Update the colormapped data, to be called when changed"""
+ if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ data = None
+ else:
+ data = getattr(
+ histoInfo,
+ self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
+ else:
+ data = self.getValueData(copy=False)
+ self._setColormappedData(data, copy=False)
+
+ @docstring(ScatterVisualizationMixIn)
+ def setVisualization(self, mode):
+ previous = self.getVisualization()
+ if super().setVisualization(mode):
+ if (bool(mode is self.Visualization.BINNED_STATISTIC) ^
+ bool(previous is self.Visualization.BINNED_STATISTIC)):
+ self._updateColormappedData()
+ return True
+ else:
+ return False
@docstring(ScatterVisualizationMixIn)
def setVisualizationParameter(self, parameter, value):
- changed = super(Scatter, self).setVisualizationParameter(parameter, value)
- if changed and parameter in (self.VisualizationParameter.GRID_BOUNDS,
- self.VisualizationParameter.GRID_MAJOR_ORDER,
- self.VisualizationParameter.GRID_SHAPE):
- self.__cacheRegularGridInfo = None
- return changed
+ if super(Scatter, self).setVisualizationParameter(parameter, value):
+ if parameter in (self.VisualizationParameter.GRID_BOUNDS,
+ self.VisualizationParameter.GRID_MAJOR_ORDER,
+ self.VisualizationParameter.GRID_SHAPE):
+ self.__cacheRegularGridInfo = None
+
+ if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION):
+ if parameter == self.VisualizationParameter.BINNED_STATISTIC_SHAPE:
+ self.__cacheHistogramInfo = None # Clean-up cache
+ if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
+ self._updateColormappedData()
+ return True
+ else:
+ return False
@docstring(ScatterVisualizationMixIn)
def getCurrentVisualizationParameter(self, parameter):
@@ -323,6 +366,10 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
grid = self.__getRegularGridInfo()
return None if grid is None else grid.shape
+ elif parameter is self.VisualizationParameter.BINNED_STATISTIC_SHAPE:
+ info = self.__getHistogramInfo()
+ return None if info is None else info.shape
+
else:
raise NotImplementedError()
@@ -345,6 +392,18 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
if order is None:
order = guess[0]
+ nbpoints = len(self.getXData(copy=False))
+ if nbpoints > shape[0] * shape[1]:
+ # More data points that provided grid shape: enlarge grid
+ _logger.warning(
+ "More data points than provided grid shape size: extends grid")
+ dim0, dim1 = shape
+ if order == 'row': # keep dim1, enlarge dim0
+ dim0 = nbpoints // dim1 + (1 if nbpoints % dim1 else 0)
+ else: # keep dim0, enlarge dim1
+ dim1 = nbpoints // dim0 + (1 if nbpoints % dim0 else 0)
+ shape = dim0, dim1
+
bounds = self.getVisualizationParameter(
self.VisualizationParameter.GRID_BOUNDS)
if bounds is None:
@@ -372,6 +431,47 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
return self.__cacheRegularGridInfo
+ def __getHistogramInfo(self):
+ """Get histogram info"""
+ if self.__cacheHistogramInfo is None:
+ shape = self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_SHAPE)
+ if shape is None:
+ shape = 100, 100 # TODO compute auto shape
+
+ x, y, values = self.getData(copy=False)[:3]
+ if len(x) == 0: # No histogram
+ return None
+
+ if not numpy.issubdtype(x.dtype, numpy.floating):
+ x = x.astype(numpy.float64)
+ if not numpy.issubdtype(y.dtype, numpy.floating):
+ y = y.astype(numpy.float64)
+ if not numpy.issubdtype(values.dtype, numpy.floating):
+ values = values.astype(numpy.float64)
+
+ ranges = (tuple(min_max(y, finite=True)),
+ tuple(min_max(x, finite=True)))
+ points = numpy.transpose(numpy.array((y, x)))
+ counts, sums, bin_edges = Histogramnd(
+ points,
+ histo_range=ranges,
+ n_bins=shape,
+ weights=values)
+ yEdges, xEdges = bin_edges
+ origin = xEdges[0], yEdges[0]
+ scale = ((xEdges[-1] - xEdges[0]) / (len(xEdges) - 1),
+ (yEdges[-1] - yEdges[0]) / (len(yEdges) - 1))
+
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ histo = sums / counts
+
+ self.__cacheHistogramInfo = _HistogramInfo(
+ mean=histo, count=counts, sum=sums,
+ origin=origin, scale=scale, shape=shape)
+
+ return self.__cacheHistogramInfo
+
def _addBackendRenderer(self, backend):
"""Update backend renderer"""
# Filter-out values <= 0
@@ -386,28 +486,47 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
if len(xFiltered) == 0:
return None # No data to display, do not add renderer to backend
+ visualization = self.getVisualization()
+
+ if visualization is self.Visualization.BINNED_STATISTIC:
+ plot = self.getPlot()
+ if (plot is None or
+ plot.getXAxis().getScale() != Axis.LINEAR or
+ plot.getYAxis().getScale() != Axis.LINEAR):
+ # Those visualizations are not available with log scaled axes
+ return None
+
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ return None
+ data = getattr(histoInfo, self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
+
+ return backend.addImage(
+ data=data,
+ origin=histoInfo.origin,
+ scale=histoInfo.scale,
+ colormap=self.getColormap(),
+ alpha=self.getAlpha())
+
# Compute colors
cmap = self.getColormap()
- rgbacolors = cmap.applyToData(self._value)
+ rgbacolors = cmap.applyToData(self)
if self.__alpha is not None:
rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8)
- # Apply mask to colors
- rgbacolors = rgbacolors[mask]
-
visualization = self.getVisualization()
if visualization is self.Visualization.POINTS:
return backend.addCurve(xFiltered, yFiltered,
- color=rgbacolors,
+ color=rgbacolors[mask],
symbol=self.getSymbol(),
linewidth=0,
linestyle="",
yaxis='left',
xerror=xerror,
yerror=yerror,
- z=self.getZValue(),
fill=False,
alpha=self.getAlpha(),
symbolsize=self.getSymbolSize(),
@@ -432,8 +551,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
return backend.addTriangles(xFiltered,
yFiltered,
triangles,
- color=rgbacolors,
- z=self.getZValue(),
+ color=rgbacolors[mask],
alpha=self.getAlpha())
elif visualization is self.Visualization.REGULAR_GRID:
@@ -461,7 +579,6 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
data=image,
origin=gridInfo.origin,
scale=gridInfo.scale,
- z=self.getZValue(),
colormap=None,
alpha=self.getAlpha())
@@ -474,31 +591,89 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
if shape is None: # No shape, no display
return None
- # clip shape to fully filled lines
- if len(xFiltered) != numpy.prod(shape):
- if gridInfo.order == 'row':
- shape = len(xFiltered) // shape[1], shape[1]
+ nbpoints = len(xFiltered)
+ if nbpoints == 1:
+ # single point, render as a square points
+ return backend.addCurve(xFiltered, yFiltered,
+ color=rgbacolors[mask],
+ symbol='s',
+ linewidth=0,
+ linestyle="",
+ yaxis='left',
+ xerror=None,
+ yerror=None,
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=7,
+ baseline=None)
+
+ # Make shape include all points
+ gridOrder = gridInfo.order
+ if nbpoints != numpy.prod(shape):
+ if gridOrder == 'row':
+ shape = int(numpy.ceil(nbpoints / shape[1])), shape[1]
else: # column-major order
- shape = shape[0], len(xFiltered) // shape[0]
- if shape[0] < 2 or shape[1] < 2: # Not enough points
- return None
-
- nbpoints = numpy.prod(shape)
- if gridInfo.order == 'row':
- points = numpy.transpose((xFiltered[:nbpoints], yFiltered[:nbpoints]))
- points = points.reshape(shape[0], shape[1], 2)
+ shape = shape[0], int(numpy.ceil(nbpoints / shape[0]))
+
+ if shape[0] < 2 or shape[1] < 2: # Single line, at least 2 points
+ points = numpy.ones((2, nbpoints, 2), dtype=numpy.float64)
+ # Use row/column major depending on shape, not on info value
+ gridOrder = 'row' if shape[0] == 1 else 'column'
+
+ if gridOrder == 'row':
+ points[0, :, 0] = xFiltered
+ points[0, :, 1] = yFiltered
+ else: # column-major order
+ points[0, :, 0] = yFiltered
+ points[0, :, 1] = xFiltered
+
+ # Add a second line that will be clipped in the end
+ points[1, :-1] = points[0, :-1] + numpy.cross(
+ points[0, 1:] - points[0, :-1], (0., 0., 1.))[:, :2]
+ points[1, -1] = points[0, -1] + numpy.cross(
+ points[0, -1] - points[0, -2], (0., 0., 1.))[:2]
+
+ points.shape = 2, nbpoints, 2 # Use same shape for both orders
+ coords, indices = _quadrilateral_grid_as_triangles(points)
+
+ elif gridOrder == 'row': # row-major order
+ if nbpoints != numpy.prod(shape):
+ points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
+ points[:nbpoints, 0] = xFiltered
+ points[:nbpoints, 1] = yFiltered
+ # Index of last element of last fully filled row
+ index = (nbpoints // shape[1]) * shape[1]
+ points[nbpoints:, 0] = xFiltered[index - (numpy.prod(shape) - nbpoints):index]
+ points[nbpoints:, 1] = yFiltered[-1]
+ else:
+ points = numpy.transpose((xFiltered, yFiltered))
+ points.shape = shape[0], shape[1], 2
else: # column-major order
- points = numpy.transpose((yFiltered[:nbpoints], xFiltered[:nbpoints]))
- points = points.reshape(shape[1], shape[0], 2)
+ if nbpoints != numpy.prod(shape):
+ points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
+ points[:nbpoints, 0] = yFiltered
+ points[:nbpoints, 1] = xFiltered
+ # Index of last element of last fully filled column
+ index = (nbpoints // shape[0]) * shape[0]
+ points[nbpoints:, 0] = yFiltered[index - (numpy.prod(shape) - nbpoints):index]
+ points[nbpoints:, 1] = xFiltered[-1]
+ else:
+ points = numpy.transpose((yFiltered, xFiltered))
+ points.shape = shape[1], shape[0], 2
coords, indices = _quadrilateral_grid_as_triangles(points)
- if gridInfo.order == 'row':
+ # Remove unused extra triangles
+ coords = coords[:4*nbpoints]
+ indices = indices[:2*nbpoints]
+
+ if gridOrder == 'row':
x, y = coords[:, 0], coords[:, 1]
else: # column-major order
y, x = coords[:, 0], coords[:, 1]
+ rgbacolors = rgbacolors[mask] # Filter-out not finite points
gridcolors = numpy.empty(
(4 * nbpoints, rgbacolors.shape[-1]), dtype=rgbacolors.dtype)
for first in range(4):
@@ -508,8 +683,8 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
y,
indices,
color=gridcolors,
- z=self.getZValue(),
alpha=self.getAlpha())
+
else:
_logger.error("Unhandled visualization %s", visualization)
return None
@@ -528,23 +703,15 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
elif visualization is self.Visualization.REGULAR_GRID:
# Specific handling of picking for the regular grid mode
- plot = self.getPlot()
- if plot is None:
- return None
-
- dataPos = plot.pixelToData(x, y)
- if dataPos is None:
+ picked = result.getIndices(copy=False)
+ if picked is None:
return None
+ row, column = picked[0][0], picked[1][0]
gridInfo = self.__getRegularGridInfo()
if gridInfo is None:
return None
- origin = gridInfo.origin
- scale = gridInfo.scale
- column = int((dataPos[0] - origin[0]) / scale[0])
- row = int((dataPos[1] - origin[1]) / scale[1])
-
if gridInfo.order == 'row':
index = row * gridInfo.shape[1] + column
else:
@@ -554,6 +721,23 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
result = PickingResult(self, (index,))
+ elif visualization is self.Visualization.BINNED_STATISTIC:
+ picked = result.getIndices(copy=False)
+ if picked is None or len(picked) == 0 or len(picked[0]) == 0:
+ return None
+ row, col = picked[0][0], picked[1][0]
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ return None
+ sx, sy = histoInfo.scale
+ ox, oy = histoInfo.origin
+ xdata = self.getXData(copy=False)
+ ydata = self.getYData(copy=False)
+ indices = numpy.nonzero(numpy.logical_and(
+ numpy.logical_and(xdata >= ox + sx * col, xdata < ox + sx * (col + 1)),
+ numpy.logical_and(ydata >= oy + sy * row, ydata < oy + sy * (row + 1))))[0]
+ result = None if len(indices) == 0 else PickingResult(self, indices)
+
return result
def __getExecutor(self):
@@ -750,8 +934,10 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
# Data changed, this needs update
self.__cacheRegularGridInfo = None
+ self.__cacheHistogramInfo = None
self._value = value
+ self._updateColormappedData()
if alpha is not None:
# Make sure alpha is an array of float in [0, 1]