From 654a6ac93513c3cc1ef97cacd782ff674c6f4559 Mon Sep 17 00:00:00 2001 From: Alexandre Marie Date: Tue, 9 Jul 2019 10:20:20 +0200 Subject: New upstream version 0.11.0+dfsg --- silx/gui/_glutils/Context.py | 42 +- silx/gui/_glutils/Program.py | 12 +- silx/gui/_glutils/Texture.py | 10 +- silx/gui/_glutils/__init__.py | 5 +- silx/gui/_glutils/utils.py | 73 ++- silx/gui/colors.py | 8 +- silx/gui/console.py | 32 +- silx/gui/data/ArrayTableModel.py | 5 +- silx/gui/data/DataViews.py | 46 +- silx/gui/data/NXdataWidgets.py | 68 +-- silx/gui/data/_VolumeWindow.py | 148 ++++++ silx/gui/data/test/test_dataviewer.py | 6 +- silx/gui/data/test/test_textformatter.py | 4 +- silx/gui/dialog/AbstractDataFileDialog.py | 34 +- silx/gui/dialog/ColormapDialog.py | 3 +- silx/gui/dialog/ImageFileDialog.py | 25 +- silx/gui/dialog/test/test_colormapdialog.py | 6 +- silx/gui/dialog/test/test_datafiledialog.py | 15 +- silx/gui/dialog/test/test_imagefiledialog.py | 74 +-- silx/gui/hdf5/test/test_hdf5.py | 8 +- silx/gui/plot/CompareImages.py | 77 +++- silx/gui/plot/ComplexImageView.py | 85 ++-- silx/gui/plot/CurvesROIWidget.py | 46 +- silx/gui/plot/PlotToolButtons.py | 104 ++++- silx/gui/plot/PlotWidget.py | 122 +++-- silx/gui/plot/PlotWindow.py | 64 +-- silx/gui/plot/Profile.py | 29 +- silx/gui/plot/ProfileMainWindow.py | 14 +- silx/gui/plot/ScatterView.py | 74 +-- silx/gui/plot/StatsWidget.py | 436 +++++++++++++++--- silx/gui/plot/_BaseMaskToolsWidget.py | 48 +- silx/gui/plot/_utils/delaunay.py | 62 +++ silx/gui/plot/actions/control.py | 10 +- silx/gui/plot/backends/BackendBase.py | 17 + silx/gui/plot/backends/BackendMatplotlib.py | 52 ++- silx/gui/plot/backends/BackendOpenGL.py | 236 ++++++---- silx/gui/plot/backends/glutils/GLPlotTriangles.py | 193 ++++++++ silx/gui/plot/backends/glutils/GLText.py | 23 +- silx/gui/plot/backends/glutils/GLTexture.py | 3 +- silx/gui/plot/backends/glutils/__init__.py | 3 +- silx/gui/plot/items/__init__.py | 7 +- silx/gui/plot/items/complex.py | 121 +++-- silx/gui/plot/items/core.py | 197 +++++++- silx/gui/plot/items/curve.py | 6 +- silx/gui/plot/items/image.py | 72 ++- silx/gui/plot/items/marker.py | 17 +- silx/gui/plot/items/roi.py | 40 +- silx/gui/plot/items/scatter.py | 231 +++++++++- silx/gui/plot/matplotlib/__init__.py | 76 +-- silx/gui/plot/test/testAlphaSlider.py | 5 +- silx/gui/plot/test/testComplexImageView.py | 6 +- silx/gui/plot/test/testCurvesROIWidget.py | 140 +++++- silx/gui/plot/test/testItem.py | 13 +- silx/gui/plot/test/testPlotWidget.py | 55 ++- silx/gui/plot/test/testPlotWindow.py | 29 +- silx/gui/plot/test/testProfile.py | 6 +- silx/gui/plot/test/testStackView.py | 6 +- silx/gui/plot/test/testStats.py | 259 +++++++++-- .../plot/tools/profile/ScatterProfileToolBar.py | 362 +++------------ silx/gui/plot/tools/roi.py | 23 +- .../plot/tools/test/testScatterProfileToolBar.py | 1 + silx/gui/plot/tools/test/testTools.py | 30 +- silx/gui/plot/tools/toolbars.py | 28 +- silx/gui/plot3d/Plot3DWidget.py | 52 ++- silx/gui/plot3d/Plot3DWindow.py | 21 +- silx/gui/plot3d/SceneWidget.py | 27 +- silx/gui/plot3d/SceneWindow.py | 22 +- silx/gui/plot3d/_model/items.py | 512 +++++++++++++++------ silx/gui/plot3d/actions/mode.py | 61 ++- silx/gui/plot3d/items/__init__.py | 4 +- silx/gui/plot3d/items/mesh.py | 5 +- silx/gui/plot3d/items/mixins.py | 18 +- silx/gui/plot3d/items/scatter.py | 108 ++--- silx/gui/plot3d/items/volume.py | 308 +++++++++++-- silx/gui/plot3d/scene/camera.py | 2 + silx/gui/plot3d/scene/core.py | 11 +- silx/gui/plot3d/scene/cutplane.py | 18 +- silx/gui/plot3d/scene/function.py | 87 +++- silx/gui/plot3d/scene/interaction.py | 60 ++- silx/gui/plot3d/scene/primitives.py | 130 +++--- silx/gui/plot3d/scene/utils.py | 73 +-- silx/gui/plot3d/scene/viewport.py | 75 ++- silx/gui/plot3d/test/__init__.py | 4 + silx/gui/plot3d/test/testSceneWidget.py | 84 ++++ silx/gui/plot3d/test/testSceneWidgetPicking.py | 96 ++-- silx/gui/plot3d/test/testSceneWindow.py | 209 +++++++++ silx/gui/plot3d/tools/PositionInfoWidget.py | 42 +- silx/gui/qt/_qt.py | 30 +- silx/gui/qt/inspect.py | 13 +- silx/gui/test/test_colors.py | 6 +- silx/gui/utils/testutils.py | 23 +- 91 files changed, 4416 insertions(+), 1737 deletions(-) create mode 100644 silx/gui/data/_VolumeWindow.py create mode 100644 silx/gui/plot/_utils/delaunay.py create mode 100644 silx/gui/plot/backends/glutils/GLPlotTriangles.py create mode 100644 silx/gui/plot3d/test/testSceneWidget.py create mode 100644 silx/gui/plot3d/test/testSceneWindow.py (limited to 'silx/gui') diff --git a/silx/gui/_glutils/Context.py b/silx/gui/_glutils/Context.py index 7600992..c62dbb9 100644 --- a/silx/gui/_glutils/Context.py +++ b/silx/gui/_glutils/Context.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,32 +32,44 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" __date__ = "25/07/2016" +import contextlib -# context ##################################################################### +class _DEFAULT_CONTEXT(object): + """The default value for OpenGL context""" + pass -def _defaultGLContextGetter(): - return None +_context = _DEFAULT_CONTEXT +"""The current OpenGL context""" -_glContextGetter = _defaultGLContextGetter - -def getGLContext(): +def getCurrent(): """Returns platform dependent object of current OpenGL context. This is useful to associate OpenGL resources with the context they are created in. :return: Platform specific OpenGL context - :rtype: None by default or a platform dependent object""" - return _glContextGetter() + """ + return _context + + +def setCurrent(context=_DEFAULT_CONTEXT): + """Set a platform dependent OpenGL context + + :param context: Platform dependent GL context + """ + global _context + _context = context -def setGLContextGetter(getter=_defaultGLContextGetter): - """Set a platform dependent function to retrieve the current OpenGL context +@contextlib.contextmanager +def current(context): + """Context manager setting the platform-dependent GL context - :param getter: Platform dependent GL context getter - :type getter: Function with no args returning the current OpenGL context + :param context: Platform dependent GL context """ - global _glContextGetter - _glContextGetter = getter + previous_context = getCurrent() + setCurrent(context) + yield + setCurrent(previous_context) diff --git a/silx/gui/_glutils/Program.py b/silx/gui/_glutils/Program.py index 48c12f5..87eec5f 100644 --- a/silx/gui/_glutils/Program.py +++ b/silx/gui/_glutils/Program.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -30,11 +30,11 @@ __date__ = "25/07/2016" import logging +import weakref import numpy -from . import gl -from .Context import getGLContext +from . import Context, gl _logger = logging.getLogger(__name__) @@ -61,7 +61,7 @@ class Program(object): self._vertexShader = vertexShader self._fragmentShader = fragmentShader self._attrib0 = attrib0 - self._programs = {} + self._programs = weakref.WeakKeyDictionary() @staticmethod def _compileGL(vertexShader, fragmentShader, attrib0): @@ -106,7 +106,7 @@ class Program(object): return program, attributes, uniforms def _getProgramInfo(self): - glcontext = getGLContext() + glcontext = Context.getCurrent() if glcontext not in self._programs: raise RuntimeError( "Program was not compiled for current OpenGL context.") @@ -149,7 +149,7 @@ class Program(object): def use(self): """Make use of the program, compiling it if necessary""" - glcontext = getGLContext() + glcontext = Context.getCurrent() if glcontext not in self._programs: self._programs[glcontext] = self._compileGL( diff --git a/silx/gui/_glutils/Texture.py b/silx/gui/_glutils/Texture.py index 0875ebe..a7fd44b 100644 --- a/silx/gui/_glutils/Texture.py +++ b/silx/gui/_glutils/Texture.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,11 @@ __license__ = "MIT" __date__ = "04/10/2016" -import collections +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc + from ctypes import c_void_p import logging @@ -93,7 +97,7 @@ class Texture(object): self.magFilter = magFilter if magFilter is not None else gl.GL_LINEAR if wrap is not None: - if not isinstance(wrap, collections.Iterable): + if not isinstance(wrap, abc.Iterable): wrap = [wrap] * self.ndim assert len(wrap) == self.ndim diff --git a/silx/gui/_glutils/__init__.py b/silx/gui/_glutils/__init__.py index 15e48e1..e88affd 100644 --- a/silx/gui/_glutils/__init__.py +++ b/silx/gui/_glutils/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,9 +34,10 @@ __date__ = "25/07/2016" # OpenGL convenient functions from .OpenGLWidget import OpenGLWidget # noqa -from .Context import getGLContext, setGLContextGetter # noqa +from . import Context # noqa from .FramebufferTexture import FramebufferTexture # noqa from .Program import Program # noqa from .Texture import Texture # noqa from .VertexBuffer import VertexBuffer, VertexBufferAttrib, vertexBuffer # noqa from .utils import sizeofGLType, isSupportedGLType, numpyToGLType # noqa +from .utils import segmentTrianglesIntersection # noqa diff --git a/silx/gui/_glutils/utils.py b/silx/gui/_glutils/utils.py index 73af338..35cf819 100644 --- a/silx/gui/_glutils/utils.py +++ b/silx/gui/_glutils/utils.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -68,3 +68,74 @@ def isSupportedGLType(type_): def numpyToGLType(type_): """Returns the GL type corresponding the provided numpy type or dtype.""" return _TYPE_CONVERTER[numpy.dtype(type_)] + + +def segmentTrianglesIntersection(segment, triangles): + """Check for segment/triangles intersection. + + This is based on signed tetrahedron volume comparison. + + See A. Kensler, A., Shirley, P. + Optimizing Ray-Triangle Intersection via Automated Search. + Symposium on Interactive Ray Tracing, vol. 0, p33-38 (2006) + + :param numpy.ndarray segment: + Segment end points as a 2x3 array of coordinates + :param numpy.ndarray triangles: + Nx3x3 array of triangles + :return: (triangle indices, segment parameter, barycentric coord) + Indices of intersected triangles, "depth" along the segment + of the intersection point and barycentric coordinates of intersection + point in the triangle. + :rtype: List[numpy.ndarray] + """ + # TODO triangles from vertices + indices + # TODO early rejection? e.g., check segment bbox vs triangle bbox + segment = numpy.asarray(segment) + assert segment.ndim == 2 + assert segment.shape == (2, 3) + + triangles = numpy.asarray(triangles) + assert triangles.ndim == 3 + assert triangles.shape[1] == 3 + + # Test line/triangles intersection + d = segment[1] - segment[0] + t0s0 = segment[0] - triangles[:, 0, :] + edge01 = triangles[:, 1, :] - triangles[:, 0, :] + edge02 = triangles[:, 2, :] - triangles[:, 0, :] + + dCrossEdge02 = numpy.cross(d, edge02) + t0s0CrossEdge01 = numpy.cross(t0s0, edge01) + volume = numpy.sum(dCrossEdge02 * edge01, axis=1) + del edge01 + subVolumes = numpy.empty((len(triangles), 3), dtype=triangles.dtype) + subVolumes[:, 1] = numpy.sum(dCrossEdge02 * t0s0, axis=1) + del dCrossEdge02 + subVolumes[:, 2] = numpy.sum(t0s0CrossEdge01 * d, axis=1) + subVolumes[:, 0] = volume - subVolumes[:, 1] - subVolumes[:, 2] + intersect = numpy.logical_or( + numpy.all(subVolumes >= 0., axis=1), # All positive + numpy.all(subVolumes <= 0., axis=1)) # All negative + intersect = numpy.where(intersect)[0] # Indices of intersected triangles + + # Get barycentric coordinates + barycentric = subVolumes[intersect] / volume[intersect].reshape(-1, 1) + del subVolumes + + # Test segment/triangles intersection + volAlpha = numpy.sum(t0s0CrossEdge01[intersect] * edge02[intersect], axis=1) + t = volAlpha / volume[intersect] # segment parameter of intersected triangles + del t0s0CrossEdge01 + del edge02 + del volAlpha + del volume + + inSegmentMask = numpy.logical_and(t >= 0., t <= 1.) + intersect = intersect[inSegmentMask] + t = t[inSegmentMask] + barycentric = barycentric[inSegmentMask] + + # Sort intersecting triangles by t + indices = numpy.argsort(t) + return intersect[indices], t[indices], barycentric[indices] diff --git a/silx/gui/colors.py b/silx/gui/colors.py index f1f34c9..aa2958a 100644 --- a/silx/gui/colors.py +++ b/silx/gui/colors.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -887,7 +887,11 @@ def preferredColormaps(): global _PREFERRED_COLORMAPS if _PREFERRED_COLORMAPS is None: # Initialize preferred colormaps - default_preferred = [k for k in _AVAILABLE_LUTS.keys() if _AVAILABLE_LUTS[k].preferred] + default_preferred = [] + for name, info in _AVAILABLE_LUTS.items(): + if (info.preferred and + (info.source != 'matplotlib' or _matplotlib_cm is not None)): + default_preferred.append(name) setPreferredColormaps(default_preferred) return tuple(_PREFERRED_COLORMAPS) diff --git a/silx/gui/console.py b/silx/gui/console.py index b6341ef..5dc6336 100644 --- a/silx/gui/console.py +++ b/silx/gui/console.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -87,17 +87,26 @@ else: msg = "Module " + __name__ + " cannot be used within an IPython shell" raise ImportError(msg) - try: - from qtconsole.rich_ipython_widget import RichJupyterWidget as \ - RichIPythonWidget + from qtconsole.rich_jupyter_widget import RichJupyterWidget as \ + _RichJupyterWidget except ImportError: - from qtconsole.rich_ipython_widget import RichIPythonWidget + try: + from qtconsole.rich_ipython_widget import RichJupyterWidget as \ + _RichJupyterWidget + except ImportError: + from qtconsole.rich_ipython_widget import RichIPythonWidget as \ + _RichJupyterWidget from qtconsole.inprocess import QtInProcessKernelManager +try: + from ipykernel import version_info as _ipykernel_version_info +except ImportError: + _ipykernel_version_info = None + -class IPythonWidget(RichIPythonWidget): +class IPythonWidget(_RichJupyterWidget): """Live IPython console widget. .. image:: img/IPythonWidget.png @@ -115,6 +124,16 @@ class IPythonWidget(RichIPythonWidget): self.setWindowTitle(self.banner) self.kernel_manager = kernel_manager = QtInProcessKernelManager() kernel_manager.start_kernel() + + # Monkey-patch to workaround issue: + # https://github.com/ipython/ipykernel/issues/370 + if (_ipykernel_version_info is not None and + _ipykernel_version_info[0] > 4 and + _ipykernel_version_info[:3] <= (5, 1, 0)): + def _abort_queues(*args, **kwargs): + pass + kernel_manager.kernel._abort_queues = _abort_queues + self.kernel_client = kernel_client = self._kernel_manager.client() kernel_client.start_channels() @@ -178,5 +197,6 @@ def main(): widget.show() app.exec_() + if __name__ == '__main__': main() diff --git a/silx/gui/data/ArrayTableModel.py b/silx/gui/data/ArrayTableModel.py index ad4d33a..8805241 100644 --- a/silx/gui/data/ArrayTableModel.py +++ b/silx/gui/data/ArrayTableModel.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -245,8 +245,7 @@ class ArrayTableModel(qt.QAbstractTableModel): if index.isValid() and role == qt.Qt.EditRole: try: # cast value to same type as array - v = numpy.asscalar( - numpy.array(value, dtype=self._array.dtype)) + v = numpy.array(value, dtype=self._array.dtype).item() except ValueError: return False diff --git a/silx/gui/data/DataViews.py b/silx/gui/data/DataViews.py index 6575d0d..664090d 100644 --- a/silx/gui/data/DataViews.py +++ b/silx/gui/data/DataViews.py @@ -893,53 +893,27 @@ class _Plot3dView(DataView): label="Cube", icon=icons.getQIcon("view-3d")) try: - import silx.gui.plot3d #noqa + from ._VolumeWindow import VolumeWindow # noqa except ImportError: - _logger.warning("Plot3dView is not available") + _logger.warning("3D visualization is not available") _logger.debug("Backtrace", exc_info=True) raise self.__resetZoomNextTime = True def createWidget(self, parent): - from silx.gui.plot3d import ScalarFieldView - from silx.gui.plot3d import SFViewParamTree + from ._VolumeWindow import VolumeWindow - plot = ScalarFieldView.ScalarFieldView(parent) + plot = VolumeWindow(parent) plot.setAxesLabels(*reversed(self.axesNames(None, None))) - - def computeIsolevel(data): - data = data[numpy.isfinite(data)] - if len(data) == 0: - return 0 - else: - return numpy.mean(data) + numpy.std(data) - - plot.addIsosurface(computeIsolevel, '#FF0000FF') - - # Create a parameter tree for the scalar field view - options = SFViewParamTree.TreeView(plot) - options.setSfView(plot) - - # Add the parameter tree to the main window in a dock widget - dock = qt.QDockWidget() - dock.setWidget(options) - plot.addDockWidget(qt.Qt.RightDockWidgetArea, dock) - return plot def clear(self): - self.getWidget().setData(None) + self.getWidget().clear() self.__resetZoomNextTime = True - def normalizeData(self, data): - data = DataView.normalizeData(self, data) - data = _normalizeComplex(data) - return data - def setData(self, data): data = self.normalizeData(data) - plot = self.getWidget() - plot.setData(data) + self.getWidget().setData(data) self.__resetZoomNextTime = False def axesNames(self, data, info): @@ -973,10 +947,10 @@ class _ComplexImageView(DataView): def createWidget(self, parent): from silx.gui.plot.ComplexImageView import ComplexImageView widget = ComplexImageView(parent=parent) - widget.setColormap(self.defaultColormap(), mode=ComplexImageView.Mode.ABSOLUTE) - widget.setColormap(self.defaultColormap(), mode=ComplexImageView.Mode.SQUARE_AMPLITUDE) - widget.setColormap(self.defaultColormap(), mode=ComplexImageView.Mode.REAL) - widget.setColormap(self.defaultColormap(), mode=ComplexImageView.Mode.IMAGINARY) + widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.ABSOLUTE) + widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.SQUARE_AMPLITUDE) + widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.REAL) + widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.IMAGINARY) widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog()) widget.getPlot().getIntensityHistogramAction().setVisible(True) widget.getPlot().setKeepDataAspectRatio(True) diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py index e5a2550..c3aefd3 100644 --- a/silx/gui/data/NXdataWidgets.py +++ b/silx/gui/data/NXdataWidgets.py @@ -29,7 +29,6 @@ __license__ = "MIT" __date__ = "12/11/2018" import logging -import numbers import numpy from silx.gui import qt @@ -533,10 +532,10 @@ class ArrayComplexImagePlot(qt.QWidget): self._plot = ComplexImageView(self) if colormap is not None: - for mode in (ComplexImageView.Mode.ABSOLUTE, - ComplexImageView.Mode.SQUARE_AMPLITUDE, - ComplexImageView.Mode.REAL, - ComplexImageView.Mode.IMAGINARY): + for mode in (ComplexImageView.ComplexMode.ABSOLUTE, + ComplexImageView.ComplexMode.SQUARE_AMPLITUDE, + ComplexImageView.ComplexMode.REAL, + ComplexImageView.ComplexMode.IMAGINARY): self._plot.setColormap(colormap, mode) self._plot.getPlot().getIntensityHistogramAction().setVisible(True) @@ -893,28 +892,9 @@ class ArrayVolumePlot(qt.QWidget): self.__x_axis = None self.__x_axis_name = None - from silx.gui.plot3d.ScalarFieldView import ScalarFieldView - from silx.gui.plot3d import SFViewParamTree + from ._VolumeWindow import VolumeWindow - self._view = ScalarFieldView(self) - - def computeIsolevel(data): - data = data[numpy.isfinite(data)] - if len(data) == 0: - return 0 - else: - return numpy.mean(data) + numpy.std(data) - - self._view.addIsosurface(computeIsolevel, '#FF0000FF') - - # Create a parameter tree for the scalar field view - options = SFViewParamTree.TreeView(self._view) - options.setSfView(self._view) - - # Add the parameter tree to the main window in a dock widget - dock = qt.QDockWidget() - dock.setWidget(options) - self._view.addDockWidget(qt.Qt.RightDockWidgetArea, dock) + self._view = VolumeWindow(self) self._hline = qt.QFrame(self) self._hline.setFrameStyle(qt.QFrame.HLine) @@ -935,24 +915,10 @@ class ArrayVolumePlot(qt.QWidget): def getVolumeView(self): """Returns the plot used for the display - :rtype: ScalarFieldView + :rtype: SceneWindow """ return self._view - def normalizeComplexData(self, data): - """ - Converts a complex data array to its amplitude, if necessary. - :param data: the data to normalize - :return: - """ - if hasattr(data, "dtype"): - isComplex = numpy.issubdtype(data.dtype, numpy.complexfloating) - else: - isComplex = isinstance(data, numbers.Complex) - if isComplex: - data = numpy.absolute(data) - return data - def setData(self, signal, x_axis=None, y_axis=None, z_axis=None, signal_name=None, @@ -977,7 +943,6 @@ class ArrayVolumePlot(qt.QWidget): :param zlabel: Label for Z axis :param title: Graph title """ - signal = self.normalizeComplexData(signal) if self.__selector_is_connected: self._selector.selectionChanged.disconnect(self._updateVolume) self.__selector_is_connected = False @@ -994,9 +959,6 @@ class ArrayVolumePlot(qt.QWidget): self._selector.setData(signal) self._selector.setAxisNames(["Y", "X", "Z"]) - self._view.setAxesLabels(self.__x_axis_name or 'X', - self.__y_axis_name or 'Y', - self.__z_axis_name or 'Z') self._updateVolume() # the legend label shows the selection slice producing the volume @@ -1017,7 +979,6 @@ class ArrayVolumePlot(qt.QWidget): def _updateVolume(self): """Update displayed stack according to the current axes selector data.""" - data = self._selector.selectedData() x_axis = self.__x_axis y_axis = self.__y_axis z_axis = self.__z_axis @@ -1049,15 +1010,16 @@ class ArrayVolumePlot(qt.QWidget): legend = legend[:-2] + "]" self._legend.setText("Displayed data: " + legend) - self._view.setData(data, copy=False) - self._view.setScale(*scale) - self._view.setTranslation(*offset) - self._view.setAxesLabels(self.__x_axis_name, - self.__y_axis_name, - self.__z_axis_name) + # Update SceneWidget + data = self._selector.selectedData() + + volumeView = self.getVolumeView() + volumeView.setData(data, offset=offset, scale=scale) + volumeView.setAxesLabels( + self.__x_axis_name, self.__y_axis_name, self.__z_axis_name) def clear(self): old = self._selector.blockSignals(True) self._selector.clear() self._selector.blockSignals(old) - self._view.setData(None) + self.getVolumeView().clear() diff --git a/silx/gui/data/_VolumeWindow.py b/silx/gui/data/_VolumeWindow.py new file mode 100644 index 0000000..03b6876 --- /dev/null +++ b/silx/gui/data/_VolumeWindow.py @@ -0,0 +1,148 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a widget to visualize 3D arrays""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "22/03/2019" + + +import numpy + +from .. import qt +from ..plot3d.SceneWindow import SceneWindow +from ..plot3d.items import ScalarField3D, ComplexField3D, ItemChangedType + + +class VolumeWindow(SceneWindow): + """Extends SceneWindow with a convenient API for 3D array + + :param QWidget: parent + """ + + def __init__(self, parent): + super(VolumeWindow, self).__init__(parent) + self.__firstData = True + # Hide global parameter dock + self.getGroupResetWidget().parent().setVisible(False) + + def setAxesLabels(self, xlabel=None, ylabel=None, zlabel=None): + """Set the text labels of the axes. + + :param Union[str,None] xlabel: Label of the X axis + :param Union[str,None] ylabel: Label of the Y axis + :param Union[str,None] zlabel: Label of the Z axis + """ + sceneWidget = self.getSceneWidget() + sceneWidget.getSceneGroup().setAxesLabels( + 'X' if xlabel is None else xlabel, + 'Y' if ylabel is None else ylabel, + 'Z' if zlabel is None else zlabel) + + def clear(self): + """Clear any currently displayed data""" + sceneWidget = self.getSceneWidget() + items = sceneWidget.getItems() + if (len(items) == 1 and + isinstance(items[0], (ScalarField3D, ComplexField3D))): + items[0].setData(None) + else: # Safety net + sceneWidget.clearItems() + + @staticmethod + def __computeIsolevel(data): + """Returns a suitable isolevel value for data + + :param numpy.ndarray data: + :rtype: float + """ + data = data[numpy.isfinite(data)] + if len(data) == 0: + return 0 + else: + return numpy.mean(data) + numpy.std(data) + + def setData(self, data, offset=(0., 0., 0.), scale=(1., 1., 1.)): + """Set the 3D array data to display. + + :param numpy.ndarray data: 3D array of float or complex + :param List[float] offset: (tx, ty, tz) coordinates of the origin + :param List[float] scale: (sx, sy, sz) scale for each dimension + """ + sceneWidget = self.getSceneWidget() + dataMaxCoords = numpy.array(list(reversed(data.shape))) - 1 + + previousItems = sceneWidget.getItems() + if (len(previousItems) == 1 and + isinstance(previousItems[0], (ScalarField3D, ComplexField3D)) and + numpy.iscomplexobj(data) == isinstance(previousItems[0], ComplexField3D)): + # Reuse existing volume item + volume = sceneWidget.getItems()[0] + volume.setData(data, copy=False) + # Make sure the plane goes through the dataset + for plane in volume.getCutPlanes(): + point = numpy.array(plane.getPoint()) + if numpy.any(point < (0, 0, 0)) or numpy.any(point > dataMaxCoords): + plane.setPoint(dataMaxCoords // 2) + else: + # Add a new volume + sceneWidget.clearItems() + volume = sceneWidget.addVolume(data, copy=False) + volume.setLabel('Volume') + for plane in volume.getCutPlanes(): + # Make plane going through the center of the data + plane.setPoint(dataMaxCoords // 2) + plane.setVisible(False) + plane.sigItemChanged.connect(self.__cutPlaneUpdated) + volume.addIsosurface(self.__computeIsolevel, '#FF0000FF') + + # Expand the parameter tree + model = self.getParamTreeView().model() + index = qt.QModelIndex() # Invalid index for top level + while 1: + rowCount = model.rowCount(parent=index) + if rowCount == 0: + break + index = model.index(rowCount - 1, 0, parent=index) + self.getParamTreeView().setExpanded(index, True) + if not index.isValid(): + break + + volume.setTranslation(*offset) + volume.setScale(*scale) + + if self.__firstData: # Only center for first dataset + self.__firstData = False + sceneWidget.centerScene() + + def __cutPlaneUpdated(self, event): + """Handle the change of visibility of the cut plane + + :param event: Kind of update + """ + if event == ItemChangedType.VISIBLE: + plane = self.sender() + if plane.isVisible(): + self.getSceneWidget().selection().setCurrentItem(plane) diff --git a/silx/gui/data/test/test_dataviewer.py b/silx/gui/data/test/test_dataviewer.py index dc6fee8..12a640e 100644 --- a/silx/gui/data/test/test_dataviewer.py +++ b/silx/gui/data/test/test_dataviewer.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -242,8 +242,8 @@ class AbstractDataViewerTests(TestCaseQt): self.assertTrue(replaced) nxdata_view = widget.getViewFromModeId(DataViews.NXDATA_MODE) self.assertNotIn(DataViews.NXDATA_INVALID_MODE, - [v.modeId() for v in nxdata_view.availableViews()]) - self.assertTrue(view in nxdata_view.availableViews()) + [v.modeId() for v in nxdata_view.getViews()]) + self.assertTrue(view in nxdata_view.getViews()) class TestDataViewer(AbstractDataViewerTests): diff --git a/silx/gui/data/test/test_textformatter.py b/silx/gui/data/test/test_textformatter.py index 935344a..1a63074 100644 --- a/silx/gui/data/test/test_textformatter.py +++ b/silx/gui/data/test/test_textformatter.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -48,7 +48,7 @@ class TestTextFormatter(TestCaseQt): self.assertIsNot(formatter, copy) copy.setFloatFormat("%.3f") self.assertEqual(formatter.integerFormat(), copy.integerFormat()) - self.assertNotEquals(formatter.floatFormat(), copy.floatFormat()) + self.assertNotEqual(formatter.floatFormat(), copy.floatFormat()) self.assertEqual(formatter.useQuoteForText(), copy.useQuoteForText()) self.assertEqual(formatter.imaginaryUnit(), copy.imaginaryUnit()) diff --git a/silx/gui/dialog/AbstractDataFileDialog.py b/silx/gui/dialog/AbstractDataFileDialog.py index c660cd7..29e7bb5 100644 --- a/silx/gui/dialog/AbstractDataFileDialog.py +++ b/silx/gui/dialog/AbstractDataFileDialog.py @@ -28,7 +28,7 @@ This module contains an :class:`AbstractDataFileDialog`. __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "03/12/2018" +__date__ = "05/03/2019" import sys @@ -468,9 +468,13 @@ class _FabioData(object): def shape(self): if self.__fabioFile.nframes == 0: return None + if self.__fabioFile.nframes == 1: + return [slice(None), slice(None)] return [self.__fabioFile.nframes, slice(None), slice(None)] def __getitem__(self, selector): + if self.__fabioFile.nframes == 1 and selector == tuple(): + return self.__fabioFile.data if isinstance(selector, tuple) and len(selector) == 1: selector = selector[0] @@ -542,6 +546,10 @@ class AbstractDataFileDialog(qt.QDialog): def _init(self): self.setWindowTitle("Open") + self.__openedFiles = [] + """Store the list of files opened by the model itself.""" + # FIXME: It should be managed one by one by Hdf5Item itself + self.__directory = None self.__directoryLoadedFilter = None self.__errorWhileLoadingFile = None @@ -591,10 +599,6 @@ class AbstractDataFileDialog(qt.QDialog): self.__fileTypeCombo.setCurrentIndex(0) self.__filterSelected(0) - self.__openedFiles = [] - """Store the list of files opened by the model itself.""" - # FIXME: It should be managed one by one by Hdf5Item itself - # It is not possible to override the QObject destructor nor # to access to the content of the Python object with the `destroyed` # signal cause the Python method was already removed with the QWidget, @@ -1038,15 +1042,16 @@ class AbstractDataFileDialog(qt.QDialog): return self.__directoryLoadedFilter = path self.__processing += 1 + if self.__fileModel is None: + return index = self.__fileModel.setRootPath(path) if not index.isValid(): + # There is a problem with this path + # No asynchronous process will be waked up self.__processing -= 1 self.__browser.setRootIndex(index, model=self.__fileModel) self.__clearData() self.__updatePath() - else: - # asynchronous process - pass def __directoryLoaded(self, path): if self.__directoryLoadedFilter is not None: @@ -1055,6 +1060,8 @@ class AbstractDataFileDialog(qt.QDialog): # The first click on the sidebar sent 2 events self.__processing -= 1 return + if self.__fileModel is None: + return index = self.__fileModel.index(path) self.__browser.setRootIndex(index, model=self.__fileModel) self.__updatePath() @@ -1233,6 +1240,7 @@ class AbstractDataFileDialog(qt.QDialog): if self.__previewWidget is not None: self.__previewWidget.setData(None) if self.__selectorWidget is not None: + self.__selectorWidget.setData(None) self.__selectorWidget.hide() self.__selectedData = None self.__data = None @@ -1250,6 +1258,8 @@ class AbstractDataFileDialog(qt.QDialog): If :meth:`_isDataSupported` returns false, this function will be inhibited and no data will be selected. """ + if isinstance(data, _FabioData): + data = data[()] if self.__previewWidget is not None: fromDataSelector = self.__selectedData is not None self.__previewWidget.setData(data, fromDataSelector=fromDataSelector) @@ -1317,8 +1327,10 @@ class AbstractDataFileDialog(qt.QDialog): filename = "" dataPath = None - if useSelectorWidget and self.__selectorWidget is not None and self.__selectorWidget.isVisible(): + if useSelectorWidget and self.__selectorWidget is not None and self.__selectorWidget.isUsed(): slicing = self.__selectorWidget.slicing() + if slicing == tuple(): + slicing = None else: slicing = None @@ -1483,9 +1495,7 @@ class AbstractDataFileDialog(qt.QDialog): self.__clearData() if self.__selectorWidget is not None: - self.__selectorWidget.setVisible(url.data_slice() is not None) - if url.data_slice() is not None: - self.__selectorWidget.setSlicing(url.data_slice()) + self.__selectorWidget.selectSlicing(url.data_slice()) else: self.__errorWhileLoadingFile = (url.file_path(), "File not found") self.__clearData() diff --git a/silx/gui/dialog/ColormapDialog.py b/silx/gui/dialog/ColormapDialog.py index 9950ad4..9c956f8 100644 --- a/silx/gui/dialog/ColormapDialog.py +++ b/silx/gui/dialog/ColormapDialog.py @@ -661,8 +661,7 @@ class ColormapDialog(qt.QDialog): dataRange = None if dataRange is not None: - min_positive = dataRange.min_positive - dataRange = dataRange.minimum, min_positive, dataRange.maximum + dataRange = dataRange.minimum, dataRange.min_positive, dataRange.maximum if dataRange is None or len(dataRange) != 3: qt.QMessageBox.warning( diff --git a/silx/gui/dialog/ImageFileDialog.py b/silx/gui/dialog/ImageFileDialog.py index ef6b472..d015bd2 100644 --- a/silx/gui/dialog/ImageFileDialog.py +++ b/silx/gui/dialog/ImageFileDialog.py @@ -28,7 +28,7 @@ This module contains an :class:`ImageFileDialog`. __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "12/02/2018" +__date__ = "05/03/2019" import logging from silx.gui.plot import actions @@ -36,7 +36,6 @@ from silx.gui import qt from silx.gui.plot.PlotWidget import PlotWidget from .AbstractDataFileDialog import AbstractDataFileDialog import silx.io -import fabio _logger = logging.getLogger(__name__) @@ -61,7 +60,7 @@ class _ImageSelection(qt.QWidget): def isUsed(self): if self.__shape is None: - return None + return False return len(self.__shape) > 2 def getSelectedData(self, data): @@ -70,6 +69,10 @@ class _ImageSelection(qt.QWidget): return image def setData(self, data): + if data is None: + self.__visibleSliders = 0 + return + shape = data.shape if self.__shape is not None: # clean up @@ -114,6 +117,22 @@ class _ImageSelection(qt.QWidget): break self.__axis[i].setValue(value) + def selectSlicing(self, slicing): + """Select a slicing. + + The provided value could be unconsistent and therefore is not supposed + to be retrivable with a getter. + + :param Union[None,Tuple[int]] slicing: + """ + if slicing is None: + # Create a default slicing + needed = self.__visibleSliders + slicing = (0,) * needed + if len(slicing) < self.__visibleSliders: + slicing = slicing + (0,) * (self.__visibleSliders - len(slicing)) + self.setSlicing(slicing) + class _ImagePreview(qt.QWidget): """Provide a preview of the selected image""" diff --git a/silx/gui/dialog/test/test_colormapdialog.py b/silx/gui/dialog/test/test_colormapdialog.py index cbc9de1..8dad196 100644 --- a/silx/gui/dialog/test/test_colormapdialog.py +++ b/silx/gui/dialog/test/test_colormapdialog.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -41,10 +41,6 @@ from silx.gui.plot.PlotWindow import PlotWindow import numpy.random -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - - class TestColormapDialog(TestCaseQt, ParametricTestCase): """Test the ColormapDialog.""" def setUp(self): diff --git a/silx/gui/dialog/test/test_datafiledialog.py b/silx/gui/dialog/test/test_datafiledialog.py index 06f8961..b60ea12 100644 --- a/silx/gui/dialog/test/test_datafiledialog.py +++ b/silx/gui/dialog/test/test_datafiledialog.py @@ -26,7 +26,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "05/10/2018" +__date__ = "08/03/2019" import unittest @@ -130,7 +130,7 @@ class _UtilsMixin(object): path2_ = os.path.normcase(path2) if path1_ == path2_: # Use the unittest API to log and display error - self.assertNotEquals(path1, path2) + self.assertNotEqual(path1, path2) class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): @@ -385,7 +385,7 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): filename = _tmpDirectory + "/singleimage.edf" url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/scan_0/instrument/detector_0/data") dialog.selectUrl(url.path()) - self.assertTrue(dialog._selectedData().shape, (100, 100)) + self.assertEqual(dialog._selectedData().shape, (100, 100)) self.assertSamePath(dialog.selectedFile(), filename) self.assertSamePath(dialog.selectedUrl(), url.path()) @@ -399,7 +399,7 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path() dialog.selectUrl(path) # test - self.assertTrue(dialog._selectedData().shape, (100, 100)) + self.assertEqual(dialog._selectedData().shape, (100, 100)) self.assertSamePath(dialog.selectedFile(), filename) self.assertSamePath(dialog.selectedUrl(), path) @@ -479,11 +479,12 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.qWaitForPendingActions(dialog) browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] filename = _tmpDirectory + "/badformat.h5" - index = browser.rootIndex().model().index(filename) + index = browser.model().index(filename) + browser.selectIndex(index) browser.activated.emit(index) self.qWaitForPendingActions(dialog) # test - self.assertTrue(dialog.selectedUrl(), filename) + self.assertSamePath(dialog.selectedUrl(), filename) def _countSelectableItems(self, model, rootIndex): selectable = 0 @@ -853,7 +854,7 @@ class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin): dialog2 = self.createDialog() result = dialog2.restoreState(state) self.assertTrue(result) - self.assertNotEquals(dialog2.directory(), directory) + self.assertNotEqual(dialog2.directory(), directory) def testHistory(self): dialog = self.createDialog() diff --git a/silx/gui/dialog/test/test_imagefiledialog.py b/silx/gui/dialog/test/test_imagefiledialog.py index 068dcb9..c019afb 100644 --- a/silx/gui/dialog/test/test_imagefiledialog.py +++ b/silx/gui/dialog/test/test_imagefiledialog.py @@ -26,7 +26,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "05/10/2018" +__date__ = "08/03/2019" import unittest @@ -70,24 +70,24 @@ def setUpModule(): image.write(filename) filename = _tmpDirectory + "/data.h5" - f = h5py.File(filename, "w") - f["scalar"] = 10 - f["image"] = data - f["cube"] = [data, data + 1, data + 2] - f["complex_image"] = data * 1j - f["group/image"] = data - f.close() + with h5py.File(filename, "w") as f: + f["scalar"] = 10 + f["image"] = data + f["cube"] = [data, data + 1, data + 2] + f["single_frame"] = [data + 5] + f["complex_image"] = data * 1j + f["group/image"] = data directory = os.path.join(_tmpDirectory, "data") os.mkdir(directory) filename = os.path.join(directory, "data.h5") - f = h5py.File(filename, "w") - f["scalar"] = 10 - f["image"] = data - f["cube"] = [data, data + 1, data + 2] - f["complex_image"] = data * 1j - f["group/image"] = data - f.close() + with h5py.File(filename, "w") as f: + f["scalar"] = 10 + f["image"] = data + f["cube"] = [data, data + 1, data + 2] + f["single_frame"] = [data + 5] + f["complex_image"] = data * 1j + f["group/image"] = data filename = _tmpDirectory + "/badformat.edf" with io.open(filename, "wb") as f: @@ -137,7 +137,7 @@ class _UtilsMixin(object): path2_ = os.path.normcase(path2) if path1_ == path2_: # Use the unittest API to log and display error - self.assertNotEquals(path1, path2) + self.assertNotEqual(path1, path2) class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): @@ -373,7 +373,7 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): filename = _tmpDirectory + "/singleimage.edf" path = filename dialog.selectUrl(path) - self.assertTrue(dialog.selectedImage().shape, (100, 100)) + self.assertEqual(dialog.selectedImage().shape, (100, 100)) self.assertSamePath(dialog.selectedFile(), filename) path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path() self.assertSamePath(dialog.selectedUrl(), path) @@ -396,7 +396,7 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): browser.activated.emit(index) self.qWaitForPendingActions(dialog) # test - self.assertTrue(dialog.selectedImage().shape, (100, 100)) + self.assertEqual(dialog.selectedImage().shape, (100, 100)) self.assertSamePath(dialog.selectedFile(), filename) self.assertSamePath(dialog.selectedUrl(), path) @@ -411,8 +411,8 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): dialog.selectUrl(path) # test image = dialog.selectedImage() - self.assertTrue(image.shape, (100, 100)) - self.assertTrue(image[0, 0], 1) + self.assertEqual(image.shape, (100, 100)) + self.assertEqual(image[0, 0], 1) self.assertSamePath(dialog.selectedFile(), filename) self.assertSamePath(dialog.selectedUrl(), path) @@ -426,7 +426,7 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path() dialog.selectUrl(path) # test - self.assertTrue(dialog.selectedImage().shape, (100, 100)) + self.assertEqual(dialog.selectedImage().shape, (100, 100)) self.assertSamePath(dialog.selectedFile(), filename) self.assertSamePath(dialog.selectedUrl(), path) @@ -440,7 +440,7 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path() dialog.selectUrl(path) # test - self.assertTrue(dialog.selectedImage().shape, (100, 100)) + self.assertEqual(dialog.selectedImage().shape, (100, 100)) self.assertSamePath(dialog.selectedFile(), filename) self.assertSamePath(dialog.selectedUrl(), path) @@ -474,8 +474,23 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/cube", data_slice=(1, )).path() dialog.selectUrl(path) # test - self.assertTrue(dialog.selectedImage().shape, (100, 100)) - self.assertTrue(dialog.selectedImage()[0, 0], 1) + self.assertEqual(dialog.selectedImage().shape, (100, 100)) + self.assertEqual(dialog.selectedImage()[0, 0], 1) + self.assertSamePath(dialog.selectedFile(), filename) + self.assertSamePath(dialog.selectedUrl(), path) + + def testSelectSingleFrameFromH5(self): + dialog = self.createDialog() + dialog.show() + self.qWaitForWindowExposed(dialog) + + # init state + filename = _tmpDirectory + "/data.h5" + path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/single_frame", data_slice=(0, )).path() + dialog.selectUrl(path) + # test + self.assertEqual(dialog.selectedImage().shape, (100, 100)) + self.assertEqual(dialog.selectedImage()[0, 0], 5) self.assertSamePath(dialog.selectedFile(), filename) self.assertSamePath(dialog.selectedUrl(), path) @@ -489,11 +504,12 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin): self.qWaitForPendingActions(dialog) browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0] filename = _tmpDirectory + "/badformat.edf" - index = browser.rootIndex().model().index(filename) + index = browser.model().index(filename) + browser.selectIndex(index) browser.activated.emit(index) self.qWaitForPendingActions(dialog) # test - self.assertTrue(dialog.selectedUrl(), filename) + self.assertSamePath(dialog.selectedUrl(), filename) def _countSelectableItems(self, model, rootIndex): selectable = 0 @@ -549,7 +565,7 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin): result = dialog2.restoreState(state) self.qWaitForPendingActions(dialog2) self.assertTrue(result) - self.assertTrue(dialog2.colormap().getNormalization(), "log") + self.assertEqual(dialog2.colormap().getNormalization(), "log") def printState(self): """ @@ -646,7 +662,7 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin): result = dialog.restoreState(state) self.assertTrue(result) colormap = dialog.colormap() - self.assertTrue(colormap.getNormalization(), "log") + self.assertEqual(colormap.getNormalization(), "log") def testRestoreRobusness(self): """What's happen if you try to open a config file with a different @@ -672,7 +688,7 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin): dialog2 = self.createDialog() result = dialog2.restoreState(state) self.assertTrue(result) - self.assertNotEquals(dialog2.directory(), directory) + self.assertNotEqual(dialog2.directory(), directory) def testHistory(self): dialog = self.createDialog() diff --git a/silx/gui/hdf5/test/test_hdf5.py b/silx/gui/hdf5/test/test_hdf5.py index f22d4ae..0ab4dc4 100644 --- a/silx/gui/hdf5/test/test_hdf5.py +++ b/silx/gui/hdf5/test/test_hdf5.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -224,7 +224,7 @@ class TestHdf5TreeModel(TestCaseQt): def testSupportedDrop(self): model = hdf5.Hdf5TreeModel() - self.assertNotEquals(model.supportedDropActions(), 0) + self.assertNotEqual(model.supportedDropActions(), 0) model.setFileMoveEnabled(False) model.setFileDropEnabled(False) @@ -232,11 +232,11 @@ class TestHdf5TreeModel(TestCaseQt): model.setFileMoveEnabled(False) model.setFileDropEnabled(True) - self.assertNotEquals(model.supportedDropActions(), 0) + self.assertNotEqual(model.supportedDropActions(), 0) model.setFileMoveEnabled(True) model.setFileDropEnabled(False) - self.assertNotEquals(model.supportedDropActions(), 0) + self.assertNotEqual(model.supportedDropActions(), 0) def testCloseFile(self): """A file inserted as a filename is open and closed internally.""" diff --git a/silx/gui/plot/CompareImages.py b/silx/gui/plot/CompareImages.py index f7c4899..3875be4 100644 --- a/silx/gui/plot/CompareImages.py +++ b/silx/gui/plot/CompareImages.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -48,7 +48,11 @@ _logger = logging.getLogger(__name__) from silx.opencl import ocl if ocl is not None: - from silx.opencl import sift + try: + from silx.opencl import sift + except ImportError: + # sift module is not available (e.g., in official Debian packages) + sift = None else: # No OpenCL device or no pyopencl sift = None @@ -62,6 +66,7 @@ class VisualizationMode(enum.Enum): HORIZONTAL_LINE = 'hline' COMPOSITE_RED_BLUE_GRAY = "rbgchannel" COMPOSITE_RED_BLUE_GRAY_NEG = "rbgnegchannel" + COMPOSITE_A_MINUS_B = "aminusb" @enum.unique @@ -161,6 +166,16 @@ class CompareImagesToolBar(qt.QToolBar): self.__ycChannelModeAction = action self.__visualizationGroup.addAction(action) + icon = icons.getQIcon("compare-mode-a-minus-b") + action = qt.QAction(icon, "Raw A minus B compare mode", self) + action.setIconVisibleInMenu(True) + action.setCheckable(True) + action.setShortcut(qt.QKeySequence(qt.Qt.Key_W)) + action.setProperty("mode", VisualizationMode.COMPOSITE_A_MINUS_B) + menu.addAction(action) + self.__ycChannelModeAction = action + self.__visualizationGroup.addAction(action) + menu = qt.QMenu(self) self.__alignmentAction = qt.QAction(self) self.__alignmentAction.setMenu(menu) @@ -539,6 +554,11 @@ class CompareImages(qt.QMainWindow): def __init__(self, parent=None, backend=None): qt.QMainWindow.__init__(self, parent) + self._resetZoomActive = True + self._colormap = Colormap() + """Colormap shared by all modes, except the compose images (rgb image)""" + self._colormapKeyPoints = Colormap('spring') + """Colormap used for sift keypoints""" if parent is None: self.setWindowTitle('Compare images') @@ -553,6 +573,7 @@ class CompareImages(qt.QMainWindow): self.__previousSeparatorPosition = None self.__plot = plot.PlotWidget(parent=self, backend=backend) + self.__plot.setDefaultColormap(self._colormap) self.__plot.getXAxis().setLabel('Columns') self.__plot.getYAxis().setLabel('Rows') if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward': @@ -630,6 +651,14 @@ class CompareImages(qt.QMainWindow): """ return self.__plot + def getColormap(self): + """ + + :return: colormap used for compare image + :rtype: silx.gui.colors.Colormap + """ + return self._colormap + def getRawPixelData(self, x, y): """Return the raw pixel of each image data from axes positions. @@ -835,7 +864,8 @@ class CompareImages(qt.QMainWindow): self.__raw1 = image1 self.__raw2 = image2 self.__updateData() - self.__plot.resetZoom() + if self.isAutoResetZoom(): + self.__plot.resetZoom() def setImage1(self, image1): """Set image1 to be compared. @@ -850,7 +880,8 @@ class CompareImages(qt.QMainWindow): """ self.__raw1 = image1 self.__updateData() - self.__plot.resetZoom() + if self.isAutoResetZoom(): + self.__plot.resetZoom() def setImage2(self, image2): """Set image2 to be compared. @@ -865,7 +896,8 @@ class CompareImages(qt.QMainWindow): """ self.__raw2 = image2 self.__updateData() - self.__plot.resetZoom() + if self.isAutoResetZoom(): + self.__plot.resetZoom() def __updateKeyPoints(self): """Update the displayed keypoints using cached keypoints. @@ -878,11 +910,11 @@ class CompareImages(qt.QMainWindow): y=data[1], z=1, value=data[2], - legend="keypoints", - colormap=Colormap("spring")) + colormap=self._colormapKeyPoints, + legend="keypoints") def __updateData(self): - """Compute aligned image when the alignement mode changes. + """Compute aligned image when the alignment mode changes. This function cache input images which are used when vertical/horizontal separators moves. @@ -943,6 +975,9 @@ class CompareImages(qt.QMainWindow): elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY: data1 = self.__composeImage(data1, data2, mode) data2 = numpy.empty((0, 0)) + elif mode == VisualizationMode.COMPOSITE_A_MINUS_B: + data1 = self.__composeImage(data1, data2, mode) + data2 = numpy.empty((0, 0)) elif mode == VisualizationMode.ONLY_A: data2 = numpy.empty((0, 0)) elif mode == VisualizationMode.ONLY_B: @@ -977,7 +1012,8 @@ class CompareImages(qt.QMainWindow): else: vmin = min(self.__data1.min(), self.__data2.min()) vmax = max(self.__data1.max(), self.__data2.max()) - colormap = Colormap(vmin=vmin, vmax=vmax) + colormap = self.getColormap() + colormap.setVRange(vmin=vmin, vmax=vmax) self.__image1.setColormap(colormap) self.__image2.setColormap(colormap) @@ -1025,6 +1061,13 @@ class CompareImages(qt.QMainWindow): :rtype: numpy.ndarray """ assert(data1.shape[0:2] == data2.shape[0:2]) + if mode == VisualizationMode.COMPOSITE_A_MINUS_B: + # TODO: this calculation has no interest of generating a 'composed' + # rgb image, this could be moved in an other function or doc + # should be modified + _type = data1.dtype + result = data1.astype(numpy.float64) - data2.astype(numpy.float64) + return result mode1 = self.__getImageMode(data1) if mode1 in ["rgb", "rgba"]: intensity1 = self.__luminosityImage(data1) @@ -1188,3 +1231,19 @@ class CompareImages(qt.QMainWindow): data2 = result["result"] self.__transformation = self.__toAffineTransformation(result) return data1, data2 + + def setAutoResetZoom(self, activate=True): + """ + + :param bool activate: True if we want to activate the automatic + plot reset zoom when setting images. + """ + self._resetZoomActive = activate + + def isAutoResetZoom(self): + """ + + :return: True if the automatic call to resetzoom is activated + :rtype: bool + """ + return self._resetZoomActive diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py index 2523cde..c8470ab 100644 --- a/silx/gui/plot/ComplexImageView.py +++ b/silx/gui/plot/ComplexImageView.py @@ -39,6 +39,7 @@ import logging import collections import numpy +from ...utils.deprecation import deprecated from .. import qt, icons from .PlotWindow import Plot2D from . import items @@ -170,16 +171,16 @@ class _ComplexDataToolButton(qt.QToolButton): """ _MODES = collections.OrderedDict([ - (ImageComplexData.Mode.ABSOLUTE, ('math-amplitude', 'Amplitude')), - (ImageComplexData.Mode.SQUARE_AMPLITUDE, + (ImageComplexData.ComplexMode.ABSOLUTE, ('math-amplitude', 'Amplitude')), + (ImageComplexData.ComplexMode.SQUARE_AMPLITUDE, ('math-square-amplitude', 'Square amplitude')), - (ImageComplexData.Mode.PHASE, ('math-phase', 'Phase')), - (ImageComplexData.Mode.REAL, ('math-real', 'Real part')), - (ImageComplexData.Mode.IMAGINARY, + (ImageComplexData.ComplexMode.PHASE, ('math-phase', 'Phase')), + (ImageComplexData.ComplexMode.REAL, ('math-real', 'Real part')), + (ImageComplexData.ComplexMode.IMAGINARY, ('math-imaginary', 'Imaginary part')), - (ImageComplexData.Mode.AMPLITUDE_PHASE, + (ImageComplexData.ComplexMode.AMPLITUDE_PHASE, ('math-phase-color', 'Amplitude and Phase')), - (ImageComplexData.Mode.LOG10_AMPLITUDE_PHASE, + (ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE, ('math-phase-color-log', 'Log10(Amp.) and Phase')) ]) @@ -208,7 +209,7 @@ class _ComplexDataToolButton(qt.QToolButton): self.setPopupMode(qt.QToolButton.InstantPopup) - self._modeChanged(self._plot2DComplex.getVisualizationMode()) + self._modeChanged(self._plot2DComplex.getComplexMode()) self._plot2DComplex.sigVisualizationModeChanged.connect( self._modeChanged) @@ -217,7 +218,8 @@ class _ComplexDataToolButton(qt.QToolButton): icon, text = self._MODES[mode] self.setIcon(icons.getQIcon(icon)) self.setToolTip('Display the ' + text.lower()) - self._rangeDialogAction.setEnabled(mode == ImageComplexData.Mode.LOG10_AMPLITUDE_PHASE) + self._rangeDialogAction.setEnabled( + mode == ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE) def _triggered(self, action): """Handle triggering of menu actions""" @@ -244,8 +246,8 @@ class _ComplexDataToolButton(qt.QToolButton): else: # update mode mode = action.data() - if isinstance(mode, ImageComplexData.Mode): - self._plot2DComplex.setVisualizationMode(mode) + if isinstance(mode, ImageComplexData.ComplexMode): + self._plot2DComplex.setComplexMode(mode) def _rangeChanged(self, range_): """Handle updates of range in the dialog""" @@ -258,8 +260,8 @@ class ComplexImageView(qt.QWidget): :param parent: See :class:`QMainWindow` """ - Mode = ImageComplexData.Mode - """Also expose the modes inside the class""" + ComplexMode = ImageComplexData.ComplexMode + """Complex Modes enumeration""" sigDataChanged = qt.Signal() """Signal emitted when data has changed.""" @@ -301,7 +303,7 @@ class ComplexImageView(qt.QWidget): if event is items.ItemChangedType.DATA: self.sigDataChanged.emit() elif event is items.ItemChangedType.VISUALIZATION_MODE: - mode = self.getVisualizationMode() + mode = self.getComplexMode() self.sigVisualizationModeChanged.emit(mode) def getPlot(self): @@ -344,15 +346,34 @@ class ComplexImageView(qt.QWidget): False to return internal data (do not modify!) :rtype: numpy.ndarray of float with 2 dims or RGBA image (uint8). """ - mode = self.getVisualizationMode() - if mode in (self.Mode.AMPLITUDE_PHASE, - self.Mode.LOG10_AMPLITUDE_PHASE): + mode = self.getComplexMode() + if mode in (self.ComplexMode.AMPLITUDE_PHASE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE): return self._plotImage.getRgbaImageData(copy=copy) else: return self._plotImage.getData(copy=copy) + # Backward compatibility + + Mode = ComplexMode + + @classmethod + @deprecated(replacement='supportedComplexModes', since_version='0.11.0') + def getSupportedVisualizationModes(cls): + return cls.supportedComplexModes() + + @deprecated(replacement='setComplexMode', since_version='0.11.0') + def setVisualizationMode(self, mode): + return self.setComplexMode(mode) + + @deprecated(replacement='getComplexMode', since_version='0.11.0') + def getVisualizationMode(self): + return self.getComplexMode() + + # Image item proxy + @staticmethod - def getSupportedVisualizationModes(): + def supportedComplexModes(): """Returns the supported visualization modes. Supported visualization modes are: @@ -365,31 +386,33 @@ class ComplexImageView(qt.QWidget): - log10_amplitude_phase: Color-coded phase with log10(amplitude) as alpha. - :rtype: List[Mode] + :rtype: List[ComplexMode] """ - return tuple(ImageComplexData.Mode) + return ImageComplexData.supportedComplexModes() - def setVisualizationMode(self, mode): + def setComplexMode(self, mode): """Set the mode of visualization of the complex data. - See :meth:`getSupportedVisualizationModes` for the list of + See :meth:`supportedComplexModes` for the list of supported modes. How-to change visualization mode:: widget = ComplexImageView() - widget.setVisualizationMode(ComplexImageView.Mode.PHASE) + widget.setComplexMode(ComplexImageView.ComplexMode.PHASE) + # or + widget.setComplexMode('phase') - :param Mode mode: The mode to use. + :param Unions[ComplexMode,str] mode: The mode to use. """ - self._plotImage.setVisualizationMode(mode) + self._plotImage.setComplexMode(mode) - def getVisualizationMode(self): + def getComplexMode(self): """Get the current visualization mode of the complex data. - :rtype: Mode + :rtype: ComplexMode """ - return self._plotImage.getVisualizationMode() + return self._plotImage.getComplexMode() def _setAmplitudeRangeInfo(self, max_=None, delta=2): """Set the amplitude range to display for 'log10_amplitude_phase' mode. @@ -407,8 +430,6 @@ class ComplexImageView(qt.QWidget): :rtype: 2-tuple""" return self._plotImage._getAmplitudeRangeInfo() - # Image item proxy - def setColormap(self, colormap, mode=None): """Set the colormap to use for amplitude, phase, real or imaginary. @@ -416,14 +437,14 @@ class ComplexImageView(qt.QWidget): amplitude and phase. :param ~silx.gui.colors.Colormap colormap: The colormap - :param Mode mode: If specified, set the colormap of this specific mode + :param ComplexMode mode: If specified, set the colormap of this specific mode """ self._plotImage.setColormap(colormap, mode) def getColormap(self, mode=None): """Returns the colormap used to display the data. - :param Mode mode: If specified, set the colormap of this specific mode + :param ComplexMode mode: If specified, set the colormap of this specific mode :rtype: ~silx.gui.colors.Colormap """ return self._plotImage.getColormap(mode=mode) diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py index b426a23..050b344 100644 --- a/silx/gui/plot/CurvesROIWidget.py +++ b/silx/gui/plot/CurvesROIWidget.py @@ -75,14 +75,19 @@ class CurvesROIWidget(qt.QWidget): """ sigROISignal = qt.Signal(object) - """Deprecated signal for backward compatibility with silx < 0.7. - Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal` - """ def __init__(self, parent=None, name=None, plot=None): super(CurvesROIWidget, self).__init__(parent) if name is not None: self.setWindowTitle(name) + self.__lastSigROISignal = None + """Store the last value emitted for the sigRoiSignal. In the case the + active curve change we need to add this extra step in order to make + sure we won't send twice the sigROISignal. + This come from the fact sigROISignal is connected to the + activeROIChanged signal which is emitted when raw and net counts + values are changing but are not embed in the sigROISignal. + """ assert plot is not None self._plotRef = weakref.ref(plot) self._showAllMarkers = False @@ -91,12 +96,12 @@ class CurvesROIWidget(qt.QWidget): layout = qt.QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) - ############## + self.headerLabel = qt.QLabel(self) self.headerLabel.setAlignment(qt.Qt.AlignHCenter) self.setHeader() layout.addWidget(self.headerLabel) - ############## + widgetAllCheckbox = qt.QWidget(parent=self) self._showAllCheckBox = qt.QCheckBox("show all ROI", parent=widgetAllCheckbox) @@ -106,14 +111,13 @@ class CurvesROIWidget(qt.QWidget): widgetAllCheckbox.layout().addWidget(spacer) widgetAllCheckbox.layout().addWidget(self._showAllCheckBox) layout.addWidget(widgetAllCheckbox) - ############## + self.roiTable = ROITable(self, plot=plot) rheight = self.roiTable.horizontalHeader().sizeHint().height() self.roiTable.setMinimumHeight(4 * rheight) layout.addWidget(self.roiTable) self._roiFileDir = qt.QDir.home().absolutePath() self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers) - ################# hbox = qt.QWidget(self) hboxlayout = qt.QHBoxLayout(hbox) @@ -216,7 +220,6 @@ class CurvesROIWidget(qt.QWidget): i += 1 newroi = "newroi %d" % i return newroi - roi = ROI(name=getNextRoiName()) if roi.getName() == "ICR": @@ -231,7 +234,6 @@ class CurvesROIWidget(qt.QWidget): fromdata, dummy0, todata, dummy1 = self._getAllLimits() roi.setFrom(fromdata) roi.setTo(todata) - self.roiTable.addRoi(roi) # back compatibility pymca roi signals @@ -257,7 +259,9 @@ class CurvesROIWidget(qt.QWidget): def _reset(self): """Reset button clicked handler""" self.roiTable.clear() + old = self.blockSignals(True) # avoid several sigROISignal emission self._add() + self.blockSignals(old) # back compatibility pymca roi signals ddict = {} @@ -402,7 +406,9 @@ class CurvesROIWidget(qt.QWidget): if visible: # if no ROI existing yet, add the default one if self.roiTable.rowCount() is 0: + old = self.blockSignals(True) # avoid several sigROISignal emission self._add() + self.blockSignals(old) self.calculateRois() def fillFromROIDict(self, *args, **kwargs): @@ -416,7 +422,10 @@ class CurvesROIWidget(qt.QWidget): ddict['current'] = self.roiTable.activeRoi.getName() else: ddict['current'] = None - self.sigROISignal.emit(ddict) + + if self.__lastSigROISignal != ddict: + self.__lastSigROISignal = ddict + self.sigROISignal.emit(ddict) @property def currentRoi(self): @@ -563,8 +572,11 @@ class ROITable(TableWidget): # backward compatibility since 0.10.0 if isinstance(rois, dict): for roiName, roi in rois.items(): - roi['name'] = roiName - _roi = ROI._fromDict(roi) + if isinstance(roi, ROI): + _roi = roi + else: + roi['name'] = roiName + _roi = ROI._fromDict(roi) self.addRoi(_roi) else: for roi in rois: @@ -688,12 +700,14 @@ class ROITable(TableWidget): activeItems = self.selectedItems() if len(activeItems) is 0: return + old = self.blockSignals(True) # avoid several emission of sigROISignal roiToRm = set() for item in activeItems: row = item.row() itemID = self.item(row, self.COLUMNS_INDEX['ID']) roiToRm.add(self._roiDict[int(itemID.text())]) [self.removeROI(roi) for roi in roiToRm] + self.blockSignals(old) self.setActiveRoi(None) def removeROI(self, roi): @@ -730,7 +744,10 @@ class ROITable(TableWidget): else: assert isinstance(roi, ROI) if roi and roi.getID() in self._roiToItems.keys(): + # avoid several call back to setActiveROI + old = self.blockSignals(True) self.selectRow(self._roiToItems[roi.getID()].row()) + self.blockSignals(old) self._markersHandler.setActiveRoi(roi) self.activeROIChanged.emit() @@ -931,9 +948,12 @@ class ROITable(TableWidget): if ddict['event'] == 'markerMoved': label = ddict['label'] roiID = self._markersHandler.getRoiID(markerID=label) - if roiID: + if roiID is not None: + # avoid several emission of sigROISignal + old = self.blockSignals(True) self._markersHandler.changePosition(markerID=label, x=ddict['x']) + self.blockSignals(old) self._updateRoiInfo(roiID) def showEvent(self, event): diff --git a/silx/gui/plot/PlotToolButtons.py b/silx/gui/plot/PlotToolButtons.py index bf6b8ce..cd1a43f 100644 --- a/silx/gui/plot/PlotToolButtons.py +++ b/silx/gui/plot/PlotToolButtons.py @@ -47,7 +47,7 @@ from .. import icons from .. import qt from ... import config -from .items import SymbolMixIn +from .items import SymbolMixIn, Scatter _logger = logging.getLogger(__name__) @@ -352,23 +352,22 @@ class ProfileToolButton(PlotToolButton): self._profileDimensionChanged(2) -class SymbolToolButton(PlotToolButton): - """A tool button with a drop-down menu to control symbol size and marker. + +class _SymbolToolButtonBase(PlotToolButton): + """Base class for PlotToolButton setting marker and size. :param parent: See QWidget :param plot: The `~silx.gui.plot.PlotWidget` to control """ def __init__(self, parent=None, plot=None): - super(SymbolToolButton, self).__init__(parent=parent, plot=plot) + super(_SymbolToolButtonBase, self).__init__(parent=parent, plot=plot) - self.setToolTip('Set symbol size and marker') - self.setIcon(icons.getQIcon('plot-symbols')) - - menu = qt.QMenu(self) - - # Size slider + def _addSizeSliderToMenu(self, menu): + """Add a slider to set size to the given menu + :param QMenu menu: + """ slider = qt.QSlider(qt.Qt.Horizontal) slider.setRange(1, 20) slider.setValue(config.DEFAULT_PLOT_SYMBOL_SIZE) @@ -378,10 +377,11 @@ class SymbolToolButton(PlotToolButton): widgetAction.setDefaultWidget(slider) menu.addAction(widgetAction) - menu.addSeparator() - - # Marker actions + def _addSymbolsToMenu(self, menu): + """Add symbols to the given menu + :param QMenu menu: + """ for marker, name in zip(SymbolMixIn.getSupportedSymbols(), SymbolMixIn.getSupportedSymbolNames()): action = qt.QAction(name, menu) @@ -390,9 +390,6 @@ class SymbolToolButton(PlotToolButton): functools.partial(self._markerChanged, marker)) menu.addAction(action) - self.setMenu(menu) - self.setPopupMode(qt.QToolButton.InstantPopup) - def _sizeChanged(self, value): """Manage slider value changed @@ -418,3 +415,78 @@ class SymbolToolButton(PlotToolButton): for item in plot._getItems(withhidden=True): if isinstance(item, SymbolMixIn): item.setSymbol(marker) + + +class SymbolToolButton(_SymbolToolButtonBase): + """A tool button with a drop-down menu to control symbol size and marker. + + :param parent: See QWidget + :param plot: The `~silx.gui.plot.PlotWidget` to control + """ + + def __init__(self, parent=None, plot=None): + super(SymbolToolButton, self).__init__(parent=parent, plot=plot) + + self.setToolTip('Set symbol size and marker') + self.setIcon(icons.getQIcon('plot-symbols')) + + menu = qt.QMenu(self) + self._addSizeSliderToMenu(menu) + menu.addSeparator() + self._addSymbolsToMenu(menu) + + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + +class ScatterVisualizationToolButton(_SymbolToolButtonBase): + """QToolButton to select the visualization mode of scatter plot + + :param parent: See QWidget + :param plot: The `~silx.gui.plot.PlotWidget` to control + """ + + def __init__(self, parent=None, plot=None): + super(ScatterVisualizationToolButton, self).__init__( + parent=parent, plot=plot) + + self.setToolTip( + 'Set scatter visualization mode, symbol marker and size') + self.setIcon(icons.getQIcon('eye')) + + menu = qt.QMenu(self) + + # Add visualization modes + + for mode in Scatter.supportedVisualizations(): + name = mode.value.capitalize() + action = qt.QAction(name, menu) + action.setCheckable(False) + action.triggered.connect( + functools.partial(self._visualizationChanged, mode)) + menu.addAction(action) + + menu.addSeparator() + + submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol") + self._addSymbolsToMenu(submenu) + + submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol Size") + self._addSizeSliderToMenu(submenu) + + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + def _visualizationChanged(self, mode): + """Handle change of visualization mode. + + :param ScatterVisualizationMixIn.Visualization mode: + The visualization mode to use for scatter + """ + plot = self.plot() + if plot is None: + return + + for item in plot._getItems(withhidden=True): + if isinstance(item, Scatter): + item.setVisualization(mode) diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py index cfe39fa..9b9b4d2 100644 --- a/silx/gui/plot/PlotWidget.py +++ b/silx/gui/plot/PlotWidget.py @@ -33,12 +33,20 @@ __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" __date__ = "21/12/2018" +import logging + +_logger = logging.getLogger(__name__) + from collections import OrderedDict, namedtuple +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc from contextlib import contextmanager import datetime as dt import itertools -import logging +import warnings import numpy @@ -46,8 +54,11 @@ import silx from silx.utils.weakref import WeakMethodProxy from silx.utils.property import classproperty from silx.utils.deprecation import deprecated -# Import matplotlib backend here to init matplotlib our way -from .backends.BackendMatplotlib import BackendMatplotlibQt +try: + # Import matplotlib now to init matplotlib our way + from . import matplotlib +except ImportError: + _logger.debug("matplotlib not available") from ..colors import Colormap from .. import colors @@ -64,7 +75,6 @@ from .. import qt from ._utils.panzoom import ViewConstraints from ...gui.plot._utils.dtime_ticklayout import timestamp -_logger = logging.getLogger(__name__) _COLORDICT = colors.COLORDICT @@ -287,33 +297,68 @@ class PlotWidget(qt.QMainWindow): self._foregroundColorsUpdated() self._backgroundColorsUpdated() + def __getBackendClass(self, backend): + """Returns backend class corresponding to backend. + + If multiple backends are provided, the first available one is used. + + :param Union[str,BackendBase,Iterable] backend: + The name of the backend or its class or an iterable of those. + :rtype: BackendBase + :raise ValueError: In case the backend is not supported + :raise RuntimeError: If a backend is not available + """ + if callable(backend): + return backend + + elif isinstance(backend, str): + backend = backend.lower() + if backend in ('matplotlib', 'mpl'): + try: + from .backends.BackendMatplotlib import \ + BackendMatplotlibQt as backendClass + except ImportError: + _logger.debug("Backtrace", exc_info=True) + raise ImportError("matplotlib backend is not available") + + elif backend in ('gl', 'opengl'): + try: + from .backends.BackendOpenGL import \ + BackendOpenGL as backendClass + except ImportError: + _logger.debug("Backtrace", exc_info=True) + raise ImportError("OpenGL backend is not available") + + elif backend == 'none': + from .backends.BackendBase import BackendBase as backendClass + + else: + raise ValueError("Backend not supported %s" % backend) + + return backendClass + + elif isinstance(backend, abc.Iterable): + for b in backend: + try: + return self.__getBackendClass(b) + except ImportError: + pass + else: # No backend was found + raise ValueError("No supported backend was found") + + raise ValueError("Backend not supported %s" % str(backend)) + def _setBackend(self, backend): - """Setup a new backend""" + """Setup a new backend + + :param backend: Either a str defining the backend to use + """ assert(self._backend is None) if backend is None: backend = silx.config.DEFAULT_PLOT_BACKEND - if hasattr(backend, "__call__"): - backend = backend(self, self) - - elif hasattr(backend, "lower"): - lowerCaseString = backend.lower() - if lowerCaseString in ("matplotlib", "mpl"): - backendClass = BackendMatplotlibQt - elif lowerCaseString in ('gl', 'opengl'): - from .backends.BackendOpenGL import BackendOpenGL - backendClass = BackendOpenGL - elif lowerCaseString == 'none': - from .backends.BackendBase import BackendBase as backendClass - else: - raise ValueError("Backend not supported %s" % backend) - backend = backendClass(self, self) - - else: - raise ValueError("Backend not supported %s" % str(backend)) - - self._backend = backend + self._backend = self.__getBackendClass(backend)(self, self) # TODO: Can be removed for silx 0.10 @staticmethod @@ -456,7 +501,7 @@ class PlotWidget(qt.QMainWindow): return qt.QColor.fromRgbF(*self._dataBackgroundColor) def setDataBackgroundColor(self, color): - """Set the background color of this widget. + """Set the background color of the plot area. Set to None or an invalid QColor to use the background color. @@ -499,16 +544,25 @@ class PlotWidget(qt.QMainWindow): if item.isVisible(): bounds = item.getBounds() if bounds is not None: - xMin = numpy.nanmin([xMin, bounds[0]]) - xMax = numpy.nanmax([xMax, bounds[1]]) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + xMin = numpy.nanmin([xMin, bounds[0]]) + xMax = numpy.nanmax([xMax, bounds[1]]) # Take care of right axis if (isinstance(item, items.YAxisMixIn) and item.getYAxis() == 'right'): - yMinRight = numpy.nanmin([yMinRight, bounds[2]]) - yMaxRight = numpy.nanmax([yMaxRight, bounds[3]]) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + yMinRight = numpy.nanmin([yMinRight, bounds[2]]) + yMaxRight = numpy.nanmax([yMaxRight, bounds[3]]) else: - yMinLeft = numpy.nanmin([yMinLeft, bounds[2]]) - yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]]) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + yMinLeft = numpy.nanmin([yMinLeft, bounds[2]]) + yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]]) def lGetRange(x, y): return None if numpy.isnan(x) and numpy.isnan(y) else (x, y) @@ -2665,9 +2719,11 @@ class PlotWidget(qt.QMainWindow): xmin, xmax = (1., 100.) if ranges.x is None else ranges.x ymin, ymax = (1., 100.) if ranges.y is None else ranges.y if ranges.yright is None: - ymin2, ymax2 = None, None + ymin2, ymax2 = ymin, ymax else: ymin2, ymax2 = ranges.yright + if ranges.y is None: + ymin, ymax = ranges.yright # Add margins around data inside the plot area newLimits = list(_utils.addMarginsToLimits( diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py index b44a512..a39430e 100644 --- a/silx/gui/plot/PlotWindow.py +++ b/silx/gui/plot/PlotWindow.py @@ -29,15 +29,19 @@ The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`. __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "21/12/2018" +__date__ = "12/04/2019" -import collections +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc import logging import weakref import silx from silx.utils.weakref import WeakMethodProxy from silx.utils.deprecation import deprecated +from silx.utils.proxy import docstring from . import PlotWidget from . import actions @@ -128,53 +132,53 @@ class PlotWindow(PlotWidget): self.group.setExclusive(False) self.resetZoomAction = self.group.addAction( - actions.control.ResetZoomAction(self)) + actions.control.ResetZoomAction(self, parent=self)) self.resetZoomAction.setVisible(resetzoom) self.addAction(self.resetZoomAction) - self.zoomInAction = actions.control.ZoomInAction(self) + self.zoomInAction = actions.control.ZoomInAction(self, parent=self) self.addAction(self.zoomInAction) - self.zoomOutAction = actions.control.ZoomOutAction(self) + self.zoomOutAction = actions.control.ZoomOutAction(self, parent=self) self.addAction(self.zoomOutAction) self.xAxisAutoScaleAction = self.group.addAction( - actions.control.XAxisAutoScaleAction(self)) + actions.control.XAxisAutoScaleAction(self, parent=self)) self.xAxisAutoScaleAction.setVisible(autoScale) self.addAction(self.xAxisAutoScaleAction) self.yAxisAutoScaleAction = self.group.addAction( - actions.control.YAxisAutoScaleAction(self)) + actions.control.YAxisAutoScaleAction(self, parent=self)) self.yAxisAutoScaleAction.setVisible(autoScale) self.addAction(self.yAxisAutoScaleAction) self.xAxisLogarithmicAction = self.group.addAction( - actions.control.XAxisLogarithmicAction(self)) + actions.control.XAxisLogarithmicAction(self, parent=self)) self.xAxisLogarithmicAction.setVisible(logScale) self.addAction(self.xAxisLogarithmicAction) self.yAxisLogarithmicAction = self.group.addAction( - actions.control.YAxisLogarithmicAction(self)) + actions.control.YAxisLogarithmicAction(self, parent=self)) self.yAxisLogarithmicAction.setVisible(logScale) self.addAction(self.yAxisLogarithmicAction) self.gridAction = self.group.addAction( - actions.control.GridAction(self, gridMode='both')) + actions.control.GridAction(self, gridMode='both', parent=self)) self.gridAction.setVisible(grid) self.addAction(self.gridAction) self.curveStyleAction = self.group.addAction( - actions.control.CurveStyleAction(self)) + actions.control.CurveStyleAction(self, parent=self)) self.curveStyleAction.setVisible(curveStyle) self.addAction(self.curveStyleAction) self.colormapAction = self.group.addAction( - actions.control.ColormapAction(self)) + actions.control.ColormapAction(self, parent=self)) self.colormapAction.setVisible(colormap) self.addAction(self.colormapAction) self.colorbarAction = self.group.addAction( - actions_control.ColorBarAction(self, self)) + actions_control.ColorBarAction(self, parent=self)) self.colorbarAction.setVisible(False) self.addAction(self.colorbarAction) self._colorbar.setVisible(False) @@ -194,18 +198,18 @@ class PlotWindow(PlotWidget): self.getMaskAction().setVisible(mask) self._intensityHistoAction = self.group.addAction( - actions_histogram.PixelIntensitiesHistoAction(self)) + actions_histogram.PixelIntensitiesHistoAction(self, parent=self)) self._intensityHistoAction.setVisible(False) self._medianFilter2DAction = self.group.addAction( - actions_medfilt.MedianFilter2DAction(self)) + actions_medfilt.MedianFilter2DAction(self, parent=self)) self._medianFilter2DAction.setVisible(False) self._medianFilter1DAction = self.group.addAction( - actions_medfilt.MedianFilter1DAction(self)) + actions_medfilt.MedianFilter1DAction(self, parent=self)) self._medianFilter1DAction.setVisible(False) - self.fitAction = self.group.addAction(actions_fit.FitAction(self)) + self.fitAction = self.group.addAction(actions_fit.FitAction(self, parent=self)) self.fitAction.setVisible(fit) self.addAction(self.fitAction) @@ -250,7 +254,7 @@ class PlotWindow(PlotWidget): hbox.addWidget(self.controlButton) if position: # Add PositionInfo widget to the bottom of the plot - if isinstance(position, collections.Iterable): + if isinstance(position, abc.Iterable): # Use position as a set of converters converters = position else: @@ -278,7 +282,7 @@ class PlotWindow(PlotWidget): parent=self, plot=self) self.addToolBar(self._interactiveModeToolBar) - self._toolbar = self._createToolBar(title='Plot', parent=None) + self._toolbar = self._createToolBar(title='Plot', parent=self) self.addToolBar(self._toolbar) self._outputToolBar = tools.OutputToolBar(parent=self, plot=self) @@ -292,24 +296,21 @@ class PlotWindow(PlotWidget): for action in toolbar.actions(): self.addAction(action) + @docstring(PlotWidget) def setBackgroundColor(self, color): super(PlotWindow, self).setBackgroundColor(color) self._updateColorBarBackground() - setBackgroundColor.__doc__ = PlotWidget.setBackgroundColor.__doc__ - + @docstring(PlotWidget) def setDataBackgroundColor(self, color): super(PlotWindow, self).setDataBackgroundColor(color) self._updateColorBarBackground() - setDataBackgroundColor.__doc__ = PlotWidget.setDataBackgroundColor.__doc__ - + @docstring(PlotWidget) def setForegroundColor(self, color): super(PlotWindow, self).setForegroundColor(color) self._updateColorBarBackground() - setForegroundColor.__doc__ = PlotWidget.setForegroundColor.__doc__ - def _updateColorBarBackground(self): """Update the colorbar background according to the state of the plot""" if self._isAxesDisplayed(): @@ -824,7 +825,9 @@ class Plot2D(PlotWindow): posInfo = [ ('X', lambda x, y: x), ('Y', lambda x, y: y), - ('Data', WeakMethodProxy(self._getImageValue))] + ('Data', WeakMethodProxy(self._getImageValue)), + ('Dims', WeakMethodProxy(self._getImageDims)), + ] super(Plot2D, self).__init__(parent=parent, backend=backend, resetzoom=True, autoScale=False, @@ -924,6 +927,15 @@ class Plot2D(PlotWindow): return value, "Masked" return value + def _getImageDims(self, *args): + activeImage = self.getActiveImage() + if (activeImage is not None and + activeImage.getData(copy=False) is not None): + dims = activeImage.getData(copy=False).shape[1::-1] + return 'x'.join(str(dim) for dim in dims) + else: + return '-' + def getProfileToolbar(self): """Profile tools attached to this plot diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py index 46e4523..e2aa5a7 100644 --- a/silx/gui/plot/Profile.py +++ b/silx/gui/plot/Profile.py @@ -28,7 +28,7 @@ and stacks of images""" __authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno"] __license__ = "MIT" -__date__ = "24/07/2018" +__date__ = "12/04/2019" import weakref @@ -419,39 +419,40 @@ class ProfileToolBar(qt.QToolBar): """ if self._profileWindow is None: - self._profileMainWindow = ProfileMainWindow(self) + backend = type(plot._backend) + self._profileMainWindow = ProfileMainWindow(self, backend=backend) # Actions self._browseAction = actions.mode.ZoomModeAction(self.plot, parent=self) self._browseAction.setVisible(False) - self.hLineAction = qt.QAction( - icons.getQIcon('shape-horizontal'), - 'Horizontal Profile Mode', None) + self.hLineAction = qt.QAction(icons.getQIcon('shape-horizontal'), + 'Horizontal Profile Mode', + self) self.hLineAction.setToolTip( 'Enables horizontal profile selection mode') self.hLineAction.setCheckable(True) self.hLineAction.toggled[bool].connect(self._hLineActionToggled) - self.vLineAction = qt.QAction( - icons.getQIcon('shape-vertical'), - 'Vertical Profile Mode', None) + self.vLineAction = qt.QAction(icons.getQIcon('shape-vertical'), + 'Vertical Profile Mode', + self) self.vLineAction.setToolTip( 'Enables vertical profile selection mode') self.vLineAction.setCheckable(True) self.vLineAction.toggled[bool].connect(self._vLineActionToggled) - self.lineAction = qt.QAction( - icons.getQIcon('shape-diagonal'), - 'Free Line Profile Mode', None) + self.lineAction = qt.QAction(icons.getQIcon('shape-diagonal'), + 'Free Line Profile Mode', + self) self.lineAction.setToolTip( 'Enables line profile selection mode') self.lineAction.setCheckable(True) self.lineAction.toggled[bool].connect(self._lineActionToggled) - self.clearAction = qt.QAction( - icons.getQIcon('profile-clear'), - 'Clear Profile', None) + self.clearAction = qt.QAction(icons.getQIcon('profile-clear'), + 'Clear Profile', + self) self.clearAction.setToolTip( 'Clear the profile Region of interest') self.clearAction.setCheckable(False) diff --git a/silx/gui/plot/ProfileMainWindow.py b/silx/gui/plot/ProfileMainWindow.py index caa076c..39830d8 100644 --- a/silx/gui/plot/ProfileMainWindow.py +++ b/silx/gui/plot/ProfileMainWindow.py @@ -35,8 +35,15 @@ __date__ = "21/02/2017" class ProfileMainWindow(qt.QMainWindow): """QMainWindow providing 2 plot widgets specialized in 1D and 2D plotting, with different toolbars. + Only one of the plots is visible at any given time. + + :param qt.QWidget parent: The parent of this widget or None (default). + :param Union[str,Class] backend: The backend to use, in: + 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none' + or a :class:`BackendBase.BackendBase` class """ + sigProfileDimensionsChanged = qt.Signal(int) """This signal is emitted when :meth:`setProfileDimensions` is called. It carries the number of dimensions for the profile data (1 or 2). @@ -51,13 +58,14 @@ class ProfileMainWindow(qt.QMainWindow): """Emitted when the method to compute the profile changed (for now can be sum or mean)""" - def __init__(self, parent=None): + def __init__(self, parent=None, backend=None): qt.QMainWindow.__init__(self, parent=parent) self.setWindowTitle('Profile window') # plots are created on demand, in self.setProfileDimensions() self._plot1D = None self._plot2D = None + self._backend = backend # by default, profile is assumed to be a 1D curve self._profileType = None self.setProfileType("1D") @@ -76,7 +84,7 @@ class ProfileMainWindow(qt.QMainWindow): if self._plot2D is not None: self._plot2D.setParent(None) # necessary to avoid widget destruction if self._plot1D is None: - self._plot1D = Plot1D() + self._plot1D = Plot1D(backend=self._backend) self._plot1D.setGraphYLabel('Profile') self._plot1D.setGraphXLabel('') self.setCentralWidget(self._plot1D) @@ -84,7 +92,7 @@ class ProfileMainWindow(qt.QMainWindow): if self._plot1D is not None: self._plot1D.setParent(None) # necessary to avoid widget destruction if self._plot2D is None: - self._plot2D = Plot2D() + self._plot2D = Plot2D(backend=self._backend) self.setCentralWidget(self._plot2D) else: raise ValueError("Profile type must be '1D' or '2D'") diff --git a/silx/gui/plot/ScatterView.py b/silx/gui/plot/ScatterView.py index 5fc66ef..1d015d4 100644 --- a/silx/gui/plot/ScatterView.py +++ b/silx/gui/plot/ScatterView.py @@ -47,6 +47,8 @@ from .ScatterMaskToolsWidget import ScatterMaskToolsWidget from ..widgets.BoxLayoutDockWidget import BoxLayoutDockWidget from .. import qt, icons +from ...utils.proxy import docstring +from ...utils.weakref import WeakMethodProxy _logger = logging.getLogger(__name__) @@ -92,10 +94,10 @@ class ScatterView(qt.QMainWindow): self.__pickingCache = None self._positionInfo = tools.PositionInfo( plot=plot, - converters=(('X', lambda x, y: x), - ('Y', lambda x, y: y), - ('Data', lambda x, y: self._getScatterValue(x, y)), - ('Index', lambda x, y: self._getScatterIndex(x, y)))) + converters=(('X', WeakMethodProxy(self._getPickedX)), + ('Y', WeakMethodProxy(self._getPickedY)), + ('Data', WeakMethodProxy(self._getPickedValue)), + ('Index', WeakMethodProxy(self._getPickedIndex)))) # Combine plot, position info and colorbar into central widget gridLayout = qt.QGridLayout() @@ -167,32 +169,52 @@ class ScatterView(qt.QMainWindow): dataIndex = indices[-1] self.__pickingCache = ( dataIndex, + item.getXData(copy=False)[dataIndex], + item.getYData(copy=False)[dataIndex], item.getValueData(copy=False)[dataIndex]) break return self.__pickingCache - def _getScatterValue(self, x, y): - """Get data value of top most scatter plot at position (x, y) + def _getPickedIndex(self, x, y): + """Get data index of top most scatter plot at position (x, y) :param float x: X position in plot coordinates :param float y: Y position in plot coordinates - :return: The data value at that point or '-' + :return: The data index at that point or '-' """ picking = self._pickScatterData(x, y) - return '-' if picking is None else picking[1] + return '-' if picking is None else picking[0] - def _getScatterIndex(self, x, y): - """Get data index of top most scatter plot at position (x, y) + def _getPickedX(self, x, y): + """Returns X position snapped to scatter plot when close enough + + :param float x: + :param float y: + :rtype: float + """ + picking = self._pickScatterData(x, y) + return x if picking is None else picking[1] + + def _getPickedY(self, x, y): + """Returns Y position snapped to scatter plot when close enough + + :param float x: + :param float y: + :rtype: float + """ + picking = self._pickScatterData(x, y) + return y if picking is None else picking[2] + + def _getPickedValue(self, x, y): + """Get data value of top most scatter plot at position (x, y) :param float x: X position in plot coordinates :param float y: Y position in plot coordinates - :return: The data index at that point or '-' + :return: The data value at that point or '-' """ picking = self._pickScatterData(x, y) - return '-' if picking is None else picking[0] - - _PICK_OFFSET = 3 # Offset in pixel used for picking + return '-' if picking is None else picking[3] def _mouseInPlotArea(self, x, y): """Clip mouse coordinates to plot area coordinates @@ -307,11 +329,10 @@ class ScatterView(qt.QMainWindow): self.getScatterItem().setData( x=x, y=y, value=value, xerror=xerror, yerror=yerror, alpha=alpha, copy=copy) + @docstring(items.Scatter) def getData(self, *args, **kwargs): return self.getScatterItem().getData(*args, **kwargs) - getData.__doc__ = items.Scatter.getData.__doc__ - def getScatterItem(self): """Returns the plot item displaying the scatter data. @@ -329,37 +350,30 @@ class ScatterView(qt.QMainWindow): # Convenient proxies + @docstring(PlotWidget) def getXAxis(self, *args, **kwargs): return self.getPlotWidget().getXAxis(*args, **kwargs) - getXAxis.__doc__ = PlotWidget.getXAxis.__doc__ - + @docstring(PlotWidget) def getYAxis(self, *args, **kwargs): return self.getPlotWidget().getYAxis(*args, **kwargs) - getYAxis.__doc__ = PlotWidget.getYAxis.__doc__ - + @docstring(PlotWidget) def setGraphTitle(self, *args, **kwargs): return self.getPlotWidget().setGraphTitle(*args, **kwargs) - setGraphTitle.__doc__ = PlotWidget.setGraphTitle.__doc__ - + @docstring(PlotWidget) def getGraphTitle(self, *args, **kwargs): return self.getPlotWidget().getGraphTitle(*args, **kwargs) - getGraphTitle.__doc__ = PlotWidget.getGraphTitle.__doc__ - + @docstring(PlotWidget) def resetZoom(self, *args, **kwargs): return self.getPlotWidget().resetZoom(*args, **kwargs) - resetZoom.__doc__ = PlotWidget.resetZoom.__doc__ - + @docstring(ScatterMaskToolsWidget) def getSelectionMask(self, *args, **kwargs): return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs) - getSelectionMask.__doc__ = ScatterMaskToolsWidget.getSelectionMask.__doc__ - + @docstring(ScatterMaskToolsWidget) def setSelectionMask(self, *args, **kwargs): return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs) - - setSelectionMask.__doc__ = ScatterMaskToolsWidget.setSelectionMask.__doc__ diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py index 4ba4fab..5e2dc58 100644 --- a/silx/gui/plot/StatsWidget.py +++ b/silx/gui/plot/StatsWidget.py @@ -35,9 +35,11 @@ from collections import OrderedDict from contextlib import contextmanager import logging import weakref - +import functools import numpy - +import enum +from silx.utils.proxy import docstring +from silx.utils.enum import Enum as _Enum from silx.gui import qt from silx.gui import icons from silx.gui.plot import stats as statsmdl @@ -52,8 +54,15 @@ from . import items as plotitems _logger = logging.getLogger(__name__) +@enum.unique +class UpdateMode(_Enum): + AUTO = 'auto' + MANUAL = 'manual' + + # Helper class to handle specific calls to PlotWidget and SceneWidget + class _Wrapper(qt.QObject): """Base class for connection with PlotWidget and SceneWidget. @@ -319,10 +328,12 @@ class _StatsWidgetBase(object): """ Base class for all widgets which want to display statistics """ + def __init__(self, statsOnVisibleData, displayOnlyActItem): self._displayOnlyActItem = displayOnlyActItem self._statsOnVisibleData = statsOnVisibleData self._statsHandler = None + self._updateMode = UpdateMode.AUTO self.__default_skipped_events = ( ItemChangedType.ALPHA, @@ -503,6 +514,29 @@ class _StatsWidgetBase(object): """ return event in self.__default_skipped_events + def setUpdateMode(self, mode): + """Set the way to update the displayed statistics. + + :param mode: mode requested for update + :type mode: Union[str,UpdateMode] + """ + mode = UpdateMode.from_value(mode) + if mode != self._updateMode: + self._updateMode = mode + self._updateModeHasChanged() + + def getUpdateMode(self): + """Returns update mode (See :meth:`setUpdateMode`). + + :return: update mode + :rtype: UpdateMode + """ + return self._updateMode + + def _updateModeHasChanged(self): + """callback when the update mode has changed""" + pass + class StatsTable(_StatsWidgetBase, TableWidget): """ @@ -522,6 +556,9 @@ class StatsTable(_StatsWidgetBase, TableWidget): _LEGEND_HEADER_DATA = 'legend' _KIND_HEADER_DATA = 'kind' + sigUpdateModeChanged = qt.Signal(object) + """Signal emitted when the update mode changed""" + def __init__(self, parent=None, plot=None): TableWidget.__init__(self, parent) _StatsWidgetBase.__init__(self, statsOnVisibleData=False, @@ -606,6 +643,8 @@ class StatsTable(_StatsWidgetBase, TableWidget): def _updateItemObserve(self, *args): """Reload table depending on mode""" + if self.getUpdateMode() is UpdateMode.MANUAL: + return self._removeAllItems() # Get selected or all items from the plot @@ -678,11 +717,19 @@ class StatsTable(_StatsWidgetBase, TableWidget): :param event: """ + if self.getUpdateMode() is UpdateMode.MANUAL: + return if self._skipPlotItemChangedEvent(event) is True: return else: item = self.sender() self._updateStats(item) + # deal with stat items visibility + if event is ItemChangedType.VISIBLE: + if len(self._itemToTableItems(item).items()) > 0: + item_0 = list(self._itemToTableItems(item).values())[0] + row_index = item_0.row() + self.setRowHidden(row_index, not item.isVisible()) def _addItem(self, item): """Add a plot item to the table @@ -810,8 +857,13 @@ class StatsTable(_StatsWidgetBase, TableWidget): else: tableItem.setText(str(value)) - def _updateAllStats(self): - """Update stats for all rows in the table""" + def _updateAllStats(self, is_request=False): + """Update stats for all rows in the table + + :param bool is_request: True if come from a manual request + """ + if self.getUpdateMode() is UpdateMode.MANUAL and not is_request: + return with self._disableSorting(): for row in range(self.rowCount()): tableItem = self.item(row, 0) @@ -851,10 +903,103 @@ class StatsTable(_StatsWidgetBase, TableWidget): else: self.setSelectionMode(qt.QAbstractItemView.NoSelection) + def _updateModeHasChanged(self): + self.sigUpdateModeChanged.emit(self._updateMode) -class _OptionsWidget(qt.QToolBar): + +class UpdateModeWidget(qt.QWidget): + """Widget used to select the mode of update""" + sigUpdateModeChanged = qt.Signal(object) + """signal emitted when the mode for update changed""" + sigUpdateRequested = qt.Signal() + """signal emitted when an manual request for example is activate""" def __init__(self, parent=None): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QHBoxLayout()) + self._buttonGrp = qt.QButtonGroup(parent=self) + self._buttonGrp.setExclusive(True) + + spacer = qt.QSpacerItem(20, 20, + qt.QSizePolicy.Expanding, + qt.QSizePolicy.Minimum) + self.layout().addItem(spacer) + + self._autoRB = qt.QRadioButton('auto', parent=self) + self.layout().addWidget(self._autoRB) + self._buttonGrp.addButton(self._autoRB) + + self._manualRB = qt.QRadioButton('manual', parent=self) + self.layout().addWidget(self._manualRB) + self._buttonGrp.addButton(self._manualRB) + self._manualRB.setChecked(True) + + refresh_icon = icons.getQIcon('view-refresh') + self._updatePB = qt.QPushButton(refresh_icon, '', parent=self) + self.layout().addWidget(self._updatePB) + + # connect signal / SLOT + self._updatePB.clicked.connect(self._updateRequested) + self._manualRB.toggled.connect(self._manualButtonToggled) + self._autoRB.toggled.connect(self._autoButtonToggled) + + def _manualButtonToggled(self, checked): + if checked: + self.setUpdateMode(UpdateMode.MANUAL) + self.sigUpdateModeChanged.emit(self.getUpdateMode()) + + def _autoButtonToggled(self, checked): + if checked: + self.setUpdateMode(UpdateMode.AUTO) + self.sigUpdateModeChanged.emit(self.getUpdateMode()) + + def _updateRequested(self): + if self.getUpdateMode() is UpdateMode.MANUAL: + self.sigUpdateRequested.emit() + + def setUpdateMode(self, mode): + """Set the way to update the displayed statistics. + + :param mode: mode requested for update + :type mode: Union[str,UpdateMode] + """ + mode = UpdateMode.from_value(mode) + + if mode is UpdateMode.AUTO: + if not self._autoRB.isChecked(): + self._autoRB.setChecked(True) + elif mode is UpdateMode.MANUAL: + if not self._manualRB.isChecked(): + self._manualRB.setChecked(True) + else: + raise ValueError('mode', mode, 'is not recognized') + + def getUpdateMode(self): + """Returns update mode (See :meth:`setUpdateMode`). + + :return: the active update mode + :rtype: UpdateMode + """ + if self._manualRB.isChecked(): + return UpdateMode.MANUAL + elif self._autoRB.isChecked(): + return UpdateMode.AUTO + else: + raise RuntimeError("No mode selected") + + def showRadioButtons(self, show): + """show / hide the QRadioButtons + + :param bool show: if True make RadioButton visible + """ + self._autoRB.setVisible(show) + self._manualRB.setVisible(show) + + +class _OptionsWidget(qt.QToolBar): + + def __init__(self, parent=None, updateMode=None, displayOnlyActItem=False): + assert updateMode is not None qt.QToolBar.__init__(self, parent) self.setIconSize(qt.QSize(16, 16)) @@ -863,7 +1008,7 @@ class _OptionsWidget(qt.QToolBar): action.setText("Active items only") action.setToolTip("Display stats for active items only.") action.setCheckable(True) - action.setChecked(True) + action.setChecked(displayOnlyActItem) self.__displayActiveItems = action action = qt.QAction(self) @@ -909,9 +1054,26 @@ class _OptionsWidget(qt.QToolBar): self.dataRangeSelection.addAction(self.__useWholeData) self.dataRangeSelection.addAction(self.__useVisibleData) + self.__updateStatsAction = qt.QAction(self) + self.__updateStatsAction.setIcon(icons.getQIcon("view-refresh")) + self.__updateStatsAction.setText("update statistics") + self.__updateStatsAction.setToolTip("update statistics") + self.__updateStatsAction.setCheckable(False) + self._updateStatsSep = self.addSeparator() + self.addAction(self.__updateStatsAction) + + self._setUpdateMode(mode=updateMode) + + # expose API + self.sigUpdateStats = self.__updateStatsAction.triggered + def isActiveItemMode(self): return self.itemSelection.checkedAction() is self.__displayActiveItems + def setDisplayActiveItems(self, only_active): + self.__displayActiveItems.setChecked(only_active) + self.__displayWholeItems.setChecked(not only_active) + def isVisibleDataRangeMode(self): return self.dataRangeSelection.checkedAction() is self.__useVisibleData @@ -925,6 +1087,18 @@ class _OptionsWidget(qt.QToolBar): if not enabled: self.__useWholeData.setChecked(True) + def _setUpdateMode(self, mode): + self.__updateStatsAction.setVisible(mode == UpdateMode.MANUAL) + self._updateStatsSep.setVisible(mode == UpdateMode.MANUAL) + + def getUpdateStatsAction(self): + """ + + :return: the action for the automatic mode + :rtype: QAction + """ + return self.__updateStatsAction + class StatsWidget(qt.QWidget): """ @@ -954,19 +1128,26 @@ class StatsWidget(qt.QWidget): qt.QWidget.__init__(self, parent) self.setLayout(qt.QVBoxLayout()) self.layout().setContentsMargins(0, 0, 0, 0) - self._options = _OptionsWidget(parent=self) + self._options = _OptionsWidget(parent=self, updateMode=UpdateMode.MANUAL) self.layout().addWidget(self._options) self._statsTable = StatsTable(parent=self, plot=plot) + self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode()) + self._options._setUpdateMode(mode=self._statsTable.getUpdateMode()) self.setStats(stats) self.layout().addWidget(self._statsTable) + old = self._statsTable.blockSignals(True) self._options.itemSelection.triggered.connect( self._optSelectionChanged) self._options.dataRangeSelection.triggered.connect( self._optDataRangeChanged) - self._optSelectionChanged() self._optDataRangeChanged() + self._statsTable.blockSignals(old) + + self._statsTable.sigUpdateModeChanged.connect(self._options._setUpdateMode) + callback = functools.partial(self._getStatsTable()._updateAllStats, is_request=True) + self._options.sigUpdateStats.connect(callback) def _getStatsTable(self): """Returns the :class:`StatsTable` used by this widget. @@ -993,33 +1174,40 @@ class StatsWidget(qt.QWidget): # Proxy methods + @docstring(StatsTable) def setStats(self, statsHandler): return self._getStatsTable().setStats(statsHandler=statsHandler) - setStats.__doc__ = StatsTable.setStats.__doc__ - + @docstring(StatsTable) def setPlot(self, plot): self._options.setVisibleDataRangeModeEnabled( plot is None or isinstance(plot, PlotWidget)) return self._getStatsTable().setPlot(plot=plot) - setPlot.__doc__ = StatsTable.setPlot.__doc__ - + @docstring(StatsTable) def getPlot(self): return self._getStatsTable().getPlot() - getPlot.__doc__ = StatsTable.getPlot.__doc__ - + @docstring(StatsTable) def setDisplayOnlyActiveItem(self, displayOnlyActItem): + old = self._options.blockSignals(True) + # update the options + self._options.setDisplayActiveItems(displayOnlyActItem) + self._options.blockSignals(old) return self._getStatsTable().setDisplayOnlyActiveItem( displayOnlyActItem=displayOnlyActItem) - setDisplayOnlyActiveItem.__doc__ = StatsTable.setDisplayOnlyActiveItem.__doc__ - + @docstring(StatsTable) def setStatsOnVisibleData(self, b): return self._getStatsTable().setStatsOnVisibleData(b=b) - setStatsOnVisibleData.__doc__ = StatsTable.setStatsOnVisibleData.__doc__ + @docstring(StatsTable) + def getUpdateMode(self): + return self._statsTable.getUpdateMode() + + @docstring(StatsTable) + def setUpdateMode(self, mode): + self._statsTable.setUpdateMode(mode) DEFAULT_STATS = StatsHandler(( @@ -1050,13 +1238,13 @@ class BasicStatsWidget(StatsWidget): from silx.gui.plot import Plot1D from silx.gui.plot.StatsWidget import BasicStatsWidget - + plot = Plot1D() x = range(100) y = x plot.addCurve(x, y, legend='curve_0') plot.setActiveCurve('curve_0') - + widget = BasicStatsWidget(plot=plot) widget.show() """ @@ -1067,9 +1255,9 @@ class BasicStatsWidget(StatsWidget): class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): """ - Widget made to display stats into a QLayout with for all stat a couple - (QLabel, QLineEdit) created. - The the layout can be defined prior of adding any statistic. + Widget made to display stats into a QLayout with couple (QLabel, QLineEdit) + created for each stats. + The layout can be defined prior of adding any statistic. :param QWidget parent: Qt parent :param Union[PlotWidget,SceneWidget] plot: @@ -1081,6 +1269,9 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): only visible ones. """ + sigUpdateModeChanged = qt.Signal(object) + """Signal emitted when the update mode changed""" + def __init__(self, parent=None, plot=None, kind='curve', stats=None, statsOnVisibleData=False): self._item_kind = kind @@ -1141,6 +1332,8 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): self._updateAllStats() def _activeItemChanged(self, kind, previous, current): + if self.getUpdateMode() is UpdateMode.MANUAL: + return if kind == self._item_kind: self._updateAllStats() @@ -1148,9 +1341,9 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): plot = self.getPlot() if plot is not None: _items = self._plotWrapper.getSelectedItems() + def kind_filter(_item): return self._plotWrapper.getKind(_item) == self.getKind() - items = list(filter(kind_filter, _items)) assert len(items) in (0, 1) if len(items) is 1: @@ -1186,8 +1379,11 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): self._statQlineEdit[statName].setText(statVal) def _updateItemObserve(self, *argv): + if self.getUpdateMode() is UpdateMode.MANUAL: + return assert self._displayOnlyActItem _items = self._plotWrapper.getSelectedItems() + def kind_filter(_item): return self._plotWrapper.getKind(_item) == self.getKind() items = list(filter(kind_filter, _items)) @@ -1208,22 +1404,11 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): def _plotCurrentChanged(selfself, current): raise NotImplementedError('Display only the active item') + def _updateModeHasChanged(self): + self.sigUpdateModeChanged.emit(self._updateMode) -class BasicLineStatsWidget(_BaseLineStatsWidget): - """ - Widget defining a simple set of :class:`Stat` to be displayed on a - :class:`LineStatsWidget`. - - :param QWidget parent: Qt parent - :param Union[PlotWidget,SceneWidget] plot: - The plot containing items on which we want statistics. - :param str kind: the kind of plotitems we want to display - :param StatsHandler stats: - Set the statistics to be displayed and how to format them using - :param bool statsOnVisibleData: compute statistics for the whole data or - only visible ones. - """ +class _BasicLineStatsWidget(_BaseLineStatsWidget): def __init__(self, parent=None, plot=None, kind='curve', stats=DEFAULT_STATS, statsOnVisibleData=False): _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind, @@ -1246,38 +1431,84 @@ class BasicLineStatsWidget(_BaseLineStatsWidget): self.layout().addWidget(widget) + def _addOptionsWidget(self, widget): + self.layout().addWidget(widget) + -class BasicGridStatsWidget(_BaseLineStatsWidget): +class BasicLineStatsWidget(qt.QWidget): """ - pymca design like widget - + Widget defining a simple set of :class:`Stat` to be displayed on a + :class:`LineStatsWidget`. + :param QWidget parent: Qt parent :param Union[PlotWidget,SceneWidget] plot: The plot containing items on which we want statistics. + :param str kind: the kind of plotitems we want to display :param StatsHandler stats: Set the statistics to be displayed and how to format them using - :param str kind: the kind of plotitems we want to display :param bool statsOnVisibleData: compute statistics for the whole data or only visible ones. - :param int statsPerLine: number of statistic to be displayed per line + """ + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QHBoxLayout()) + self.layout().setSpacing(0) + self.layout().setContentsMargins(0, 0, 0, 0) + self._lineStatsWidget = _BasicLineStatsWidget(parent=self, plot=plot, + kind=kind, stats=stats, + statsOnVisibleData=statsOnVisibleData) + self.layout().addWidget(self._lineStatsWidget) + + self._options = UpdateModeWidget() + self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode()) + self._options.showRadioButtons(False) + self.layout().addWidget(self._options) - .. snapshotqt:: img/BasicGridStatsWidget.png - :width: 600px - :align: center + # connect Signal ? SLOT + self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode) + self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode) + self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats) - from silx.gui.plot import Plot1D - from silx.gui.plot.StatsWidget import BasicGridStatsWidget - - plot = Plot1D() - x = range(100) - y = x - plot.addCurve(x, y, legend='curve_0') - plot.setActiveCurve('curve_0') - - widget = BasicGridStatsWidget(plot=plot, kind='curve') - widget.show() - """ + def showControl(self, visible): + self._options.setVisible(visible) + + # Proxy methods + + @docstring(_BasicLineStatsWidget) + def setUpdateMode(self, mode): + self._lineStatsWidget.setUpdateMode(mode=mode) + + @docstring(_BasicLineStatsWidget) + def getUpdateMode(self): + return self._lineStatsWidget.getUpdateMode() + + @docstring(_BasicLineStatsWidget) + def setPlot(self, plot): + self._lineStatsWidget.setPlot(plot=plot) + @docstring(_BasicLineStatsWidget) + def setStats(self, statsHandler): + self._lineStatsWidget.setStats(statsHandler=statsHandler) + + @docstring(_BasicLineStatsWidget) + def setKind(self, kind): + self._lineStatsWidget.setKind(kind=kind) + + @docstring(_BasicLineStatsWidget) + def getKind(self): + return self._lineStatsWidget.getKind() + + @docstring(_BasicLineStatsWidget) + def setStatsOnVisibleData(self, b): + self._lineStatsWidget.setStatsOnVisibleData(b) + + @docstring(UpdateModeWidget) + def showRadioButtons(self, show): + self._options.showRadioButtons(show=show) + + +class _BasicGridStatsWidget(_BaseLineStatsWidget): def __init__(self, parent=None, plot=None, kind='curve', stats=DEFAULT_STATS, statsOnVisibleData=False, statsPerLine=4): @@ -1294,3 +1525,94 @@ class BasicGridStatsWidget(_BaseLineStatsWidget): def _createLayout(self): return qt.QGridLayout() + + +class BasicGridStatsWidget(qt.QWidget): + """ + pymca design like widget + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param str kind: the kind of plotitems we want to display + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + :param int statsPerLine: number of statistic to be displayed per line + + .. snapshotqt:: img/_BasicGridStatsWidget.png + :width: 600px + :align: center + + from silx.gui.plot import Plot1D + from silx.gui.plot.StatsWidget import _BasicGridStatsWidget + + plot = Plot1D() + x = range(100) + y = x + plot.addCurve(x, y, legend='curve_0') + plot.setActiveCurve('curve_0') + + widget = _BasicGridStatsWidget(plot=plot, kind='curve') + widget.show() + """ + + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setSpacing(0) + self.layout().setContentsMargins(0, 0, 0, 0) + + self._options = UpdateModeWidget() + self._options.showRadioButtons(False) + self.layout().addWidget(self._options) + + self._lineStatsWidget = _BasicGridStatsWidget(parent=self, plot=plot, + kind=kind, stats=stats, + statsOnVisibleData=statsOnVisibleData) + self.layout().addWidget(self._lineStatsWidget) + + # tune options + self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode()) + + # connect Signal ? SLOT + self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode) + self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode) + self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats) + + def showControl(self, visible): + self._options.setVisible(visible) + + @docstring(_BasicGridStatsWidget) + def setUpdateMode(self, mode): + self._lineStatsWidget.setUpdateMode(mode=mode) + + @docstring(_BasicGridStatsWidget) + def getUpdateMode(self): + return self._lineStatsWidget.getUpdateMode() + + @docstring(_BasicGridStatsWidget) + def setPlot(self, plot): + self._lineStatsWidget.setPlot(plot=plot) + + @docstring(_BasicGridStatsWidget) + def setStats(self, statsHandler): + self._lineStatsWidget.setStats(statsHandler=statsHandler) + + @docstring(_BasicGridStatsWidget) + def setKind(self, kind): + self._lineStatsWidget.setKind(kind=kind) + + @docstring(_BasicGridStatsWidget) + def getKind(self): + return self._lineStatsWidget.getKind() + + @docstring(_BasicGridStatsWidget) + def setStatsOnVisibleData(self, b): + self._lineStatsWidget.setStatsOnVisibleData(b) + + @docstring(UpdateModeWidget) + def showRadioButtons(self, show): + self._options.showRadioButtons(show=show) diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py index 0d11f17..d8e9fb5 100644 --- a/silx/gui/plot/_BaseMaskToolsWidget.py +++ b/silx/gui/plot/_BaseMaskToolsWidget.py @@ -29,7 +29,7 @@ from __future__ import division __authors__ = ["T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "15/02/2019" +__date__ = "12/04/2019" import os import weakref @@ -519,7 +519,7 @@ class BaseMaskToolsWidget(qt.QWidget): def _initTransparencyWidget(self): """ Init the mask transparency widget """ - transparencyWidget = qt.QWidget(self) + transparencyWidget = qt.QWidget(parent=self) grid = qt.QGridLayout() grid.setContentsMargins(0, 0, 0, 0) self.transparencySlider = qt.QSlider(qt.Qt.Horizontal, parent=transparencyWidget) @@ -619,8 +619,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.addAction(self.browseAction) # Draw tools - self.rectAction = qt.QAction( - icons.getQIcon('shape-rectangle'), 'Rectangle selection', None) + self.rectAction = qt.QAction(icons.getQIcon('shape-rectangle'), + 'Rectangle selection', + self) self.rectAction.setToolTip( 'Rectangle selection tool: (Un)Mask a rectangular region R') self.rectAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R)) @@ -628,8 +629,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.rectAction.triggered.connect(self._activeRectMode) self.addAction(self.rectAction) - self.ellipseAction = qt.QAction( - icons.getQIcon('shape-ellipse'), 'Circle selection', None) + self.ellipseAction = qt.QAction(icons.getQIcon('shape-ellipse'), + 'Circle selection', + self) self.ellipseAction.setToolTip( 'Rectangle selection tool: (Un)Mask a circle region R') self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R)) @@ -637,8 +639,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.ellipseAction.triggered.connect(self._activeEllipseMode) self.addAction(self.ellipseAction) - self.polygonAction = qt.QAction( - icons.getQIcon('shape-polygon'), 'Polygon selection', None) + self.polygonAction = qt.QAction(icons.getQIcon('shape-polygon'), + 'Polygon selection', + self) self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S)) self.polygonAction.setToolTip( 'Polygon selection tool: (Un)Mask a polygonal region S
' @@ -648,8 +651,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.polygonAction.triggered.connect(self._activePolygonMode) self.addAction(self.polygonAction) - self.pencilAction = qt.QAction( - icons.getQIcon('draw-pencil'), 'Pencil tool', None) + self.pencilAction = qt.QAction(icons.getQIcon('draw-pencil'), + 'Pencil tool', + self) self.pencilAction.setShortcut(qt.QKeySequence(qt.Qt.Key_P)) self.pencilAction.setToolTip( 'Pencil tool: (Un)Mask using a pencil P') @@ -733,21 +737,24 @@ class BaseMaskToolsWidget(qt.QWidget): def _initThresholdGroupBox(self): """Init thresholding widgets""" - self.belowThresholdAction = qt.QAction( - icons.getQIcon('plot-roi-below'), 'Mask below threshold', None) + self.belowThresholdAction = qt.QAction(icons.getQIcon('plot-roi-below'), + 'Mask below threshold', + self) self.belowThresholdAction.setToolTip( 'Mask image where values are below given threshold') self.belowThresholdAction.setCheckable(True) self.belowThresholdAction.setChecked(True) - self.betweenThresholdAction = qt.QAction( - icons.getQIcon('plot-roi-between'), 'Mask within range', None) + self.betweenThresholdAction = qt.QAction(icons.getQIcon('plot-roi-between'), + 'Mask within range', + self) self.betweenThresholdAction.setToolTip( 'Mask image where values are within given range') self.betweenThresholdAction.setCheckable(True) - self.aboveThresholdAction = qt.QAction( - icons.getQIcon('plot-roi-above'), 'Mask above threshold', None) + self.aboveThresholdAction = qt.QAction(icons.getQIcon('plot-roi-above'), + 'Mask above threshold', + self) self.aboveThresholdAction.setToolTip( 'Mask image where values are above given threshold') self.aboveThresholdAction.setCheckable(True) @@ -760,8 +767,9 @@ class BaseMaskToolsWidget(qt.QWidget): self.thresholdActionGroup.triggered.connect( self._thresholdActionGroupTriggered) - self.loadColormapRangeAction = qt.QAction( - icons.getQIcon('view-refresh'), 'Set min-max from colormap', None) + self.loadColormapRangeAction = qt.QAction(icons.getQIcon('view-refresh'), + 'Set min-max from colormap', + self) self.loadColormapRangeAction.setToolTip( 'Set min and max values from current colormap range') self.loadColormapRangeAction.setCheckable(False) @@ -774,7 +782,7 @@ class BaseMaskToolsWidget(qt.QWidget): btn.setDefaultAction(action) widgets.append(btn) - spacer = qt.QWidget() + spacer = qt.QWidget(parent=self) spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Preferred) widgets.append(spacer) @@ -1059,7 +1067,7 @@ class BaseMaskToolsWidget(qt.QWidget): self.maxLineLabel.setVisible(False) self.minLineEdit.setVisible(True) self.maxLineEdit.setVisible(False) - self.applyMaskBtn.setText("Mask bellow") + self.applyMaskBtn.setText("Mask below") elif triggeredAction is self.betweenThresholdAction: self.minLineLabel.setVisible(True) self.maxLineLabel.setVisible(True) diff --git a/silx/gui/plot/_utils/delaunay.py b/silx/gui/plot/_utils/delaunay.py new file mode 100644 index 0000000..49ad05f --- /dev/null +++ b/silx/gui/plot/_utils/delaunay.py @@ -0,0 +1,62 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Wrapper over Delaunay implementation""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/05/2019" + + +import logging +import sys + +import numpy + + +_logger = logging.getLogger(__name__) + + +def delaunay(x, y): + """Returns Delaunay instance for x, y points + + :param numpy.ndarray x: + :param numpy.ndarray y: + :rtype: Union[None,scipy.spatial.Delaunay] + """ + # Lazy-loading of Delaunay + try: + from scipy.spatial import Delaunay as _Delaunay + except ImportError: # Fallback using local Delaunay + from silx.third_party.scipy_spatial import Delaunay as _Delaunay + + points = numpy.array((x, y)).T + try: + delaunay = _Delaunay(points) + except (RuntimeError, ValueError): + _logger.error("Delaunay tesselation failed: %s", + sys.exc_info()[1]) + delaunay = None + + return delaunay diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py index 2d01ef1..ec4a3de 100644 --- a/silx/gui/plot/actions/control.py +++ b/silx/gui/plot/actions/control.py @@ -35,6 +35,7 @@ The following QAction are available: - :class:`KeepAspectRatioAction` - :class:`PanWithArrowKeysAction` - :class:`ResetZoomAction` +- :class:`ShowAxisAction` - :class:`XAxisLogarithmicAction` - :class:`XAxisAutoScaleAction` - :class:`YAxisInvertedAction` @@ -43,7 +44,6 @@ The following QAction are available: - :class:`ZoomBackAction` - :class:`ZoomInAction` - :class:`ZoomOutAction` -- :class:'ShowAxisAction' """ from __future__ import division @@ -377,11 +377,11 @@ class ColormapAction(PlotAction): # Specific init for complex images colormap = image.getColormap() - mode = image.getVisualizationMode() - if mode in (items.ImageComplexData.Mode.AMPLITUDE_PHASE, - items.ImageComplexData.Mode.LOG10_AMPLITUDE_PHASE): + mode = image.getComplexMode() + if mode in (items.ImageComplexData.ComplexMode.AMPLITUDE_PHASE, + items.ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE): data = image.getData( - copy=False, mode=items.ImageComplexData.Mode.PHASE) + copy=False, mode=items.ImageComplexData.ComplexMode.PHASE) else: data = image.getData(copy=False) diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py index 0514c85..af37543 100644 --- a/silx/gui/plot/backends/BackendBase.py +++ b/silx/gui/plot/backends/BackendBase.py @@ -170,6 +170,23 @@ class BackendBase(object): """ return legend + def addTriangles(self, x, y, triangles, legend, + color, z, selectable, alpha): + """Add a set of triangles. + + :param numpy.ndarray x: The data corresponding to the x axis + :param numpy.ndarray y: The data corresponding to the y axis + :param numpy.ndarray triangles: The indices to make triangles + as a (Ntriangle, 3) array + :param str legend: The legend to be associated to the curve + :param numpy.ndarray color: color(s) as (npoints, 4) array + :param int z: Layer on which to draw the cuve + :param bool selectable: indicate if the curve can be selected + :param float alpha: Opacity as a float in [0., 1.] + :returns: The triangles' unique identifier used by the backend + """ + return legend + def addItem(self, x, y, legend, shape, color, fill, overlay, z, linestyle, linewidth, linebgcolor): """Add an item (i.e. a shape) to the plot. diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py index 726a839..7739329 100644 --- a/silx/gui/plot/backends/BackendMatplotlib.py +++ b/silx/gui/plot/backends/BackendMatplotlib.py @@ -54,7 +54,8 @@ from matplotlib.backend_bases import MouseEvent from matplotlib.lines import Line2D from matplotlib.collections import PathCollection, LineCollection from matplotlib.ticker import Formatter, ScalarFormatter, Locator - +from matplotlib.tri import Triangulation +from matplotlib.collections import TriMesh from . import BackendBase from .._utils import FLOAT32_MINPOS @@ -359,9 +360,12 @@ class BackendMatplotlib(BackendBase.BackendBase): else: errorbarColor = color - # On Debian 7 at least, Nx1 array yerr does not seems supported + # Nx1 error array deprecated in matplotlib >=3.1 (removed in 3.3) + if (isinstance(xerror, numpy.ndarray) and xerror.ndim == 2 and + xerror.shape[1] == 1): + xerror = numpy.ravel(xerror) if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and - yerror.shape[1] == 1 and len(x) != 1): + yerror.shape[1] == 1): yerror = numpy.ravel(yerror) errorbars = axes.errorbar(x, y, label=legend, @@ -477,6 +481,32 @@ class BackendMatplotlib(BackendBase.BackendBase): self.ax.add_artist(image) return image + def addTriangles(self, x, y, triangles, legend, + color, z, selectable, alpha): + for parameter in (x, y, triangles, legend, color, + z, selectable, alpha): + assert parameter is not None + + # 0 enables picking on filled triangle + picker = 0 if selectable else None + + color = numpy.array(color, copy=False) + assert color.ndim == 2 and len(color) == len(x) + + if color.dtype not in [numpy.float32, numpy.float]: + color = color.astype(numpy.float32) / 255. + + collection = TriMesh( + Triangulation(x, y, triangles), + label=legend, + alpha=alpha, + picker=picker, + zorder=z) + collection.set_color(color) + self.ax.add_collection(collection) + + return collection + def addItem(self, x, y, legend, shape, color, fill, overlay, z, linestyle, linewidth, linebgcolor): if (linebgcolor is not None and @@ -1100,6 +1130,22 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): elif label.startswith('__IMAGE__'): self._picked.append({'kind': 'image', 'legend': label[9:]}) + elif isinstance(event.artist, TriMesh): + # Convert selected triangle to data point indices + triangulation = event.artist._triangulation + indices = triangulation.get_masked_triangles()[event.ind[0]] + + # Sort picked triangle points by distance to mouse + # from furthest to closest to put closest point last + # This is to be somewhat consistent with last scatter point + # being the top one. + dists = ((triangulation.x[indices] - event.mouseevent.xdata) ** 2 + + (triangulation.y[indices] - event.mouseevent.ydata) ** 2) + indices = indices[numpy.flip(numpy.argsort(dists))] + + self._picked.append({'kind': 'curve', 'legend': label, + 'indices': indices}) + else: # it's a curve, item have no picker for now if not isinstance(event.artist, (PathCollection, Line2D)): _logger.info('Unsupported artist, ignored') diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py index e33d03c..0420aa9 100644 --- a/silx/gui/plot/backends/BackendOpenGL.py +++ b/silx/gui/plot/backends/BackendOpenGL.py @@ -31,8 +31,9 @@ __license__ = "MIT" __date__ = "21/12/2018" from collections import OrderedDict, namedtuple -from ctypes import c_void_p import logging +import warnings +import weakref import numpy @@ -44,7 +45,7 @@ from ... import qt from ..._glutils import gl from ... import _glutils as glu from .glutils import ( - GLLines2D, + GLLines2D, GLPlotTriangles, GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D, mat4Ortho, mat4Identity, LEFT, RIGHT, BOTTOM, TOP, @@ -106,7 +107,7 @@ class PlotDataContent(object): This class is only meant to work with _OpenGLPlotCanvas. """ - _PRIMITIVE_TYPES = 'curve', 'image' + _PRIMITIVE_TYPES = 'curve', 'image', 'triangles' def __init__(self): self._primitives = OrderedDict() # For images and curves @@ -124,6 +125,8 @@ class PlotDataContent(object): primitiveType = 'curve' elif isinstance(primitive, (GLPlotColormap, GLPlotRGBAImage)): primitiveType = 'image' + elif isinstance(primitive, GLPlotTriangles): + primitiveType = 'triangles' else: raise RuntimeError('Unsupported object type: %s', primitive) @@ -304,16 +307,8 @@ _texFragShd = """ } """ - # BackendOpenGL ############################################################### -_current_context = None - - -def _getContext(): - assert _current_context is not None - return _current_context - class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): """OpenGL-based Plot backend. @@ -348,7 +343,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): _baseVertShd, _baseFragShd, attrib0='position') self._progTex = glu.Program( _texVertShd, _texFragShd, attrib0='position') - self._plotFBOs = {} + self._plotFBOs = weakref.WeakKeyDictionary() self._keepDataAspectRatio = False @@ -386,6 +381,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): return qt.QSize(8 * 80, 6 * 80) # Mimic MatplotlibBackend def mousePressEvent(self, event): + if event.button() not in self._MOUSE_BTNS: + return super(BackendOpenGL, self).mousePressEvent(event) xPixel = event.x() * self.getDevicePixelRatio() yPixel = event.y() * self.getDevicePixelRatio() btn = self._MOUSE_BTNS[event.button()] @@ -411,6 +408,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): event.accept() def mouseReleaseEvent(self, event): + if event.button() not in self._MOUSE_BTNS: + return super(BackendOpenGL, self).mouseReleaseEvent(event) xPixel = event.x() * self.getDevicePixelRatio() yPixel = event.y() * self.getDevicePixelRatio() @@ -462,15 +461,17 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._renderOverlayGL() def _paintFBOGL(self): - context = glu.getGLContext() + context = glu.Context.getCurrent() plotFBOTex = self._plotFBOs.get(context) if (self._plot._getDirtyPlot() or self._plotFrame.isDirty or plotFBOTex is None): - self._plotVertices = numpy.array(((-1., -1., 0., 0.), - (1., -1., 1., 0.), - (-1., 1., 0., 1.), - (1., 1., 1., 1.)), - dtype=numpy.float32) + self._plotVertices = ( + # Vertex coordinates + numpy.array(((-1., -1.), (1., -1.), (-1., 1.), (1., 1.)), + dtype=numpy.float32), + # Texture coordinates + numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)), + dtype=numpy.float32)) if plotFBOTex is None or \ plotFBOTex.shape[1] != self._plotFrame.size[0] or \ plotFBOTex.shape[0] != self._plotFrame.size[1]: @@ -502,53 +503,45 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE, mat4Identity().astype(numpy.float32)) - stride = self._plotVertices.shape[-1] * self._plotVertices.itemsize gl.glEnableVertexAttribArray(self._progTex.attributes['position']) gl.glVertexAttribPointer(self._progTex.attributes['position'], 2, gl.GL_FLOAT, gl.GL_FALSE, - stride, self._plotVertices) + 0, + self._plotVertices[0]) - texCoordsPtr = c_void_p(self._plotVertices.ctypes.data + - 2 * self._plotVertices.itemsize) # Better way? gl.glEnableVertexAttribArray(self._progTex.attributes['texCoords']) gl.glVertexAttribPointer(self._progTex.attributes['texCoords'], 2, gl.GL_FLOAT, gl.GL_FALSE, - stride, texCoordsPtr) + 0, + self._plotVertices[1]) with plotFBOTex.texture: - gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices)) + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices[0])) self._renderMarkersGL() self._renderOverlayGL() def paintGL(self): - global _current_context - _current_context = self.context() - - glu.setGLContextGetter(_getContext) - - # Release OpenGL resources - for item in self._glGarbageCollector: - item.discard() - self._glGarbageCollector = [] - - gl.glClearColor(*self._backgroundColor) - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) + with glu.Context.current(self.context()): + # Release OpenGL resources + for item in self._glGarbageCollector: + item.discard() + self._glGarbageCollector = [] - # Check if window is large enough - plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] - if plotWidth <= 2 or plotHeight <= 2: - return + gl.glClearColor(*self._backgroundColor) + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) - # self._paintDirectGL() - self._paintFBOGL() + # Check if window is large enough + plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + if plotWidth <= 2 or plotHeight <= 2: + return - glu.setGLContextGetter() - _current_context = None + # self._paintDirectGL() + self._paintFBOGL() def _renderMarkersGL(self): if len(self._markers) == 0: @@ -892,7 +885,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): xErrorMinus, xErrorPlus = xerror[0], xerror[1] else: xErrorMinus, xErrorPlus = xerror, xerror - xErrorMinus = logX - numpy.log10(x - xErrorMinus) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore divide by zero, invalid value encountered in log10 + xErrorMinus = logX - numpy.log10(x - xErrorMinus) xErrorPlus = numpy.log10(x + xErrorPlus) - logX xerror = numpy.array((xErrorMinus, xErrorPlus), dtype=numpy.float32) @@ -912,7 +908,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): yErrorMinus, yErrorPlus = yerror[0], yerror[1] else: yErrorMinus, yErrorPlus = yerror, yerror - yErrorMinus = logY - numpy.log10(y - yErrorMinus) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore divide by zero, invalid value encountered in log10 + yErrorMinus = logY - numpy.log10(y - yErrorMinus) yErrorPlus = numpy.log10(y + yErrorPlus) - logY yerror = numpy.array((yErrorMinus, yErrorPlus), dtype=numpy.float32) @@ -1043,6 +1042,25 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): return legend, 'image' + def addTriangles(self, x, y, triangles, legend, + color, z, selectable, alpha): + + # Handle axes log scale: convert data + if self._plotFrame.xAxis.isLog: + x = numpy.log10(x) + if self._plotFrame.yAxis.isLog: + y = numpy.log10(y) + + triangles = GLPlotTriangles(x, y, color, triangles, alpha) + triangles.info = { + 'legend': legend, + 'zOrder': z, + 'behaviors': set(['selectable']) if selectable else set(), + } + self._plotContent.add(triangles) + + return legend, 'triangles' + def addItem(self, x, y, legend, shape, color, fill, overlay, z, linestyle, linewidth, linebgcolor): # TODO handle overlay @@ -1132,10 +1150,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._glGarbageCollector.append(curve) - elif kind == 'image': - image = self._plotContent.pop('image', legend) - if image is not None: - self._glGarbageCollector.append(image) + elif kind in ('image', 'triangles'): + item = self._plotContent.pop(kind, legend) + if item is not None: + self._glGarbageCollector.append(item) elif kind == 'marker': self._markers.pop(legend, False) @@ -1188,6 +1206,60 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._plotFrame.size[1] - self._plotFrame.margins.bottom - 1) return xPlot, yPlot + def __pickCurves(self, item, x, y): + """Perform picking on a curve item. + + :param GLPlotCurve2D item: + :param float x: X position of the mouse in widget coordinates + :param float y: Y position of the mouse in widget coordinates + :return: List of indices of picked points + :rtype: List[int] + """ + offset = self._PICK_OFFSET + if item.marker is not None: + offset = max(item.markerSize / 2., offset) + if item.lineStyle is not None: + offset = max(item.lineWidth / 2., offset) + + yAxis = item.info['yAxis'] + + inAreaPos = self._mouseInPlotArea(x - offset, y - offset) + dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], + axis=yAxis, check=True) + if dataPos is None: + return [] + xPick0, yPick0 = dataPos + + inAreaPos = self._mouseInPlotArea(x + offset, y + offset) + dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], + axis=yAxis, check=True) + if dataPos is None: + return [] + xPick1, yPick1 = dataPos + + if xPick0 < xPick1: + xPickMin, xPickMax = xPick0, xPick1 + else: + xPickMin, xPickMax = xPick1, xPick0 + + if yPick0 < yPick1: + yPickMin, yPickMax = yPick0, yPick1 + else: + yPickMin, yPickMax = yPick1, yPick0 + + # Apply log scale if axis is log + if self._plotFrame.xAxis.isLog: + xPickMin = numpy.log10(xPickMin) + xPickMax = numpy.log10(xPickMax) + + if (yAxis == 'left' and self._plotFrame.yAxis.isLog) or ( + yAxis == 'right' and self._plotFrame.y2Axis.isLog): + yPickMin = numpy.log10(yPickMin) + yPickMax = numpy.log10(yPickMax) + + return item.pick(xPickMin, yPickMin, + xPickMax, yPickMax) + def pickItems(self, x, y, kinds): picked = [] @@ -1236,56 +1308,20 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): picked.append(dict(kind='image', legend=item.info['legend'])) - elif 'curve' in kinds and isinstance(item, GLPlotCurve2D): - offset = self._PICK_OFFSET - if item.marker is not None: - offset = max(item.markerSize / 2., offset) - if item.lineStyle is not None: - offset = max(item.lineWidth / 2., offset) - - yAxis = item.info['yAxis'] - - inAreaPos = self._mouseInPlotArea(x - offset, y - offset) - dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], - axis=yAxis, check=True) - if dataPos is None: - continue - xPick0, yPick0 = dataPos - - inAreaPos = self._mouseInPlotArea(x + offset, y + offset) - dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], - axis=yAxis, check=True) - if dataPos is None: - continue - xPick1, yPick1 = dataPos - - if xPick0 < xPick1: - xPickMin, xPickMax = xPick0, xPick1 - else: - xPickMin, xPickMax = xPick1, xPick0 - - if yPick0 < yPick1: - yPickMin, yPickMax = yPick0, yPick1 - else: - yPickMin, yPickMax = yPick1, yPick0 - - # Apply log scale if axis is log - if self._plotFrame.xAxis.isLog: - xPickMin = numpy.log10(xPickMin) - xPickMax = numpy.log10(xPickMax) - - if (yAxis == 'left' and self._plotFrame.yAxis.isLog) or ( - yAxis == 'right' and self._plotFrame.y2Axis.isLog): - yPickMin = numpy.log10(yPickMin) - yPickMax = numpy.log10(yPickMax) - - pickedIndices = item.pick(xPickMin, yPickMin, - xPickMax, yPickMax) - if pickedIndices: - picked.append(dict(kind='curve', - legend=item.info['legend'], - indices=pickedIndices)) - + elif 'curve' in kinds: + if isinstance(item, GLPlotCurve2D): + pickedIndices = self.__pickCurves(item, x, y) + if pickedIndices: + picked.append(dict(kind='curve', + legend=item.info['legend'], + indices=pickedIndices)) + + elif isinstance(item, GLPlotTriangles): + pickedIndices = item.pick(*dataPos) + if pickedIndices: + picked.append(dict(kind='curve', + legend=item.info['legend'], + indices=pickedIndices)) return picked # Update curve diff --git a/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/silx/gui/plot/backends/glutils/GLPlotTriangles.py new file mode 100644 index 0000000..c756749 --- /dev/null +++ b/silx/gui/plot/backends/glutils/GLPlotTriangles.py @@ -0,0 +1,193 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides a class to render a set of 2D triangles +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import ctypes + +import numpy + +from .....math.combo import min_max +from .... import _glutils as glutils +from ...._glutils import gl + + +class GLPlotTriangles(object): + """Handle rendering of a set of colored triangles""" + + _PROGRAM = glutils.Program( + vertexShader=""" + #version 120 + + uniform mat4 matrix; + attribute float xPos; + attribute float yPos; + attribute vec4 color; + + varying vec4 vColor; + + void main(void) { + gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0); + vColor = color; + } + """, + fragmentShader=""" + #version 120 + + uniform float alpha; + varying vec4 vColor; + + void main(void) { + gl_FragColor = vColor; + gl_FragColor.a *= alpha; + } + """, + attrib0='xPos') + + def __init__(self, x, y, color, triangles, alpha=1.): + """ + + :param numpy.ndarray x: X coordinates of triangle corners + :param numpy.ndarray y: Y coordinates of triangle corners + :param numpy.ndarray color: color for each point + :param numpy.ndarray triangles: (N, 3) array of indices of triangles + :param float alpha: Opacity in [0, 1] + """ + # Check and convert input data + x = numpy.ravel(numpy.array(x, dtype=numpy.float32)) + y = numpy.ravel(numpy.array(y, dtype=numpy.float32)) + color = numpy.array(color, copy=False) + # Cast to uint32 + triangles = numpy.array(triangles, copy=False, dtype=numpy.uint32) + + assert x.size == y.size + assert x.size == len(color) + assert color.ndim == 2 and color.shape[1] in (3, 4) + if numpy.issubdtype(color.dtype, numpy.floating): + color = numpy.array(color, dtype=numpy.float32, copy=False) + elif numpy.issubdtype(color.dtype, numpy.integer): + color = numpy.array(color, dtype=numpy.uint8, copy=False) + else: + raise ValueError('Unsupported color type') + assert triangles.ndim == 2 and triangles.shape[1] == 3 + + self.__x_y_color = x, y, color + self.xMin, self.xMax = min_max(x, finite=True) + self.yMin, self.yMax = min_max(y, finite=True) + self.__triangles = triangles + self.__alpha = numpy.clip(float(alpha), 0., 1.) + self.__vbos = None + self.__indicesVbo = None + self.__picking_triangles = None + + def pick(self, x, y): + """Perform picking + + :param float x: X coordinates in plot data frame + :param float y: Y coordinates in plot data frame + :return: List of picked data point indices + :rtype: numpy.ndarray + """ + if (x < self.xMin or x > self.xMax or + y < self.yMin or y > self.yMax): + return () + xPts, yPts = self.__x_y_color[:2] + if self.__picking_triangles is None: + self.__picking_triangles = numpy.zeros( + self.__triangles.shape + (3,), dtype=numpy.float32) + self.__picking_triangles[:, :, 0] = xPts[self.__triangles] + self.__picking_triangles[:, :, 1] = yPts[self.__triangles] + + segment = numpy.array(((x, y, -1), (x, y, 1)), dtype=numpy.float32) + # Picked triangle indices + indices = glutils.segmentTrianglesIntersection( + segment, self.__picking_triangles)[0] + # Point indices + indices = numpy.unique(numpy.ravel(self.__triangles[indices])) + + # Sorted from furthest to closest point + dists = (xPts[indices] - x) ** 2 + (yPts[indices] - y) ** 2 + indices = indices[numpy.flip(numpy.argsort(dists))] + + return tuple(indices) + + def discard(self): + """Release resources on the GPU""" + if self.__vbos is not None: + self.__vbos[0].vbo.discard() + self.__vbos = None + self.__indicesVbo.discard() + self.__indicesVbo = None + + def prepare(self): + """Allocate resources on the GPU""" + if self.__vbos is None: + self.__vbos = glutils.vertexBuffer(self.__x_y_color) + # Normalization is need for color + self.__vbos[-1].normalization = True + + if self.__indicesVbo is None: + self.__indicesVbo = glutils.VertexBuffer( + numpy.ravel(self.__triangles), + usage=gl.GL_STATIC_DRAW, + target=gl.GL_ELEMENT_ARRAY_BUFFER) + + def render(self, matrix, isXLog, isYLog): + """Perform rendering + + :param numpy.ndarray matrix: 4x4 transform matrix to use + :param bool isXLog: + :param bool isYLog: + """ + self.prepare() + + if self.__vbos is None or self.__indicesVbo is None: + return # Nothing to display + + self._PROGRAM.use() + + gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], + 1, + gl.GL_TRUE, + matrix.astype(numpy.float32)) + + gl.glUniform1f(self._PROGRAM.uniforms['alpha'], self.__alpha) + + for index, name in enumerate(('xPos', 'yPos', 'color')): + attr = self._PROGRAM.attributes[name] + gl.glEnableVertexAttribArray(attr) + self.__vbos[index].setVertexAttrib(attr) + + with self.__indicesVbo: + gl.glDrawElements(gl.GL_TRIANGLES, + self.__triangles.size, + glutils.numpyToGLType(self.__triangles.dtype), + ctypes.c_void_p(0)) diff --git a/silx/gui/plot/backends/glutils/GLText.py b/silx/gui/plot/backends/glutils/GLText.py index 3d262bc..725c12c 100644 --- a/silx/gui/plot/backends/glutils/GLText.py +++ b/silx/gui/plot/backends/glutils/GLText.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,9 +33,11 @@ __date__ = "03/04/2017" from collections import OrderedDict +import weakref + import numpy -from ...._glutils import font, gl, getGLContext, Program, Texture +from ...._glutils import font, gl, Context, Program, Texture from .GLSupport import mat4Translate @@ -128,7 +130,7 @@ class Text2D(object): attrib0='position') # Discard texture objects when removed from the cache - _textures = _Cache(callback=lambda key, value: value[0].discard()) + _textures = weakref.WeakKeyDictionary() """Cache already created textures""" _sizes = _Cache() @@ -159,15 +161,20 @@ class Text2D(object): self._rotate = numpy.radians(rotate) def _getTexture(self, text): - key = getGLContext(), text - - if key not in self._textures: + # Retrieve/initialize texture cache for current context + context = Context.getCurrent() + if context not in self._textures: + self._textures[context] = _Cache( + callback=lambda key, value: value[0].discard()) + textures = self._textures[context] + + if text not in textures: image, offset = font.rasterText(text, font.getDefaultFontFamily()) if text not in self._sizes: self._sizes[text] = image.shape[1], image.shape[0] - self._textures[key] = ( + textures[text] = ( Texture(gl.GL_RED, data=image, minFilter=gl.GL_NEAREST, @@ -176,7 +183,7 @@ class Text2D(object): gl.GL_CLAMP_TO_EDGE)), offset) - return self._textures[key] + return textures[text] @property def text(self): diff --git a/silx/gui/plot/backends/glutils/GLTexture.py b/silx/gui/plot/backends/glutils/GLTexture.py index 25dd9f1..118a36f 100644 --- a/silx/gui/plot/backends/glutils/GLTexture.py +++ b/silx/gui/plot/backends/glutils/GLTexture.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -163,7 +163,6 @@ class Image(object): data[yOrig:yOrig+hData, xOrig:xOrig+wData], format_, - shape=(hData, wData), texUnit=texUnit, minFilter=self._MIN_FILTER, magFilter=self._MAG_FILTER, diff --git a/silx/gui/plot/backends/glutils/__init__.py b/silx/gui/plot/backends/glutils/__init__.py index 771de39..d58c084 100644 --- a/silx/gui/plot/backends/glutils/__init__.py +++ b/silx/gui/plot/backends/glutils/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -39,6 +39,7 @@ _logger = logging.getLogger(__name__) from .GLPlotCurve import * # noqa from .GLPlotFrame import * # noqa from .GLPlotImage import * # noqa +from .GLPlotTriangles import GLPlotTriangles # noqa from .GLSupport import * # noqa from .GLText import * # noqa from .GLTexture import * # noqa diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py index f829f78..f3a36db 100644 --- a/silx/gui/plot/items/__init__.py +++ b/silx/gui/plot/items/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,14 +34,15 @@ __date__ = "22/06/2017" from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa - AlphaMixIn, LineMixIn, ItemChangedType) # noqa + AlphaMixIn, LineMixIn, ScatterVisualizationMixIn, # noqa + ComplexMixIn, ItemChangedType, PointsBase) # noqa from .complex import ImageComplexData # noqa from .curve import Curve, CurveStyle # noqa from .histogram import Histogram # noqa from .image import ImageBase, ImageData, ImageRgba, MaskImageData # noqa from .shape import Shape # noqa from .scatter import Scatter # noqa -from .marker import Marker, XMarker, YMarker # noqa +from .marker import MarkerBase, Marker, XMarker, YMarker # noqa from .axis import Axis, XAxis, YAxis, YRightAxis DATA_ITEMS = ImageComplexData, Curve, Histogram, ImageBase, Scatter diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py index 7fffd77..3869a05 100644 --- a/silx/gui/plot/items/complex.py +++ b/silx/gui/plot/items/complex.py @@ -33,12 +33,13 @@ __date__ = "14/06/2018" import logging -import enum import numpy +from ....utils.proxy import docstring +from ....utils.deprecation import deprecated from ...colors import Colormap -from .core import ColormapMixIn, ItemChangedType +from .core import ColormapMixIn, ComplexMixIn, ItemChangedType from .image import ImageBase @@ -105,29 +106,19 @@ def _complex2rgbalin(phaseColormap, data, gamma=1.0, smax=None): return rgba -class ImageComplexData(ImageBase, ColormapMixIn): +class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn): """Specific plot item to force colormap when using complex colormap. This is returning the specific colormap when displaying colored phase + amplitude. """ - class Mode(enum.Enum): - """Identify available display mode for complex""" - ABSOLUTE = 'absolute' - PHASE = 'phase' - REAL = 'real' - IMAGINARY = 'imaginary' - AMPLITUDE_PHASE = 'amplitude_phase' - LOG10_AMPLITUDE_PHASE = 'log10_amplitude_phase' - SQUARE_AMPLITUDE = 'square_amplitude' - def __init__(self): ImageBase.__init__(self) ColormapMixIn.__init__(self) + ComplexMixIn.__init__(self) self._data = numpy.zeros((0, 0), dtype=numpy.complex64) self._dataByModesCache = {} - self._mode = self.Mode.ABSOLUTE self._amplitudeRangeInfo = None, 2 # Use default from ColormapMixIn @@ -139,13 +130,13 @@ class ImageComplexData(ImageBase, ColormapMixIn): vmax=numpy.pi) self._colormaps = { # Default colormaps for all modes - self.Mode.ABSOLUTE: colormap, - self.Mode.PHASE: phaseColormap, - self.Mode.REAL: colormap, - self.Mode.IMAGINARY: colormap, - self.Mode.AMPLITUDE_PHASE: phaseColormap, - self.Mode.LOG10_AMPLITUDE_PHASE: phaseColormap, - self.Mode.SQUARE_AMPLITUDE: colormap, + self.ComplexMode.ABSOLUTE: colormap, + self.ComplexMode.PHASE: phaseColormap, + self.ComplexMode.REAL: colormap, + self.ComplexMode.IMAGINARY: colormap, + self.ComplexMode.AMPLITUDE_PHASE: phaseColormap, + self.ComplexMode.LOG10_AMPLITUDE_PHASE: phaseColormap, + self.ComplexMode.SQUARE_AMPLITUDE: colormap, } def _addBackendRenderer(self, backend): @@ -156,9 +147,9 @@ class ImageComplexData(ImageBase, ColormapMixIn): # Do not render with non linear scales return None - mode = self.getVisualizationMode() - if mode in (self.Mode.AMPLITUDE_PHASE, - self.Mode.LOG10_AMPLITUDE_PHASE): + mode = self.getComplexMode() + if mode in (self.ComplexMode.AMPLITUDE_PHASE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE): # For those modes, compute RGBA image here colormap = None data = self.getRgbaImageData(copy=False) @@ -179,33 +170,21 @@ class ImageComplexData(ImageBase, ColormapMixIn): colormap=colormap, alpha=self.getAlpha()) - def setVisualizationMode(self, mode): - """Set the visualization mode to use. - - :param Mode mode: - """ - assert isinstance(mode, self.Mode) - assert mode in self._colormaps - - if mode != self._mode: - self._mode = mode - + @docstring(ComplexMixIn) + def setComplexMode(self, mode): + changed = super(ImageComplexData, self).setComplexMode(mode) + if changed: + # Backward compatibility self._updated(ItemChangedType.VISUALIZATION_MODE) # Send data updated as value returned by getData has changed self._updated(ItemChangedType.DATA) # Update ColormapMixIn colormap - colormap = self._colormaps[self._mode] + colormap = self._colormaps[self.getComplexMode()] if colormap is not super(ImageComplexData, self).getColormap(): super(ImageComplexData, self).setColormap(colormap) - - def getVisualizationMode(self): - """Returns the visualization mode in use. - - :rtype: Mode - """ - return self._mode + return changed def _setAmplitudeRangeInfo(self, max_=None, delta=2): """Set the amplitude range to display for 'log10_amplitude_phase' mode. @@ -228,15 +207,17 @@ class ImageComplexData(ImageBase, ColormapMixIn): """Set the colormap for this specific mode. :param ~silx.gui.colors.Colormap colormap: The colormap - :param Mode mode: + :param Union[ComplexMode,str] mode: If specified, set the colormap of this specific mode. Default: current mode. """ if mode is None: - mode = self.getVisualizationMode() + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) self._colormaps[mode] = colormap - if mode is self.getVisualizationMode(): + if mode is self.getComplexMode(): super(ImageComplexData, self).setColormap(colormap) else: self._updated(ItemChangedType.COLORMAP) @@ -244,13 +225,15 @@ class ImageComplexData(ImageBase, ColormapMixIn): def getColormap(self, mode=None): """Get the colormap for the (current) mode. - :param Mode mode: + :param Union[ComplexMode,str] mode: If specified, get the colormap of this specific mode. Default: current mode. :rtype: ~silx.gui.colors.Colormap """ if mode is None: - mode = self.getVisualizationMode() + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) return self._colormaps[mode] @@ -296,28 +279,30 @@ class ImageComplexData(ImageBase, ColormapMixIn): :param bool copy: True (Default) to get a copy, False to use internal representation (do not modify!) - :param Mode mode: + :param Union[ComplexMode,str] mode: If specified, get data corresponding to the mode. Default: Current mode. :rtype: numpy.ndarray of float """ if mode is None: - mode = self.getVisualizationMode() + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) if mode not in self._dataByModesCache: # Compute data for mode and store it in cache complexData = self.getComplexData(copy=False) - if mode is self.Mode.PHASE: + if mode is self.ComplexMode.PHASE: data = numpy.angle(complexData) - elif mode is self.Mode.REAL: + elif mode is self.ComplexMode.REAL: data = numpy.real(complexData) - elif mode is self.Mode.IMAGINARY: + elif mode is self.ComplexMode.IMAGINARY: data = numpy.imag(complexData) - elif mode in (self.Mode.ABSOLUTE, - self.Mode.LOG10_AMPLITUDE_PHASE, - self.Mode.AMPLITUDE_PHASE): + elif mode in (self.ComplexMode.ABSOLUTE, + self.ComplexMode.LOG10_AMPLITUDE_PHASE, + self.ComplexMode.AMPLITUDE_PHASE): data = numpy.absolute(complexData) - elif mode is self.Mode.SQUARE_AMPLITUDE: + elif mode is self.ComplexMode.SQUARE_AMPLITUDE: data = numpy.absolute(complexData) ** 2 else: _logger.error( @@ -333,22 +318,36 @@ class ImageComplexData(ImageBase, ColormapMixIn): """Get the displayed RGB(A) image for (current) mode :param bool copy: Ignored for this class - :param Mode mode: + :param Union[ComplexMode,str] mode: If specified, get data corresponding to the mode. Default: Current mode. :rtype: numpy.ndarray of uint8 of shape (height, width, 4) """ if mode is None: - mode = self.getVisualizationMode() + mode = self.getComplexMode() + else: + mode = self.ComplexMode.from_value(mode) colormap = self.getColormap(mode=mode) - if mode is self.Mode.AMPLITUDE_PHASE: + if mode is self.ComplexMode.AMPLITUDE_PHASE: data = self.getComplexData(copy=False) return _complex2rgbalin(colormap, data) - elif mode is self.Mode.LOG10_AMPLITUDE_PHASE: + elif mode is self.ComplexMode.LOG10_AMPLITUDE_PHASE: data = self.getComplexData(copy=False) max_, delta = self._getAmplitudeRangeInfo() return _complex2rgbalog(colormap, data, dlogs=delta, smax=max_) else: data = self.getData(copy=False, mode=mode) return colormap.applyToData(data) + + # Backward compatibility + + Mode = ComplexMixIn.ComplexMode + + @deprecated(replacement='setComplexMode', since_version='0.11.0') + def setVisualizationMode(self, mode): + return self.setComplexMode(mode) + + @deprecated(replacement='getComplexMode', since_version='0.11.0') + def getVisualizationMode(self): + return self.getComplexMode() diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py index bf3b719..e7342b0 100644 --- a/silx/gui/plot/items/core.py +++ b/silx/gui/plot/items/core.py @@ -30,6 +30,10 @@ __license__ = "MIT" __date__ = "29/01/2019" import collections +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc from copy import deepcopy import logging import enum @@ -39,6 +43,7 @@ import weakref import numpy import six +from ....utils.enum import Enum as _Enum from ... import qt from ... import colors from ...colors import Colormap @@ -128,6 +133,9 @@ class ItemChangedType(enum.Enum): VISUALIZATION_MODE = 'visualizationModeChanged' """Item's visualization mode changed flag.""" + COMPLEX_MODE = 'complexModeChanged' + """Item's complex data visualization mode changed flag.""" + class Item(qt.QObject): """Description of an item of the plot""" @@ -404,6 +412,14 @@ class DraggableMixIn(ItemMixInBase): """ self._draggable = bool(draggable) + def drag(self, from_, to): + """Perform a drag of the item. + + :param List[float] from_: (x, y) previous position in data coordinates + :param List[float] to: (x, y) current position in data coordinates + """ + raise NotImplementedError("Must be implemented in subclass") + class ColormapMixIn(ItemMixInBase): """Mix-in class for items with colormap""" @@ -757,7 +773,164 @@ class AlphaMixIn(ItemMixInBase): self._updated(ItemChangedType.ALPHA) -class Points(Item, SymbolMixIn, AlphaMixIn): +class ComplexMixIn(ItemMixInBase): + """Mix-in class for complex data mode""" + + _SUPPORTED_COMPLEX_MODES = None + """Override to only support a subset of all ComplexMode""" + + class ComplexMode(_Enum): + """Identify available display mode for complex""" + ABSOLUTE = 'amplitude' + PHASE = 'phase' + REAL = 'real' + IMAGINARY = 'imaginary' + AMPLITUDE_PHASE = 'amplitude_phase' + LOG10_AMPLITUDE_PHASE = 'log10_amplitude_phase' + SQUARE_AMPLITUDE = 'square_amplitude' + + def __init__(self): + self.__complex_mode = self.ComplexMode.ABSOLUTE + + def getComplexMode(self): + """Returns the current complex visualization mode. + + :rtype: ComplexMode + """ + return self.__complex_mode + + def setComplexMode(self, mode): + """Set the complex visualization mode. + + :param ComplexMode mode: The visualization mode in: + 'real', 'imaginary', 'phase', 'amplitude' + :return: True if value was set, False if is was already set + :rtype: bool + """ + mode = self.ComplexMode.from_value(mode) + assert mode in self.supportedComplexModes() + + if mode != self.__complex_mode: + self.__complex_mode = mode + self._updated(ItemChangedType.COMPLEX_MODE) + return True + else: + return False + + def _convertComplexData(self, data, mode=None): + """Convert complex data to the specific mode. + + :param Union[ComplexMode,None] mode: + The kind of value to compute. + If None (the default), the current complex mode is used. + :return: The converted dataset + :rtype: Union[numpy.ndarray[float],None] + """ + if data is None: + return None + + if mode is None: + mode = self.getComplexMode() + + if mode is self.ComplexMode.REAL: + return numpy.real(data) + elif mode is self.ComplexMode.IMAGINARY: + return numpy.imag(data) + elif mode is self.ComplexMode.ABSOLUTE: + return numpy.absolute(data) + elif mode is self.ComplexMode.PHASE: + return numpy.angle(data) + elif mode is self.ComplexMode.SQUARE_AMPLITUDE: + return numpy.absolute(data) ** 2 + else: + raise ValueError('Unsupported conversion mode: %s', str(mode)) + + @classmethod + def supportedComplexModes(cls): + """Returns the list of supported complex visualization modes. + + See :class:`ComplexMode` and :meth:`setComplexMode`. + + :rtype: List[ComplexMode] + """ + if cls._SUPPORTED_COMPLEX_MODES is None: + return cls.ComplexMode.members() + else: + return cls._SUPPORTED_COMPLEX_MODES + + +class ScatterVisualizationMixIn(ItemMixInBase): + """Mix-in class for scatter plot visualization modes""" + + _SUPPORTED_SCATTER_VISUALIZATION = None + """Allows to override supported Visualizations""" + + @enum.unique + class Visualization(_Enum): + """Different modes of scatter plot visualizations""" + + POINTS = 'points' + """Display scatter plot as a point cloud""" + + LINES = 'lines' + """Display scatter plot as a wireframe. + + This is based on Delaunay triangulation + """ + + SOLID = 'solid' + """Display scatter plot as a set of filled triangles. + + This is based on Delaunay triangulation + """ + + def __init__(self): + self.__visualization = self.Visualization.POINTS + + @classmethod + def supportedVisualizations(cls): + """Returns the list of supported scatter visualization modes. + + See :meth:`setVisualization` + + :rtype: List[Visualization] + """ + if cls._SUPPORTED_SCATTER_VISUALIZATION is None: + return cls.Visualization.members() + else: + return cls._SUPPORTED_SCATTER_VISUALIZATION + + def setVisualization(self, mode): + """Set the scatter plot visualization mode to use. + + See :class:`Visualization` for all possible values, + and :meth:`supportedVisualizations` for supported ones. + + :param Union[str,Visualization] mode: + The visualization mode to use. + :return: True if value was set, False if is was already set + :rtype: bool + """ + mode = self.Visualization.from_value(mode) + assert mode in self.supportedVisualizations() + + if mode != self.__visualization: + self.__visualization = mode + + self._updated(ItemChangedType.VISUALIZATION_MODE) + return True + else: + return False + + def getVisualization(self): + """Returns the scatter plot visualization mode in use. + + :rtype: Visualization + """ + return self.__visualization + + +class PointsBase(Item, SymbolMixIn, AlphaMixIn): """Base class for :class:`Curve` and :class:`Scatter`""" # note: _logFilterData must be overloaded if you overload # getData to change its signature @@ -906,8 +1079,7 @@ class Points(Item, SymbolMixIn, AlphaMixIn): if (xPositive, yPositive) not in self._boundsCache: # use the getData class method because instance method can be # overloaded to return additional arrays - data = Points.getData(self, copy=False, - displayed=True) + data = PointsBase.getData(self, copy=False, displayed=True) if len(data) == 5: # hack to avoid duplicating caching mechanism in Scatter # (happens when cached data is used, caching done using @@ -916,12 +1088,15 @@ class Points(Item, SymbolMixIn, AlphaMixIn): else: x, y, _xerror, _yerror = data - self._boundsCache[(xPositive, yPositive)] = ( - numpy.nanmin(x), - numpy.nanmax(x), - numpy.nanmin(y), - numpy.nanmax(y) - ) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore All-NaN slice encountered + self._boundsCache[(xPositive, yPositive)] = ( + numpy.nanmin(x), + numpy.nanmax(x), + numpy.nanmin(y), + numpy.nanmax(y) + ) return self._boundsCache[(xPositive, yPositive)] def _getCachedData(self): @@ -1026,12 +1201,12 @@ class Points(Item, SymbolMixIn, AlphaMixIn): assert x.ndim == y.ndim == 1 if xerror is not None: - if isinstance(xerror, collections.Iterable): + if isinstance(xerror, abc.Iterable): xerror = numpy.array(xerror, copy=copy) else: xerror = float(xerror) if yerror is not None: - if isinstance(yerror, collections.Iterable): + if isinstance(yerror, abc.Iterable): yerror = numpy.array(yerror, copy=copy) else: yerror = float(yerror) diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py index 79def55..439af33 100644 --- a/silx/gui/plot/items/curve.py +++ b/silx/gui/plot/items/curve.py @@ -37,7 +37,7 @@ import six from ....utils.deprecation import deprecated from ... import colors -from .core import (Points, LabelsMixIn, ColorMixIn, YAxisMixIn, +from .core import (PointsBase, LabelsMixIn, ColorMixIn, YAxisMixIn, FillMixIn, LineMixIn, SymbolMixIn, ItemChangedType) @@ -151,7 +151,7 @@ class CurveStyle(object): return False -class Curve(Points, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, LineMixIn): +class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, LineMixIn): """Description of a curve""" _DEFAULT_Z_LAYER = 1 @@ -170,7 +170,7 @@ class Curve(Points, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, LineMixIn): """Default highlight style of the item""" def __init__(self): - Points.__init__(self) + PointsBase.__init__(self) ColorMixIn.__init__(self) YAxisMixIn.__init__(self) FillMixIn.__init__(self) diff --git a/silx/gui/plot/items/image.py b/silx/gui/plot/items/image.py index 99a916a..d74f4d3 100644 --- a/silx/gui/plot/items/image.py +++ b/silx/gui/plot/items/image.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,11 +31,15 @@ __license__ = "MIT" __date__ = "20/10/2017" -from collections import Sequence +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc import logging import numpy +from ....utils.proxy import docstring from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, AlphaMixIn, ItemChangedType) @@ -170,6 +174,12 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): else: return xmin, xmax, ymin, ymax + @docstring(DraggableMixIn) + def drag(self, from_, to): + origin = self.getOrigin() + self.setOrigin((origin[0] + to[0] - from_[0], + origin[1] + to[1] - from_[1])) + def getData(self, copy=True): """Returns the image data @@ -199,7 +209,7 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): :param origin: (ox, oy) Offset from origin :type origin: float or 2-tuple of float """ - if isinstance(origin, Sequence): + if isinstance(origin, abc.Sequence): origin = float(origin[0]), float(origin[1]) else: # single value origin origin = float(origin), float(origin) @@ -227,7 +237,7 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): :param scale: (sx, sy) Scale of the image :type scale: float or 2-tuple of float """ - if isinstance(scale, Sequence): + if isinstance(scale, abc.Sequence): scale = float(scale[0]), float(scale[1]) else: # single value scale scale = float(scale), float(scale) @@ -252,6 +262,7 @@ class ImageData(ImageBase, ColormapMixIn): ColormapMixIn.__init__(self) self._data = numpy.zeros((0, 0), dtype=numpy.float32) self._alternativeImage = None + self.__alpha = None def _addBackendRenderer(self, backend): """Update backend renderer""" @@ -261,8 +272,9 @@ class ImageData(ImageBase, ColormapMixIn): # Do not render with non linear scales return None - if self.getAlternativeImageData(copy=False) is not None: - dataToUse = self.getAlternativeImageData(copy=False) + if (self.getAlternativeImageData(copy=False) is not None or + self.getAlphaData(copy=False) is not None): + dataToUse = self.getRgbaImageData(copy=False) else: dataToUse = self.getData(copy=False) @@ -293,37 +305,56 @@ class ImageData(ImageBase, ColormapMixIn): def getRgbaImageData(self, copy=True): """Get the displayed RGB(A) image - :returns: numpy.ndarray of uint8 of shape (height, width, 4) + :returns: Array of uint8 of shape (height, width, 4) + :rtype: numpy.ndarray """ - if self._alternativeImage is not None: - return _convertImageToRgba32( - self.getAlternativeImageData(copy=False), copy=copy) + alternative = self.getAlternativeImageData(copy=False) + if alternative is not None: + return _convertImageToRgba32(alternative, copy=copy) else: # Apply colormap, in this case an new array is always returned colormap = self.getColormap() image = colormap.applyToData(self.getData(copy=False)) + alphaImage = self.getAlphaData(copy=False) + if alphaImage is not None: + # Apply transparency + image[:, :, 3] = image[:, :, 3] * alphaImage return image def getAlternativeImageData(self, copy=True): """Get the optional RGBA image that is displayed instead of the data - :param copy: True (Default) to get a copy, - False to use internal representation (do not modify!) - :returns: None or numpy.ndarray - :rtype: numpy.ndarray or None + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: Union[None,numpy.ndarray] """ if self._alternativeImage is None: return None else: return numpy.array(self._alternativeImage, copy=copy) - def setData(self, data, alternative=None, copy=True): + def getAlphaData(self, copy=True): + """Get the optional transparency image applied on the data + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: Union[None,numpy.ndarray] + """ + if self.__alpha is None: + return None + else: + return numpy.array(self.__alpha, copy=copy) + + def setData(self, data, alternative=None, alpha=None, copy=True): """"Set the image data and optionally an alternative RGB(A) representation :param numpy.ndarray data: Data array with 2 dimensions (h, w) :param alternative: RGB(A) image to display instead of data, shape: (h, w, 3 or 4) - :type alternative: None or numpy.ndarray + :type alternative: Union[None,numpy.ndarray] + :param alpha: An array of transparency value in [0, 1] to use for + display with shape: (h, w) + :type alpha: Union[None,numpy.ndarray] :param bool copy: True (Default) to get a copy, False to use internal representation (do not modify!) """ @@ -346,6 +377,15 @@ class ImageData(ImageBase, ColormapMixIn): assert alternative.shape[:2] == data.shape[:2] self._alternativeImage = alternative + if alpha is not None: + alpha = numpy.array(alpha, copy=copy) + assert alpha.shape == data.shape + if alpha.dtype.kind != 'f': + alpha = alpha.astype(numpy.float32) + if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)): + alpha = numpy.clip(alpha, 0., 1.) + self.__alpha = alpha + # TODO hackish data range implementation if self.isVisible(): plot = self.getPlot() diff --git a/silx/gui/plot/items/marker.py b/silx/gui/plot/items/marker.py index 09767a5..80ca0b6 100644 --- a/silx/gui/plot/items/marker.py +++ b/silx/gui/plot/items/marker.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,6 +32,7 @@ __date__ = "06/03/2017" import logging +from ....utils.proxy import docstring from .core import (Item, DraggableMixIn, ColorMixIn, LineMixIn, SymbolMixIn, ItemChangedType) @@ -39,7 +40,7 @@ from .core import (Item, DraggableMixIn, ColorMixIn, LineMixIn, SymbolMixIn, _logger = logging.getLogger(__name__) -class _BaseMarker(Item, DraggableMixIn, ColorMixIn): +class MarkerBase(Item, DraggableMixIn, ColorMixIn): """Base class for markers""" _DEFAULT_COLOR = (0., 0., 0., 1.) @@ -75,6 +76,10 @@ class _BaseMarker(Item, DraggableMixIn, ColorMixIn): """Update backend renderer""" raise NotImplementedError() + @docstring(DraggableMixIn) + def drag(self, from_, to): + self.setPosition(to[0], to[1]) + def isOverlay(self): """Return true if marker is drawn as an overlay. @@ -166,14 +171,14 @@ class _BaseMarker(Item, DraggableMixIn, ColorMixIn): return args -class Marker(_BaseMarker, SymbolMixIn): +class Marker(MarkerBase, SymbolMixIn): """Description of a marker""" _DEFAULT_SYMBOL = '+' """Default symbol of the marker""" def __init__(self): - _BaseMarker.__init__(self) + MarkerBase.__init__(self) SymbolMixIn.__init__(self) self._x = 0. @@ -204,11 +209,11 @@ class Marker(_BaseMarker, SymbolMixIn): return x, self.getYPosition() -class _LineMarker(_BaseMarker, LineMixIn): +class _LineMarker(MarkerBase, LineMixIn): """Base class for line markers""" def __init__(self): - _BaseMarker.__init__(self) + MarkerBase.__init__(self) LineMixIn.__init__(self) def _addBackendRenderer(self, backend): diff --git a/silx/gui/plot/items/roi.py b/silx/gui/plot/items/roi.py index 0169439..65831be 100644 --- a/silx/gui/plot/items/roi.py +++ b/silx/gui/plot/items/roi.py @@ -73,6 +73,7 @@ class RegionOfInterest(qt.QObject): self._label = '' self._labelItem = None self._editable = False + self._visible = True def __del__(self): # Clean-up plot items @@ -176,6 +177,34 @@ class RegionOfInterest(qt.QObject): # This can be avoided once marker.setDraggable is public self._createPlotItems() + def isVisible(self): + """Returns whether the ROI is visible in the plot. + + .. note:: + This does not take into account whether or not the plot + widget itself is visible (unlike :meth:`QWidget.isVisible` which + checks the visibility of all its parent widgets up to the window) + + :rtype: bool + """ + return self._visible + + def setVisible(self, visible): + """Set whether the plot items associated with this ROI are + visible in the plot. + + :param bool visible: True to show the ROI in the plot, False to + hide it. + """ + visible = bool(visible) + if self._visible == visible: + return + self._visible = visible + if self._labelItem is not None: + self._labelItem.setVisible(visible) + for item in self._items + self._editAnchors: + item.setVisible(visible) + def _getControlPoints(self): """Returns the current ROI control points. @@ -292,12 +321,14 @@ class RegionOfInterest(qt.QObject): if self._labelItem is not None: self._labelItem._setLegend(legendPrefix + "label") plot._add(self._labelItem) + self._labelItem.setVisible(self.isVisible()) self._items = WeakList() plotItems = self._createShapeItems(controlPoints) for item in plotItems: item._setLegend(legendPrefix + str(itemIndex)) plot._add(item) + item.setVisible(self.isVisible()) self._items.append(item) itemIndex += 1 @@ -309,6 +340,7 @@ class RegionOfInterest(qt.QObject): for index, item in enumerate(plotItems): item._setLegend(legendPrefix + str(itemIndex)) item.setColor(color) + item.setVisible(self.isVisible()) plot._add(item) item.sigItemChanged.connect(functools.partial( self._controlPointAnchorChanged, index)) @@ -512,10 +544,10 @@ class LineROI(RegionOfInterest, items.LineMixIn): return controlPoints def setEndPoints(self, startPoint, endPoint): - """Set this line location using the endding points + """Set this line location using the ending points :param numpy.ndarray startPoint: Staring bounding point of the line - :param numpy.ndarray endPoint: Endding bounding point of the line + :param numpy.ndarray endPoint: Ending bounding point of the line """ assert(startPoint.shape == (2,) and endPoint.shape == (2,)) shapePoints = numpy.array([startPoint, endPoint]) @@ -1261,13 +1293,13 @@ class ArcROI(RegionOfInterest, items.LineMixIn): def getGeometry(self): """Returns a tuple containing the geometry of this ROI - It is a symetric fonction of :meth:`setGeometry`. + It is a symmetric function of :meth:`setGeometry`. If `startAngle` is smaller than `endAngle` the rotation is clockwise, else the rotation is anticlockwise. :rtype: Tuple[numpy.ndarray,float,float,float,float] - :raise ValueError: In case the ROI can't be representaed as section of + :raise ValueError: In case the ROI can't be represented as section of a circle """ geometry = self._getInternalGeometry() diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py index 707dd3d..b2f087b 100644 --- a/silx/gui/plot/items/scatter.py +++ b/silx/gui/plot/items/scatter.py @@ -31,26 +31,79 @@ __date__ = "29/03/2017" import logging - +import threading import numpy -from .core import Points, ColormapMixIn +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, CancelledError + +from ....utils.weakref import WeakList +from .._utils.delaunay import delaunay +from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn +from .axis import Axis _logger = logging.getLogger(__name__) -class Scatter(Points, ColormapMixIn): +class _GreedyThreadPoolExecutor(ThreadPoolExecutor): + """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method. + """ + + def __init__(self, *args, **kwargs): + super(_GreedyThreadPoolExecutor, self).__init__(*args, **kwargs) + self.__futures = defaultdict(WeakList) + self.__lock = threading.RLock() + + def submit_greedy(self, queue, fn, *args, **kwargs): + """Same as :meth:`submit` but cancel previous tasks in given queue. + + This means that when a new task is submitted for a given queue, + all other pending tasks of that queue are cancelled. + + :param queue: Identifier of the queue. This must be hashable. + :param callable fn: The callable to call with provided extra arguments + :return: Future corresponding to this task + :rtype: concurrent.futures.Future + """ + with self.__lock: + # Cancel previous tasks in given queue + for future in self.__futures.pop(queue, []): + if not future.done(): + future.cancel() + + future = super(_GreedyThreadPoolExecutor, self).submit( + fn, *args, **kwargs) + self.__futures[queue].append(future) + + return future + + +class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): """Description of a scatter""" _DEFAULT_SELECTABLE = True """Default selectable state for scatter plots""" + _SUPPORTED_SCATTER_VISUALIZATION = ( + ScatterVisualizationMixIn.Visualization.POINTS, + ScatterVisualizationMixIn.Visualization.SOLID) + """Overrides supported Visualizations""" + def __init__(self): - Points.__init__(self) + PointsBase.__init__(self) ColormapMixIn.__init__(self) + ScatterVisualizationMixIn.__init__(self) self._value = () self.__alpha = None + # Cache Delaunay triangulation future object + self.__delaunayFuture = None + # Cache interpolator future object + self.__interpolatorFuture = None + self.__executor = None + + # Cache triangles: x, y, indices + self.__cacheTriangles = None, None, None def _addBackendRenderer(self, backend): """Update backend renderer""" @@ -58,28 +111,154 @@ class Scatter(Points, ColormapMixIn): xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData( copy=False, displayed=True) + # Remove not finite numbers (this includes filtered out x, y <= 0) + mask = numpy.logical_and(numpy.isfinite(xFiltered), numpy.isfinite(yFiltered)) + xFiltered = xFiltered[mask] + yFiltered = yFiltered[mask] + if len(xFiltered) == 0: return None # No data to display, do not add renderer to backend + # Compute colors cmap = self.getColormap() rgbacolors = cmap.applyToData(self._value) if self.__alpha is not None: rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8) - return backend.addCurve(xFiltered, yFiltered, self.getLegend(), - color=rgbacolors, - symbol=self.getSymbol(), - linewidth=0, - linestyle="", - yaxis='left', - xerror=xerror, - yerror=yerror, - z=self.getZValue(), - selectable=self.isSelectable(), - fill=False, - alpha=self.getAlpha(), - symbolsize=self.getSymbolSize()) + # Apply mask to colors + rgbacolors = rgbacolors[mask] + + if self.getVisualization() is self.Visualization.POINTS: + return backend.addCurve(xFiltered, yFiltered, self.getLegend(), + color=rgbacolors, + symbol=self.getSymbol(), + linewidth=0, + linestyle="", + yaxis='left', + xerror=xerror, + yerror=yerror, + z=self.getZValue(), + selectable=self.isSelectable(), + fill=False, + alpha=self.getAlpha(), + symbolsize=self.getSymbolSize()) + + else: # 'solid' + plot = self.getPlot() + if (plot is None or + plot.getXAxis().getScale() != Axis.LINEAR or + plot.getYAxis().getScale() != Axis.LINEAR): + # Solid visualization is not available with log scaled axes + return None + + triangulation = self._getDelaunay().result() + if triangulation is None: + return None + else: + triangles = triangulation.simplices.astype(numpy.int32) + return backend.addTriangles(xFiltered, + yFiltered, + triangles, + legend=self.getLegend(), + color=rgbacolors, + z=self.getZValue(), + selectable=self.isSelectable(), + alpha=self.getAlpha()) + + def __getExecutor(self): + """Returns async greedy executor + + :rtype: _GreedyThreadPoolExecutor + """ + if self.__executor is None: + self.__executor = _GreedyThreadPoolExecutor(max_workers=2) + return self.__executor + + def _getDelaunay(self): + """Returns a :class:`Future` which result is the Delaunay object. + + :rtype: concurrent.futures.Future + """ + if self.__delaunayFuture is None or self.__delaunayFuture.cancelled(): + # Need to init a new delaunay + x, y = self.getData(copy=False)[:2] + # Remove not finite points + mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y)) + + self.__delaunayFuture = self.__getExecutor().submit_greedy( + 'delaunay', delaunay, x[mask], y[mask]) + + return self.__delaunayFuture + + @staticmethod + def __initInterpolator(delaunayFuture, values): + """Returns an interpolator for the given data points + + :param concurrent.futures.Future delaunayFuture: + Future object which result is a Delaunay object + :param numpy.ndarray values: The data value of valid points. + :rtype: Union[callable,None] + """ + # Wait for Delaunay to complete + try: + triangulation = delaunayFuture.result() + except CancelledError: + triangulation = None + + if triangulation is None: + interpolator = None # Error case + else: + # Lazy-loading of interpolator + try: + from scipy.interpolate import LinearNDInterpolator + except ImportError: + LinearNDInterpolator = None + + if LinearNDInterpolator is not None: + interpolator = LinearNDInterpolator(triangulation, values) + + # First call takes a while, do it here + interpolator([(0., 0.)]) + + else: + # Fallback using matplotlib interpolator + import matplotlib.tri + + x, y = triangulation.points.T + tri = matplotlib.tri.Triangulation( + x, y, triangles=triangulation.simplices) + mplInterpolator = matplotlib.tri.LinearTriInterpolator( + tri, values) + + # Wrap interpolator to have same API as scipy's one + def interpolator(points): + return mplInterpolator(*points.T) + + return interpolator + + def _getInterpolator(self): + """Returns a :class:`Future` which result is the interpolator. + + The interpolator is a callable taking an array Nx2 of points + as a single argument. + The :class:`Future` result is None in case the interpolator cannot + be initialized. + + :rtype: concurrent.futures.Future + """ + if (self.__interpolatorFuture is None or + self.__interpolatorFuture.cancelled()): + # Need to init a new interpolator + x, y, values = self.getData(copy=False)[:3] + # Remove not finite points + mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y)) + x, y, values = x[mask], y[mask], values[mask] + + self.__interpolatorFuture = self.__getExecutor().submit_greedy( + 'interpolator', + self.__initInterpolator, self._getDelaunay(), values) + return self.__interpolatorFuture def _logFilterData(self, xPositive, yPositive): """Filter out values with x or y <= 0 on log axes @@ -89,7 +268,7 @@ class Scatter(Points, ColormapMixIn): :return: The filtered arrays or unchanged object if not filtering needed :rtype: (x, y, value, xerror, yerror) """ - # overloaded from Points to filter also value. + # overloaded from PointsBase to filter also value. value = self.getValueData(copy=False) if xPositive or yPositive: @@ -100,7 +279,7 @@ class Scatter(Points, ColormapMixIn): value = numpy.array(value, copy=True, dtype=numpy.float) value[clipped] = numpy.nan - x, y, xerror, yerror = Points._logFilterData(self, xPositive, yPositive) + x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive) return x, y, value, xerror, yerror @@ -146,7 +325,7 @@ class Scatter(Points, ColormapMixIn): self.getXErrorData(copy), self.getYErrorData(copy)) - # reimplemented from Points to handle `value` + # reimplemented from PointsBase to handle `value` def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True): """Set the data of the scatter. @@ -171,6 +350,14 @@ class Scatter(Points, ColormapMixIn): assert value.ndim == 1 assert len(x) == len(value) + # Reset triangulation and interpolator + if self.__delaunayFuture is not None: + self.__delaunayFuture.cancel() + self.__delaunayFuture = None + if self.__interpolatorFuture is not None: + self.__interpolatorFuture.cancel() + self.__interpolatorFuture = None + self._value = value if alpha is not None: @@ -183,8 +370,8 @@ class Scatter(Points, ColormapMixIn): if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)): alpha = numpy.clip(alpha, 0., 1.) self.__alpha = alpha - + # set x, y, xerror, yerror # call self._updated + plot._invalidateDataRange() - Points.setData(self, x, y, xerror, yerror, copy) + PointsBase.setData(self, x, y, xerror, yerror, copy) diff --git a/silx/gui/plot/matplotlib/__init__.py b/silx/gui/plot/matplotlib/__init__.py index a4dc235..7298866 100644 --- a/silx/gui/plot/matplotlib/__init__.py +++ b/silx/gui/plot/matplotlib/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -25,7 +25,7 @@ from __future__ import absolute_import -"""This module inits matplotlib and setups the backend to use. +"""This module initializes matplotlib and sets-up the backend to use. It MUST be imported prior to any other import of matplotlib. @@ -38,64 +38,34 @@ __license__ = "MIT" __date__ = "02/05/2018" -import sys -import logging +from pkg_resources import parse_version +import matplotlib +from ... import qt -_logger = logging.getLogger(__name__) -_matplotlib_already_loaded = 'matplotlib' in sys.modules -"""If true, matplotlib was already loaded""" +def _matplotlib_use(backend, warn, force): + """Wrapper of `matplotlib.use` to set-up backend. -import matplotlib -from ... import qt + It adds extra initialization for PySide and PySide2 with matplotlib < 2.2. + """ + # This is kept for compatibility with matplotlib < 2.2 + if parse_version(matplotlib.__version__) < parse_version('2.2'): + if qt.BINDING == 'PySide': + matplotlib.rcParams['backend.qt4'] = 'PySide' + if qt.BINDING == 'PySide2': + matplotlib.rcParams['backend.qt5'] = 'PySide2' + matplotlib.use(backend, warn=warn, force=force) -def _configure(backend, backend_qt4=None, backend_qt5=None, check=False): - """Configure matplotlib using a specific backend. - It initialize `matplotlib.rcParams` using the requested backend, or check - if it is already configured as requested. +if qt.BINDING in ('PyQt4', 'PySide'): + _matplotlib_use('Qt4Agg', warn=True, force=False) + from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa - :param bool check: If true, the function only check that matplotlib - is already initialized as request. If not a warning is emitted. - If `check` is false, matplotlib is initialized. - """ - if check: - valid = matplotlib.rcParams['backend'] == backend - if backend_qt4 is not None: - valid = valid and matplotlib.rcParams['backend.qt4'] == backend_qt4 - if backend_qt5 is not None: - valid = valid and matplotlib.rcParams['backend.qt5'] == backend_qt5 - - if not valid: - _logger.warning('matplotlib already loaded, setting its backend may not work') - else: - matplotlib.rcParams['backend'] = backend - if backend_qt4 is not None: - matplotlib.rcParams['backend.qt4'] = backend_qt4 - if backend_qt5 is not None: - matplotlib.rcParams['backend.qt5'] = backend_qt5 - - -if qt.BINDING == 'PySide': - _configure('Qt4Agg', backend_qt4='PySide', check=_matplotlib_already_loaded) - import matplotlib.backends.backend_qt4agg as backend - -elif qt.BINDING == 'PyQt4': - _configure('Qt4Agg', check=_matplotlib_already_loaded) - import matplotlib.backends.backend_qt4agg as backend - -elif qt.BINDING == 'PySide2': - _configure('Qt5Agg', backend_qt5="PySide2", check=_matplotlib_already_loaded) - import matplotlib.backends.backend_qt5agg as backend - -elif qt.BINDING == 'PyQt5': - _configure('Qt5Agg', check=_matplotlib_already_loaded) - import matplotlib.backends.backend_qt5agg as backend +elif qt.BINDING in ('PyQt5', 'PySide2'): + _matplotlib_use('Qt5Agg', warn=True, force=False) + from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa else: - backend = None - -if backend is not None: - FigureCanvasQTAgg = backend.FigureCanvasQTAgg # noqa + raise ImportError("Unsupported Qt binding: %s" % qt.BINDING) diff --git a/silx/gui/plot/test/testAlphaSlider.py b/silx/gui/plot/test/testAlphaSlider.py index 63de441..01e6969 100644 --- a/silx/gui/plot/test/testAlphaSlider.py +++ b/silx/gui/plot/test/testAlphaSlider.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -37,9 +37,6 @@ from silx.gui.utils.testutils import TestCaseQt from silx.gui.plot import PlotWidget from silx.gui.plot import AlphaSlider -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - class TestActiveImageAlphaSlider(TestCaseQt): def setUp(self): diff --git a/silx/gui/plot/test/testComplexImageView.py b/silx/gui/plot/test/testComplexImageView.py index 1933a95..051ec4d 100644 --- a/silx/gui/plot/test/testComplexImageView.py +++ b/silx/gui/plot/test/testComplexImageView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -63,10 +63,10 @@ class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase): self.qWait(100) # Test all modes - modes = self.plot.getSupportedVisualizationModes() + modes = self.plot.supportedComplexModes() for mode in modes: with self.subTest(mode=mode): - self.plot.setVisualizationMode(mode) + self.plot.setComplexMode(mode) self.qWait(100) # Test origin and scale API diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py index 5bcabd8..5886456 100644 --- a/silx/gui/plot/test/testCurvesROIWidget.py +++ b/silx/gui/plot/test/testCurvesROIWidget.py @@ -34,11 +34,15 @@ import os.path import unittest from collections import OrderedDict import numpy + from silx.gui import qt +from silx.gui.plot import Plot1D from silx.test.utils import temp_dir from silx.gui.utils.testutils import TestCaseQt, SignalListener from silx.gui.plot import PlotWindow, CurvesROIWidget - +from silx.gui.plot.CurvesROIWidget import ROITable +from silx.gui.utils.testutils import getQToolButtonFromAction +from silx.gui.plot.PlotInteraction import ItemsInteraction _logger = logging.getLogger(__name__) @@ -68,6 +72,18 @@ class TestCurvesROIWidget(TestCaseQt): super(TestCurvesROIWidget, self).tearDown() + def testDummyAPI(self): + """Simple test of the getRois and setRois API""" + roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20, + todata=-10, type_='X') + roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10, + todata=20, type_='X') + + self.widget.roiWidget.setRois((roi_pos, roi_neg)) + + rois_defs = self.widget.roiWidget.getRois() + self.widget.roiWidget.setRois(rois=rois_defs) + def testWithCurves(self): """Plot with curves: test all ROI widget buttons""" for offset in range(2): @@ -301,7 +317,7 @@ class TestCurvesROIWidget(TestCaseQt): self.widget.roiWidget.setRois((roi,)) self.widget.roiWidget.roiTable.setActiveRoi(None) - self.assertTrue(len(self.widget.roiWidget.roiTable.selectedItems()) is 0) + self.assertEqual(len(self.widget.roiWidget.roiTable.selectedItems()), 0) self.widget.roiWidget.setRois((roi,)) self.plot.setActiveCurve(legend='linearCurve') self.widget.calculateROIs() @@ -314,14 +330,128 @@ class TestCurvesROIWidget(TestCaseQt): self.widget.roiWidget.sigROISignal.connect(signalListener.partial()) self.widget.show() self.qapp.processEvents() - self.assertTrue(signalListener.callCount() is 0) + self.assertEqual(signalListener.callCount(), 0) self.assertTrue(self.widget.roiWidget.roiTable.activeRoi is roi) roi.setFrom(0.0) self.qapp.processEvents() - self.assertTrue(signalListener.callCount() is 0) + self.assertEqual(signalListener.callCount(), 0) roi.setFrom(0.3) self.qapp.processEvents() - self.assertTrue(signalListener.callCount() is 1) + self.assertEqual(signalListener.callCount(), 1) + + +class TestRoiWidgetSignals(TestCaseQt): + """Test Signals emitted by the RoiWidgetSignals""" + + def setUp(self): + self.plot = Plot1D() + x = range(20) + y = range(20) + self.plot.addCurve(x, y, legend='curve0') + self.listener = SignalListener() + self.curves_roi_widget = self.plot.getCurvesRoiWidget() + self.curves_roi_widget.sigROISignal.connect(self.listener) + assert self.curves_roi_widget.isVisible() is False + assert self.listener.callCount() == 0 + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + toolButton = getQToolButtonFromAction(self.plot.getRoiAction()) + self.mouseClick(widget=toolButton, button=qt.Qt.LeftButton) + + self.curves_roi_widget.show() + self.qWaitForWindowExposed(self.curves_roi_widget) + + def tearDown(self): + self.plot = None + + def testSigROISignalAddRmRois(self): + """Test SigROISignal when adding and removing ROIS""" + print(self.listener.callCount()) + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + + roi1 = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi1) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') + self.listener.clear() + + roi2 = CurvesROIWidget.ROI(name='linear2', fromdata=0, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi2) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear2') + self.listener.clear() + + self.curves_roi_widget.roiTable.removeROI(roi2) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') + self.listener.clear() + + self.curves_roi_widget.roiTable.deleteActiveRoi() + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.curves_roi_widget.roiTable.activeRoi is None) + self.assertTrue(self.listener.arguments()[0][0]['current'] is None) + self.listener.clear() + + self.curves_roi_widget.roiTable.addRoi(roi1) + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') + self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1) + self.listener.clear() + self.qapp.processEvents() + + self.curves_roi_widget.roiTable.removeROI(roi1) + self.qapp.processEvents() + self.assertEqual(self.listener.callCount(), 1) + self.assertTrue(self.listener.arguments()[0][0]['current'] == 'ICR') + self.listener.clear() + + def testSigROISignalModifyROI(self): + """Test SigROISignal when modifying it""" + self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True) + roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi1) + self.curves_roi_widget.roiTable.setActiveRoi(roi1) + + # test modify the roi2 object + self.listener.clear() + roi1.setFrom(0.56) + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + roi1.setTo(2.56) + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + roi1.setName('linear2') + self.assertEqual(self.listener.callCount(), 1) + self.listener.clear() + roi1.setType('new type') + self.assertEqual(self.listener.callCount(), 1) + + # modify roi limits (from the gui) + roi_marker_handler = self.curves_roi_widget.roiTable._markersHandler.getMarkerHandler(roi1.getID()) + for marker_type in ('min', 'max', 'middle'): + with self.subTest(marker_type=marker_type): + self.listener.clear() + marker = roi_marker_handler.getMarker(marker_type) + self.qapp.processEvents() + items_interaction = ItemsInteraction(plot=self.plot) + x_pix, y_pix = self.plot.dataToPixel(marker.getXPosition(), 1) + items_interaction.beginDrag(x_pix, y_pix) + self.qapp.processEvents() + items_interaction.endDrag(x_pix+10, y_pix) + self.qapp.processEvents() + self.assertEqual(self.listener.callCount(), 1) + + def testSetActiveCurve(self): + """Test sigRoiSignal when set an active curve""" + roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5) + self.curves_roi_widget.roiTable.addRoi(roi1) + self.curves_roi_widget.roiTable.setActiveRoi(roi1) + self.listener.clear() + self.plot.setActiveCurve('curve0') + self.assertEqual(self.listener.callCount(), 0) def suite(): diff --git a/silx/gui/plot/test/testItem.py b/silx/gui/plot/test/testItem.py index 993cce7..c864545 100644 --- a/silx/gui/plot/test/testItem.py +++ b/silx/gui/plot/test/testItem.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -54,7 +54,7 @@ class TestSigItemChangedSignal(PlotWidgetTestCase): curve.setVisible(True) curve.setZValue(100) - # Test for signals in Points class + # Test for signals in PointsBase class curve.setData(numpy.arange(100), numpy.arange(100)) # SymbolMixIn @@ -194,14 +194,17 @@ class TestSigItemChangedSignal(PlotWidgetTestCase): # ColormapMixIn scatter.getColormap().setName('viridis') - data2 = data + 10 # Test of signals in Scatter class - scatter.setData(data2, data2, data2) + scatter.setData((0, 1, 2), (1, 0, 2), (0, 1, 2)) + + # Visualization mode changed + scatter.setVisualization(scatter.Visualization.SOLID) self.assertEqual(listener.arguments(), [(ItemChangedType.COLORMAP,), - (ItemChangedType.DATA,)]) + (ItemChangedType.DATA,), + (ItemChangedType.VISUALIZATION_MODE,)]) def testShapeChanged(self): """Test sigItemChanged for shape""" diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py index 9d7c093..7449c12 100644 --- a/silx/gui/plot/test/testPlotWidget.py +++ b/silx/gui/plot/test/testPlotWidget.py @@ -386,6 +386,16 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): self.assertTrue(numpy.all(numpy.equal(retrievedData, data))) self.assertIs(retrievedData.dtype.type, numpy.int8) + def testPlotAlphaImage(self): + """Test with an alpha image layer""" + data = numpy.random.random((10, 10)) + alpha = numpy.linspace(0, 1, 100).reshape(10, 10) + self.plot.addImage(data, legend='image') + image = self.plot.getActiveImage() + image.setData(data, alpha=alpha) + self.qapp.processEvents() + self.assertTrue(numpy.array_equal(alpha, image.getAlphaData())) + class TestPlotCurve(PlotWidgetTestCase): """Basic tests for addCurve.""" @@ -463,7 +473,34 @@ class TestPlotCurve(PlotWidgetTestCase): self.plot.addCurve(self.xData, self.yData, legend="curve 2", replace=False, resetzoom=False, - color=color, symbol='o') + color=color, symbol='o') + + +class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase): + """Basic tests for addScatter""" + + def testScatter(self): + x = numpy.arange(100) + y = numpy.arange(100) + value = numpy.arange(100) + self.plot.addScatter(x, y, value) + self.plot.resetZoom() + + def testScatterVisualization(self): + self.plot.addScatter((0, 1, 2, 3), (2, 0, 2, 1), (0, 1, 2, 3)) + self.plot.resetZoom() + self.qapp.processEvents() + + scatter = self.plot.getItems()[0] + + for visualization in ('solid', + 'points', + scatter.Visualization.SOLID, + scatter.Visualization.POINTS): + with self.subTest(visualization=visualization): + scatter.setVisualization(visualization) + self.qapp.processEvents() + class TestPlotMarker(PlotWidgetTestCase): """Basic tests for add*Marker""" @@ -1524,11 +1561,19 @@ class TestPlotItemLog(PlotWidgetTestCase): def suite(): - testClasses = (TestPlotWidget, TestPlotImage, TestPlotCurve, - TestPlotMarker, TestPlotItem, TestPlotAxes, + testClasses = (TestPlotWidget, + TestPlotImage, + TestPlotCurve, + TestPlotScatter, + TestPlotMarker, + TestPlotItem, + TestPlotAxes, TestPlotActiveCurveImage, - TestPlotEmptyLog, TestPlotCurveLog, TestPlotImageLog, - TestPlotMarkerLog, TestPlotItemLog) + TestPlotEmptyLog, + TestPlotCurveLog, + TestPlotImageLog, + TestPlotMarkerLog, + TestPlotItemLog) test_suite = unittest.TestSuite() diff --git a/silx/gui/plot/test/testPlotWindow.py b/silx/gui/plot/test/testPlotWindow.py index 6d3eb8f..0a7d108 100644 --- a/silx/gui/plot/test/testPlotWindow.py +++ b/silx/gui/plot/test/testPlotWindow.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -38,32 +38,6 @@ from silx.gui import qt from silx.gui.plot import PlotWindow -# Test of the docstrings # - -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - - -def _tearDownQt(docTest): - """Tear down to use for test from docstring. - - Checks that plt widget is displayed - """ - _qapp.processEvents() - for obj in docTest.globs.values(): - if isinstance(obj, PlotWindow): - # Commented out as it takes too long - # qWaitForWindowExposedAndActivate(obj) - obj.setAttribute(qt.Qt.WA_DeleteOnClose) - obj.close() - del obj - - -plotWindowDocTestSuite = doctest.DocTestSuite('silx.gui.plot.PlotWindow', - tearDown=_tearDownQt) -"""Test suite of tests from the module's docstrings.""" - - class TestPlotWindow(TestCaseQt): """Base class for tests of PlotWindow.""" @@ -128,7 +102,6 @@ class TestPlotWindow(TestCaseQt): def suite(): test_suite = unittest.TestSuite() - test_suite.addTest(plotWindowDocTestSuite) test_suite.addTest( unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotWindow)) return test_suite diff --git a/silx/gui/plot/test/testProfile.py b/silx/gui/plot/test/testProfile.py index 847f404..cf40f76 100644 --- a/silx/gui/plot/test/testProfile.py +++ b/silx/gui/plot/test/testProfile.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -39,10 +39,6 @@ from silx.gui.plot import PlotWindow, Plot1D, Plot2D, Profile from silx.gui.plot.StackView import StackView -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - - class TestProfileToolBar(TestCaseQt, ParametricTestCase): """Tests for ProfileToolBar widget.""" diff --git a/silx/gui/plot/test/testStackView.py b/silx/gui/plot/test/testStackView.py index a5f649c..80c85d6 100644 --- a/silx/gui/plot/test/testStackView.py +++ b/silx/gui/plot/test/testStackView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -41,10 +41,6 @@ from silx.gui.plot.StackView import StackViewMainWindow from silx.utils.array_like import ListOfImages -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - - class TestStackView(TestCaseQt): """Base class for tests of StackView.""" diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py index 7fbc247..4bc2144 100644 --- a/silx/gui/plot/test/testStats.py +++ b/silx/gui/plot/test/testStats.py @@ -33,8 +33,9 @@ from silx.gui import qt from silx.gui.plot.stats import stats from silx.gui.plot import StatsWidget from silx.gui.plot.stats import statshandler -from silx.gui.utils.testutils import TestCaseQt +from silx.gui.utils.testutils import TestCaseQt, SignalListener from silx.gui.plot import Plot1D, Plot2D +from silx.utils.testutils import ParametricTestCase import unittest import logging import numpy @@ -350,7 +351,7 @@ class TestStatsHandler(unittest.TestCase): statshandler.StatsHandler(('name')) -class TestStatsWidgetWithCurves(TestCaseQt): +class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): """Basic test for StatsWidget with curves""" def setUp(self): TestCaseQt.setUp(self) @@ -363,7 +364,8 @@ class TestStatsWidgetWithCurves(TestCaseQt): self.plot.addCurve(x, y, legend='curve1') y = range(-2, 18) self.plot.addCurve(x, y, legend='curve2') - self.widget = StatsWidget.StatsTable(plot=self.plot) + self.widget = StatsWidget.StatsWidget(plot=self.plot) + self.statsTable = self.widget._statsTable mystats = statshandler.StatsHandler(( stats.StatMin(), @@ -376,67 +378,170 @@ class TestStatsWidgetWithCurves(TestCaseQt): stats.StatCOM() )) - self.widget.setStats(mystats) + self.statsTable.setStats(mystats) def tearDown(self): self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() + self.statsTable = None self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) self.widget.close() self.widget = None self.plot = None TestCaseQt.tearDown(self) + def testDisplayActiveItemsSyncOptions(self): + """ + Test that the several option of the sync options are well + synchronized between the different object""" + widget = StatsWidget.StatsWidget(plot=self.plot) + table = StatsWidget.StatsTable(plot=self.plot) + + def check_display_only_active_item(only_active): + # check internal value + self.assertTrue(widget._statsTable._displayOnlyActItem is only_active) + # self.assertTrue(table._displayOnlyActItem is only_active) + # check gui display + self.assertTrue(widget._options.isActiveItemMode() is only_active) + + for displayOnlyActiveItems in (True, False): + with self.subTest(displayOnlyActiveItems=displayOnlyActiveItems): + widget.setDisplayOnlyActiveItem(displayOnlyActiveItems) + # table.setDisplayOnlyActiveItem(displayOnlyActiveItems) + check_display_only_active_item(displayOnlyActiveItems) + + check_display_only_active_item(only_active=False) + widget.setAttribute(qt.Qt.WA_DeleteOnClose) + table.setAttribute(qt.Qt.WA_DeleteOnClose) + widget.close() + table.close() + def testInit(self): """Make sure all the curves are registred on initialization""" - self.assertEqual(self.widget.rowCount(), 3) + self.assertEqual(self.statsTable.rowCount(), 3) def testRemoveCurve(self): """Make sure the Curves stats take into account the curve removal from plot""" self.plot.removeCurve('curve2') - self.assertEqual(self.widget.rowCount(), 2) + self.assertEqual(self.statsTable.rowCount(), 2) for iRow in range(2): - self.assertTrue(self.widget.item(iRow, 0).text() in ('curve0', 'curve1')) + self.assertTrue(self.statsTable.item(iRow, 0).text() in ('curve0', 'curve1')) self.plot.removeCurve('curve0') - self.assertEqual(self.widget.rowCount(), 1) + self.assertEqual(self.statsTable.rowCount(), 1) self.plot.removeCurve('curve1') - self.assertEqual(self.widget.rowCount(), 0) + self.assertEqual(self.statsTable.rowCount(), 0) def testAddCurve(self): """Make sure the Curves stats take into account the add curve action""" self.plot.addCurve(legend='curve3', x=range(10), y=range(10)) - self.assertEqual(self.widget.rowCount(), 4) + self.assertEqual(self.statsTable.rowCount(), 4) def testUpdateCurveFromAddCurve(self): """Make sure the stats of the cuve will be removed after updating a curve""" self.plot.addCurve(legend='curve0', x=range(10), y=range(10)) self.qapp.processEvents() - self.assertEqual(self.widget.rowCount(), 3) + self.assertEqual(self.statsTable.rowCount(), 3) curve = self.plot._getItem(kind='curve', legend='curve0') - tableItems = self.widget._itemToTableItems(curve) + tableItems = self.statsTable._itemToTableItems(curve) self.assertEqual(tableItems['max'].text(), '9') def testUpdateCurveFromCurveObj(self): self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) self.qapp.processEvents() - self.assertEqual(self.widget.rowCount(), 3) + self.assertEqual(self.statsTable.rowCount(), 3) curve = self.plot._getItem(kind='curve', legend='curve0') - tableItems = self.widget._itemToTableItems(curve) + tableItems = self.statsTable._itemToTableItems(curve) self.assertEqual(tableItems['max'].text(), '3') def testSetAnotherPlot(self): plot2 = Plot1D() plot2.addCurve(x=range(26), y=range(26), legend='new curve') - self.widget.setPlot(plot2) - self.assertEqual(self.widget.rowCount(), 1) + self.statsTable.setPlot(plot2) + self.assertEqual(self.statsTable.rowCount(), 1) self.qapp.processEvents() plot2.setAttribute(qt.Qt.WA_DeleteOnClose) plot2.close() plot2 = None + def testUpdateMode(self): + """Make sure the update modes are well take into account""" + self.plot.setActiveCurve('curve0') + for display_only_active in (True, False): + with self.subTest(display_only_active=display_only_active): + self.widget.setDisplayOnlyActiveItem(display_only_active) + self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + update_stats_action = self.widget._options.getUpdateStatsAction() + # test from api + self.assertTrue(self.widget.getUpdateMode() is StatsWidget.UpdateMode.AUTO) + self.widget.show() + # check stats change in auto mode + self.plot.getCurve('curve0').setData(x=range(4), y=range(-1, 3)) + self.qapp.processEvents() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + print(curve0_min) + self.assertTrue(float(curve0_min) == -1.) + + self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5)) + self.qapp.processEvents() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == 1.) + + # check stats change in manual mode only if requested + self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) + self.assertTrue(self.widget.getUpdateMode() is StatsWidget.UpdateMode.MANUAL) + + self.plot.getCurve('curve0').setData(x=range(4), y=range(2, 6)) + self.qapp.processEvents() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == 1.) + + update_stats_action.trigger() + tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) + curve0_min = tableItems['min'].text() + self.assertTrue(float(curve0_min) == 2.) + + def testItemHidden(self): + """Test if an item is hide, then the associated stats item is also + hide""" + curve0 = self.plot.getCurve('curve0') + curve1 = self.plot.getCurve('curve1') + curve2 = self.plot.getCurve('curve2') + + self.plot.show() + self.widget.show() + self.qWaitForWindowExposed(self.widget) + self.assertFalse(self.statsTable.isRowHidden(0)) + self.assertFalse(self.statsTable.isRowHidden(1)) + self.assertFalse(self.statsTable.isRowHidden(2)) + + curve0.setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.statsTable.isRowHidden(0)) + curve0.setVisible(True) + self.qapp.processEvents() + self.assertFalse(self.statsTable.isRowHidden(0)) + curve1.setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.statsTable.isRowHidden(1)) + tableItems = self.statsTable._itemToTableItems(curve2) + curve2_min = tableItems['min'].text() + self.assertTrue(float(curve2_min) == -2.) + + curve0.setVisible(False) + curve1.setVisible(False) + curve2.setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.statsTable.isRowHidden(0)) + self.assertTrue(self.statsTable.isRowHidden(1)) + self.assertTrue(self.statsTable.isRowHidden(2)) + class TestStatsWidgetWithImages(TestCaseQt): """Basic test for StatsWidget with images""" @@ -487,6 +592,17 @@ class TestStatsWidgetWithImages(TestCaseQt): self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0') self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0') + def testItemHidden(self): + """Test if an item is hide, then the associated stats item is also + hide""" + self.widget.show() + self.plot.show() + self.qWaitForWindowExposed(self.widget) + self.assertFalse(self.widget.isRowHidden(0)) + self.plot.getImage(self.IMAGE_LEGEND).setVisible(False) + self.qapp.processEvents() + self.assertTrue(self.widget.isRowHidden(0)) + class TestStatsWidgetWithScatters(TestCaseQt): @@ -556,13 +672,13 @@ class TestLineWidget(TestCaseQt): self.plot = Plot1D() self.plot.show() - x = range(20) - y = range(20) - self.plot.addCurve(x, y, legend='curve0') - y = range(12, 32) - self.plot.addCurve(x, y, legend='curve1') - y = range(-2, 18) - self.plot.addCurve(x, y, legend='curve2') + self.x = range(20) + self.y0 = range(20) + self.curve0 = self.plot.addCurve(self.x, self.y0, legend='curve0') + self.y1 = range(12, 32) + self.plot.addCurve(self.x, self.y1, legend='curve1') + self.y2 = range(-2, 18) + self.plot.addCurve(self.x, self.y2, legend='curve2') self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot, kind='curve', stats=mystats) @@ -572,33 +688,112 @@ class TestLineWidget(TestCaseQt): self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() self.widget.setPlot(None) - self.widget._statQlineEdit.clear() + self.widget._lineStatsWidget._statQlineEdit.clear() self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) self.widget.close() self.widget = None self.plot = None TestCaseQt.tearDown(self) - def test(self): - self.widget.setStatsOnVisibleData(False) + def testProcessing(self): + self.widget._lineStatsWidget.setStatsOnVisibleData(False) self.qapp.processEvents() self.plot.setActiveCurve(legend='curve0') - self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.000') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.000') self.plot.setActiveCurve(legend='curve1') - self.assertTrue(self.widget._statQlineEdit['min'].text() == '12.000') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '12.000') self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5) self.widget.setStatsOnVisibleData(True) self.qapp.processEvents() - self.assertTrue(self.widget._statQlineEdit['min'].text() == '14.000') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000') self.plot.setActiveCurve(None) self.assertTrue(self.plot.getActiveCurve() is None) self.widget.setStatsOnVisibleData(False) self.qapp.processEvents() - self.assertFalse(self.widget._statQlineEdit['min'].text() == '14.000') + self.assertFalse(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000') self.widget.setKind('image') self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312) self.qapp.processEvents() - self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.312') + self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.312') + + def testUpdateMode(self): + """Make sure the update modes are well take into account""" + self.plot.setActiveCurve(self.curve0) + _autoRB = self.widget._options._autoRB + _manualRB = self.widget._options._manualRB + # test from api + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + self.assertTrue(_autoRB.isChecked()) + self.assertFalse(_manualRB.isChecked()) + + # check stats change in auto mode + curve0_min = self.widget._lineStatsWidget._statQlineEdit['min'].text() + new_y = numpy.array(self.y0) - 2.56 + self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0) + curve0_min2 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.assertTrue(curve0_min != curve0_min2) + + # check stats change in manual mode only if requested + self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) + self.assertFalse(_autoRB.isChecked()) + self.assertTrue(_manualRB.isChecked()) + + new_y = numpy.array(self.y0) - 1.2 + self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0) + curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.assertTrue(curve0_min3 == curve0_min2) + self.widget._options._updateRequested() + curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.assertTrue(curve0_min3 != curve0_min2) + + # test from gui + self.widget.showRadioButtons(True) + self.widget._options._autoRB.toggle() + self.assertTrue(_autoRB.isChecked()) + self.assertFalse(_manualRB.isChecked()) + + self.widget._options._manualRB.toggle() + self.assertFalse(_autoRB.isChecked()) + self.assertTrue(_manualRB.isChecked()) + + +class TestUpdateModeWidget(TestCaseQt): + """Test UpdateModeWidget""" + def setUp(self): + TestCaseQt.setUp(self) + self.widget = StatsWidget.UpdateModeWidget(parent=None) + + def tearDown(self): + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + TestCaseQt.tearDown(self) + + def testSignals(self): + """Test the signal emission of the widget""" + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + modeChangedListener = SignalListener() + manualUpdateListener = SignalListener() + self.widget.sigUpdateModeChanged.connect(modeChangedListener) + self.widget.sigUpdateRequested.connect(manualUpdateListener) + self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) + self.assertTrue(self.widget.getUpdateMode() is StatsWidget.UpdateMode.AUTO) + self.assertTrue(modeChangedListener.callCount() is 0) + self.qapp.processEvents() + + self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) + self.assertTrue(self.widget.getUpdateMode() is StatsWidget.UpdateMode.MANUAL) + self.qapp.processEvents() + self.assertTrue(modeChangedListener.callCount() is 1) + self.assertTrue(manualUpdateListener.callCount() is 0) + self.widget._updatePB.click() + self.widget._updatePB.click() + self.assertTrue(manualUpdateListener.callCount() is 2) + + self.widget._autoRB.setChecked(True) + self.assertTrue(modeChangedListener.callCount() is 2) + self.widget._updatePB.click() + self.assertTrue(manualUpdateListener.callCount() is 2) def suite(): @@ -606,7 +801,7 @@ def suite(): for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters, TestStatsWidgetWithImages, TestStatsWidgetWithCurves, TestStatsFormatter, TestEmptyStatsWidget, - TestLineWidget): + TestLineWidget, TestUpdateModeWidget): test_suite.addTest( unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) return test_suite diff --git a/silx/gui/plot/tools/profile/ScatterProfileToolBar.py b/silx/gui/plot/tools/profile/ScatterProfileToolBar.py index fd21515..0d30651 100644 --- a/silx/gui/plot/tools/profile/ScatterProfileToolBar.py +++ b/silx/gui/plot/tools/profile/ScatterProfileToolBar.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,196 +31,18 @@ __date__ = "28/06/2018" import logging -import threading -import time +import weakref import numpy -try: - from scipy.interpolate import LinearNDInterpolator -except ImportError: - LinearNDInterpolator = None - - # Fallback using local Delaunay and matplotlib interpolator - from silx.third_party.scipy_spatial import Delaunay - import matplotlib.tri - from ._BaseProfileToolBar import _BaseProfileToolBar -from .... import qt from ... import items +from ....utils.concurrent import submitToQtMainThread _logger = logging.getLogger(__name__) -# TODO support log scale - - -class _InterpolatorInitThread(qt.QThread): - """Thread building a scatter interpolator - - This works in greedy mode in that the signal is only emitted - when no other request is pending - """ - - sigInterpolatorReady = qt.Signal(object) - """Signal emitted whenever an interpolator is ready - - It provides a 3-tuple (points, values, interpolator) - """ - - _RUNNING_THREADS_TO_DELETE = [] - """Store reference of no more used threads but still running""" - - def __init__(self): - super(_InterpolatorInitThread, self).__init__() - self._lock = threading.RLock() - self._pendingData = None - self._firstFallbackRun = True - - def discard(self, obj=None): - """Wait for pending thread to complete and delete then - - Connect this to the destroyed signal of widget using this thread - """ - if self.isRunning(): - self.cancel() - self._RUNNING_THREADS_TO_DELETE.append(self) # Keep a reference - self.finished.connect(self.__finished) - - def __finished(self): - """Handle finished signal of threads to delete""" - try: - self._RUNNING_THREADS_TO_DELETE.remove(self) - except ValueError: - _logger.warning('Finished thread no longer in reference list') - - def request(self, points, values): - """Request new initialisation of interpolator - - :param numpy.ndarray points: Point coordinates (N, D) - :param numpy.ndarray values: Values the N points (1D array) - """ - with self._lock: - # Possibly replace already pending data - self._pendingData = points, values - - if not self.isRunning(): - self.start() - - def cancel(self): - """Cancel any running/pending requests""" - with self._lock: - self._pendingData = 'cancelled' - - def run(self): - """Run the init of the scatter interpolator""" - if LinearNDInterpolator is None: - self.run_matplotlib() - else: - self.run_scipy() - - def run_matplotlib(self): - """Run the init of the scatter interpolator""" - if self._firstFallbackRun: - self._firstFallbackRun = False - _logger.warning( - "scipy.spatial.LinearNDInterpolator not available: " - "Scatter plot interpolator initialisation can freeze the GUI.") - - while True: - with self._lock: - data = self._pendingData - self._pendingData = None - - if data in (None, 'cancelled'): - return - - points, values = data - - startTime = time.time() - try: - delaunay = Delaunay(points) - except: - _logger.warning( - "Cannot triangulate scatter data") - else: - with self._lock: - data = self._pendingData - - if data is not None: # Break point - _logger.info('Interpolator discarded after %f s', - time.time() - startTime) - else: - - x, y = points.T - triangulation = matplotlib.tri.Triangulation( - x, y, triangles=delaunay.simplices) - - interpolator = matplotlib.tri.LinearTriInterpolator( - triangulation, values) - - with self._lock: - data = self._pendingData - - if data is not None: - _logger.info('Interpolator discarded after %f s', - time.time() - startTime) - else: - # No other processing requested: emit the signal - _logger.info("Interpolator initialised in %f s", - time.time() - startTime) - - # Wrap interpolator to have same API as scipy's one - def wrapper(points): - return interpolator(*points.T) - - self.sigInterpolatorReady.emit( - (points, values, wrapper)) - - def run_scipy(self): - """Run the init of the scatter interpolator""" - while True: - with self._lock: - data = self._pendingData - self._pendingData = None - - if data in (None, 'cancelled'): - return - - points, values = data - - startTime = time.time() - try: - interpolator = LinearNDInterpolator(points, values) - except: - _logger.warning( - "Cannot initialise scatter profile interpolator") - else: - with self._lock: - data = self._pendingData - - if data is not None: # Break point - _logger.info('Interpolator discarded after %f s', - time.time() - startTime) - else: - # First call takes a while, do it here - interpolator([(0., 0.)]) - - with self._lock: - data = self._pendingData - - if data is not None: - _logger.info('Interpolator discarded after %f s', - time.time() - startTime) - else: - # No other processing requested: emit the signal - _logger.info("Interpolator initialised in %f s", - time.time() - startTime) - self.sigInterpolatorReady.emit( - (points, values, interpolator)) - - class ScatterProfileToolBar(_BaseProfileToolBar): """QToolBar providing scatter plot profiling tools @@ -233,49 +55,13 @@ class ScatterProfileToolBar(_BaseProfileToolBar): super(ScatterProfileToolBar, self).__init__(parent, plot, title) self.__nPoints = 1024 - self.__interpolator = None - self.__interpolatorCache = None # points, values, interpolator - - self.__initThread = _InterpolatorInitThread() - self.destroyed.connect(self.__initThread.discard) - self.__initThread.sigInterpolatorReady.connect( - self.__interpolatorReady) - - roiManager = self._getRoiManager() - if roiManager is None: - _logger.error( - "Error during scatter profile toolbar initialisation") - else: - roiManager.sigInteractiveModeStarted.connect( - self.__interactionStarted) - roiManager.sigInteractiveModeFinished.connect( - self.__interactionFinished) - if roiManager.isStarted(): - self.__interactionStarted(roiManager.getCurrentInteractionModeRoiClass()) - - def __interactionStarted(self, roiClass): - """Handle start of ROI interaction""" - plot = self.getPlotWidget() - if plot is None: - return - - plot.sigActiveScatterChanged.connect(self.__activeScatterChanged) - - scatter = plot._getActiveItem(kind='scatter') - legend = None if scatter is None else scatter.getLegend() - self.__activeScatterChanged(None, legend) + self.__scatterRef = None + self.__futureInterpolator = None - def __interactionFinished(self): - """Handle end of ROI interaction""" plot = self.getPlotWidget() - if plot is None: - return - - plot.sigActiveScatterChanged.disconnect(self.__activeScatterChanged) - - scatter = plot._getActiveItem(kind='scatter') - legend = None if scatter is None else scatter.getLegend() - self.__activeScatterChanged(legend, None) + if plot is not None: + self._setScatterItem(plot._getActiveItem(kind='scatter')) + plot.sigActiveScatterChanged.connect(self.__activeScatterChanged) def __activeScatterChanged(self, previous, legend): """Handle change of active scatter @@ -283,35 +69,37 @@ class ScatterProfileToolBar(_BaseProfileToolBar): :param Union[str,None] previous: :param Union[str,None] legend: """ - self.__initThread.cancel() + plot = self.getPlotWidget() + if plot is None or legend is None: + scatter = None + else: + scatter = plot.getScatter(legend) + self._setScatterItem(scatter) - # Reset interpolator - self.__interpolator = None + def _getScatterItem(self): + """Returns the scatter item currently handled by this tool. - plot = self.getPlotWidget() - if plot is None: - _logger.error("Associated PlotWidget no longer exists") + :rtype: ~silx.gui.plot.items.Scatter + """ + return None if self.__scatterRef is None else self.__scatterRef() + def _setScatterItem(self, scatter): + """Set the scatter tracked by this tool + + :param Union[None,silx.gui.plot.items.Scatter] scatter: + """ + self.__futureInterpolator = None # Reset currently expected future + + previousScatter = self._getScatterItem() + if previousScatter is not None: + previousScatter.sigItemChanged.disconnect( + self.__scatterItemChanged) + + if scatter is None: + self.__scatterRef = None else: - if previous is not None: # Disconnect signal - scatter = plot.getScatter(previous) - if scatter is not None: - scatter.sigItemChanged.disconnect( - self.__scatterItemChanged) - - if legend is not None: - scatter = plot.getScatter(legend) - if scatter is None: - _logger.error("Cannot retrieve active scatter") - - else: - scatter.sigItemChanged.connect(self.__scatterItemChanged) - points = numpy.transpose(numpy.array(( - scatter.getXData(copy=False), - scatter.getYData(copy=False)))) - values = scatter.getValueData(copy=False) - - self.__updateInterpolator(points, values) + self.__scatterRef = weakref.ref(scatter) + scatter.sigItemChanged.connect(self.__scatterItemChanged) # Refresh profile self.updateProfile() @@ -322,49 +110,15 @@ class ScatterProfileToolBar(_BaseProfileToolBar): :param ItemChangedType event: """ if event == items.ItemChangedType.DATA: - self.__interpolator = None - scatter = self.sender() - if scatter is None: - _logger.error("Cannot retrieve updated scatter item") - - else: - points = numpy.transpose(numpy.array(( - scatter.getXData(copy=False), - scatter.getYData(copy=False)))) - values = scatter.getValueData(copy=False) - - self.__updateInterpolator(points, values) - - # Handle interpolator init thread - - def __updateInterpolator(self, points, values): - """Update used interpolator with new data""" - if (self.__interpolatorCache is not None and - len(points) == len(self.__interpolatorCache[0]) and - numpy.all(numpy.equal(self.__interpolatorCache[0], points)) and - numpy.all(numpy.equal(self.__interpolatorCache[1], values))): - # Reuse previous interpolator - _logger.info( - 'Scatter changed: Reuse previous interpolator') - self.__interpolator = self.__interpolatorCache[2] - - else: - # Interpolator needs update: Start background processing - _logger.info( - 'Scatter changed: Rebuild interpolator') - self.__interpolator = None - self.__interpolatorCache = None - self.__initThread.request(points, values) - - def __interpolatorReady(self, data): - """Handle end of init interpolator thread""" - points, values, interpolator = data - self.__interpolator = interpolator - self.__interpolatorCache = None if interpolator is None else data - self.updateProfile() + self.updateProfile() # Refresh profile def hasPendingOperations(self): - return self.__initThread.isRunning() + """Returns True if waiting for an interpolator to be ready + + :rtype: bool + """ + return (self.__futureInterpolator is not None and + not self.__futureInterpolator.done()) # Number of points @@ -383,8 +137,9 @@ class ScatterProfileToolBar(_BaseProfileToolBar): npoints = int(npoints) if npoints < 1: raise ValueError("Unsupported number of points: %d" % npoints) - else: + elif npoints != self.__nPoints: self.__nPoints = npoints + self.updateProfile() # Overridden methods @@ -400,11 +155,16 @@ class ScatterProfileToolBar(_BaseProfileToolBar): """ if self.hasPendingOperations(): return 'Pre-processing data...' - else: return super(ScatterProfileToolBar, self).computeProfileTitle( x0, y0, x1, y1) + def __futureDone(self, future): + """Handle completion of the interpolator creation""" + if future is self.__futureInterpolator: + # Only handle future callbacks for the current one + submitToQtMainThread(self.updateProfile) + def computeProfile(self, x0, y0, x1, y1): """Compute corresponding profile @@ -414,16 +174,32 @@ class ScatterProfileToolBar(_BaseProfileToolBar): :param float y1: Profile end point Y coord :return: (points, values) profile data or None """ - if self.__interpolator is None: + scatter = self._getScatterItem() + if scatter is None or self.hasPendingOperations(): return None - nPoints = self.getNPoints() + # Lazy async request of the interpolator + future = scatter._getInterpolator() + if future is not self.__futureInterpolator: + # First time we request this interpolator + self.__futureInterpolator = future + if not future.done(): + future.add_done_callback(self.__futureDone) + return None + + if future.cancelled() or future.exception() is not None: + return None # Something went wrong + interpolator = future.result() + if interpolator is None: + return None # Cannot init an interpolator + + nPoints = self.getNPoints() points = numpy.transpose(( numpy.linspace(x0, x1, nPoints, endpoint=True), numpy.linspace(y0, y1, nPoints, endpoint=True))) - values = self.__interpolator(points) + values = interpolator(points) if not numpy.any(numpy.isfinite(values)): return None # Profile outside convex hull diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py index 98295ba..eb933a0 100644 --- a/silx/gui/plot/tools/roi.py +++ b/silx/gui/plot/tools/roi.py @@ -106,6 +106,9 @@ class RegionOfInterestManager(qt.QObject): self._rois = [] # List of ROIs self._drawnROI = None # New ROI being currently drawn + # Handle unique selection of interaction mode action + self._actionGroup = qt.QActionGroup(self) + self._roiClass = None self._color = rgba('red') @@ -158,6 +161,8 @@ class RegionOfInterestManager(qt.QObject): action.setChecked(self.getCurrentInteractionModeRoiClass() is roiClass) action.setToolTip(text) + self._actionGroup.addAction(action) + action.triggered[bool].connect(functools.partial( WeakMethodProxy(self._modeActionTriggered), roiClass=roiClass)) self._modeActions[roiClass] = action @@ -171,9 +176,6 @@ class RegionOfInterestManager(qt.QObject): """ if checked: self.start(roiClass) - else: # Keep action checked - action = self.sender() - action.setChecked(True) def _updateModeActions(self): """Check/Uncheck action corresponding to current mode""" @@ -781,9 +783,9 @@ class RegionOfInterestTableWidget(qt.QTableWidget): super(RegionOfInterestTableWidget, self).__init__(parent) self._roiManagerRef = None - self.setColumnCount(5) - self.setHorizontalHeaderLabels( - ['Label', 'Edit', 'Kind', 'Coordinates', '']) + headers = ['Label', 'Edit', 'Kind', 'Coordinates', ''] + self.setColumnCount(len(headers)) + self.setHorizontalHeaderLabels(headers) horizontalHeader = self.horizontalHeader() horizontalHeader.setDefaultAlignment(qt.Qt.AlignLeft) @@ -815,9 +817,10 @@ class RegionOfInterestTableWidget(qt.QTableWidget): manager = self.getRegionOfInterestManager() roi = manager.getRois()[index] else: - roi = None + return if column == 0: + roi.setVisible(item.checkState() == qt.Qt.Checked) roi.setLabel(item.text()) elif column == 1: roi.setEditable( @@ -884,11 +887,13 @@ class RegionOfInterestTableWidget(qt.QTableWidget): for index, roi in enumerate(rois): baseFlags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled - # Label + # Label and visible label = roi.getLabel() item = qt.QTableWidgetItem(label) - item.setFlags(baseFlags | qt.Qt.ItemIsEditable) + item.setFlags(baseFlags | qt.Qt.ItemIsEditable | qt.Qt.ItemIsUserCheckable) item.setData(qt.Qt.UserRole, index) + item.setCheckState( + qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked) self.setItem(index, 0, item) # Editable diff --git a/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/silx/gui/plot/tools/test/testScatterProfileToolBar.py index 0f4b668..714746a 100644 --- a/silx/gui/plot/tools/test/testScatterProfileToolBar.py +++ b/silx/gui/plot/tools/test/testScatterProfileToolBar.py @@ -101,6 +101,7 @@ class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase): self.qWait(200) if not self.profile.hasPendingOperations(): break + self.qapp.processEvents() self.assertIsNotNone(self.profile.getProfileValues()) points = self.profile.getProfilePoints() diff --git a/silx/gui/plot/tools/test/testTools.py b/silx/gui/plot/tools/test/testTools.py index f4adda0..70c8105 100644 --- a/silx/gui/plot/tools/test/testTools.py +++ b/silx/gui/plot/tools/test/testTools.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -41,34 +41,6 @@ from silx.gui.plot import tools from silx.gui.plot.test.utils import PlotWidgetTestCase -# Makes sure a QApplication exists -_qapp = qt.QApplication.instance() or qt.QApplication([]) - - -def _tearDownDocTest(docTest): - """Tear down to use for test from docstring. - - Checks that plot widget is displayed - """ - plot = docTest.globs['plot'] - qWaitForWindowExposedAndActivate(plot) - plot.setAttribute(qt.Qt.WA_DeleteOnClose) - plot.close() - del plot - -# Disable doctest because of -# "NameError: name 'numpy' is not defined" -# -# import doctest -# positionInfoTestSuite = doctest.DocTestSuite( -# PlotTools, tearDown=_tearDownDocTest, -# optionflags=doctest.ELLIPSIS) -# """Test suite of tests from PlotTools docstrings. -# -# Test PositionInfo and ProfileToolBar docstrings. -# """ - - class TestPositionInfo(PlotWidgetTestCase): """Tests for PositionInfo widget.""" diff --git a/silx/gui/plot/tools/toolbars.py b/silx/gui/plot/tools/toolbars.py index 28fb7f9..04d0cfc 100644 --- a/silx/gui/plot/tools/toolbars.py +++ b/silx/gui/plot/tools/toolbars.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,6 +34,7 @@ from ... import qt from .. import actions from ..PlotWidget import PlotWidget from .. import PlotToolButtons +from ....utils.deprecation import deprecated class InteractiveModeToolBar(qt.QToolBar): @@ -302,9 +303,9 @@ class ScatterToolBar(qt.QToolBar): parent=self, plot=plot) self.addAction(self._colormapAction) - self._symbolToolButton = PlotToolButtons.SymbolToolButton( - parent=self, plot=plot) - self.addWidget(self._symbolToolButton) + self._visualizationToolButton = \ + PlotToolButtons.ScatterVisualizationToolButton(parent=self, plot=plot) + self.addWidget(self._visualizationToolButton) def getResetZoomAction(self): """Returns the QAction to reset the zoom. @@ -341,16 +342,21 @@ class ScatterToolBar(qt.QToolBar): """ return self._colormapAction - def getSymbolToolButton(self): - """Returns the QToolButton controlling symbol size and marker. - - :rtype: SymbolToolButton - """ - return self._symbolToolButton - def getKeepDataAspectRatioButton(self): """Returns the QToolButton controlling data aspect ratio. :rtype: QToolButton """ return self._keepDataAspectRatioButton + + def getScatterVisualizationToolButton(self): + """Returns the QToolButton controlling the visualization mode. + + :rtype: ScatterVisualizationToolButton + """ + return self._visualizationToolButton + + @deprecated(replacement='getScatterVisualizationToolButton', + since_version='0.11.0') + def getSymbolToolButton(self): + return self.getScatterVisualizationToolButton() diff --git a/silx/gui/plot3d/Plot3DWidget.py b/silx/gui/plot3d/Plot3DWidget.py index eed4438..f512cd8 100644 --- a/silx/gui/plot3d/Plot3DWidget.py +++ b/silx/gui/plot3d/Plot3DWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,11 +31,14 @@ __license__ = "MIT" __date__ = "24/04/2018" +import enum import logging from silx.gui import qt from silx.gui.colors import rgba from . import actions + +from ...utils.enum import Enum as _Enum from ..utils.image import convertArrayToQImage from .. import _glutils as glu @@ -106,6 +109,22 @@ class Plot3DWidget(glu.OpenGLWidget): It provides the updated property. """ + sigSceneClicked = qt.Signal(float, float) + """Signal emitted when the scene is clicked with the left mouse button. + + It provides the (x, y) clicked mouse position + """ + + @enum.unique + class FogMode(_Enum): + """Different mode to render the scene with fog""" + + NONE = 'none' + """No fog effect""" + + LINEAR = 'linear' + """Linear fog through the whole scene""" + def __init__(self, parent=None, f=qt.Qt.WindowFlags()): self._firstRender = True @@ -146,6 +165,11 @@ class Plot3DWidget(glu.OpenGLWidget): self.eventHandler = None self.setInteractiveMode('rotate') + def __clickHandler(self, *args): + """Handle interaction state machine click""" + x, y = args[0][:2] + self.sigSceneClicked.emit(x, y) + def setInteractiveMode(self, mode): """Set the interactive mode. @@ -163,7 +187,7 @@ class Plot3DWidget(glu.OpenGLWidget): orbitAroundCenter=False, mode='position', scaleTransform=self._sceneScale, - selectCB=None) + selectCB=self.__clickHandler) elif mode == 'pan': self.eventHandler = interaction.PanCameraControl( @@ -171,7 +195,7 @@ class Plot3DWidget(glu.OpenGLWidget): orbitAroundCenter=False, mode='position', scaleTransform=self._sceneScale, - selectCB=None) + selectCB=self.__clickHandler) elif isinstance(mode, interaction.StateMachine): self.eventHandler = mode @@ -244,6 +268,28 @@ class Plot3DWidget(glu.OpenGLWidget): """Returns the RGBA background color (QColor).""" return qt.QColor.fromRgbF(*self.viewport.background) + def setFogMode(self, mode): + """Set the kind of fog to use for the whole scene. + + :param Union[str,FogMode] mode: The mode to use + :raise ValueError: If mode is not supported + """ + mode = self.FogMode.from_value(mode) + if mode != self.getFogMode(): + self.viewport.fog.isOn = mode is self.FogMode.LINEAR + self.sigStyleChanged.emit('fogMode') + + def getFogMode(self): + """Returns the kind of fog in use + + :return: The kind of fog in use + :rtype: FogMode + """ + if self.viewport.fog.isOn: + return self.FogMode.LINEAR + else: + return self.FogMode.NONE + def isOrientationIndicatorVisible(self): """Returns True if the orientation indicator is displayed. diff --git a/silx/gui/plot3d/Plot3DWindow.py b/silx/gui/plot3d/Plot3DWindow.py index 331eca2..470b966 100644 --- a/silx/gui/plot3d/Plot3DWindow.py +++ b/silx/gui/plot3d/Plot3DWindow.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,6 +32,7 @@ __license__ = "MIT" __date__ = "26/01/2017" +from silx.utils.proxy import docstring from silx.gui import qt from .Plot3DWidget import Plot3DWidget @@ -62,32 +63,26 @@ class Plot3DWindow(qt.QMainWindow): # Proxy to Plot3DWidget + @docstring(Plot3DWidget) def setProjection(self, projection): return self._plot3D.setProjection(projection) - setProjection.__doc__ = Plot3DWidget.setProjection.__doc__ - + @docstring(Plot3DWidget) def getProjection(self): return self._plot3D.getProjection() - getProjection.__doc__ = Plot3DWidget.getProjection.__doc__ - + @docstring(Plot3DWidget) def centerScene(self): return self._plot3D.centerScene() - centerScene.__doc__ = Plot3DWidget.centerScene.__doc__ - + @docstring(Plot3DWidget) def resetZoom(self): return self._plot3D.resetZoom() - resetZoom.__doc__ = Plot3DWidget.resetZoom.__doc__ - + @docstring(Plot3DWidget) def getBackgroundColor(self): return self._plot3D.getBackgroundColor() - getBackgroundColor.__doc__ = Plot3DWidget.getBackgroundColor.__doc__ - + @docstring(Plot3DWidget) def setBackgroundColor(self, color): return self._plot3D.setBackgroundColor(color) - - setBackgroundColor.__doc__ = Plot3DWidget.setBackgroundColor.__doc__ diff --git a/silx/gui/plot3d/SceneWidget.py b/silx/gui/plot3d/SceneWidget.py index e60dcfc..883f5e7 100644 --- a/silx/gui/plot3d/SceneWidget.py +++ b/silx/gui/plot3d/SceneWidget.py @@ -45,7 +45,6 @@ from .scene import interaction from ._model import SceneModel, visitQAbstractItemModel from ._model.items import Item3DRow - __all__ = ['items', 'SceneWidget'] @@ -268,7 +267,7 @@ class SceneSelection(qt.QObject): assert isinstance(parent, SceneWidget) if item.root() != parent.getSceneGroup(): - self.setSelectedItem(None) + self.setCurrentItem(None) # Synchronization with QItemSelectionModel @@ -482,27 +481,37 @@ class SceneWidget(Plot3DWidget): # Add/remove items - def add3DScalarField(self, data, copy=True, index=None): - """Add 3D scalar data volume to :class:`SceneWidget` content. + def addVolume(self, data, copy=True, index=None): + """Add 3D data volume of scalar or complex to :class:`SceneWidget` content. Dataset order is zyx (i.e., first dimension is z). - :param data: 3D array - :type data: 3D numpy.ndarray of float32 with shape at least (2, 2, 2) + :param data: 3D array of complex with shape at least (2, 2, 2) + :type data: numpy.ndarray[Union[numpy.complex64,numpy.float32]] :param bool copy: True (default) to make a copy, False to avoid copy (DO NOT MODIFY data afterwards) :param int index: The index at which to place the item. By default it is appended to the end of the list. - :return: The newly created scalar volume item - :rtype: ~silx.gui.plot3d.items.volume.ScalarField3D + :return: The newly created 3D volume item + :rtype: Union[ScalarField3D,ComplexField3D] """ - volume = items.ScalarField3D() + if data is not None: + data = numpy.array(data, copy=False) + + if numpy.iscomplexobj(data): + volume = items.ComplexField3D() + else: + volume = items.ScalarField3D() volume.setData(data, copy=copy) self.addItem(volume, index) return volume + def add3DScalarField(self, data, copy=True, index=None): + # TODO deprecate in the future + return self.addVolume(data, copy=copy, index=index) + def add3DScatter(self, x, y, z, value, copy=True, index=None): """Add 3D scatter data to :class:`SceneWidget` content. diff --git a/silx/gui/plot3d/SceneWindow.py b/silx/gui/plot3d/SceneWindow.py index 56fb21f..052a4dc 100644 --- a/silx/gui/plot3d/SceneWindow.py +++ b/silx/gui/plot3d/SceneWindow.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,11 +33,13 @@ __date__ = "29/11/2017" from ...gui import qt, icons +from ...gui.widgets.BoxLayoutDockWidget import BoxLayoutDockWidget from .actions.mode import InteractiveModeAction from .SceneWidget import SceneWidget from .tools import OutputToolBar, InteractiveModeToolBar, ViewpointToolBar from .tools.GroupPropertiesWidget import GroupPropertiesWidget +from .tools.PositionInfoWidget import PositionInfoWidget from .ParamTreeView import ParamTreeView @@ -118,8 +120,19 @@ class SceneWindow(qt.QMainWindow): self._sceneWidget = SceneWidget() self.setCentralWidget(self._sceneWidget) + # Add PositionInfoWidget to display picking info + self._positionInfo = PositionInfoWidget() + self._positionInfo.setSceneWidget(self._sceneWidget) + + dock = BoxLayoutDockWidget() + dock.setWindowTitle("Selection Info") + dock.setWidget(self._positionInfo) + self.addDockWidget(qt.Qt.BottomDockWidgetArea, dock) + self._interactiveModeToolBar = InteractiveModeToolBar(parent=self) panPlaneAction = _PanPlaneAction(self, plot3d=self._sceneWidget) + self._interactiveModeToolBar.addAction( + self._positionInfo.toggleAction()) self._interactiveModeToolBar.addAction(panPlaneAction) self._viewpointToolBar = ViewpointToolBar(parent=self) @@ -197,3 +210,10 @@ class SceneWindow(qt.QMainWindow): :rtype: ~silx.gui.plot3d.tools.OutputToolBar """ return self._outputToolBar + + def getPositionInfoWidget(self): + """Returns the widget displaying selected position information. + + :rtype: ~silx.gui.plot3d.tools.PositionInfoWidget.PositionInfoWidget + """ + return self._positionInfo diff --git a/silx/gui/plot3d/_model/items.py b/silx/gui/plot3d/_model/items.py index 7e58d14..9fe3e51 100644 --- a/silx/gui/plot3d/_model/items.py +++ b/silx/gui/plot3d/_model/items.py @@ -33,6 +33,7 @@ __license__ = "MIT" __date__ = "24/04/2018" +from collections import OrderedDict import functools import logging import weakref @@ -45,6 +46,7 @@ from ...colors import preferredColormaps from ... import qt, icons from .. import items from ..items.volume import Isosurface, CutPlane +from ..Plot3DWidget import Plot3DWidget from .core import AngleDegreeRow, BaseRow, ColorProxyRow, ProxyRow, StaticRow @@ -53,6 +55,76 @@ from .core import AngleDegreeRow, BaseRow, ColorProxyRow, ProxyRow, StaticRow _logger = logging.getLogger(__name__) +class ItemProxyRow(ProxyRow): + """Provides a node to proxy a data accessible through functions. + + It listens on sigItemChanged to trigger the update. + + Warning: Only weak reference are kept on fget and fset. + + :param Item3D item: The item to + :param str name: The name of this node + :param callable fget: A callable returning the data + :param callable fset: + An optional callable setting the data with data as a single argument. + :param events: + An optional event kind or list of event kinds to react upon. + :param callable toModelData: + An optional callable to convert from fget + callable to data returned by the model. + :param callable fromModelData: + An optional callable converting data provided to the model to + data for fset. + :param editorHint: Data to provide as UserRole for editor selection/setup + """ + + def __init__(self, + item, + name='', + fget=None, + fset=None, + events=None, + toModelData=None, + fromModelData=None, + editorHint=None): + super(ItemProxyRow, self).__init__( + name=name, + fget=fget, + fset=fset, + notify=None, + toModelData=toModelData, + fromModelData=fromModelData, + editorHint=editorHint) + + if isinstance(events, (items.ItemChangedType, + items.Item3DChangedType)): + events = (events,) + self.__events = events + item.sigItemChanged.connect(self.__itemChanged) + + def __itemChanged(self, event): + """Handle item changed + + :param Union[ItemChangedType,Item3DChangedType] event: + """ + if self.__events is None or event in self.__events: + self._notified() + + +class ItemColorProxyRow(ColorProxyRow, ItemProxyRow): + """Combines :class:`ColorProxyRow` and :class:`ItemProxyRow`""" + + def __init__(self, *args, **kwargs): + ItemProxyRow.__init__(self, *args, **kwargs) + + +class ItemAngleDegreeRow(AngleDegreeRow, ItemProxyRow): + """Combines :class:`AngleDegreeRow` and :class:`ItemProxyRow`""" + + def __init__(self, *args, **kwargs): + ItemProxyRow.__init__(self, *args, **kwargs) + + class _DirectionalLightProxy(qt.QObject): """Proxy to handle directional light with angles rather than vector. """ @@ -67,8 +139,8 @@ class _DirectionalLightProxy(qt.QObject): super(_DirectionalLightProxy, self).__init__() self._light = light light.addListener(self._directionUpdated) - self._azimuth = 0. - self._altitude = 0. + self._azimuth = 0 + self._altitude = 0 def getAzimuthAngle(self): """Returns the signed angle in the horizontal plane. @@ -76,7 +148,7 @@ class _DirectionalLightProxy(qt.QObject): Unit: degrees. The 0 angle corresponds to the axis perpendicular to the screen. - :rtype: float + :rtype: int """ return self._azimuth @@ -86,15 +158,16 @@ class _DirectionalLightProxy(qt.QObject): Unit: degrees. Range: [-90, +90] - :rtype: float + :rtype: int """ return self._altitude def setAzimuthAngle(self, angle): """Set the horizontal angle. - :param float angle: Angle from -z axis in zx plane in degrees. + :param int angle: Angle from -z axis in zx plane in degrees. """ + angle = int(round(angle)) if angle != self._azimuth: self._azimuth = angle self._updateLight() @@ -103,8 +176,9 @@ class _DirectionalLightProxy(qt.QObject): def setAltitudeAngle(self, angle): """Set the horizontal angle. - :param float angle: Angle from -z axis in zy plane in degrees. + :param int angle: Angle from -z axis in zy plane in degrees. """ + angle = int(round(angle)) if angle != self._altitude: self._altitude = angle self._updateLight() @@ -117,20 +191,21 @@ class _DirectionalLightProxy(qt.QObject): x, y, z = - self._light.direction # Horizontal plane is plane xz - azimuth = numpy.degrees(numpy.arctan2(x, z)) - altitude = numpy.degrees(numpy.pi/2. - numpy.arccos(y)) + azimuth = int(round(numpy.degrees(numpy.arctan2(x, z)))) + altitude = int(round(numpy.degrees(numpy.pi/2. - numpy.arccos(y)))) - if (abs(azimuth - self.getAzimuthAngle()) > 0.01 and - abs(abs(altitude) - 90.) >= 0.001): # Do not update when at zenith + if azimuth != self.getAzimuthAngle(): self.setAzimuthAngle(azimuth) - if abs(altitude - self.getAltitudeAngle()) > 0.01: + if altitude != self.getAltitudeAngle(): self.setAltitudeAngle(altitude) def _updateLight(self): """Update light direction in the scene""" azimuth = numpy.radians(self._azimuth) delta = numpy.pi/2. - numpy.radians(self._altitude) + if delta == 0.: # Avoids zenith position + delta = 0.0001 z = - numpy.sin(delta) * numpy.cos(azimuth) x = - numpy.sin(delta) * numpy.sin(azimuth) y = - numpy.cos(delta) @@ -195,9 +270,18 @@ class Settings(StaticRow): lightDirection = StaticRow(('Light Direction', None), children=(azimuthNode, altitudeNode)) + # Fog + fog = ProxyRow( + name='Fog', + fget=sceneWidget.getFogMode, + fset=sceneWidget.setFogMode, + notify=sceneWidget.sigStyleChanged, + toModelData=lambda mode: mode is Plot3DWidget.FogMode.LINEAR, + fromModelData=lambda mode: Plot3DWidget.FogMode.LINEAR if mode else Plot3DWidget.FogMode.NONE) + # Settings row children = (background, foreground, text, highlight, - axesIndicator, lightDirection) + axesIndicator, lightDirection, fog) super(Settings, self).__init__(('Settings', None), children=children) @@ -208,6 +292,9 @@ class Item3DRow(BaseRow): :param str name: The optional name of the item """ + _EVENTS = items.ItemChangedType.VISIBLE, items.Item3DChangedType.LABEL + """Events for which to update the first column in the tree""" + def __init__(self, item, name=None): self.__name = None if name is None else six.text_type(name) super(Item3DRow, self).__init__() @@ -221,12 +308,11 @@ class Item3DRow(BaseRow): item.sigItemChanged.connect(self._itemChanged) def _itemChanged(self, event): - """Handle visibility change""" - if event in (items.ItemChangedType.VISIBLE, - items.Item3DChangedType.LABEL): + """Handle model update upon change""" + if event in self._EVENTS: model = self.model() if model is not None: - index = self.index(column=1) + index = self.index(column=0) model.dataChanged.emit(index, index) def item(self): @@ -268,7 +354,7 @@ class Item3DRow(BaseRow): return 2 -class DataItem3DBoundingBoxRow(ProxyRow): +class DataItem3DBoundingBoxRow(ItemProxyRow): """Represents :class:`DataItem3D` bounding box visibility :param DataItem3D item: The item for which to display/control bounding box @@ -276,13 +362,14 @@ class DataItem3DBoundingBoxRow(ProxyRow): def __init__(self, item): super(DataItem3DBoundingBoxRow, self).__init__( + item=item, name='Bounding box', fget=item.isBoundingBoxVisible, fset=item.setBoundingBoxVisible, - notify=item.sigItemChanged) + events=items.Item3DChangedType.BOUNDING_BOX_VISIBLE) -class MatrixProxyRow(ProxyRow): +class MatrixProxyRow(ItemProxyRow): """Proxy for a row of a DataItem3D 3x3 matrix transform :param DataItem3D item: @@ -294,10 +381,11 @@ class MatrixProxyRow(ProxyRow): self._index = index super(MatrixProxyRow, self).__init__( + item=item, name='', fget=self._getMatrixRow, fset=self._setMatrixRow, - notify=item.sigItemChanged) + events=items.Item3DChangedType.TRANSFORM) def _getMatrixRow(self): """Returns the matrix row. @@ -344,11 +432,13 @@ class DataItem3DTransformRow(StaticRow): super(DataItem3DTransformRow, self).__init__(('Transform', None)) self._item = weakref.ref(item) - translation = ProxyRow(name='Translation', - fget=item.getTranslation, - fset=self._setTranslation, - notify=item.sigItemChanged, - toModelData=lambda data: qt.QVector3D(*data)) + translation = ItemProxyRow( + item=item, + name='Translation', + fget=item.getTranslation, + fset=self._setTranslation, + events=items.Item3DChangedType.TRANSFORM, + toModelData=lambda data: qt.QVector3D(*data)) self.addRow(translation) # Here to keep a reference @@ -359,51 +449,60 @@ class DataItem3DTransformRow(StaticRow): rotateCenter = StaticRow( ('Center', None), children=( - ProxyRow(name='X axis', - fget=item.getRotationCenter, - fset=self._xSetCenter, - notify=item.sigItemChanged, - toModelData=functools.partial( - self._centerToModelData, index=0), - editorHint=self._ROTATION_CENTER_OPTIONS), - ProxyRow(name='Y axis', - fget=item.getRotationCenter, - fset=self._ySetCenter, - notify=item.sigItemChanged, - toModelData=functools.partial( - self._centerToModelData, index=1), - editorHint=self._ROTATION_CENTER_OPTIONS), - ProxyRow(name='Z axis', - fget=item.getRotationCenter, - fset=self._zSetCenter, - notify=item.sigItemChanged, - toModelData=functools.partial( - self._centerToModelData, index=2), - editorHint=self._ROTATION_CENTER_OPTIONS), + ItemProxyRow(item=item, + name='X axis', + fget=item.getRotationCenter, + fset=self._xSetCenter, + events=items.Item3DChangedType.TRANSFORM, + toModelData=functools.partial( + self._centerToModelData, index=0), + editorHint=self._ROTATION_CENTER_OPTIONS), + ItemProxyRow(item=item, + name='Y axis', + fget=item.getRotationCenter, + fset=self._ySetCenter, + events=items.Item3DChangedType.TRANSFORM, + toModelData=functools.partial( + self._centerToModelData, index=1), + editorHint=self._ROTATION_CENTER_OPTIONS), + ItemProxyRow(item=item, + name='Z axis', + fget=item.getRotationCenter, + fset=self._zSetCenter, + events=items.Item3DChangedType.TRANSFORM, + toModelData=functools.partial( + self._centerToModelData, index=2), + editorHint=self._ROTATION_CENTER_OPTIONS), )) rotate = StaticRow( ('Rotation', None), children=( - AngleDegreeRow(name='Angle', - fget=item.getRotation, - fset=self._setAngle, - notify=item.sigItemChanged, - toModelData=lambda data: data[0]), - ProxyRow(name='Axis', - fget=item.getRotation, - fset=self._setAxis, - notify=item.sigItemChanged, - toModelData=lambda data: qt.QVector3D(*data[1])), + ItemAngleDegreeRow( + item=item, + name='Angle', + fget=item.getRotation, + fset=self._setAngle, + events=items.Item3DChangedType.TRANSFORM, + toModelData=lambda data: data[0]), + ItemProxyRow( + item=item, + name='Axis', + fget=item.getRotation, + fset=self._setAxis, + events=items.Item3DChangedType.TRANSFORM, + toModelData=lambda data: qt.QVector3D(*data[1])), rotateCenter )) self.addRow(rotate) - scale = ProxyRow(name='Scale', - fget=item.getScale, - fset=self._setScale, - notify=item.sigItemChanged, - toModelData=lambda data: qt.QVector3D(*data)) + scale = ItemProxyRow( + item=item, + name='Scale', + fget=item.getScale, + fset=self._setScale, + events=items.Item3DChangedType.TRANSFORM, + toModelData=lambda data: qt.QVector3D(*data)) self.addRow(scale) matrix = StaticRow( @@ -545,7 +644,7 @@ class GroupItemRow(Item3DRow): raise RuntimeError("Model does not correspond to scene content") -class InterpolationRow(ProxyRow): +class InterpolationRow(ItemProxyRow): """Represents :class:`InterpolationMixIn` property. :param Item3D item: Scene item with interpolation property @@ -554,10 +653,11 @@ class InterpolationRow(ProxyRow): def __init__(self, item): modes = [mode.title() for mode in item.INTERPOLATION_MODES] super(InterpolationRow, self).__init__( + item=item, name='Interpolation', fget=item.getInterpolation, fset=item.setInterpolation, - notify=item.sigItemChanged, + events=items.Item3DChangedType.INTERPOLATION, toModelData=lambda mode: mode.title(), fromModelData=lambda mode: mode.lower(), editorHint=modes) @@ -817,7 +917,7 @@ class ColormapRow(_ColormapBaseProxyRow): return super(ColormapRow, self).data(column, role) -class SymbolRow(ProxyRow): +class SymbolRow(ItemProxyRow): """Represents :class:`SymbolMixIn` symbol property. :param Item3D item: Scene item with symbol property @@ -826,14 +926,15 @@ class SymbolRow(ProxyRow): def __init__(self, item): names = [item.getSymbolName(s) for s in item.getSupportedSymbols()] super(SymbolRow, self).__init__( - name='Marker', - fget=item.getSymbolName, - fset=item.setSymbol, - notify=item.sigItemChanged, - editorHint=names) + item=item, + name='Marker', + fget=item.getSymbolName, + fset=item.setSymbol, + events=items.ItemChangedType.SYMBOL, + editorHint=names) -class SymbolSizeRow(ProxyRow): +class SymbolSizeRow(ItemProxyRow): """Represents :class:`SymbolMixIn` symbol size property. :param Item3D item: Scene item with symbol size property @@ -841,25 +942,27 @@ class SymbolSizeRow(ProxyRow): def __init__(self, item): super(SymbolSizeRow, self).__init__( + item=item, name='Marker size', fget=item.getSymbolSize, fset=item.setSymbolSize, - notify=item.sigItemChanged, + events=items.ItemChangedType.SYMBOL_SIZE, editorHint=(1, 20)) # TODO link with OpenGL max point size -class PlaneRow(ProxyRow): - """Represents :class:`PlaneMixIn` property. +class PlaneEquationRow(ItemProxyRow): + """Represents :class:`PlaneMixIn` as plane equation. :param Item3D item: Scene item with plane equation property """ def __init__(self, item): - super(PlaneRow, self).__init__( + super(PlaneEquationRow, self).__init__( + item=item, name='Equation', fget=item.getParameters, fset=item.setParameters, - notify=item.sigItemChanged, + events=items.ItemChangedType.POSITION, toModelData=lambda data: qt.QVector4D(*data), fromModelData=lambda data: (data.x(), data.y(), data.z(), data.w())) self._item = weakref.ref(item) @@ -871,7 +974,99 @@ class PlaneRow(ProxyRow): params = item.getParameters() return ('%gx %+gy %+gz %+g = 0' % (params[0], params[1], params[2], params[3])) - return super(PlaneRow, self).data(column, role) + return super(PlaneEquationRow, self).data(column, role) + + +class PlaneRow(ItemProxyRow): + """Represents :class:`PlaneMixIn` property. + + :param Item3D item: Scene item with plane equation property + """ + + _PLANES = OrderedDict((('Plane 0', (1., 0., 0.)), + ('Plane 1', (0., 1., 0.)), + ('Plane 2', (0., 0., 1.)), + ('-', None))) + """Mapping of plane names to normals""" + + _PLANE_ICONS = {'Plane 0': '3d-plane-normal-x', + 'Plane 1': '3d-plane-normal-y', + 'Plane 2': '3d-plane-normal-z', + '-': '3d-plane'} + """Mapping of plane names to normals""" + + def __init__(self, item): + super(PlaneRow, self).__init__( + item=item, + name='Plane', + fget=self.__getPlaneName, + fset=self.__setPlaneName, + events=items.ItemChangedType.POSITION, + editorHint=tuple(self._PLANES.keys())) + self._item = weakref.ref(item) + self._lastName = None + + self.addRow(PlaneEquationRow(item)) + + def _notified(self, *args, **kwargs): + """Handle notification of modification + + Here only send if plane name actually changed + """ + if self._lastName != self.__getPlaneName(): + super(PlaneRow, self)._notified() + + def __getPlaneName(self): + """Returns name of plane // to axes or '-' + + :rtype: str + """ + item = self._item() + planeNormal = item.getNormal() if item is not None else None + + for name, normal in self._PLANES.items(): + if numpy.array_equal(planeNormal, normal): + return name + return '-' + + def __setPlaneName(self, data): + """Set plane normal according to given plane name + + :param str data: Selected plane name + """ + item = self._item() + if item is not None: + for name, normal in self._PLANES.items(): + if data == name and normal is not None: + item.setNormal(normal) + + def data(self, column, role): + if column == 1 and role == qt.Qt.DecorationRole: + return icons.getQIcon(self._PLANE_ICONS[self.__getPlaneName()]) + data = super(PlaneRow, self).data(column, role) + if column == 1 and role == qt.Qt.DisplayRole: + self._lastName = data + return data + + +class ComplexModeRow(ItemProxyRow): + """Represents :class:`items.ComplexMixIn` symbol property. + + :param Item3D item: Scene item with symbol property + """ + + def __init__(self, item): + names = [m.value.replace('_', ' ').title() + for m in item.supportedComplexModes()] + super(ComplexModeRow, self).__init__( + item=item, + name='Mode', + fget=item.getComplexMode, + fset=item.setComplexMode, + events=items.ItemChangedType.COMPLEX_MODE, + toModelData=lambda data: data.value.replace('_', ' ').title(), + fromModelData=lambda data: data.lower().replace(' ', '_'), + editorHint=names) class RemoveIsosurfaceRow(BaseRow): @@ -923,9 +1118,9 @@ class RemoveIsosurfaceRow(BaseRow): """Handle Delete button clicked""" isosurface = self.isosurface() if isosurface is not None: - scalarField3D = isosurface.parent() - if scalarField3D is not None: - scalarField3D.removeIsosurface(isosurface) + volume = isosurface.parent() + if volume is not None: + volume.removeIsosurface(isosurface) class IsosurfaceRow(Item3DRow): @@ -937,6 +1132,9 @@ class IsosurfaceRow(Item3DRow): _LEVEL_SLIDER_RANGE = 0, 1000 """Range given as editor hint""" + _EVENTS = items.ItemChangedType.VISIBLE, items.ItemChangedType.COLOR + """Events for which to update the first column in the tree""" + def __init__(self, item): super(IsosurfaceRow, self).__init__(item, name=item.getLevel()) @@ -944,24 +1142,27 @@ class IsosurfaceRow(Item3DRow): item.sigItemChanged.connect(self._levelChanged) - self.addRow(ProxyRow( + self.addRow(ItemProxyRow( + item=item, name='Level', fget=self._getValueForLevelSlider, fset=self._setLevelFromSliderValue, - notify=item.sigItemChanged, + events=items.Item3DChangedType.ISO_LEVEL, editorHint=self._LEVEL_SLIDER_RANGE)) - self.addRow(ColorProxyRow( + self.addRow(ItemColorProxyRow( + item=item, name='Color', fget=self._rgbColor, fset=self._setRgbColor, - notify=item.sigItemChanged)) + events=items.ItemChangedType.COLOR)) - self.addRow(ProxyRow( + self.addRow(ItemProxyRow( + item=item, name='Opacity', fget=self._opacity, fset=self._setOpacity, - notify=item.sigItemChanged, + events=items.ItemChangedType.COLOR, editorHint=(0, 255))) self.addRow(RemoveIsosurfaceRow(item)) @@ -973,12 +1174,15 @@ class IsosurfaceRow(Item3DRow): """ item = self.item() if item is not None: - scalarField3D = item.parent() - if scalarField3D is not None: - dataRange = scalarField3D.getDataRange() + volume = item.parent() + if volume is not None: + dataRange = volume.getDataRange() if dataRange is not None: dataMin, dataMax = dataRange[0], dataRange[-1] - offset = (item.getLevel() - dataMin) / (dataMax - dataMin) + if dataMax != dataMin: + offset = (item.getLevel() - dataMin) / (dataMax - dataMin) + else: + offset = 0. sliderMin, sliderMax = self._LEVEL_SLIDER_RANGE value = sliderMin + (sliderMax - sliderMin) * offset @@ -992,9 +1196,9 @@ class IsosurfaceRow(Item3DRow): """ item = self.item() if item is not None: - scalarField3D = item.parent() - if scalarField3D is not None: - dataRange = scalarField3D.getDataRange() + volume = item.parent() + if volume is not None: + dataRange = volume.getDataRange() if dataRange is not None: sliderMin, sliderMax = self._LEVEL_SLIDER_RANGE offset = (value - sliderMin) / (sliderMax - sliderMin) @@ -1082,13 +1286,13 @@ class IsosurfaceRow(Item3DRow): class AddIsosurfaceRow(BaseRow): """Class for Isosurface create button - :param ScalarField3D scalarField3D: - The ScalarField3D item to attach the button to. + :param Union[ScalarField3D,ComplexField3D] volume: + The volume item to attach the button to. """ - def __init__(self, scalarField3D): + def __init__(self, volume): super(AddIsosurfaceRow, self).__init__() - self._scalarField3D = weakref.ref(scalarField3D) + self._volume = weakref.ref(volume) def createEditor(self): """Specific editor factory provided to the model""" @@ -1106,12 +1310,12 @@ class AddIsosurfaceRow(BaseRow): layout.addStretch(1) return editor - def scalarField3D(self): - """Returns the controlled ScalarField3D + def volume(self): + """Returns the controlled volume item - :rtype: ScalarField3D + :rtype: Union[ScalarField3D,ComplexField3D] """ - return self._scalarField3D() + return self._volume() def data(self, column, role): if column == 0 and role == qt.Qt.UserRole: # editor hint @@ -1127,53 +1331,59 @@ class AddIsosurfaceRow(BaseRow): def _addClicked(self): """Handle Delete button clicked""" - scalarField3D = self.scalarField3D() - if scalarField3D is not None: - dataRange = scalarField3D.getDataRange() + volume = self.volume() + if volume is not None: + dataRange = volume.getDataRange() if dataRange is None: dataRange = 0., 1. - scalarField3D.addIsosurface( + volume.addIsosurface( numpy.mean((dataRange[0], dataRange[-1])), '#0000FF') -class ScalarField3DIsoSurfacesRow(StaticRow): +class VolumeIsoSurfacesRow(StaticRow): """Represents :class:`ScalarFieldView`'s isosurfaces - :param ScalarFieldView scalarField3D: ScalarFieldView to control + :param Union[ScalarField3D,ComplexField3D] volume: + Volume item to control """ - def __init__(self, scalarField3D): - super(ScalarField3DIsoSurfacesRow, self).__init__( + def __init__(self, volume): + super(VolumeIsoSurfacesRow, self).__init__( ('Isosurfaces', None)) - self._scalarField3D = weakref.ref(scalarField3D) + self._volume = weakref.ref(volume) - scalarField3D.sigIsosurfaceAdded.connect(self._isosurfaceAdded) - scalarField3D.sigIsosurfaceRemoved.connect(self._isosurfaceRemoved) + volume.sigIsosurfaceAdded.connect(self._isosurfaceAdded) + volume.sigIsosurfaceRemoved.connect(self._isosurfaceRemoved) - for item in scalarField3D.getIsosurfaces(): + if isinstance(volume, items.ComplexMixIn): + self.addRow(ComplexModeRow(volume)) + + for item in volume.getIsosurfaces(): self.addRow(nodeFromItem(item)) - self.addRow(AddIsosurfaceRow(scalarField3D)) + self.addRow(AddIsosurfaceRow(volume)) - def scalarField3D(self): - """Returns the controlled ScalarField3D + def volume(self): + """Returns the controlled volume item - :rtype: ScalarField3D + :rtype: Union[ScalarField3D,ComplexField3D] """ - return self._scalarField3D() + return self._volume() def _isosurfaceAdded(self, item): """Handle isosurface addition :param Isosurface item: added isosurface """ - scalarField3D = self.scalarField3D() - if scalarField3D is None: + volume = self.volume() + if volume is None: return - row = scalarField3D.getIsosurfaces().index(item) + row = volume.getIsosurfaces().index(item) + if isinstance(volume, items.ComplexMixIn): + row += 1 # Offset for the ComplexModeRow self.addRow(nodeFromItem(item), row) def _isosurfaceRemoved(self, item): @@ -1181,13 +1391,13 @@ class ScalarField3DIsoSurfacesRow(StaticRow): :param Isosurface item: removed isosurface """ - scalarField3D = self.scalarField3D() - if scalarField3D is None: + volume = self.volume() + if volume is None: return # Find item for row in self.children(): - if row.item() is item: + if isinstance(row, IsosurfaceRow) and row.item() is item: self.removeRow(row) break # Got it else: @@ -1267,7 +1477,7 @@ class Scatter2DSymbolSizeRow(Scatter2DPropertyMixInRow, SymbolSizeRow): Scatter2DPropertyMixInRow.__init__(self, item, 'symbolSize') -class Scatter2DLineWidth(Scatter2DPropertyMixInRow, ProxyRow): +class Scatter2DLineWidth(Scatter2DPropertyMixInRow, ItemProxyRow): """Specific class for Scatter2D symbol size. It is enabled/disabled according to visualization mode. @@ -1277,12 +1487,13 @@ class Scatter2DLineWidth(Scatter2DPropertyMixInRow, ProxyRow): def __init__(self, item): # TODO link editorHint with OpenGL max line width - ProxyRow.__init__(self, - name='Line width', - fget=item.getLineWidth, - fset=item.setLineWidth, - notify=item.sigItemChanged, - editorHint=(1, 10)) + ItemProxyRow.__init__(self, + item=item, + name='Line width', + fget=item.getLineWidth, + fset=item.setLineWidth, + events=items.ItemChangedType.LINE_WIDTH, + editorHint=(1, 10)) Scatter2DPropertyMixInRow.__init__(self, item, 'lineWidth') @@ -1292,20 +1503,22 @@ def initScatter2DNode(node, item): :param Item3DRow node: The model node to setup :param Scatter2D item: The Scatter2D the node is representing """ - node.addRow(ProxyRow( + node.addRow(ItemProxyRow( + item=item, name='Mode', fget=item.getVisualization, fset=item.setVisualization, - notify=item.sigItemChanged, - editorHint=[m.title() for m in item.supportedVisualizations()], - toModelData=lambda data: data.title(), + events=items.ItemChangedType.VISUALIZATION_MODE, + editorHint=[m.value.title() for m in item.supportedVisualizations()], + toModelData=lambda data: data.value.title(), fromModelData=lambda data: data.lower())) - node.addRow(ProxyRow( + node.addRow(ItemProxyRow( + item=item, name='Height map', fget=item.isHeightMap, fset=item.setHeightMap, - notify=item.sigItemChanged)) + events=items.Item3DChangedType.HEIGHT_MAP)) node.addRow(ColormapRow(item)) @@ -1315,39 +1528,44 @@ def initScatter2DNode(node, item): node.addRow(Scatter2DLineWidth(item)) -def initScalarField3DNode(node, item): - """Specific node init for ScalarField3D +def initVolumeNode(node, item): + """Specific node init for volume items :param Item3DRow node: The model node to setup - :param ScalarField3D item: The ScalarField3D the node is representing + :param Union[ScalarField3D,ComplexField3D] item: + The volume item represented by the node """ node.addRow(nodeFromItem(item.getCutPlanes()[0])) # Add cut plane - node.addRow(ScalarField3DIsoSurfacesRow(item)) + node.addRow(VolumeIsoSurfacesRow(item)) -def initScalarField3DCutPlaneNode(node, item): - """Specific node init for ScalarField3D CutPlane +def initVolumeCutPlaneNode(node, item): + """Specific node init for volume CutPlane :param Item3DRow node: The model node to setup :param CutPlane item: The CutPlane the node is representing """ + if isinstance(item, items.ComplexMixIn): + node.addRow(ComplexModeRow(item)) + node.addRow(PlaneRow(item)) node.addRow(ColormapRow(item)) - node.addRow(ProxyRow( - name='Values<=Min', + node.addRow(ItemProxyRow( + item=item, + name='Show <=Min', fget=item.getDisplayValuesBelowMin, fset=item.setDisplayValuesBelowMin, - notify=item.sigItemChanged)) + events=items.ItemChangedType.ALPHA)) node.addRow(InterpolationRow(item)) NODE_SPECIFIC_INIT = [ # class, init(node, item) (items.Scatter2D, initScatter2DNode), - (items.ScalarField3D, initScalarField3DNode), - (CutPlane, initScalarField3DCutPlaneNode), + (items.ScalarField3D, initVolumeNode), + (CutPlane, initVolumeCutPlaneNode), ] """List of specific node init for different item class""" diff --git a/silx/gui/plot3d/actions/mode.py b/silx/gui/plot3d/actions/mode.py index b591290..ce09b4c 100644 --- a/silx/gui/plot3d/actions/mode.py +++ b/silx/gui/plot3d/actions/mode.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,7 +24,8 @@ # ###########################################################################*/ """This module provides Plot3DAction related to interaction modes. -It provides QAction to rotate or pan a Plot3DWidget. +It provides QAction to rotate or pan a Plot3DWidget +as well as toggle a picking mode. """ from __future__ import absolute_import, division @@ -36,7 +37,9 @@ __date__ = "06/09/2017" import logging -from silx.gui.icons import getQIcon +from ....utils.proxy import docstring +from ... import qt +from ...icons import getQIcon from .Plot3DAction import Plot3DAction @@ -69,6 +72,7 @@ class InteractiveModeAction(Plot3DAction): plot3d.setInteractiveMode(self._interaction) self.setChecked(True) + @docstring(Plot3DAction) def setPlot3DWidget(self, widget): # Disconnect from previous Plot3DWidget plot3d = self.getPlot3DWidget() @@ -86,9 +90,6 @@ class InteractiveModeAction(Plot3DAction): widget.sigInteractiveModeChanged.connect( self._interactiveModeChanged) - # Reuse docstring from super class - setPlot3DWidget.__doc__ = Plot3DAction.setPlot3DWidget.__doc__ - def _interactiveModeChanged(self): plot3d = self.getPlot3DWidget() if plot3d is None: @@ -127,3 +128,51 @@ class PanAction(InteractiveModeAction): self.setIcon(getQIcon('pan')) self.setText('Pan') self.setToolTip('Pan the view. Press Ctrl to rotate.') + + +class PickingModeAction(Plot3DAction): + """QAction to toggle picking moe on a Plot3DWidget + + :param parent: See :class:`QAction` + :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d: + Plot3DWidget the action is associated with + """ + + sigSceneClicked = qt.Signal(float, float) + """Signal emitted when the scene is clicked with the left mouse button. + + This signal is only emitted when the action is checked. + + It provides the (x, y) clicked mouse position + """ + + def __init__(self, parent, plot3d=None): + super(PickingModeAction, self).__init__(parent, plot3d) + self.setIcon(getQIcon('pointing-hand')) + self.setText('Picking') + self.setToolTip('Toggle picking with left button click') + self.setCheckable(True) + self.triggered[bool].connect(self._triggered) + + def _triggered(self, checked=False): + plot3d = self.getPlot3DWidget() + if plot3d is not None: + if checked: + plot3d.sigSceneClicked.connect(self.sigSceneClicked) + else: + plot3d.sigSceneClicked.disconnect(self.sigSceneClicked) + + @docstring(Plot3DAction) + def setPlot3DWidget(self, widget): + # Disconnect from previous Plot3DWidget + plot3d = self.getPlot3DWidget() + if plot3d is not None and self.isChecked(): + plot3d.sigSceneClicked.disconnect(self.sigSceneClicked) + + super(PickingModeAction, self).setPlot3DWidget(widget) + + # Connect to new Plot3DWidget + if widget is None: + self.setChecked(False) + elif self.isChecked(): + widget.sigSceneClicked.connect(self.sigSceneClicked) diff --git a/silx/gui/plot3d/items/__init__.py b/silx/gui/plot3d/items/__init__.py index 58eee9c..5810618 100644 --- a/silx/gui/plot3d/items/__init__.py +++ b/silx/gui/plot3d/items/__init__.py @@ -34,10 +34,10 @@ __date__ = "15/11/2017" from .core import DataItem3D, Item3D, GroupItem, GroupWithAxesItem # noqa from .core import ItemChangedType, Item3DChangedType # noqa -from .mixins import (ColormapMixIn, InterpolationMixIn, # noqa +from .mixins import (ColormapMixIn, ComplexMixIn, InterpolationMixIn, # noqa PlaneMixIn, SymbolMixIn) # noqa from .clipplane import ClipPlane # noqa from .image import ImageData, ImageRgba # noqa from .mesh import Mesh, ColormapMesh, Box, Cylinder, Hexagon # noqa from .scatter import Scatter2D, Scatter3D # noqa -from .volume import ScalarField3D # noqa +from .volume import ComplexField3D, ScalarField3D # noqa diff --git a/silx/gui/plot3d/items/mesh.py b/silx/gui/plot3d/items/mesh.py index d3f5e38..3577dbf 100644 --- a/silx/gui/plot3d/items/mesh.py +++ b/silx/gui/plot3d/items/mesh.py @@ -35,6 +35,7 @@ __date__ = "17/07/2018" import logging import numpy +from ... import _glutils as glu from ..scene import primitives, utils, function from ..scene.transform import Rotate from .core import DataItem3D, ItemChangedType @@ -168,7 +169,7 @@ class _MeshBase(DataItem3D): _logger.warning("Unsupported draw mode: %s" % mode) return None - trianglesIndices, t, barycentric = utils.segmentTrianglesIntersection( + trianglesIndices, t, barycentric = glu.segmentTrianglesIntersection( rayObject, triangles) if len(trianglesIndices) == 0: @@ -494,7 +495,7 @@ class _CylindricalVolume(DataItem3D): positions = self._mesh.getAttribute('position', copy=False) triangles = positions.reshape(-1, 3, 3) # 'triangle' draw mode - trianglesIndices, t = utils.segmentTrianglesIntersection( + trianglesIndices, t = glu.segmentTrianglesIntersection( rayObject, triangles)[:2] if len(trianglesIndices) == 0: diff --git a/silx/gui/plot3d/items/mixins.py b/silx/gui/plot3d/items/mixins.py index 40b8438..b355627 100644 --- a/silx/gui/plot3d/items/mixins.py +++ b/silx/gui/plot3d/items/mixins.py @@ -38,6 +38,7 @@ from silx.math.combo import min_max from ...plot.items.core import ItemMixInBase from ...plot.items.core import ColormapMixIn as _ColormapMixIn from ...plot.items.core import SymbolMixIn as _SymbolMixIn +from ...plot.items.core import ComplexMixIn as _ComplexMixIn from ...colors import rgba from ..scene import primitives @@ -139,8 +140,9 @@ class ColormapMixIn(_ColormapMixIn): self._dataRange = dataRange - if self.getColormap().isAutoscale(): - self._syncSceneColormap() + colormap = self.getColormap() + if None in (colormap.getVMin(), colormap.getVMax()): + self._colormapChanged() def _getDataRange(self): """Returns the data range as used in the scene for colormap @@ -173,6 +175,18 @@ class ColormapMixIn(_ColormapMixIn): self.__sceneColormap.range_ = range_ +class ComplexMixIn(_ComplexMixIn): + __doc__ = _ComplexMixIn.__doc__ # Reuse docstring + + _SUPPORTED_COMPLEX_MODES = ( + _ComplexMixIn.ComplexMode.REAL, + _ComplexMixIn.ComplexMode.IMAGINARY, + _ComplexMixIn.ComplexMode.ABSOLUTE, + _ComplexMixIn.ComplexMode.PHASE, + _ComplexMixIn.ComplexMode.SQUARE_AMPLITUDE) + """Overrides supported ComplexMode""" + + class SymbolMixIn(_SymbolMixIn): """Mix-in class for symbol and symbolSize properties for Item3D""" diff --git a/silx/gui/plot3d/items/scatter.py b/silx/gui/plot3d/items/scatter.py index b7bcd09..e8ffee1 100644 --- a/silx/gui/plot3d/items/scatter.py +++ b/silx/gui/plot3d/items/scatter.py @@ -31,14 +31,19 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" __date__ = "15/11/2017" -import collections +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc import logging -import sys import numpy from ....utils.deprecation import deprecated +from ... import _glutils as glu +from ...plot._utils.delaunay import delaunay from ..scene import function, primitives, utils +from ...plot.items import ScatterVisualizationMixIn from .core import DataItem3D, Item3DChangedType, ItemChangedType from .mixins import ColormapMixIn, SymbolMixIn from ._pick import PickingResult @@ -213,16 +218,19 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn): return None -class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): +class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn, + ScatterVisualizationMixIn): """2D scatter data with settable visualization mode. :param parent: The View widget this item belongs to. """ _VISUALIZATION_PROPERTIES = { - 'points': ('symbol', 'symbolSize'), - 'lines': ('lineWidth',), - 'solid': (), + ScatterVisualizationMixIn.Visualization.POINTS: + ('symbol', 'symbolSize'), + ScatterVisualizationMixIn.Visualization.LINES: + ('lineWidth',), + ScatterVisualizationMixIn.Visualization.SOLID: (), } """Dict {visualization mode: property names used in this mode}""" @@ -230,8 +238,8 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): DataItem3D.__init__(self, parent=parent) ColormapMixIn.__init__(self) SymbolMixIn.__init__(self) + ScatterVisualizationMixIn.__init__(self) - self._visualizationMode = 'points' self._heightMap = False self._lineWidth = 1. @@ -254,48 +262,14 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): child.marker = symbol child.setAttribute('size', size, copy=True) - elif event == ItemChangedType.VISIBLE: + elif event is ItemChangedType.VISIBLE: # TODO smart update?, need dirty flags self._updateScene() - super(Scatter2D, self)._updated(event) - - def supportedVisualizations(self): - """Returns the list of supported visualization modes. - - See :meth:`setVisualizationModes` - - :rtype: tuple of str - """ - return tuple(self._VISUALIZATION_PROPERTIES.keys()) - - def setVisualization(self, mode): - """Set the visualization mode of the data. - - Supported visualization modes are: - - - 'points': For scatter plot representation - - 'lines': For Delaunay tessellation-based wireframe representation - - 'solid': For Delaunay tessellation-based solid surface representation - - :param str mode: Mode of representation to use - """ - mode = str(mode) - assert mode in self.supportedVisualizations() - - if mode != self.getVisualization(): - self._visualizationMode = mode + elif event is ItemChangedType.VISUALIZATION_MODE: self._updateScene() - self._updated(ItemChangedType.VISUALIZATION_MODE) - def getVisualization(self): - """Returns the current visualization mode. - - See :meth:`setVisualization` - - :rtype: str - """ - return self._visualizationMode + super(Scatter2D, self)._updated(event) def isPropertyEnabled(self, name, visualization=None): """Returns true if the property is used with visualization mode. @@ -374,7 +348,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): y, copy=copy, dtype=numpy.float32, order='C').reshape(-1) assert len(x) == len(y) - if isinstance(value, collections.Iterable): + if isinstance(value, abc.Iterable): value = numpy.array( value, copy=copy, dtype=numpy.float32, order='C').reshape(-1) assert len(value) == len(x) @@ -503,7 +477,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): trianglesIndices = self._cachedTrianglesIndices.reshape(-1, 3) triangles = points[trianglesIndices, :3] - selectedIndices, t, barycentric = utils.segmentTrianglesIntersection( + selectedIndices, t, barycentric = glu.segmentTrianglesIntersection( rayObject, triangles) closest = numpy.argmax(barycentric, axis=1) @@ -542,14 +516,14 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): numpy.ones_like(xData))) mode = self.getVisualization() - if mode == 'points': + if mode is self.Visualization.POINTS: # TODO issue with symbol size: using pixel instead of points # Get "corrected" symbol size _, threshold = self._getSceneSymbol() return self._pickPoints( context, points, threshold=max(3., threshold)) - elif mode == 'lines': + elif mode is self.Visualization.LINES: # Picking only at point return self._pickPoints(context, points, threshold=5.) @@ -569,7 +543,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): mode = self.getVisualization() heightMap = self.isHeightMap() - if mode == 'points': + if mode is self.Visualization.POINTS: z = value if heightMap else 0. symbol, size = self._getSceneSymbol() primitive = primitives.Points( @@ -582,35 +556,19 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): # TODO run delaunay in a thread # Compute lines/triangles indices if not cached if self._cachedTrianglesIndices is None: - coordinates = numpy.array((x, y)).T - - if len(coordinates) > 3: - # Enough points to try a Delaunay tesselation - - # Lazy loading of Delaunay - from silx.third_party.scipy_spatial import Delaunay as _Delaunay - - try: - tri = _Delaunay(coordinates) - except RuntimeError: - _logger.error("Delaunay tesselation failed: %s", - sys.exc_info()[1]) - return None - - self._cachedTrianglesIndices = numpy.ravel( - tri.simplices.astype(numpy.uint32)) - - else: - # 3 or less points: Draw one triangle - self._cachedTrianglesIndices = \ - numpy.arange(3, dtype=numpy.uint32) % len(coordinates) - - if mode == 'lines' and self._cachedLinesIndices is None: + triangulation = delaunay(x, y) + if triangulation is None: + return None + self._cachedTrianglesIndices = numpy.ravel( + triangulation.simplices.astype(numpy.uint32)) + + if (mode is self.Visualization.LINES and + self._cachedLinesIndices is None): # Compute line indices self._cachedLinesIndices = utils.triangleToLineIndices( self._cachedTrianglesIndices, unicity=True) - if mode == 'lines': + if mode is self.Visualization.LINES: indices = self._cachedLinesIndices renderMode = 'lines' else: @@ -627,7 +585,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn): # TODO option to enable/disable light, cache normals # TODO smooth surface - if mode == 'solid': + if mode is self.Visualization.SOLID: if heightMap: coordinates = coordinates[indices] if len(value) > 1: diff --git a/silx/gui/plot3d/items/volume.py b/silx/gui/plot3d/items/volume.py index 08ad02a..ae91e82 100644 --- a/silx/gui/plot3d/items/volume.py +++ b/silx/gui/plot3d/items/volume.py @@ -38,13 +38,15 @@ import numpy from silx.math.combo import min_max from silx.math.marchingcubes import MarchingCubes +from ....utils.proxy import docstring +from ... import _glutils as glu from ... import qt from ...colors import rgba from ..scene import cutplane, primitives, transform, utils from .core import BaseNodeItem, Item3D, ItemChangedType, Item3DChangedType -from .mixins import ColormapMixIn, InterpolationMixIn, PlaneMixIn +from .mixins import ColormapMixIn, ComplexMixIn, InterpolationMixIn, PlaneMixIn from ._pick import PickingResult @@ -60,12 +62,13 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn): def __init__(self, parent): plane = cutplane.CutPlane(normal=(0, 1, 0)) - Item3D.__init__(self, parent=parent) + Item3D.__init__(self, parent=None) ColormapMixIn.__init__(self) InterpolationMixIn.__init__(self) PlaneMixIn.__init__(self, plane=plane) self._dataRange = None + self._data = None self._getScenePrimitive().children = [plane] @@ -73,20 +76,53 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn): ColormapMixIn._setSceneColormap(self, plane.colormap) InterpolationMixIn._setPrimitive(self, plane) - parent.sigItemChanged.connect(self._parentChanged) + self.setParent(parent) + + def _updateData(self, data, range_): + """Update used dataset + + No copy is made. + + :param Union[numpy.ndarray[float],None] data: The dataset + :param Union[List[float],None] range_: + (min, min positive, max) values + """ + self._data = None if data is None else numpy.array(data, copy=False) + self._getPlane().setData(self._data, copy=False) + + # Store data range info as 3-tuple of values + self._dataRange = range_ + self._setRangeFromData( + None if self._dataRange is None else numpy.array(self._dataRange)) + + self._updated(ItemChangedType.DATA) + + def _syncDataWithParent(self): + """Synchronize this instance data with that of its parent""" + parent = self.parent() + if parent is None: + data, range_ = None, None + else: + data = parent.getData(copy=False) + range_ = parent.getDataRange() + self._updateData(data, range_) def _parentChanged(self, event): """Handle data change in the parent this plane belongs to""" if event == ItemChangedType.DATA: - data = self.sender().getData(copy=False) - self._getPlane().setData(data, copy=False) + self._syncDataWithParent() + + def setParent(self, parent): + oldParent = self.parent() + if isinstance(oldParent, Item3D): + oldParent.sigItemChanged.disconnect(self._parentChanged) - # Store data range info as 3-tuple of values - self._dataRange = self.sender().getDataRange() - self._setRangeFromData( - None if self._dataRange is None else numpy.array(self._dataRange)) + super(CutPlane, self).setParent(parent) - self._updated(ItemChangedType.DATA) + if isinstance(parent, Item3D): + parent.sigItemChanged.connect(self._parentChanged) + + self._syncDataWithParent() # Colormap @@ -114,8 +150,9 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn): positive min is NaN if no data is positive. :return: (min, positive min, max) or None. + :rtype: Union[List[float],None] """ - return self._dataRange + return None if self._dataRange is None else tuple(self._dataRange) def getData(self, copy=True): """Return 3D dataset. @@ -125,8 +162,10 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn): False to get the internal data (DO NOT modify!) :return: The data set (or None if not set) """ - parent = self.parent() - return None if parent is None else parent.getData(copy=copy) + if self._data is None: + return None + else: + return numpy.array(self._data, copy=copy) def _pickFull(self, context): """Perform picking in this item at given widget position. @@ -172,18 +211,38 @@ class Isosurface(Item3D): """ def __init__(self, parent): - Item3D.__init__(self, parent=parent) - assert isinstance(parent, ScalarField3D) - parent.sigItemChanged.connect(self._scalarField3DChanged) + Item3D.__init__(self, parent=None) + self._data = None self._level = float('nan') self._autoLevelFunction = None self._color = rgba('#FFD700FF') + self.setParent(parent) + + def _syncDataWithParent(self): + """Synchronize this instance data with that of its parent""" + parent = self.parent() + if parent is None: + self._data = None + else: + self._data = parent.getData(copy=False) self._updateScenePrimitive() - def _scalarField3DChanged(self, event): - """Handle parent's ScalarField3D sigItemChanged""" + def _parentChanged(self, event): + """Handle data change in the parent this isosurface belongs to""" if event == ItemChangedType.DATA: - self._updateScenePrimitive() + self._syncDataWithParent() + + def setParent(self, parent): + oldParent = self.parent() + if isinstance(oldParent, Item3D): + oldParent.sigItemChanged.disconnect(self._parentChanged) + + super(Isosurface, self).setParent(parent) + + if isinstance(parent, Item3D): + parent.sigItemChanged.connect(self._parentChanged) + + self._syncDataWithParent() def getData(self, copy=True): """Return 3D dataset. @@ -193,8 +252,10 @@ class Isosurface(Item3D): False to get the internal data (DO NOT modify!) :return: The data set (or None if not set) """ - parent = self.parent() - return None if parent is None else parent.getData(copy=copy) + if self._data is None: + return None + else: + return numpy.array(self._data, copy=copy) def getLevel(self): """Return the level of this iso-surface (float)""" @@ -349,7 +410,7 @@ class Isosurface(Item3D): mc = MarchingCubes(data.reshape(2, 2, 2), isolevel=level) points = mc.get_vertices() + currentBin triangles = points[mc.get_indices()] - t = utils.segmentTrianglesIntersection(rayObject, triangles)[1] + t = glu.segmentTrianglesIntersection(rayObject, triangles)[1] t = numpy.unique(t) # Duplicates happen on triangle edges if len(t) != 0: # Compute intersection points and get closest data point @@ -372,6 +433,12 @@ class ScalarField3D(BaseNodeItem): :param parent: The View widget this item belongs to. """ + _CutPlane = CutPlane + """CutPlane class associated to this class""" + + _Isosurface = Isosurface + """Isosurface classe associated to this class""" + def __init__(self, parent=None): BaseNodeItem.__init__(self, parent=parent) @@ -385,7 +452,7 @@ class ScalarField3D(BaseNodeItem): self._data = None self._dataRange = None - self._cutPlane = CutPlane(parent=self) + self._cutPlane = self._CutPlane(parent=self) self._cutPlane.setVisible(False) self._isogroup = primitives.GroupDepthOffset() @@ -405,6 +472,26 @@ class ScalarField3D(BaseNodeItem): self._cutPlane._getScenePrimitive(), self._isogroup] + @staticmethod + def _computeRangeFromData(data): + """Compute range info (min, min positive, max) from data + + :param Union[numpy.ndarray,None] data: + :return: Union[List[float],None] + """ + if data is None: + return None + + dataRange = min_max(data, min_positive=True, finite=True) + if dataRange.minimum is None: # Only non-finite data + return None + + if dataRange is not None: + min_positive = dataRange.min_positive + if min_positive is None: + min_positive = float('nan') + return dataRange.minimum, min_positive, dataRange.maximum + def setData(self, data, copy=True): """Set the 3D scalar data represented by this item. @@ -418,7 +505,6 @@ class ScalarField3D(BaseNodeItem): """ if data is None: self._data = None - self._dataRange = None self._boundedGroup.shape = None else: @@ -427,21 +513,9 @@ class ScalarField3D(BaseNodeItem): assert min(data.shape) >= 2 self._data = data - - # Store data range info - dataRange = min_max(self._data, min_positive=True, finite=True) - if dataRange.minimum is None: # Only non-finite data - dataRange = None - - if dataRange is not None: - min_positive = dataRange.min_positive - if min_positive is None: - min_positive = float('nan') - dataRange = dataRange.minimum, min_positive, dataRange.maximum - self._dataRange = dataRange - self._boundedGroup.shape = self._data.shape + self._dataRange = self._computeRangeFromData(self._data) self._updated(ItemChangedType.DATA) def getData(self, copy=True): @@ -506,7 +580,7 @@ class ScalarField3D(BaseNodeItem): :return: isosurface object :rtype: ~silx.gui.plot3d.items.volume.Isosurface """ - isosurface = Isosurface(parent=self) + isosurface = self._Isosurface(parent=self) isosurface.setColor(color) if callable(level): isosurface.setAutoLevelFunction(level) @@ -561,8 +635,164 @@ class ScalarField3D(BaseNodeItem): # BaseNodeItem def getItems(self): - """Returns the list of items currently present in the ScalarField3D. + """Returns the list of items currently present in this item. :rtype: tuple """ return self.getCutPlanes() + self.getIsosurfaces() + + +################## +# ComplexField3D # +################## + +class ComplexCutPlane(CutPlane, ComplexMixIn): + """Class representing a cutting plane in a :class:`ComplexField3D` item. + + :param parent: 3D Data set in which the cut plane is applied. + """ + + def __init__(self, parent): + ComplexMixIn.__init__(self) + CutPlane.__init__(self, parent=parent) + + def _syncDataWithParent(self): + """Synchronize this instance data with that of its parent""" + parent = self.parent() + if parent is None: + data, range_ = None, None + else: + mode = self.getComplexMode() + data = parent.getData(mode=mode, copy=False) + range_ = parent.getDataRange(mode=mode) + self._updateData(data, range_) + + def _updated(self, event=None): + """Handle update of the cut plane (and take care of mode change + + :param Union[None,ItemChangedType] event: The kind of update + """ + if event == ItemChangedType.COMPLEX_MODE: + self._syncDataWithParent() + super(ComplexCutPlane, self)._updated(event) + + +class ComplexIsosurface(Isosurface): + """Class representing an iso-surface in a :class:`ComplexField3D` item. + + :param parent: The DataItem3D this iso-surface belongs to + """ + + def __init__(self, parent): + super(ComplexIsosurface, self).__init__(parent) + + def _syncDataWithParent(self): + """Synchronize this instance data with that of its parent""" + parent = self.parent() + if parent is None: + self._data = None + else: + self._data = parent.getData( + mode=parent.getComplexMode(), copy=False) + self._updateScenePrimitive() + + def _parentChanged(self, event): + """Handle data change in the parent this isosurface belongs to""" + if event == ItemChangedType.COMPLEX_MODE: + self._syncDataWithParent() + super(ComplexIsosurface, self)._parentChanged(event) + + +class ComplexField3D(ScalarField3D, ComplexMixIn): + """3D complex field on a regular grid. + + :param parent: The View widget this item belongs to. + """ + + _CutPlane = ComplexCutPlane + _Isosurface = ComplexIsosurface + + def __init__(self, parent=None): + self._dataRangeCache = None + + ComplexMixIn.__init__(self) + ScalarField3D.__init__(self, parent=parent) + + @docstring(ComplexMixIn) + def setComplexMode(self, mode): + if mode != self.getComplexMode(): + self.clearIsosurfaces() # Reset isosurfaces + ComplexMixIn.setComplexMode(self, mode) + + def setData(self, data, copy=True): + """Set the 3D complex data represented by this item. + + Dataset order is zyx (i.e., first dimension is z). + + :param data: 3D array + :type data: 3D numpy.ndarray of float32 with shape at least (2, 2, 2) + :param bool copy: + True (default) to make a copy, + False to avoid copy (DO NOT MODIFY data afterwards) + """ + if data is None: + self._data = None + self._dataRangeCache = None + self._boundedGroup.shape = None + + else: + data = numpy.array(data, copy=copy, dtype=numpy.complex64, order='C') + assert data.ndim == 3 + assert min(data.shape) >= 2 + + self._data = data + self._dataRangeCache = {} + self._boundedGroup.shape = self._data.shape + + self._updated(ItemChangedType.DATA) + + def getData(self, copy=True, mode=None): + """Return 3D dataset. + + This method does not cache data converted to a specific mode, + it computes it for each request. + + :param bool copy: + True (default) to get a copy, + False to get the internal data (DO NOT modify!) + :param Union[None,Mode] mode: + The kind of data to retrieve. + If None (the default), it returns the complex data, + else it computes the requested scalar data. + :return: The data set (or None if not set) + :rtype: Union[numpy.ndarray,None] + """ + if mode is None: + return super(ComplexField3D, self).getData(copy=copy) + else: + return self._convertComplexData(self._data, mode) + + def getDataRange(self, mode=None): + """Return the range of the requested data as a 3-tuple of values. + + Positive min is NaN if no data is positive. + + :param Union[None,Mode] mode: + The kind of data for which to get the range information. + If None (the default), it returns the data range for the current mode, + else it returns the data range for the requested mode. + :return: (min, positive min, max) or None. + :rtype: Union[None,List[float]] + """ + if self._dataRangeCache is None: + return None + + if mode is None: + mode = self.getComplexMode() + + if mode not in self._dataRangeCache: + # Compute it and store it in cache + data = self.getData(copy=False, mode=mode) + self._dataRangeCache[mode] = self._computeRangeFromData(data) + + return self._dataRangeCache[mode] diff --git a/silx/gui/plot3d/scene/camera.py b/silx/gui/plot3d/scene/camera.py index acc5899..90de7ed 100644 --- a/silx/gui/plot3d/scene/camera.py +++ b/silx/gui/plot3d/scene/camera.py @@ -292,6 +292,8 @@ class Camera(transform.Transform): center = 0.5 * (bounds[0] + bounds[1]) radius = numpy.linalg.norm(0.5 * (bounds[1] - bounds[0])) + if radius == 0.: # bounds are all collapsed + radius = 1. if isinstance(self.intrinsic, transform.Perspective): # Get the viewpoint distance from the bounds center diff --git a/silx/gui/plot3d/scene/core.py b/silx/gui/plot3d/scene/core.py index a293f28..43838fe 100644 --- a/silx/gui/plot3d/scene/core.py +++ b/silx/gui/plot3d/scene/core.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -111,6 +111,15 @@ class Base(event.Notifier): root = self.path[0] return root if isinstance(root, Viewport) else None + @property + def root(self): + """The root node of the scene. + + If attached to a :class:`Viewport`, this is the item right under it + """ + path = self.path + return path[1] if isinstance(path[0], Viewport) else path[0] + @property def objectToNDCTransform(self): """Transform from object to normalized device coordinates. diff --git a/silx/gui/plot3d/scene/cutplane.py b/silx/gui/plot3d/scene/cutplane.py index 08a9899..81c74c7 100644 --- a/silx/gui/plot3d/scene/cutplane.py +++ b/silx/gui/plot3d/scene/cutplane.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -79,19 +79,20 @@ class ColormapMesh3D(Geometry): uniform float alpha; $colormapDecl - - $clippingDecl + $sceneDecl $lightingFunction void main(void) { + $scenePreCall(vCameraPosition); + float value = texture3D(data, vTexCoords).r; vec4 color = $colormapCall(value); color.a = alpha; - $clippingCall(vCameraPosition); - gl_FragColor = $lightingCall(color, vPosition, vNormal); + + $scenePostCall(vCameraPosition); } """)) @@ -186,8 +187,9 @@ class ColormapMesh3D(Geometry): def renderGL2(self, ctx): fragment = self._shaders[1].substitute( - clippingDecl=ctx.clipper.fragDecl, - clippingCall=ctx.clipper.fragCall, + sceneDecl=ctx.fragDecl, + scenePreCall=ctx.fragCallPre, + scenePostCall=ctx.fragCallPost, lightingFunction=ctx.viewport.light.fragmentDef, lightingCall=ctx.viewport.light.fragmentCall, colormapDecl=self.colormap.decl, @@ -216,7 +218,7 @@ class ColormapMesh3D(Geometry): gl.glUniform1i(program.uniforms['data'], self._texture.texUnit) - ctx.clipper.setupProgram(ctx, program) + ctx.setupProgram(program) self._texture.bind() self._draw(program) diff --git a/silx/gui/plot3d/scene/function.py b/silx/gui/plot3d/scene/function.py index 2921d48..7651f75 100644 --- a/silx/gui/plot3d/scene/function.py +++ b/silx/gui/plot3d/scene/function.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -60,6 +60,91 @@ class ProgramFunction(object): pass +class Fog(event.Notifier, ProgramFunction): + """Linear fog over the whole scene content. + + The background of the viewport is used as fog color, + otherwise it defaults to white. + """ + # TODO: add more controls (set fog range), add more fog modes + + _fragDecl = """ + /* (1/(far - near) or 0, near) z in [0 (camera), -inf[ */ + uniform vec2 fogExtentInfo; + + /* Color to use as fog color */ + uniform vec3 fogColor; + + vec4 fog(vec4 color, vec4 cameraPosition) { + /* d = (pos - near) / (far - near) */ + float distance = fogExtentInfo.x * (cameraPosition.z/cameraPosition.w - fogExtentInfo.y); + float fogFactor = clamp(distance, 0.0, 1.0); + vec3 rgb = mix(color.rgb, fogColor, fogFactor); + return vec4(rgb.r, rgb.g, rgb.b, color.a); + } + """ + + _fragDeclNoop = """ + vec4 fog(vec4 color, vec4 cameraPosition) { + return color; + } + """ + + def __init__(self): + super(Fog, self).__init__() + self._isOn = True + + @property + def isOn(self): + """True to enable fog, False to disable (bool)""" + return self._isOn + + @isOn.setter + def isOn(self, isOn): + isOn = bool(isOn) + if self._isOn != isOn: + self._isOn = bool(isOn) + self.notify() + + @property + def fragDecl(self): + return self._fragDecl if self.isOn else self._fragDeclNoop + + @property + def fragCall(self): + return "fog" + + @staticmethod + def _zExtentCamera(viewport): + """Return (far, near) planes Z in camera coordinates. + + :param Viewport viewport: + :return: (far, near) position in camera coords (from 0 to -inf) + """ + # Provide scene z extent in camera coords + bounds = viewport.camera.extrinsic.transformBounds( + viewport.scene.bounds(transformed=True, dataBounds=True)) + return bounds[:, 2] + + def setupProgram(self, context, program): + if not self.isOn: + return + + far, near = context.cache(key='zExtentCamera', + factory=self._zExtentCamera, + viewport=context.viewport) + extent = far - near + gl.glUniform2f(program.uniforms['fogExtentInfo'], + 0.9/extent if extent != 0. else 0., + near) + + # Use background color as fog color + bgColor = context.viewport.background + if bgColor is None: + bgColor = 1., 1., 1. + gl.glUniform3f(program.uniforms['fogColor'], *bgColor[:3]) + + class ClippingPlane(ProgramFunction): """Description of a clipping plane and rendering. diff --git a/silx/gui/plot3d/scene/interaction.py b/silx/gui/plot3d/scene/interaction.py index e5cfb6d..14a54dc 100644 --- a/silx/gui/plot3d/scene/interaction.py +++ b/silx/gui/plot3d/scene/interaction.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -43,11 +43,11 @@ from . import transform _logger = logging.getLogger(__name__) -# ClickOrDrag ################################################################# - -# TODO merge with silx.gui.plot.Interaction.ClickOrDrag class ClickOrDrag(StateMachine): - """Click or drag interaction for a given button.""" + """Click or drag interaction for a given button. + + """ + #TODO: merge this class with silx.gui.plot.Interaction.ClickOrDrag DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2 @@ -126,23 +126,29 @@ class ClickOrDrag(StateMachine): pass -# CameraRotate ################################################################ - -class CameraRotate(ClickOrDrag): +class CameraSelectRotate(ClickOrDrag): """Camera rotation using an arcball-like interaction.""" - def __init__(self, viewport, orbitAroundCenter=True, button=RIGHT_BTN): + def __init__(self, viewport, orbitAroundCenter=True, button=RIGHT_BTN, + selectCB=None): self._viewport = viewport self._orbitAroundCenter = orbitAroundCenter + self._selectCB = selectCB self._reset() - super(CameraRotate, self).__init__(button) + super(CameraSelectRotate, self).__init__(button) def _reset(self): self._origin, self._center = None, None self._startExtrinsic = None def click(self, x, y): - pass # No interaction yet + if self._selectCB is not None: + ndcZ = self._viewport._pickNdcZGL(x, y) + position = self._viewport._getXZYGL(x, y) + # This assume no object lie on the far plane + # Alternative, change the depth range so that far is < 1 + if ndcZ != 1. and position is not None: + self._selectCB((x, y, ndcZ), position) def beginDrag(self, x, y): centerPos = None @@ -205,8 +211,6 @@ class CameraRotate(ClickOrDrag): self._reset() -# CameraSelectPan ############################################################# - class CameraSelectPan(ClickOrDrag): """Picking on click and pan camera on drag.""" @@ -259,8 +263,6 @@ class CameraSelectPan(ClickOrDrag): self._lastPosNdc = None -# CameraWheel ################################################################# - class CameraWheel(object): """StateMachine like class, just handling wheel events.""" @@ -371,8 +373,6 @@ class CameraWheel(object): return True -# FocusManager ################################################################ - class FocusManager(StateMachine): """Manages focus across multiple event handlers @@ -449,8 +449,6 @@ class FocusManager(StateMachine): handler.cancel() -# CameraControl ############################################################### - class RotateCameraControl(FocusManager): """Combine wheel and rotate state machine for left button and pan when ctrl is pressed @@ -460,7 +458,8 @@ class RotateCameraControl(FocusManager): mode='center', scaleTransform=None, selectCB=None): handlers = (CameraWheel(viewport, mode, scaleTransform), - CameraRotate(viewport, orbitAroundCenter, LEFT_BTN)) + CameraSelectRotate( + viewport, orbitAroundCenter, LEFT_BTN, selectCB)) ctrlHandlers = (CameraWheel(viewport, mode, scaleTransform), CameraSelectPan(viewport, LEFT_BTN, selectCB)) super(RotateCameraControl, self).__init__(handlers, ctrlHandlers) @@ -476,7 +475,8 @@ class PanCameraControl(FocusManager): handlers = (CameraWheel(viewport, mode, scaleTransform), CameraSelectPan(viewport, LEFT_BTN, selectCB)) ctrlHandlers = (CameraWheel(viewport, mode, scaleTransform), - CameraRotate(viewport, orbitAroundCenter, LEFT_BTN)) + CameraSelectRotate( + viewport, orbitAroundCenter, LEFT_BTN, selectCB)) super(PanCameraControl, self).__init__(handlers, ctrlHandlers) @@ -488,12 +488,11 @@ class CameraControl(FocusManager): selectCB=None): handlers = (CameraWheel(viewport, mode, scaleTransform), CameraSelectPan(viewport, LEFT_BTN, selectCB), - CameraRotate(viewport, orbitAroundCenter, RIGHT_BTN)) + CameraSelectRotate( + viewport, orbitAroundCenter, RIGHT_BTN, selectCB)) super(CameraControl, self).__init__(handlers) -# PlaneRotate ################################################################# - class PlaneRotate(ClickOrDrag): """Plane rotation using arcball interaction. @@ -603,8 +602,6 @@ class PlaneRotate(ClickOrDrag): self._reset() -# PlanePan ################################################################### - class PlanePan(ClickOrDrag): """Pan a plane along its normal on drag.""" @@ -668,8 +665,6 @@ class PlanePan(ClickOrDrag): self._beginPlanePoint = None -# PlaneControl ################################################################ - class PlaneControl(FocusManager): """Combine wheel, selectPan and rotate state machine for plane control.""" def __init__(self, viewport, plane, @@ -686,9 +681,9 @@ class PanPlaneRotateCameraControl(FocusManager): mode='center', scaleTransform=None): handlers = (CameraWheel(viewport, mode, scaleTransform), PlanePan(viewport, plane, LEFT_BTN), - CameraRotate(viewport, - orbitAroundCenter=False, - button=RIGHT_BTN)) + CameraSelectRotate(viewport, + orbitAroundCenter=False, + button=RIGHT_BTN)) super(PanPlaneRotateCameraControl, self).__init__(handlers) @@ -701,5 +696,6 @@ class PanPlaneZoomOnWheelControl(FocusManager): handlers = (CameraWheel(viewport, mode, scaleTransform), PlanePan(viewport, plane, LEFT_BTN)) ctrlHandlers = (CameraWheel(viewport, mode, scaleTransform), - CameraRotate(viewport, orbitAroundCenter, LEFT_BTN)) + CameraSelectRotate( + viewport, orbitAroundCenter, LEFT_BTN)) super(PanPlaneZoomOnWheelControl, self).__init__(handlers, ctrlHandlers) diff --git a/silx/gui/plot3d/scene/primitives.py b/silx/gui/plot3d/scene/primitives.py index ca06e30..08724ba 100644 --- a/silx/gui/plot3d/scene/primitives.py +++ b/silx/gui/plot3d/scene/primitives.py @@ -29,8 +29,10 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" __date__ = "24/04/2018" - -import collections +try: + from collections import abc +except ImportError: # Python2 support + import collections as abc import ctypes from functools import reduce import logging @@ -47,7 +49,7 @@ from . import event from . import core from . import transform from . import utils -from .function import Colormap +from .function import Colormap, Fog _logger = logging.getLogger(__name__) @@ -146,7 +148,7 @@ class Geometry(core.Elem): :param bool copy: True to make a copy of the array, False to use as is """ # Convert single value (int, float, numpy types) to tuple - if not isinstance(array, collections.Iterable): + if not isinstance(array, abc.Iterable): array = (array, ) # Makes sure it is an array @@ -361,9 +363,11 @@ class Geometry(core.Elem): if attribute.ndim == 1: # Single value min_ = attribute max_ = attribute - else: # Array of values, compute min/max + elif len(attribute) > 0: # Array of values, compute min/max min_ = numpy.nanmin(attribute, axis=0) max_ = numpy.nanmax(attribute, axis=0) + else: + min_, max_ = numpy.zeros((2, attribute.shape[1]), dtype=numpy.float32) toCopy = min(len(min_), 3-index) if toCopy != len(min_): @@ -451,13 +455,14 @@ class Lines(Geometry): varying vec3 vNormal; varying vec4 vColor; - $clippingDecl + $sceneDecl $lightingFunction void main(void) { - $clippingCall(vCameraPosition); + $scenePreCall(vCameraPosition); gl_FragColor = $lightingCall(vColor, vPosition, vNormal); + $scenePostCall(vCameraPosition); } """)) @@ -492,8 +497,9 @@ class Lines(Geometry): fraglightfunction = ctx.viewport.light.fragmentShaderFunctionNoop fragment = self._shaders[1].substitute( - clippingDecl=ctx.clipper.fragDecl, - clippingCall=ctx.clipper.fragCall, + sceneDecl=ctx.fragDecl, + scenePreCall=ctx.fragCallPre, + scenePostCall=ctx.fragCallPost, lightingFunction=fraglightfunction, lightingCall=ctx.viewport.light.fragmentCall) prog = ctx.glCtx.prog(self._shaders[0], fragment) @@ -509,7 +515,7 @@ class Lines(Geometry): ctx.objectToCamera.matrix, safe=True) - ctx.clipper.setupProgram(ctx, prog) + ctx.setupProgram(prog) with gl.enabled(gl.GL_LINE_SMOOTH, self._smooth): self._draw(prog) @@ -560,18 +566,21 @@ class DashedLines(Lines): uniform vec2 dash; - $clippingDecl + $sceneDecl $lightingFunction void main(void) { + $scenePreCall(vCameraPosition); + /* Discard off dash fragments */ float lineDist = distance(vOriginFragCoord, gl_FragCoord.xy); if (mod(lineDist, dash.x + dash.y) > dash.x) { discard; } - $clippingCall(vCameraPosition); gl_FragColor = $lightingCall(vColor, vPosition, vNormal); + + $scenePostCall(vCameraPosition); } """)) @@ -627,8 +636,9 @@ class DashedLines(Lines): context.viewport.light.fragmentShaderFunctionNoop fragment = self._shaders[1].substitute( - clippingDecl=context.clipper.fragDecl, - clippingCall=context.clipper.fragCall, + sceneDecl=context.fragDecl, + scenePreCall=context.fragCallPre, + scenePostCall=context.fragCallPost, lightingFunction=fraglightfunction, lightingCall=context.viewport.light.fragmentCall) program = context.glCtx.prog(self._shaders[0], fragment) @@ -648,7 +658,7 @@ class DashedLines(Lines): program.uniforms['viewportSize'], *context.viewport.size) gl.glUniform2f(program.uniforms['dash'], *self.dash) - context.clipper.setupProgram(context, program) + context.setupProgram(program) self._draw(program) @@ -1236,14 +1246,12 @@ class _Points(Geometry): varying $valueType vValue; $valueToColorDecl - - $clippingDecl - + $sceneDecl $alphaSymbolDecl void main(void) { - $clippingCall(vCameraPosition); + $scenePreCall(vCameraPosition); float alpha = alphaSymbol(gl_PointCoord, vSize); @@ -1252,6 +1260,8 @@ class _Points(Geometry): if (gl_FragColor.a == 0.0) { discard; } + + $scenePostCall(vCameraPosition); } """)) @@ -1305,8 +1315,9 @@ class _Points(Geometry): vertexShader = self._shaders[0].substitute( valueType=valueType) fragmentShader = self._shaders[1].substitute( - clippingDecl=ctx.clipper.fragDecl, - clippingCall=ctx.clipper.fragCall, + sceneDecl=ctx.fragDecl, + scenePreCall=ctx.fragCallPre, + scenePostCall=ctx.fragCallPost, valueType=valueType, valueToColorDecl=valueToColorDecl, valueToColorCall=valueToColorCall, @@ -1324,7 +1335,7 @@ class _Points(Geometry): ctx.objectToCamera.matrix, safe=True) - ctx.clipper.setupProgram(ctx, program) + ctx.setupProgram(program) self._renderGL2PreDrawHook(ctx, program) @@ -1475,15 +1486,17 @@ class GridPoints(Geometry): in vec4 vCameraPosition; in float vNormValue; - out vec4 fragColor; + out vec4 gl_FragColor; - $clippingDecl + $sceneDecl void main(void) { - $clippingCall(vCameraPosition); + $scenePreCall(vCameraPosition); + + gl_FragColor = vec4(0.5 * vNormValue + 0.5, 0.0, 0.0, 1.0); - fragColor = vec4(0.5 * vNormValue + 0.5, 0.0, 0.0, 1.0); + $scenePostCall(vCameraPosition); } """)) @@ -1497,7 +1510,7 @@ class GridPoints(Geometry): def __init__(self, values=0., shape=None, sizes=1., indices=None, minValue=None, maxValue=None): - if isinstance(values, collections.Iterable): + if isinstance(values, abc.Iterable): values = numpy.array(values, copy=False) # Test if gl_VertexID will overflow @@ -1532,8 +1545,9 @@ class GridPoints(Geometry): def renderGL2(self, ctx): fragment = self._shaders[1].substitute( - clippingDecl=ctx.clipper.fragDecl, - clippingCall=ctx.clipper.fragCall) + sceneDecl=ctx.fragDecl, + scenePreCall=ctx.fragCallPre, + scenePostCall=ctx.fragCallPost) prog = ctx.glCtx.prog(self._shaders[0], fragment) prog.use() @@ -1546,7 +1560,7 @@ class GridPoints(Geometry): ctx.objectToCamera.matrix, safe=True) - ctx.clipper.setupProgram(ctx, prog) + ctx.setupProgram(prog) gl.glUniform3i(prog.uniforms['gridDims'], self._shape[2] if len(self._shape) == 3 else 1, @@ -1632,12 +1646,12 @@ class Spheres(Geometry): varying float vViewDepth; varying float vViewRadius; - $clippingDecl + $sceneDecl $lightingFunction void main(void) { - $clippingCall(vCameraPosition); + $scenePreCall(vCameraPosition); /* Get normal from point coords */ vec3 normal; @@ -1658,6 +1672,8 @@ class Spheres(Geometry): float viewDepth = vViewDepth + vViewRadius * normal.z; vec2 clipZW = viewDepth * projMat[2].zw + projMat[3].zw; gl_FragDepth = 0.5 * (clipZW.x / clipZW.y) + 0.5; + + $scenePostCall(vCameraPosition); } """)) @@ -1676,8 +1692,9 @@ class Spheres(Geometry): def renderGL2(self, ctx): fragment = self._shaders[1].substitute( - clippingDecl=ctx.clipper.fragDecl, - clippingCall=ctx.clipper.fragCall, + sceneDecl=ctx.fragDecl, + scenePreCall=ctx.fragCallPre, + scenePostCall=ctx.fragCallPost, lightingFunction=ctx.viewport.light.fragmentDef, lightingCall=ctx.viewport.light.fragmentCall) prog = ctx.glCtx.prog(self._shaders[0], fragment) @@ -1694,7 +1711,7 @@ class Spheres(Geometry): ctx.objectToCamera.matrix, safe=True) - ctx.clipper.setupProgram(ctx, prog) + ctx.setupProgram(prog) gl.glUniform2f(prog.uniforms['screenSize'], *ctx.viewport.size) @@ -1748,14 +1765,16 @@ class Mesh3D(Geometry): varying vec3 vNormal; varying vec4 vColor; - $clippingDecl + $sceneDecl $lightingFunction void main(void) { - $clippingCall(vCameraPosition); + $scenePreCall(vCameraPosition); gl_FragColor = $lightingCall(vColor, vPosition, vNormal); + + $scenePostCall(vCameraPosition); } """)) @@ -1798,8 +1817,9 @@ class Mesh3D(Geometry): fragLightFunction = ctx.viewport.light.fragmentShaderFunctionNoop fragment = self._shaders[1].substitute( - clippingDecl=ctx.clipper.fragDecl, - clippingCall=ctx.clipper.fragCall, + sceneDecl=ctx.fragDecl, + scenePreCall=ctx.fragCallPre, + scenePostCall=ctx.fragCallPost, lightingFunction=fragLightFunction, lightingCall=ctx.viewport.light.fragmentCall) prog = ctx.glCtx.prog(self._shaders[0], fragment) @@ -1818,7 +1838,7 @@ class Mesh3D(Geometry): ctx.objectToCamera.matrix, safe=True) - ctx.clipper.setupProgram(ctx, prog) + ctx.setupProgram(prog) self._draw(prog) @@ -1860,15 +1880,17 @@ class ColormapMesh3D(Geometry): varying float vValue; $colormapDecl - $clippingDecl + $sceneDecl $lightingFunction void main(void) { - $clippingCall(vCameraPosition); + $scenePreCall(vCameraPosition); vec4 color = $colormapCall(vValue); gl_FragColor = $lightingCall(color, vPosition, vNormal); + + $scenePostCall(vCameraPosition); } """)) @@ -1933,8 +1955,9 @@ class ColormapMesh3D(Geometry): def _renderGL2(self, ctx): fragment = self._shaders[1].substitute( - clippingDecl=ctx.clipper.fragDecl, - clippingCall=ctx.clipper.fragCall, + sceneDecl=ctx.fragDecl, + scenePreCall=ctx.fragCallPre, + scenePostCall=ctx.fragCallPost, lightingFunction=ctx.viewport.light.fragmentDef, lightingCall=ctx.viewport.light.fragmentCall, colormapDecl=self.colormap.decl, @@ -1943,7 +1966,7 @@ class ColormapMesh3D(Geometry): program.use() ctx.viewport.light.setupProgram(ctx, program) - ctx.clipper.setupProgram(ctx, program) + ctx.setupProgram(program) self.colormap.setupProgram(ctx, program) if self.culling is not None: @@ -2001,20 +2024,20 @@ class _Image(Geometry): uniform float alpha; $imageDecl - - $clippingDecl - + $sceneDecl $lightingFunction void main(void) { + $scenePreCall(vCameraPosition); + vec4 color = imageColor(data, vTexCoords); color.a = alpha; - $clippingCall(vCameraPosition); - vec3 normal = vec3(0.0, 0.0, 1.0); gl_FragColor = $lightingCall(color, vPosition, normal); + + $scenePostCall(vCameraPosition); } """)) @@ -2133,8 +2156,9 @@ class _Image(Geometry): def _renderGL2(self, ctx): fragment = self._shaders[1].substitute( - clippingDecl=ctx.clipper.fragDecl, - clippingCall=ctx.clipper.fragCall, + sceneDecl=ctx.fragDecl, + scenePreCall=ctx.fragCallPre, + scenePostCall=ctx.fragCallPost, lightingFunction=ctx.viewport.light.fragmentDef, lightingCall=ctx.viewport.light.fragmentCall, imageDecl=self._shaderImageColorDecl() @@ -2159,7 +2183,7 @@ class _Image(Geometry): gl.glUniform1i(program.uniforms['data'], self._texture.texUnit) - ctx.clipper.setupProgram(ctx, program) + ctx.setupProgram(program) self._texture.bind() diff --git a/silx/gui/plot3d/scene/utils.py b/silx/gui/plot3d/scene/utils.py index 1224f5e..bddbcac 100644 --- a/silx/gui/plot3d/scene/utils.py +++ b/silx/gui/plot3d/scene/utils.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -544,77 +544,6 @@ def segmentVolumeIntersect(segment, nbins): return bins -def segmentTrianglesIntersection(segment, triangles): - """Check for segment/triangles intersection. - - This is based on signed tetrahedron volume comparison. - - See A. Kensler, A., Shirley, P. - Optimizing Ray-Triangle Intersection via Automated Search. - Symposium on Interactive Ray Tracing, vol. 0, p33-38 (2006) - - :param numpy.ndarray segment: - Segment end points as a 2x3 array of coordinates - :param numpy.ndarray triangles: - Nx3x3 array of triangles - :return: (triangle indices, segment parameter, barycentric coord) - Indices of intersected triangles, "depth" along the segment - of the intersection point and barycentric coordinates of intersection - point in the triangle. - :rtype: List[numpy.ndarray] - """ - # TODO triangles from vertices + indices - # TODO early rejection? e.g., check segment bbox vs triangle bbox - segment = numpy.asarray(segment) - assert segment.ndim == 2 - assert segment.shape == (2, 3) - - triangles = numpy.asarray(triangles) - assert triangles.ndim == 3 - assert triangles.shape[1] == 3 - - # Test line/triangles intersection - d = segment[1] - segment[0] - t0s0 = segment[0] - triangles[:, 0, :] - edge01 = triangles[:, 1, :] - triangles[:, 0, :] - edge02 = triangles[:, 2, :] - triangles[:, 0, :] - - dCrossEdge02 = numpy.cross(d, edge02) - t0s0CrossEdge01 = numpy.cross(t0s0, edge01) - volume = numpy.sum(dCrossEdge02 * edge01, axis=1) - del edge01 - subVolumes = numpy.empty((len(triangles), 3), dtype=triangles.dtype) - subVolumes[:, 1] = numpy.sum(dCrossEdge02 * t0s0, axis=1) - del dCrossEdge02 - subVolumes[:, 2] = numpy.sum(t0s0CrossEdge01 * d, axis=1) - subVolumes[:, 0] = volume - subVolumes[:, 1] - subVolumes[:, 2] - intersect = numpy.logical_or( - numpy.all(subVolumes >= 0., axis=1), # All positive - numpy.all(subVolumes <= 0., axis=1)) # All negative - intersect = numpy.where(intersect)[0] # Indices of intersected triangles - - # Get barycentric coordinates - barycentric = subVolumes[intersect] / volume[intersect].reshape(-1, 1) - del subVolumes - - # Test segment/triangles intersection - volAlpha = numpy.sum(t0s0CrossEdge01[intersect] * edge02[intersect], axis=1) - t = volAlpha / volume[intersect] # segment parameter of intersected triangles - del t0s0CrossEdge01 - del edge02 - del volAlpha - del volume - - inSegmentMask = numpy.logical_and(t >= 0., t <= 1.) - intersect = intersect[inSegmentMask] - t = t[inSegmentMask] - barycentric = barycentric[inSegmentMask] - - # Sort intersecting triangles by t - indices = numpy.argsort(t) - return intersect[indices], t[indices], barycentric[indices] - - # Plane ####################################################################### class Plane(event.Notifier): diff --git a/silx/gui/plot3d/scene/viewport.py b/silx/gui/plot3d/scene/viewport.py index 41aa999..6de640e 100644 --- a/silx/gui/plot3d/scene/viewport.py +++ b/silx/gui/plot3d/scene/viewport.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -36,6 +36,7 @@ __license__ = "MIT" __date__ = "24/04/2018" +import string import numpy from silx.gui.colors import rgba @@ -45,7 +46,7 @@ from ..._glutils import gl from . import camera from . import event from . import transform -from .function import DirectionalLight, ClippingPlane +from .function import DirectionalLight, ClippingPlane, Fog class RenderContext(object): @@ -61,12 +62,33 @@ class RenderContext(object): :param Context glContext: The operating system OpenGL context in use. """ + _FRAGMENT_SHADER_SRC = string.Template(""" + void scene_post(vec4 cameraPosition) { + gl_FragColor = $fogCall(gl_FragColor, cameraPosition); + } + """) + def __init__(self, viewport, glContext): self._viewport = viewport self._glContext = glContext self._transformStack = [viewport.camera.extrinsic] self._clipPlane = ClippingPlane(normal=(0., 0., 0.)) + # cache + self.__cache = {} + + def cache(self, key, factory, *args, **kwargs): + """Lazy-loading cache to store values in the context for rendering + + :param key: The key to retrieve + :param factory: A callback taking args and kwargs as arguments + and returning the value to store. + :return: The stored or newly allocated value + """ + if key not in self.__cache: + self.__cache[key] = factory(*args, **kwargs) + return self.__cache[key] + @property def viewport(self): """Viewport doing the current rendering""" @@ -127,8 +149,7 @@ class RenderContext(object): @property def clipper(self): - """The current clipping plane - """ + """The current clipping plane (ClippingPlane)""" return self._clipPlane def setClipPlane(self, point=(0., 0., 0.), normal=(0., 0., 0.)): @@ -143,6 +164,40 @@ class RenderContext(object): """ self._clipPlane = ClippingPlane(point, normal) + def setupProgram(self, program): + """Sets-up uniforms of a program using the context shader functions. + + :param GLProgram program: The program to set-up. + It MUST be in use and using the context function. + """ + self.clipper.setupProgram(self, program) + self.viewport.fog.setupProgram(self, program) + + @property + def fragDecl(self): + """Fragment shader declaration for scene shader functions""" + return '\n'.join(( + self.clipper.fragDecl, + self.viewport.fog.fragDecl, + self._FRAGMENT_SHADER_SRC.substitute( + fogCall=self.viewport.fog.fragCall))) + + @property + def fragCallPre(self): + """Fragment shader call for scene shader functions (to do first) + + It takes the camera position (vec4) as argument. + """ + return self.clipper.fragCall + + @property + def fragCallPost(self): + """Fragment shader call for scene shader functions (to do last) + + It takes the camera position (vec4) as argument. + """ + return "scene_post" + class Viewport(event.Notifier): """Rendering a single scene through a camera in part of a framebuffer. @@ -170,6 +225,9 @@ class Viewport(event.Notifier): ambient=(0.3, 0.3, 0.3), diffuse=(0.7, 0.7, 0.7)) self._light.addListener(self._changed) + self._fog = Fog() + self._fog.isOn = False + self._fog.addListener(self._changed) @property def transforms(self): @@ -223,6 +281,11 @@ class Viewport(event.Notifier): """The light used to render the scene.""" return self._light + @property + def fog(self): + """The fog function used to render the scene""" + return self._fog + @property def origin(self): """Origin (ox, oy) of the viewport in pixels""" @@ -351,8 +414,8 @@ class Viewport(event.Notifier): """ bounds = self.scene.bounds(transformed=True) if bounds is None: - bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)), - dtype=numpy.float32) + bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)), + dtype=numpy.float32) self.camera.resetCamera(bounds) def orbitCamera(self, direction, angle=1.): diff --git a/silx/gui/plot3d/test/__init__.py b/silx/gui/plot3d/test/__init__.py index 8825cf4..77172d1 100644 --- a/silx/gui/plot3d/test/__init__.py +++ b/silx/gui/plot3d/test/__init__.py @@ -58,14 +58,18 @@ def suite(): from ..tools.test import suite as toolsTestSuite from .testGL import suite as testGLSuite from .testScalarFieldView import suite as testScalarFieldViewSuite + from .testSceneWidget import suite as testSceneWidgetSuite from .testSceneWidgetPicking import suite as testSceneWidgetPickingSuite + from .testSceneWindow import suite as testSceneWindowSuite from .testStatsWidget import suite as testStatsWidgetSuite testsuite = unittest.TestSuite() testsuite.addTest(testGLSuite()) testsuite.addTest(sceneTestSuite()) testsuite.addTest(testScalarFieldViewSuite()) + testsuite.addTest(testSceneWidgetSuite()) testsuite.addTest(testSceneWidgetPickingSuite()) + testsuite.addTest(testSceneWindowSuite()) testsuite.addTest(toolsTestSuite()) testsuite.addTest(testStatsWidgetSuite()) return testsuite diff --git a/silx/gui/plot3d/test/testSceneWidget.py b/silx/gui/plot3d/test/testSceneWidget.py new file mode 100644 index 0000000..13ddd37 --- /dev/null +++ b/silx/gui/plot3d/test/testSceneWidget.py @@ -0,0 +1,84 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# ###########################################################################*/ +"""Test SceneWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2019" + + +import unittest + +import numpy + +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import TestCaseQt +from silx.gui import qt + +from silx.gui.plot3d.SceneWidget import SceneWidget + + +class TestSceneWidget(TestCaseQt, ParametricTestCase): + """Tests SceneWidget picking feature""" + + def setUp(self): + super(TestSceneWidget, self).setUp() + self.widget = SceneWidget() + self.widget.show() + self.qWaitForWindowExposed(self.widget) + + def tearDown(self): + self.qapp.processEvents() + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + del self.widget + super(TestSceneWidget, self).tearDown() + + def testFogEffect(self): + """Test fog effect on scene primitive""" + image = self.widget.addImage(numpy.arange(100).reshape(10, 10)) + scatter = self.widget.add3DScatter(*numpy.random.random(4000).reshape(4, -1)) + scatter.setTranslation(10, 10) + scatter.setScale(10, 10, 10) + + self.widget.resetZoom('front') + self.qapp.processEvents() + + self.widget.setFogMode(self.widget.FogMode.LINEAR) + self.qapp.processEvents() + + self.widget.setFogMode(self.widget.FogMode.NONE) + self.qapp.processEvents() + + +def suite(): + testsuite = unittest.TestSuite() + testsuite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase( + TestSceneWidget)) + return testsuite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot3d/test/testSceneWidgetPicking.py b/silx/gui/plot3d/test/testSceneWidgetPicking.py index 649fb47..aea30f6 100644 --- a/silx/gui/plot3d/test/testSceneWidgetPicking.py +++ b/silx/gui/plot3d/test/testSceneWidgetPicking.py @@ -128,50 +128,60 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase): picking = list(self.widget.pickItems(1, 1)) self.assertEqual(len(picking), 0) - def testPickScalarField3D(self): + def testPickVolume(self): """Test picking of volume CutPlane and Isosurface items""" - volume = self.widget.add3DScalarField( - numpy.arange(10**3, dtype=numpy.float32).reshape(10, 10, 10)) - self.widget.resetZoom('front') - - cutplane = volume.getCutPlanes()[0] - cutplane.getColormap().setVRange(0, 100) - cutplane.setNormal((0, 0, 1)) - - # Picking on data without anything displayed - cutplane.setVisible(False) - picking = list(self.widget.pickItems(*self._widgetCenter())) - self.assertEqual(len(picking), 0) - - # Picking on data with the cut plane - cutplane.setVisible(True) - picking = list(self.widget.pickItems(*self._widgetCenter())) - - self.assertEqual(len(picking), 1) - self.assertIs(picking[0].getItem(), cutplane) - data = picking[0].getData() - self.assertEqual(len(data), 1) - self.assertEqual(picking[0].getPositions().shape, (1, 3)) - self.assertTrue(numpy.array_equal( - data, - volume.getData(copy=False)[picking[0].getIndices()])) - - # Picking on data with an isosurface - isosurface = volume.addIsosurface(level=500, color=(1., 0., 0., .5)) - picking = list(self.widget.pickItems(*self._widgetCenter())) - self.assertEqual(len(picking), 2) - self.assertIs(picking[0].getItem(), cutplane) - self.assertIs(picking[1].getItem(), isosurface) - self.assertEqual(picking[1].getPositions().shape, (1, 3)) - data = picking[1].getData() - self.assertEqual(len(data), 1) - self.assertTrue(numpy.array_equal( - data, - volume.getData(copy=False)[picking[1].getIndices()])) - - # Picking outside data - picking = list(self.widget.pickItems(1, 1)) - self.assertEqual(len(picking), 0) + for dtype in (numpy.float32, numpy.complex64): + with self.subTest(dtype=dtype): + refData = numpy.arange(10**3, dtype=dtype).reshape(10, 10, 10) + volume = self.widget.addVolume(refData) + if dtype == numpy.complex64: + volume.setComplexMode(volume.ComplexMode.REAL) + refData = numpy.real(refData) + self.widget.resetZoom('front') + + cutplane = volume.getCutPlanes()[0] + if dtype == numpy.complex64: + cutplane.setComplexMode(volume.ComplexMode.REAL) + cutplane.getColormap().setVRange(0, 100) + cutplane.setNormal((0, 0, 1)) + + # Picking on data without anything displayed + cutplane.setVisible(False) + picking = list(self.widget.pickItems(*self._widgetCenter())) + self.assertEqual(len(picking), 0) + + # Picking on data with the cut plane + cutplane.setVisible(True) + picking = list(self.widget.pickItems(*self._widgetCenter())) + + self.assertEqual(len(picking), 1) + self.assertIs(picking[0].getItem(), cutplane) + data = picking[0].getData() + self.assertEqual(len(data), 1) + self.assertEqual(picking[0].getPositions().shape, (1, 3)) + self.assertTrue(numpy.array_equal( + data, + refData[picking[0].getIndices()])) + + # Picking on data with an isosurface + isosurface = volume.addIsosurface( + level=500, color=(1., 0., 0., .5)) + picking = list(self.widget.pickItems(*self._widgetCenter())) + self.assertEqual(len(picking), 2) + self.assertIs(picking[0].getItem(), cutplane) + self.assertIs(picking[1].getItem(), isosurface) + self.assertEqual(picking[1].getPositions().shape, (1, 3)) + data = picking[1].getData() + self.assertEqual(len(data), 1) + self.assertTrue(numpy.array_equal( + data, + refData[picking[1].getIndices()])) + + # Picking outside data + picking = list(self.widget.pickItems(1, 1)) + self.assertEqual(len(picking), 0) + + self.widget.clearItems() def testPickMesh(self): """Test picking of Mesh items""" diff --git a/silx/gui/plot3d/test/testSceneWindow.py b/silx/gui/plot3d/test/testSceneWindow.py new file mode 100644 index 0000000..b2e6ea0 --- /dev/null +++ b/silx/gui/plot3d/test/testSceneWindow.py @@ -0,0 +1,209 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# ###########################################################################*/ +"""Test SceneWindow""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "22/03/2019" + + +import unittest + +import numpy + +from silx.utils.testutils import ParametricTestCase +from silx.gui.utils.testutils import TestCaseQt +from silx.gui import qt + +from silx.gui.plot3d.SceneWindow import SceneWindow + + +class TestSceneWindow(TestCaseQt, ParametricTestCase): + """Tests SceneWidget picking feature""" + + def setUp(self): + super(TestSceneWindow, self).setUp() + self.window = SceneWindow() + self.window.show() + self.qWaitForWindowExposed(self.window) + + def tearDown(self): + self.qapp.processEvents() + self.window.setAttribute(qt.Qt.WA_DeleteOnClose) + self.window.close() + del self.window + super(TestSceneWindow, self).tearDown() + + def testAdd(self): + """Test add basic scene primitive""" + sceneWidget = self.window.getSceneWidget() + items = [] + + # RGB image + image = sceneWidget.addImage(numpy.random.random( + 10*10*3).astype(numpy.float32).reshape(10, 10, 3)) + image.setLabel('RGB image') + items.append(image) + self.assertEqual(sceneWidget.getItems(), tuple(items)) + + # Data image + image = sceneWidget.addImage( + numpy.arange(100, dtype=numpy.float32).reshape(10, 10)) + image.setTranslation(10.) + items.append(image) + self.assertEqual(sceneWidget.getItems(), tuple(items)) + + # 2D scatter + scatter = sceneWidget.add2DScatter( + *numpy.random.random(3000).astype(numpy.float32).reshape(3, -1), + index=0) + scatter.setTranslation(0, 10) + scatter.setScale(10, 10, 10) + items.insert(0, scatter) + self.assertEqual(sceneWidget.getItems(), tuple(items)) + + # 3D scatter + scatter = sceneWidget.add3DScatter( + *numpy.random.random(4000).astype(numpy.float32).reshape(4, -1)) + scatter.setTranslation(10, 10) + scatter.setScale(10, 10, 10) + items.append(scatter) + self.assertEqual(sceneWidget.getItems(), tuple(items)) + + # 3D array of float + volume = sceneWidget.addVolume( + numpy.arange(10**3, dtype=numpy.float32).reshape(10, 10, 10)) + volume.setTranslation(0, 0, 10) + volume.setRotation(45, (0, 0, 1)) + volume.addIsosurface(500, 'red') + volume.getCutPlanes()[0].getColormap().setName('viridis') + items.append(volume) + self.assertEqual(sceneWidget.getItems(), tuple(items)) + + # 3D array of complex + volume = sceneWidget.addVolume( + numpy.arange(10**3).reshape(10, 10, 10).astype(numpy.complex64)) + volume.setTranslation(10, 0, 10) + volume.setRotation(45, (0, 0, 1)) + volume.setComplexMode(volume.ComplexMode.REAL) + volume.addIsosurface(500, (1., 0., 0., .5)) + items.append(volume) + self.assertEqual(sceneWidget.getItems(), tuple(items)) + + sceneWidget.resetZoom('front') + self.qapp.processEvents() + + def testChangeContent(self): + """Test add/remove/clear items""" + sceneWidget = self.window.getSceneWidget() + items = [] + + # Add 2 images + image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10) + items.append(sceneWidget.addImage(image)) + items.append(sceneWidget.addImage(image)) + self.qapp.processEvents() + self.assertEqual(sceneWidget.getItems(), tuple(items)) + + # Clear + sceneWidget.clearItems() + self.qapp.processEvents() + self.assertEqual(sceneWidget.getItems(), ()) + + # Add 2 images and remove first one + image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10) + sceneWidget.addImage(image) + items = (sceneWidget.addImage(image),) + self.qapp.processEvents() + + sceneWidget.removeItem(sceneWidget.getItems()[0]) + self.qapp.processEvents() + self.assertEqual(sceneWidget.getItems(), items) + + def testColors(self): + """Test setting scene colors""" + sceneWidget = self.window.getSceneWidget() + + color = qt.QColor(128, 128, 128) + sceneWidget.setBackgroundColor(color) + self.assertEqual(sceneWidget.getBackgroundColor(), color) + + color = qt.QColor(0, 0, 0) + sceneWidget.setForegroundColor(color) + self.assertEqual(sceneWidget.getForegroundColor(), color) + + color = qt.QColor(255, 0, 0) + sceneWidget.setTextColor(color) + self.assertEqual(sceneWidget.getTextColor(), color) + + color = qt.QColor(0, 255, 0) + sceneWidget.setHighlightColor(color) + self.assertEqual(sceneWidget.getHighlightColor(), color) + + self.qapp.processEvents() + + def testInteractiveMode(self): + """Test changing interactive mode""" + sceneWidget = self.window.getSceneWidget() + center = numpy.array((sceneWidget.width() //2, sceneWidget.height() // 2)) + + self.mouseMove(sceneWidget, pos=center) + self.mouseClick(sceneWidget, qt.Qt.LeftButton, pos=center) + + volume = sceneWidget.addVolume( + numpy.arange(10**3).astype(numpy.float32).reshape(10, 10, 10)) + sceneWidget.selection().setCurrentItem( volume.getCutPlanes()[0]) + sceneWidget.resetZoom('side') + + for mode in (None, 'rotate', 'pan', 'panSelectedPlane'): + with self.subTest(mode=mode): + sceneWidget.setInteractiveMode(mode) + self.qapp.processEvents() + self.assertEqual(sceneWidget.getInteractiveMode(), mode) + + self.mouseMove(sceneWidget, pos=center) + self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center) + self.mouseMove(sceneWidget, pos=center-10) + self.mouseMove(sceneWidget, pos=center-20) + self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20) + + self.keyPress(sceneWidget, qt.Qt.Key_Control) + self.mouseMove(sceneWidget, pos=center) + self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center) + self.mouseMove(sceneWidget, pos=center-10) + self.mouseMove(sceneWidget, pos=center-20) + self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20) + self.keyRelease(sceneWidget, qt.Qt.Key_Control) + + +def suite(): + testsuite = unittest.TestSuite() + testsuite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase( + TestSceneWindow)) + return testsuite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot3d/tools/PositionInfoWidget.py b/silx/gui/plot3d/tools/PositionInfoWidget.py index b4d2c05..fc86a7f 100644 --- a/silx/gui/plot3d/tools/PositionInfoWidget.py +++ b/silx/gui/plot3d/tools/PositionInfoWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -36,6 +36,7 @@ import logging import weakref from ... import qt +from .. import actions from .. import items from ..items import volume from ..SceneWidget import SceneWidget @@ -65,6 +66,27 @@ class PositionInfoWidget(qt.QWidget): layout.addStretch(1) + self._action = actions.mode.PickingModeAction(parent=self) + self._action.setText('Selection') + self._action.setToolTip( + 'Toggle selection information update with left button click') + self._action.sigSceneClicked.connect(self.pick) + self._action.changed.connect(self.__actionChanged) + self._action.setChecked(False) # Disabled by default + self.__actionChanged() # Sync action/widget + + def __actionChanged(self): + """Handle toggle action change signal""" + if self.toggleAction().isChecked() != self.isEnabled(): + self.setEnabled(self.toggleAction().isChecked()) + + def toggleAction(self): + """The action to toggle the picking mode. + + :rtype: QAction + """ + return self._action + def _addInfoField(self, label): """Add a description: info widget to this widget @@ -108,23 +130,9 @@ class PositionInfoWidget(qt.QWidget): if widget is not None and not isinstance(widget, SceneWidget): raise ValueError("widget must be a SceneWidget or None") - previous = self.getSceneWidget() - if previous is not None: - previous.removeEventFilter(self) - - if widget is None: - self._sceneWidgetRef = None - else: - widget.installEventFilter(self) - self._sceneWidgetRef = weakref.ref(widget) - - def eventFilter(self, watched, event): - # Filter events of SceneWidget to react on mouse events. - if (event.type() == qt.QEvent.MouseButtonDblClick and - event.button() == qt.Qt.LeftButton): - self.pick(event.x(), event.y()) + self._sceneWidgetRef = None if widget is None else weakref.ref(widget) - return super(PositionInfoWidget, self).eventFilter(watched, event) + self.toggleAction().setPlot3DWidget(widget) def clear(self): """Clean-up displayed values""" diff --git a/silx/gui/qt/_qt.py b/silx/gui/qt/_qt.py index a4b9007..9615342 100644 --- a/silx/gui/qt/_qt.py +++ b/silx/gui/qt/_qt.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -66,17 +66,34 @@ elif 'PyQt4.QtCore' in sys.modules: else: # Then try Qt bindings try: - import PyQt5 # noqa + import PyQt5.QtCore # noqa except ImportError: + if 'PyQt5' in sys.modules: + del sys.modules["PyQt5"] try: - import PyQt4 # noqa + import sip + sip.setapi("QString", 2) + sip.setapi("QVariant", 2) + sip.setapi('QDate', 2) + sip.setapi('QDateTime', 2) + sip.setapi('QTextStream', 2) + sip.setapi('QTime', 2) + sip.setapi('QUrl', 2) + import PyQt4.QtCore # noqa except ImportError: + if 'PyQt4' in sys.modules: + del sys.modules["sip"] + del sys.modules["PyQt4"] try: - import PySide2 # noqa + import PySide2.QtCore # noqa except ImportError: + if 'PySide2' in sys.modules: + del sys.modules["PySide2"] try: - import PySide # noqa + import PySide.QtCore # noqa except ImportError: + if 'PySide' in sys.modules: + del sys.modules["PySide"] raise ImportError( 'No Qt wrapper found. Install PyQt5, PyQt4 or PySide2.') else: @@ -98,7 +115,6 @@ if BINDING == 'PyQt4': if sys.version_info < (3, ): try: import sip - sip.setapi("QString", 2) sip.setapi("QVariant", 2) sip.setapi('QDate', 2) @@ -210,8 +226,6 @@ elif BINDING == 'PyQt5': elif BINDING == 'PySide2': _logger.debug('Using PySide2 bindings') - _logger.warning( - 'Using PySide2 Qt binding: PySide2 support in silx is experimental!') import PySide2 as QtBinding # noqa diff --git a/silx/gui/qt/inspect.py b/silx/gui/qt/inspect.py index c6c2cbe..3c08835 100644 --- a/silx/gui/qt/inspect.py +++ b/silx/gui/qt/inspect.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -62,9 +62,14 @@ if qt.BINDING in ('PyQt4', 'PyQt5'): return not _isdeleted(obj) elif qt.BINDING == 'PySide2': - from PySide2.shiboken2 import isValid # noqa - from PySide2.shiboken2 import createdByPython # noqa - from PySide2.shiboken2 import ownedByPython # noqa + try: + from PySide2.shiboken2 import isValid # noqa + from PySide2.shiboken2 import createdByPython # noqa + from PySide2.shiboken2 import ownedByPython # noqa + except ImportError: + from shiboken2 import isValid # noqa + from shiboken2 import createdByPython # noqa + from shiboken2 import ownedByPython # noqa elif qt.BINDING == 'PySide': try: # Available through PySide diff --git a/silx/gui/test/test_colors.py b/silx/gui/test/test_colors.py index 2f883bc..6e4fc73 100644 --- a/silx/gui/test/test_colors.py +++ b/silx/gui/test/test_colors.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -226,7 +226,7 @@ class TestObjectAPI(ParametricTestCase): def testCopy(self): """Make sure the copy function is correctly processing """ - colormapObject = Colormap(name='red', + colormapObject = Colormap(name=None, colors=numpy.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), @@ -445,7 +445,7 @@ class TestRegisteredLut(unittest.TestCase): def testLut(self): colormap = Colormap("test_8") colors = colormap.getNColors(8) - self.assertEquals(len(colors), 8) + self.assertEqual(len(colors), 8) def testUint8(self): lut = numpy.array([[255, 0, 0], [200, 0, 0], [150, 0, 0]], dtype="uint") diff --git a/silx/gui/utils/testutils.py b/silx/gui/utils/testutils.py index 6c54357..d7f2f41 100644 --- a/silx/gui/utils/testutils.py +++ b/silx/gui/utils/testutils.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -87,10 +87,6 @@ def qWaitForWindowExposedAndActivate(window, timeout=None): return result -# Placeholder for QApplication -_qapp = None - - class TestCaseQt(unittest.TestCase): """Base class to write test for Qt stuff. @@ -122,6 +118,9 @@ class TestCaseQt(unittest.TestCase): allow to view the tested widgets. """ + _qapp = None + """Placeholder for QApplication""" + @classmethod def exceptionHandler(cls, exceptionClass, exception, stack): import traceback @@ -136,14 +135,15 @@ class TestCaseQt(unittest.TestCase): cls._oldExceptionHook = sys.excepthook sys.excepthook = cls.exceptionHandler - global _qapp - if _qapp is None: - # Makes sure a QApplication exists and do it once for all - _qapp = qt.QApplication.instance() or qt.QApplication([]) + # Makes sure a QApplication exists and do it once for all + if not qt.QApplication.instance(): + cls._qapp = qt.QApplication([]) @classmethod def tearDownClass(cls): sys.excepthook = cls._oldExceptionHook + if cls._qapp is not None: + cls._qapp = None def setUp(self): """Get the list of existing widgets.""" @@ -330,9 +330,10 @@ class TestCaseQt(unittest.TestCase): # PySide has no qWait, provide a replacement timeout = int(ms) endTimeMS = int(time.time() * 1000) + timeout + qapp = qt.QApplication.instance() while timeout > 0: - _qapp.processEvents(qt.QEventLoop.AllEvents, - maxtime=timeout) + qapp.processEvents(qt.QEventLoop.AllEvents, + maxtime=timeout) timeout = endTimeMS - int(time.time() * 1000) else: QTest.qWait(ms + cls.TIMEOUT_WAIT) -- cgit v1.2.3