summaryrefslogtreecommitdiff
path: root/silx/gui/plot/items/scatter.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/gui/plot/items/scatter.py')
-rw-r--r--silx/gui/plot/items/scatter.py231
1 files changed, 209 insertions, 22 deletions
diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py
index 707dd3d..b2f087b 100644
--- a/silx/gui/plot/items/scatter.py
+++ b/silx/gui/plot/items/scatter.py
@@ -31,26 +31,79 @@ __date__ = "29/03/2017"
import logging
-
+import threading
import numpy
-from .core import Points, ColormapMixIn
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor, CancelledError
+
+from ....utils.weakref import WeakList
+from .._utils.delaunay import delaunay
+from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn
+from .axis import Axis
_logger = logging.getLogger(__name__)
-class Scatter(Points, ColormapMixIn):
+class _GreedyThreadPoolExecutor(ThreadPoolExecutor):
+ """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(_GreedyThreadPoolExecutor, self).__init__(*args, **kwargs)
+ self.__futures = defaultdict(WeakList)
+ self.__lock = threading.RLock()
+
+ def submit_greedy(self, queue, fn, *args, **kwargs):
+ """Same as :meth:`submit` but cancel previous tasks in given queue.
+
+ This means that when a new task is submitted for a given queue,
+ all other pending tasks of that queue are cancelled.
+
+ :param queue: Identifier of the queue. This must be hashable.
+ :param callable fn: The callable to call with provided extra arguments
+ :return: Future corresponding to this task
+ :rtype: concurrent.futures.Future
+ """
+ with self.__lock:
+ # Cancel previous tasks in given queue
+ for future in self.__futures.pop(queue, []):
+ if not future.done():
+ future.cancel()
+
+ future = super(_GreedyThreadPoolExecutor, self).submit(
+ fn, *args, **kwargs)
+ self.__futures[queue].append(future)
+
+ return future
+
+
+class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
"""Description of a scatter"""
_DEFAULT_SELECTABLE = True
"""Default selectable state for scatter plots"""
+ _SUPPORTED_SCATTER_VISUALIZATION = (
+ ScatterVisualizationMixIn.Visualization.POINTS,
+ ScatterVisualizationMixIn.Visualization.SOLID)
+ """Overrides supported Visualizations"""
+
def __init__(self):
- Points.__init__(self)
+ PointsBase.__init__(self)
ColormapMixIn.__init__(self)
+ ScatterVisualizationMixIn.__init__(self)
self._value = ()
self.__alpha = None
+ # Cache Delaunay triangulation future object
+ self.__delaunayFuture = None
+ # Cache interpolator future object
+ self.__interpolatorFuture = None
+ self.__executor = None
+
+ # Cache triangles: x, y, indices
+ self.__cacheTriangles = None, None, None
def _addBackendRenderer(self, backend):
"""Update backend renderer"""
@@ -58,28 +111,154 @@ class Scatter(Points, ColormapMixIn):
xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData(
copy=False, displayed=True)
+ # Remove not finite numbers (this includes filtered out x, y <= 0)
+ mask = numpy.logical_and(numpy.isfinite(xFiltered), numpy.isfinite(yFiltered))
+ xFiltered = xFiltered[mask]
+ yFiltered = yFiltered[mask]
+
if len(xFiltered) == 0:
return None # No data to display, do not add renderer to backend
+ # Compute colors
cmap = self.getColormap()
rgbacolors = cmap.applyToData(self._value)
if self.__alpha is not None:
rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8)
- return backend.addCurve(xFiltered, yFiltered, self.getLegend(),
- color=rgbacolors,
- symbol=self.getSymbol(),
- linewidth=0,
- linestyle="",
- yaxis='left',
- xerror=xerror,
- yerror=yerror,
- z=self.getZValue(),
- selectable=self.isSelectable(),
- fill=False,
- alpha=self.getAlpha(),
- symbolsize=self.getSymbolSize())
+ # Apply mask to colors
+ rgbacolors = rgbacolors[mask]
+
+ if self.getVisualization() is self.Visualization.POINTS:
+ return backend.addCurve(xFiltered, yFiltered, self.getLegend(),
+ color=rgbacolors,
+ symbol=self.getSymbol(),
+ linewidth=0,
+ linestyle="",
+ yaxis='left',
+ xerror=xerror,
+ yerror=yerror,
+ z=self.getZValue(),
+ selectable=self.isSelectable(),
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=self.getSymbolSize())
+
+ else: # 'solid'
+ plot = self.getPlot()
+ if (plot is None or
+ plot.getXAxis().getScale() != Axis.LINEAR or
+ plot.getYAxis().getScale() != Axis.LINEAR):
+ # Solid visualization is not available with log scaled axes
+ return None
+
+ triangulation = self._getDelaunay().result()
+ if triangulation is None:
+ return None
+ else:
+ triangles = triangulation.simplices.astype(numpy.int32)
+ return backend.addTriangles(xFiltered,
+ yFiltered,
+ triangles,
+ legend=self.getLegend(),
+ color=rgbacolors,
+ z=self.getZValue(),
+ selectable=self.isSelectable(),
+ alpha=self.getAlpha())
+
+ def __getExecutor(self):
+ """Returns async greedy executor
+
+ :rtype: _GreedyThreadPoolExecutor
+ """
+ if self.__executor is None:
+ self.__executor = _GreedyThreadPoolExecutor(max_workers=2)
+ return self.__executor
+
+ def _getDelaunay(self):
+ """Returns a :class:`Future` which result is the Delaunay object.
+
+ :rtype: concurrent.futures.Future
+ """
+ if self.__delaunayFuture is None or self.__delaunayFuture.cancelled():
+ # Need to init a new delaunay
+ x, y = self.getData(copy=False)[:2]
+ # Remove not finite points
+ mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
+
+ self.__delaunayFuture = self.__getExecutor().submit_greedy(
+ 'delaunay', delaunay, x[mask], y[mask])
+
+ return self.__delaunayFuture
+
+ @staticmethod
+ def __initInterpolator(delaunayFuture, values):
+ """Returns an interpolator for the given data points
+
+ :param concurrent.futures.Future delaunayFuture:
+ Future object which result is a Delaunay object
+ :param numpy.ndarray values: The data value of valid points.
+ :rtype: Union[callable,None]
+ """
+ # Wait for Delaunay to complete
+ try:
+ triangulation = delaunayFuture.result()
+ except CancelledError:
+ triangulation = None
+
+ if triangulation is None:
+ interpolator = None # Error case
+ else:
+ # Lazy-loading of interpolator
+ try:
+ from scipy.interpolate import LinearNDInterpolator
+ except ImportError:
+ LinearNDInterpolator = None
+
+ if LinearNDInterpolator is not None:
+ interpolator = LinearNDInterpolator(triangulation, values)
+
+ # First call takes a while, do it here
+ interpolator([(0., 0.)])
+
+ else:
+ # Fallback using matplotlib interpolator
+ import matplotlib.tri
+
+ x, y = triangulation.points.T
+ tri = matplotlib.tri.Triangulation(
+ x, y, triangles=triangulation.simplices)
+ mplInterpolator = matplotlib.tri.LinearTriInterpolator(
+ tri, values)
+
+ # Wrap interpolator to have same API as scipy's one
+ def interpolator(points):
+ return mplInterpolator(*points.T)
+
+ return interpolator
+
+ def _getInterpolator(self):
+ """Returns a :class:`Future` which result is the interpolator.
+
+ The interpolator is a callable taking an array Nx2 of points
+ as a single argument.
+ The :class:`Future` result is None in case the interpolator cannot
+ be initialized.
+
+ :rtype: concurrent.futures.Future
+ """
+ if (self.__interpolatorFuture is None or
+ self.__interpolatorFuture.cancelled()):
+ # Need to init a new interpolator
+ x, y, values = self.getData(copy=False)[:3]
+ # Remove not finite points
+ mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
+ x, y, values = x[mask], y[mask], values[mask]
+
+ self.__interpolatorFuture = self.__getExecutor().submit_greedy(
+ 'interpolator',
+ self.__initInterpolator, self._getDelaunay(), values)
+ return self.__interpolatorFuture
def _logFilterData(self, xPositive, yPositive):
"""Filter out values with x or y <= 0 on log axes
@@ -89,7 +268,7 @@ class Scatter(Points, ColormapMixIn):
:return: The filtered arrays or unchanged object if not filtering needed
:rtype: (x, y, value, xerror, yerror)
"""
- # overloaded from Points to filter also value.
+ # overloaded from PointsBase to filter also value.
value = self.getValueData(copy=False)
if xPositive or yPositive:
@@ -100,7 +279,7 @@ class Scatter(Points, ColormapMixIn):
value = numpy.array(value, copy=True, dtype=numpy.float)
value[clipped] = numpy.nan
- x, y, xerror, yerror = Points._logFilterData(self, xPositive, yPositive)
+ x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive)
return x, y, value, xerror, yerror
@@ -146,7 +325,7 @@ class Scatter(Points, ColormapMixIn):
self.getXErrorData(copy),
self.getYErrorData(copy))
- # reimplemented from Points to handle `value`
+ # reimplemented from PointsBase to handle `value`
def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
"""Set the data of the scatter.
@@ -171,6 +350,14 @@ class Scatter(Points, ColormapMixIn):
assert value.ndim == 1
assert len(x) == len(value)
+ # Reset triangulation and interpolator
+ if self.__delaunayFuture is not None:
+ self.__delaunayFuture.cancel()
+ self.__delaunayFuture = None
+ if self.__interpolatorFuture is not None:
+ self.__interpolatorFuture.cancel()
+ self.__interpolatorFuture = None
+
self._value = value
if alpha is not None:
@@ -183,8 +370,8 @@ class Scatter(Points, ColormapMixIn):
if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)):
alpha = numpy.clip(alpha, 0., 1.)
self.__alpha = alpha
-
+
# set x, y, xerror, yerror
# call self._updated + plot._invalidateDataRange()
- Points.setData(self, x, y, xerror, yerror, copy)
+ PointsBase.setData(self, x, y, xerror, yerror, copy)